Advanced Examples

PyMC

The Example page introduces how to use muse-inference for a problem defined with PyMC. Here we consider a more complex problem to highlight additional features. In particular:

  • We can estimate any number of parameters with any shapes. Here we have a 2-dimensional array \(\mu\) and a scalar \(\theta\). Note that by default, muse-inference considers any variables which do not depend on others as “parameters” (i.e. the “leaves” of the probabilistic graph). However, the algorithm is not limited to such parameters, and any choice can be selected by providing a list of params to the PyMCMuseProblem constructor.

  • We can work with distributions with limited domain support. For example, below we use the \(\rm Beta\) distribution with support on \((0,1)\) and the \(\rm LogNormal\) distribution with support on \((0,\infty)\). All necessary transformations are handled internally.

  • The data and latent space can include any number of variables, with any shapes. Below we demonstrate an \(x\) and \(z\) which are 2-dimensional arrays.

First, load the relevant packages:

[2]:
%pylab inline
import pymc as pm
from muse_inference.pymc import PyMCMuseProblem
Populating the interactive namespace from numpy and matplotlib

Then define the problem,

[3]:
def gen_funnel(x=None, σ=None, μ=None):
    with pm.Model() as model:
        μ = pm.Beta("μ", 2, 5, size=2) if μ is None else μ
        σ = pm.Normal("σ", 0, 3) if σ is None else σ
        z = pm.LogNormal("z", μ, np.exp(σ/2), size=(100, 2))
        x = pm.Normal("x", z, 1, observed=x)
    return model

generate the model and some data, given some chosen true values of parameters,

[4]:
θ_true = dict(μ=[0.3, 0.7], σ=1)
with gen_funnel(**θ_true):
    x_obs = pm.sample_prior_predictive(1, random_seed=0).prior.x[0,0]
model = gen_funnel(x=x_obs)
prob = PyMCMuseProblem(model)
Sampling: [x, z]

and finally, run MUSE:

[5]:
θ_start = dict(μ=[0.5, 0.5], σ=0)
result = prob.solve(θ_start=θ_start, progress=True)
MUSE: 100%|██████████| 5050/5050 [00:17<00:00, 290.65it/s]
get_H: 100%|██████████| 10/10 [00:03<00:00,  3.03it/s]

When there are multiple parameters, the starting guess should be specified as as a dictionary, as above.

The parameter estimate is returned as a dictionary,

[6]:
result.θ
[6]:
{'μ': array([0.38661822, 0.40076772]), 'σ': array(0.93624234)}

and the covariance as matrix, with parameters concatenated in the order they appear in the model (or in the order specified in params, if that was used):

[7]:
result.Σ
[7]:
array([[ 0.01211905, -0.01119086,  0.00169359],
       [-0.01119086,  0.02619911, -0.0032245 ],
       [ 0.00169359, -0.0032245 ,  0.00265886]])

The result.ravel and result.unravel functions can be used to convert between dictionary and vector representations of the parameters. For example, to compute the standard deviation for each parameter (the square root of the diagonal of the covariance):

[8]:
result.unravel(np.sqrt(np.diag(result.Σ)))
[8]:
{'μ': array([0.11008654, 0.16186139]), 'σ': array(0.05156411)}

or to convert the mean parameters to a vector:

[9]:
result.ravel(result.θ)
[9]:
array([0.38661822, 0.40076772, 0.93624234])

Jax

We can also use Jax to define the problem. In this case we will write out function to generate forward samples and to compute the posterior, and Jax will provide necessary gradients for free. To use Jax, load the necessary packages:

[10]:
from functools import partial
import jax
import jax.numpy as jnp
from muse_inference.jax import JaxMuseProblem
2023-05-17 15:42:39.845140: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.9.16/x64/lib
2023-05-17 15:42:39.951110: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.9.16/x64/lib
2023-05-17 15:42:39.953216: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.9.16/x64/lib

Let’s implement the noisy funnel problem from the Example page. To do so, extend JaxMuseProblem and define sample_x_z, logLike, and logPrior.

[11]:
class JaxFunnelMuseProblem(JaxMuseProblem):

    def __init__(self, N, **kwargs):
        super().__init__(**kwargs)
        self.N = N

    def sample_x_z(self, key, θ):
        keys = jax.random.split(key, 2)
        z = jax.random.normal(keys[0], (self.N,)) * jnp.exp(θ/2)
        x = z + jax.random.normal(keys[1], (self.N,))
        return (x, z)

    def logLike(self, x, z, θ):
        return -(jnp.sum((x - z)**2) + jnp.sum(z**2) / jnp.exp(θ) + 512*θ) / 2

    def logPrior(self, θ):
        return -θ**2 / (2*3**2)

