In [1]:
import numpy as np

In [2]:
import jax.numpy as jnp
import jax.random as jrandom
import jax.scipy.stats as jspstats

In [3]:
import numpyro
import numpyro.distributions as dists
from numpyro.infer import MCMC, NUTS

In [4]:
def model(samples=None, shape=(60, 60)):
    sigma_s = numpyro.sample('sigma_s', dists.Normal(0., 1_000.))
    sigma_b = numpyro.sample('sigma_b', dists.Normal(0., 1_000.))
    mu_tot = numpyro.sample('mu_tot', dists.Normal(0., 50.))
    
    draws_s = numpyro.sample('draws_s', dists.Normal(0., sigma_s), sample_shape=shape[:-1])
    draws_b = numpyro.sample('draws_b', dists.Normal(0., sigma_b), sample_shape=shape[-1:])
    draws_bias = mu_tot + draws_s[:, np.newaxis] + draws_b[np.newaxis, :]
    numpyro.sample('samples', dists.Bernoulli(jspstats.norm.cdf(draws_bias)), obs=samples)

In [5]:
s_s = 2.5
s_b = 0.6
m_t = -1.

grid_size = [60, 60]
grid = (np.random.normal(size=grid_size[0]) * s_s)[:, np.newaxis] + (np.random.normal(size=grid_size[1]) * s_b)[np.newaxis, :] + np.random.normal(size=grid_size)
thresholded_grid = 1. * (grid + m_t > 0)

In [6]:
rng_key = jrandom.PRNGKey(0)
rng_key, rng_key_ = jrandom.split(rng_key)

kernel = NUTS(model)
num_samples = 2000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
    rng_key_, samples=thresholded_grid
)
mcmc.print_summary()
samples_1 = mcmc.get_samples()

I0000 00:00:1700507813.984381 1419943 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
sample: 100%|â–ˆ| 3000/3000 [00:42<00:00, 70.76it/s, 31 steps of size 1.05e-01. ac



                 mean       std    median      5.0%     95.0%     n_eff     r_hat
 draws_b[0]      0.54      0.27      0.54      0.07      0.95   1726.16      1.00
 draws_b[1]      0.32      0.25      0.31     -0.09      0.71   1369.21      1.00
 draws_b[2]      0.68      0.26      0.69      0.25      1.09   1518.55      1.00
 draws_b[3]     -0.22      0.27     -0.22     -0.65      0.22   1533.48      1.00
 draws_b[4]      0.23      0.26      0.23     -0.20      0.65   1078.14      1.00
 draws_b[5]      0.10      0.27      0.10     -0.35      0.52   1411.39      1.00
 draws_b[6]     -1.09      0.31     -1.09     -1.59     -0.58   1922.28      1.00
 draws_b[7]      0.58      0.25      0.58      0.19      1.00   1254.93      1.00
 draws_b[8]     -0.22      0.27     -0.22     -0.67      0.19   1412.75      1.00
 draws_b[9]      0.55      0.25      0.55      0.13      0.96   1191.63      1.00
draws_b[10]      0.26      0.27      0.26     -0.18      0.68   1384.75      1.00
draws_b[11]    