Example

This package allows defining your inference problem with:

  • the PyMC probabilistic programming language

  • by coding up the posterior by-hand and using Jax to automatically compute the necessary posterior gradients

  • by specifying the posterior and its gradients completely by-hand

We’ll start with PyMC, since it is the easiest. First, load up the relevant packages:

[1]:
%pylab inline
from scipy import stats
import pymc as pm
from ttictoc import tic, toc
from muse_inference.pymc import PyMCMuseProblem
Populating the interactive namespace from numpy and matplotlib

As an example, consider the following hierarchical problem, which has the classic Neal’s Funnel problem embedded in it. Neal’s funnel is a standard example of a non-Gaussian latent space which HMC struggles to sample efficiently without extra tricks. Specifically, we consider the model defined by:

\[\begin{split}\begin{aligned} \theta &\sim {\rm Normal(0,3)} \\ z_i &\sim {\rm Normal}(0,\exp(\theta/2)) \\ x_i &\sim {\rm Normal}(z_i, 1) \end{aligned}\end{split}\]

for \(i=1..10000\). This problem can be described by the following PyMC model:

[3]:
def gen_funnel(x=None, θ=None):
    with pm.Model() as funnel:
        θ = pm.Normal("θ", mu=0, sigma=3) if θ is None else θ
        z = pm.Normal("z", mu=0, sigma=np.exp(θ/2), size=10000)
        x = pm.Normal("x", mu=z, sigma=1, observed=x)
    return funnel

Next, lets choose a true value of \(\theta=0\) and generate some simulated data, \(x\), which we’ll use as “observations”:

[4]:
with gen_funnel(θ=0):
    x_obs = pm.sample_prior_predictive(1, random_seed=0).prior.x[0,0]
model = gen_funnel(x=x_obs)

We can run HMC on the problem to compute the “true” answer to compare against:

[6]:
with model:
    tic()
    np.random.seed(1)
    chain = pm.sample(500, tune=500, chains=1, discard_tuned_samples=False)
    t_hmc = toc()
100.00% [1000/1000 00:39<00:00 Sampling chain 0, 0 divergences]

We next compute the MUSE estimate for the same problem. To reach the same Monte Carlo error as HMC, the number of MUSE simulations should be the same as the effective sample size of the chain we just ran. This is:

[7]:
nsims = int(pm.ess(chain)["θ"])
nsims
[7]:
234

Running the MUSE estimate,

[8]:
prob = PyMCMuseProblem(model)
rng = np.random.SeedSequence(1)
tic()
result = prob.solve(θ_start=0, nsims=nsims, rng=rng, progress=True, save_MAP_history=True)
t_muse = toc()
MUSE: 100%|██████████| 11750/11750 [00:19<00:00, 598.36it/s]
get_H: 100%|██████████| 23/23 [00:01<00:00, 18.55it/s]

Lets also try mean-field variational inference (MFVI) to compare to another approximate method.

[9]:
with model:
    tic()
    mfvi = pm.fit(10000, method="advi", obj_n_mc=10, tf_n_mc=10)
    t_mfvi = toc()
100.00% [10000/10000 01:43<00:00 Average Loss = 17,685]

Now lets plot the different estimates. In this case, MUSE gives a nearly perfect answer using only a fraction posterior gradient calls. MFVI struggles in both speed and accuracy by comparison.

[10]:
figure(figsize=(6,5))
axvline(0, c="k", ls="--", alpha=0.5)
ncalls_hmc = sum(chain.sample_stats["n_steps"]) + sum(chain.warmup_sample_stats["n_steps"])
hist(
    chain["posterior"]["θ"].to_series(),
    bins=30, density=True, alpha=0.5, color="C0",
    label="HMC (%.2fs, %i ∇logP calls)"%(t_hmc, ncalls_hmc)
)
θs = linspace(*xlim())
ncalls_muse = sum(
    [s.nfev for h in result.history for s in [h["MAP_history_dat"]]+h["MAP_history_sims"]]
)
plot(
    θs, stats.norm(result.θ["θ"], sqrt(result.Σ[0,0])).pdf(θs),
    color="C1", label="MUSE (%.2fs,  %i ∇logP calls)"%(t_muse, ncalls_muse)
)
hist(
    mfvi.sample(1000)["posterior"]["θ"].to_series(),
    bins=30, density=True, alpha=0.5, color="C2",
    label="MFVI (%.2fs)"%t_mfvi
)
ylim(0, ylim()[1]*1.2)
legend(frameon=True)
xlabel(r"$\theta$")
ylabel(r"$\mathcal{P}(\theta\,|\,x)$")
title("10000-dimensional noisy funnel");
_images/example_18_0.png

Note that due to PyMC overhead, the timing difference between HMC and MUSE is less drastic than the number of gradient calls would imply. For more realistic and expensive posterior functions, this overhead becomes negligible and you can see speedups of 10-100X or more (depending on problem and latent space dimensoinality). If you are working with small problems and need less overhead, the Jax or Numpy interfaces will be faster (or even consider using the Julia package MuseInference.jl which is the fastest of all options).