Note that the super-class JaxMuseProblem will automatically take care of JIT compiling these functions, so you do not need to manually decorate them with @jit. However, if your functions contain code which cannot be JIT compiled, you should pass super().__init__(jit=False) to the super constructor in your __init__ function.

The JAX MUSE interface also contains an option to use implicit differentation to compute the \(H\) matrix (paper in prep). This is more numerically stable and faster than the default, which uses finite differences, although requires 2nd order automatic differentiation to work through your posterior. It’s enabled by default, but can be disabled with super().__init__(implicit_diff=False).

With the problem defined, we now generate some simulated data and save it to the problem with set_x. Note also the use of PRNGKey (rather than RandomState for PyMC/Numpy) for random number generation.

[12]:
prob = JaxFunnelMuseProblem(10000, implicit_diff=True)
key = jax.random.PRNGKey(0)
(x, z) = prob.sample_x_z(key, 0)
prob.set_x(x)
2023-05-17 15:42:41.087595: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.9.16/x64/lib
2023-05-17 15:42:41.087628: W external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

And finally, run MUSE:

[14]:
result = prob.solve(θ_start=0., rng=jax.random.PRNGKey(1), progress=True)
MUSE: 100%|██████████| 5050/5050 [00:00<00:00, 5388.64it/s]
get_H: 100%|██████████| 10/10 [00:00<00:00, 1243.46it/s]

Note that the solution here is obtained around 10X faster that the PyMC version of this in the Example page (the cloud machines which build these docs don’t always achieve the 10X, but you see this if you run these examples locally). The Jax interface has much lower overhead, which will be noticeable for very fast posteriors like the one above.

One convenient aspect of using Jax is that the parameters, θ, and latent space, z, can be any pytree, ie tuples, dictionaries, nested combinations of them, etc… (there is no requirement on the data format of the x variable). To demonstrate, consider a problem which is just two copies of the noisy funnel problem:

[15]:
class JaxPyTreeFunnelMuseProblem(JaxMuseProblem):

    def __init__(self, N):
        super().__init__()
        self.N = N

    def sample_x_z(self, key, θ):
        (θ1, θ2) = (θ["θ1"], θ["θ2"])
        keys = jax.random.split(key, 4)
        z1 = jax.random.normal(keys[0], (self.N,)) * jnp.exp(θ1/2)
        z2 = jax.random.normal(keys[1], (self.N,)) * jnp.exp(θ2/2)
        x1 = z1 + jax.random.normal(keys[2], (self.N,))
        x2 = z2 + jax.random.normal(keys[3], (self.N,))
        return ({"x1":x1, "x2":x2}, {"z1":z1, "z2":z2})

    def logLike(self, x, z, θ):
        return (
            -(jnp.sum((x["x1"] - z["z1"])**2) + jnp.sum(z["z1"]**2) / jnp.exp(θ["θ1"]) + 512*θ["θ1"]) / 2
            -(jnp.sum((x["x2"] - z["z2"])**2) + jnp.sum(z["z2"]**2) / jnp.exp(θ["θ2"]) + 512*θ["θ2"]) / 2
        )

    def logPrior(self, θ):
        return - θ["θ1"]**2 / (2*3**2) - θ["θ2"]**2 / (2*3**2)

Here, x, θ, and z are all dictionaries. We generate the problem as usual, passing in parameters as dictionaries,

[16]:
θ_true = dict(θ1=-1., θ2=2.)
θ_start = dict(θ1=0., θ2=0.)
[17]:
prob = JaxPyTreeFunnelMuseProblem(10000)
key = jax.random.PRNGKey(0)
(x, z) = prob.sample_x_z(key, θ_true)
prob.set_x(x)

and run MUSE:

[19]:
result = prob.solve(θ_start=θ_start, rng=jax.random.PRNGKey(0), progress=True)
MUSE: 100%|██████████| 5050/5050 [00:10<00:00, 496.91it/s]
get_H: 100%|██████████| 10/10 [00:00<00:00, 570.62it/s]

The result is returned as a pytree:

[20]:
result.θ
[20]:
{'θ1': DeviceArray(-1.000121, dtype=float32),
 'θ2': DeviceArray(2.0271356, dtype=float32)}

and the covariance as a matrix:

[21]:
result.Σ
[21]:
array([[ 8.616477  , -0.02479646],
       [-0.02479644,  7.92195   ]], dtype=float32)

The result.ravel and result.unravel functions can be used to convert between pytree and vector representations of the parameters. For example, to compute the standard deviation for each parameter (the square root of the diagonal of the covariance):

[22]:
result.unravel(np.sqrt(np.diag(result.Σ)))
[22]:
{'θ1': DeviceArray(2.9353836, dtype=float32),
 'θ2': DeviceArray(2.814596, dtype=float32)}

or to convert the mean parameters to a vector:

[23]:
result.ravel(result.θ)
[23]:
DeviceArray([-1.000121 ,  2.0271356], dtype=float32)