Advanced Patterns

emcee is generally pretty simple but it has a few key features that make the usage easier in real problems. Here are a few examples of things that you might find useful.

Incrementally saving progress

It is often useful to incrementally save the state of the chain to a file. This makes it easier to monitor the chain’s progress and it makes things a little less disastrous if your code/computer crashes somewhere in the middle of an expensive MCMC run. If you just want to append the walker positions to the end of a file, you could do something like:

f = open("chain.dat", "w")
f.close()

for result in sampler.sample(pos0, iterations=500, storechain=False):
    position = result[0]
    f = open("chain.dat", "a")
    for k in range(position.shape[0]):
        f.write("{0:4d} {1:s}\n".format(k, " ".join(map(str,position[k]))))
    f.close()

Printing the sampler’s progress

You might want to monitor the progress of the sampler in your terminal while it runs. There are several modules out there that can help you make shiny progress bars (e.g., are progressbar and clint), but it’s straightforward to implement a simple progress counter yourself.

The solution here is very similar to the incremental saving snippet. For example, to display the current percentage:

nsteps = 5000
for i, result in enumerate(sampler.sample(p0, iterations=nsteps)):
    if (i+1) % 100 == 0:
        print("{0:5.1%}".format(float(i) / nsteps))

Or, to display a rudimentary progress bar that updates iteself on a single line:

import sys

nsteps = 5000
width = 30
for i, result in enumerate(sampler.sample(p0, iterations=nsteps)):
    n = int((width+1) * float(i) / nsteps)
    sys.stdout.write("\r[{0}{1}]".format('#' * n, ' ' * (width - n)))
sys.stdout.write("\n")

Multiprocessing

In principle, running emcee in parallel is as simple instantiating an EnsembleSampler object with the threads argument set to an integer greater than 1:

sampler = emcee.EnsembleSampler(nwalkers, ndim, lnpostfn, threads=15)

In practice, the parallelization is implemented using the built in Python multiprocessing module. With this comes a few constraints. In particular, both lnpostfn and args must be pickleable. The exceptions thrown while using multiprocessing can be quite cryptic and even though we’ve tried to make this feature as user-friendly as possible, it can sometimes cause some headaches. One useful debugging tactic is to try running with 1 thread if your processes start to crash. This will generally provide much more illuminating error messages than in the parallel case. Note that the parallelized EnsembleSampler object is not pickleable. Therefore, if it (or an object that contains it) is passed to lnpostfn when multiprocessing is turned on, the code will fail.

It is also important to note that the multiprocessing module works by spawning a large number of new python processes and running the code in isolation within those processes. This means that there is a significant amount of overhead involved at each step of the parallelization process. With this in mind, it is not surprising that running a simple problem like the quickstart example in parallel will run much slower than the equivalent serial code. If your log-probability function takes a significant amount of time (> 1 second or so) to compute then using the parallel sampler actually provides significant speed gains.

Arbitrary metadata blobs

Added in version 1.1.0

Imagine that your log-probability function involves an extremely computationally expensive numerical simulation starting from initial conditions parameterized by the position of the walker in parameter space. Then you have to compare the results of your simulation by projecting into data space (predicting you data) and computing something like a chi-squared scalar in this space. After you run MCMC, you might want to visualize the draws from your probability function in data space by over-plotting samples on your data points. It is obviously unreasonable to recompute all the simulations for all the initial conditions that you want to display as a part of your post-processing—especially since you already computed all of them before! Instead, it would be ideal to be able to store realizations associated with each step in the MCMC and then just display those after the fact. This is possible using the “arbitrary blob” pattern.

To use blobs, you just need to modify your log-probability function to return a second argument (this can be any arbitrary Python object). Then, the sampler object will have an attribute (called EnsembleSampler.blobs) that is a list (of length niterations) of lists (of length nwalkers) containing all the accepted blobs associated with the walker positions in EnsembleSampler.chain.

As an absolutely trivial example, let’s say that we wanted to store the sum of cubes of the input parameters as a string at each position in the chain. To do this we could simply sample a function like:

def lnprobfn(p):
    return -0.5 * np.sum(p ** 2), str(np.sum(p ** 3))

It is important to note that by returning two values from our log-probability function, we also change the output of EnsembleSampler.sample() and EnsembleSampler.run_mcmc() to return 4 values (position, probability, random number generator state and blobs) instead of just the first three.

Using MPI to distribute the computations

Added in version 1.2.0

The standard implementation of emcee relies on the multiprocessing module to parallelize tasks. This works well on a single machine with multiple cores but it is sometimes useful to distribute the computation across a larger cluster. To do this, we need to do something a little bit more sophisticated using the mpi4py module. Below, we’ll implement an example similar to the quickstart using MPI but first you’ll need to install mpi4py.

The utils.MPIPool object provides most of the needed functionality so we’ll start by importing that and the other needed modules:

import sys
import numpy as np
import emcee
from emcee.utils import MPIPool

This time, we’ll just sample a simple isotropic Gaussian (remember that the emcee algorithm doesn’t care about covariances between parameters because it is affine-invariant):

ndim = 50
nwalkers = 250
p0 = [np.random.rand(ndim) for i in xrange(nwalkers)]

def lnprob(x):
    return -0.5 * np.sum(x ** 2)

Now, this is where things start to change:

pool = MPIPool()
if not pool.is_master():
    pool.wait()
    sys.exit(0)

First, we’re initializing the pool object and then—if the process isn’t running as master—we wait for instructions and then exit. Then, we can set up the sampler providing this pool object to do the parallelization:

sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, pool=pool)

and then run and analyse as usual. The key here is that only the master chain should actually directly interact with the sampler and the other processes should only wait for instructions.

Note: don’t forget to close the pool if you don’t want the processes to hang forever:

pool.close()

The full source code for this example is available on Github.

If we save this script to the file mpi.py, we can then run this example with the command:

mpirun -np 2 python mpi.py

for local testing.

Loadbalancing in parallel runs

Added in version 2.1.0

When emcee is being used in a multi-processing mode (multiprocessing or mpi4py), the parameters need to distributed evenly over all the available cores. emcee uses a map function to distribute the jobs over the available cores. In case of multiprocessing, the map function is in-built and dynamically schedules the tasks. In order to get a similar dynamic scheduling in map when using utils.MPIPool , use the following invocation:

pool = MPIPool(loadbalance=True)

By default, loadbalance is set to False. If your jobs have a lot of variance in run-time, then setting the loadbalance option will improve the overall run-time.

If your problem is such that the runtime for each invocation of the log-probability function scales with one/some of the parameters, then you can improve load-balancing even further. By sorting the jobs in decreasing order of (expected) run-time, the longest jobs get run simultaneously and you only have the wait for the duration of the longest job. In the following example, the first parameter strongly determines the run-time – larger the first parameter, the longer the runtime. The sort_on_runtime returns the re-ordered list and the corresponding index.

def sort_on_runtime(pos):
    p = np.atleast_2d(pos)
    idx = np.argsort(p[:, 0])[::-1]
    return p[idx], idx

In order to use this function, you will have to instantiate an EnsembleSampler object with:

sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, pool=pool,
                                runtime_sortingfn=sort_on_runtime)

Such a sort_on_runtime can be applied to both multiprocessing and mpi4py invocations for emcee. You can see a benchmarking routine using the mpi4py module on Github.