Animated MCMC with Matplotlib
This blog was generated from a working notebook that is available here.
1. Write down an interesting distribution
A mixture of Gaussians is different from a sum of Gaussians, in that it is not Gaussian itself, but it is visually interesting, can be difficult to generate independent samples from, and knows many secrets.
Here is an implementation that mostly follows the API of
scipy.stats in that it provides a
.rvs function to provide random samples. If the
rvs function looks a little complicated, it is because shapes can be hard in high dimensions, ok?
We can use the
.rvs method to view the density of this distribution.
class MixtureOfGaussians: """Two standard normal distributions, centered at +2 and -2.""" def __init__(self): self.components = [st.norm(-2, 1), st.norm(2, 1)] self.weights = np.array([0.5, 0.5]) def pdf(self, x): return self.weights.dot([component.pdf(x) for component in self.components]) def rvs(self, size=1): idxs = np.random.randint(0, 2, size=size) result = np.empty(size) for idx, component in enumerate(self.components): spots, = np.where(idxs==idx) result[spots] = component.rvs(size=spots.shape) return result az.plot_kde(MixtureOfGaussians().rvs(10_000), figsize=FIGSIZE);
2. Write down MCMC
You should look up the Metropolis algorithm if you are not familiar! It is beautiful and important. Also, don’t use it.
In general, this lets you generate draws from a probability distribution, given access to the probability density function. So we will pretend we did not implement
.rvs above, and generate samples using only the
def metropolis_sample(pdf, *, steps, step_size, init=0.): """Metropolis sampler with a normal proposal.""" point = init samples =  for _ in range(steps): proposed = st.norm(point, step_size).rvs() if np.random.rand() < pdf(proposed) / pdf(point): point = proposed samples.append(point) return np.array(samples)
3. Find a visually pleasing set of draws
This is more art that science, but the animation looks nice if the draws:
- Are correlated
- Switch between modes pretty often
- End up with a histogram that is “close” to the true one
- Have about 3,000 draws (the animation ends up being ~30s long)
I found a random seed that did all this, by looking at the trace plot. The seed was 0, but I was ready to do some real work on it.
seed = 0 np.random.seed(seed) samples = metropolis_sample(MixtureOfGaussians().pdf, steps=3_000, step_size=0.4) az.plot_trace(samples);
4. Prepare the static plot
Influenced by Bret Beheim’s visualizations with tweenr, I was looking for a plot with a similar aesthetic.
To do that, I have to
- Bucket the data into discrete bins (using
- Set a y-value for each data point. I just count upwards from 0 for each bin, then divide by the max, so I know it is between 0 and 1.
There is also a bunch of matplotlib styling at the bottom, to make everything look beautiful. I use the
viridis color map to show which draw I am on. Later draws will be yellower.
hi, lo = samples.max(), samples.min() x = np.linspace(lo, hi, 100) bins = np.digitize(samples, x, right=True) counter = np.zeros_like(bins) # y values counts = np.zeros_like(x) # keep track of how points are already in each bin for idx, bin_ in enumerate(bins): counts[bin_] += 1 counter[idx] = counts[bin_] counter = counter / counter.max() # Mess with plot styles here, since it is cheaper than animating fig, ax = plt.subplots(figsize=FIGSIZE) ax.set_ylim(0, 1) ax.set_xlim(bins.min(), bins.max()) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.spines['left'].set_visible(False) ax.set_xticks() ax.get_yaxis().set_visible(False) cmap = plt.get_cmap('viridis') colors = cmap(np.linspace(0, 1, bins.shape)) ax.scatter(bins, counter, marker='.', facecolors=colors);
5. Make the animation
This is taken pretty directly from the matplotlib animation docs, but I am using
scatter instead of
plot so that I can change colors of already plotted points. This means in the
update step, I use
set_offsets instead of
The falling animation is done with the
offset below. The y-axis goes from 0 to 1, and each step I add a new particle. If each particle moves δ each step, then after, 10 steps, the y positions of the first 10 particles will be:
y0 -> 1 - 10δ y1 -> 1 - 9δ y2 -> 1 - 8δ ...
until it reaches the true
y position. If you scribble on some paper, you can convince yourself this is equivalent to something like
(np.arange(n) - n) * δ, then taking the maximum of that and the true position.
fig, ax = plt.subplots(figsize=FIGSIZE) xdata, ydata = ,  ln = ax.scatter(, , marker='.', animated=True) cmap = plt.get_cmap('viridis') def init(): ax.set_xlim(bins.min(), bins.max()) ax.set_ylim(0, 1) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.spines['left'].set_visible(False) ax.set_xticks() ax.get_yaxis().set_visible(False) return ln, def update(idx): xdata.append(bins[idx]) ydata.append(counter[idx]) colors = cmap(np.linspace(0, 1, len(xdata))) offset = (np.arange(idx + 1) - idx + 49) / 50 y = np.maximum(ydata, offset) ln.set_offsets(np.array([xdata, y]).T) ln.set_facecolors(colors) return ln, anim = FuncAnimation(fig, update, frames=np.arange(bins.shape), init_func=init, blit=True, interval=20) HTML(anim.to_html5_video())
6. Now implement a tiny fire hose animation