The goal of this post is to scare users away from using Metropolis-Hastings, in favor of the NUTS sampler. I also want to provide a concrete example of a model that fails with a Metropolis sampler.
I am using the model from a previous post. The details are interesting, but not that important, except that this is a hierarchical model. We will also just concentrate on two of the values from the model.
The takeaway points are:
- The NUTS sampler generates (effective) samples about 10,000 times as fast as Metropolis
- What does a bad trace actually look like?
- How far from 1 can a bad Gelman-Rubin statistic be?
- It is easy to compare two models
varnames=('pooled_rate', 'state_rate')
with car_model(miles_e6, fatalities) as nuts_model:
nuts_trace = pm.sample(10000, njobs=4)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
100%|██████████| 10500/10500 [00:39<00:00, 267.26it/s]
with car_model(miles_e6, fatalities) as metropolis_model:
metropolis_trace = pm.sample(10000, step=pm.Metropolis(), njobs=4)
100%|██████████| 10500/10500 [00:28<00:00, 372.62it/s]
How fast did we sample?
On my machine, the NUTS sampler took around 40 seconds to generate 10,000 samples, and the Metropolis sampler took around 30 seconds. But really, we should be looking at the effective sample size.
pm.effective_n(nuts_trace, varnames=varnames)
{'pooled_rate': 40000.0,
'state_rate': array([ 40000., 40000., 40000., 40000., 40000., 40000., 40000.,
40000., 40000., 40000., 40000., 40000., 40000., 40000.,
40000., 40000., 40000., 40000., 40000., 40000., 40000.,
40000., 40000., 40000., 40000., 40000., 40000., 40000.,
40000., 40000., 40000., 40000., 40000., 40000., 40000.,
40000., 40000., 40000., 40000., 40000., 40000., 40000.,
40000., 40000., 40000., 40000., 40000., 40000., 40000.,
40000., 40000.])}
pm.effective_n(metropolis_trace, varnames=varnames)
{'pooled_rate': 3.0,
'state_rate': array([ 13., 2., 6., 5., 29., 8., 7., 6., 2., 45., 48.,
12., 6., 28., 7., 2., 2., 6., 17., 2., 4., 17.,
10., 11., 12., 13., 2., 12., 2., 7., 5., 4., 35.,
31., 2., 21., 5., 2., 25., 2., 9., 2., 26., 24.,
2., 2., 2., 8., 2., 10., 2.])}
More is better here. We see that every sample from the NUTS was a good sample, but the Metropolis sampler only generated a few effective samples. Intuitively, this is because the Metropolis sampler produces draws that are highly correlated, and does not explore the space as efficiently as the NUTS sampler.
Most of what follows is morally related to this: any valid MCMC sampler will eventually produce samples according to the pdf, but “eventually” here is in the true mathematical sense, in that we may need to sample forever. In this example, the Metropolis sampler just did not generate many independent samples, so it is heavily biased towards its starting position.
Checking the trace
Now that we have some samples, checking pm.traceplot
gives a picture of our histograms on the left, and a timeline of samples drawn on the right. The NUTS trace looks pretty good, and the Metropolis trace looks pretty bad.
pm.traceplot(nuts_trace, varnames=varnames)
pm.traceplot(metropolis_trace, varnames=varnames);
Using statistical tests
In case we are not convinced yet that the NUTS trace did a great job of sampling, while the Metropolis one would give us wildly incorrect results, we can use the Gelman Rubin statistic, which compares between chain variance with inter chain variance. Intuitively, if each chain looks like each other chain, then it might have been a good draw, so the Gelman-Rubin statistic should be near 1. Here’s a bar plot of the state rate for the NUTS
trace, and for the Metropolis
trace.
The NUTS plot is, in the words of Abraham Simpson, like “a haircut you could set a watch to”. The Metropolis trace is, conversely, an apogee of sculpted sartorium, all over the place, and often nowhere close to 1.
nuts_gr = pm.gelman_rubin(nuts_trace, varnames=varnames)
metropolis_gr = pm.gelman_rubin(metropolis_trace, varnames=varnames)
fig, axs = plt.subplots(2, 1)
axs[0].bar(x=df.State, height=nuts_gr['state_rate']);
axs[1].bar(x=df.State, height=metropolis_gr['state_rate']);
#axs[1].set_yscale('log')
fig.set_size_inches(18.5, 10.5)
Finally, there is a nice “compare” function to compare how two models have fit. This is meant for more subtle work than this, but a lower WAIC is, roughly, better. The first row here is the trace from the NUTS model.
pm.compare([nuts_trace, metropolis_trace], [nuts_model, metropolis_model])
WAIC | pWAIC | dWAIC | weight | SE | dSE | warning | |
---|---|---|---|---|---|---|---|
0 | 498.99 | 25.5 | 0 | 1 | 12.19 | 0 | 1 |
1 | 3.1742e+09 | 1.5871e+09 | 3.1742e+09 | 0 | 1.83667e+09 | 1.83667e+09 | 1 |