You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
263 lines
15 KiB
Plaintext
263 lines
15 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "33c24830-7601-4c52-8fa0-66330801d6a9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "985106e5-7a80-4d4d-b1c9-54a873f0b429",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import jax.numpy as jnp\n",
|
|
"import jax.random as jrandom\n",
|
|
"import jax.scipy.stats as jspstats"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "8a8cc9de-9802-4761-9bba-7b77300d0bb8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpyro\n",
|
|
"import numpyro.distributions as dists\n",
|
|
"from numpyro.infer import MCMC, NUTS"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "9b4c0d3e-2889-4bf3-9871-6648b5de3508",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def model(samples=None, shape=(60, 60)):\n",
|
|
" sigma_s = numpyro.sample('sigma_s', dists.Normal(0., 1_000.))\n",
|
|
" sigma_b = numpyro.sample('sigma_b', dists.Normal(0., 1_000.))\n",
|
|
" mu_tot = numpyro.sample('mu_tot', dists.Normal(0., 50.))\n",
|
|
" \n",
|
|
" draws_s = numpyro.sample('draws_s', dists.Normal(0., sigma_s), sample_shape=shape[:-1])\n",
|
|
" draws_b = numpyro.sample('draws_b', dists.Normal(0., sigma_b), sample_shape=shape[-1:])\n",
|
|
" draws_bias = mu_tot + draws_s[:, np.newaxis] + draws_b[np.newaxis, :]\n",
|
|
" numpyro.sample('samples', dists.Bernoulli(jspstats.norm.cdf(draws_bias)), obs=samples)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "e1be367c-e3e9-4b61-ab47-35b75e3a3ecf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def generate_grid(sigma_s, sigma_b, mu_tot, grid_size=[60, 60]):\n",
|
|
" grid = (\n",
|
|
" mu_tot\n",
|
|
" + (np.random.normal(size=grid_size[0]) * sigma_s)[:, np.newaxis]\n",
|
|
" + (np.random.normal(size=grid_size[1]) * sigma_b)[np.newaxis, :]\n",
|
|
" + np.random.normal(size=grid_size)\n",
|
|
" )\n",
|
|
" grid = 1 * (grid > 0)\n",
|
|
" return grid\n",
|
|
"\n",
|
|
"grid = generate_grid(sigma_s=2.4, sigma_b=0.6, mu_tot=-1.5)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "06b670cf-73bb-40fd-b152-b111b13b9d50",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
|
|
"I0000 00:00:1700516044.602393 1486344 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n",
|
|
"2023-11-20 22:34:04.630229: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 12620922880\n",
|
|
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
|
|
"sample: 100%|█| 3000/3000 [00:19<00:00, 153.28it/s, 31 steps of size 1.17e-01. a\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
" mean std median 5.0% 95.0% n_eff r_hat\n",
|
|
" draws_b[0] 0.41 0.24 0.41 0.02 0.80 2476.09 1.00\n",
|
|
" draws_b[1] -0.03 0.23 -0.04 -0.43 0.32 1978.01 1.00\n",
|
|
" draws_b[2] -0.71 0.28 -0.69 -1.15 -0.26 2136.96 1.00\n",
|
|
" draws_b[3] 0.20 0.24 0.20 -0.23 0.56 2107.37 1.00\n",
|
|
" draws_b[4] 0.22 0.24 0.22 -0.17 0.59 1564.28 1.00\n",
|
|
" draws_b[5] -0.60 0.27 -0.59 -1.09 -0.21 1970.86 1.00\n",
|
|
" draws_b[6] 0.66 0.23 0.66 0.28 1.03 2108.18 1.00\n",
|
|
" draws_b[7] -0.70 0.27 -0.69 -1.14 -0.27 2018.21 1.00\n",
|
|
" draws_b[8] -0.80 0.28 -0.80 -1.25 -0.36 2037.66 1.00\n",
|
|
" draws_b[9] -0.08 0.25 -0.08 -0.51 0.30 1846.97 1.00\n",
|
|
"draws_b[10] 0.08 0.25 0.09 -0.36 0.46 2327.05 1.00\n",
|
|
"draws_b[11] 0.28 0.24 0.28 -0.13 0.65 2285.25 1.00\n",
|
|
"draws_b[12] -0.07 0.25 -0.07 -0.50 0.31 2415.74 1.00\n",
|
|
"draws_b[13] 0.12 0.25 0.12 -0.29 0.53 2349.12 1.00\n",
|
|
"draws_b[14] -0.16 0.25 -0.16 -0.52 0.28 2589.76 1.00\n",
|
|
"draws_b[15] 0.31 0.23 0.31 -0.06 0.67 2138.49 1.00\n",
|
|
"draws_b[16] -0.71 0.27 -0.70 -1.14 -0.24 2389.35 1.00\n",
|
|
"draws_b[17] 0.38 0.24 0.38 0.01 0.79 1920.50 1.00\n",
|
|
"draws_b[18] 0.03 0.25 0.02 -0.38 0.41 1551.87 1.00\n",
|
|
"draws_b[19] -0.13 0.25 -0.13 -0.55 0.28 2579.31 1.00\n",
|
|
"draws_b[20] 0.74 0.24 0.74 0.33 1.13 1878.62 1.00\n",
|
|
"draws_b[21] -0.77 0.28 -0.76 -1.22 -0.33 2172.10 1.00\n",
|
|
"draws_b[22] -0.57 0.26 -0.57 -0.96 -0.12 2198.03 1.00\n",
|
|
"draws_b[23] -0.38 0.27 -0.37 -0.82 0.05 2556.15 1.00\n",
|
|
"draws_b[24] 0.48 0.23 0.48 0.09 0.86 2054.13 1.00\n",
|
|
"draws_b[25] -0.59 0.27 -0.58 -1.02 -0.14 2313.31 1.00\n",
|
|
"draws_b[26] 0.24 0.23 0.24 -0.16 0.60 1828.17 1.00\n",
|
|
"draws_b[27] 0.10 0.24 0.10 -0.33 0.45 1965.00 1.00\n",
|
|
"draws_b[28] 0.19 0.23 0.19 -0.17 0.57 2146.14 1.00\n",
|
|
"draws_b[29] -0.38 0.24 -0.37 -0.80 0.01 2472.86 1.00\n",
|
|
"draws_b[30] 0.89 0.24 0.88 0.49 1.25 1775.39 1.00\n",
|
|
"draws_b[31] 0.56 0.23 0.55 0.16 0.93 2322.01 1.00\n",
|
|
"draws_b[32] 0.91 0.23 0.91 0.53 1.29 1971.15 1.00\n",
|
|
"draws_b[33] -0.07 0.24 -0.07 -0.48 0.29 2625.08 1.00\n",
|
|
"draws_b[34] -0.69 0.28 -0.69 -1.13 -0.23 2386.87 1.00\n",
|
|
"draws_b[35] 0.20 0.23 0.20 -0.18 0.58 2416.61 1.00\n",
|
|
"draws_b[36] 0.05 0.24 0.05 -0.38 0.40 2416.72 1.00\n",
|
|
"draws_b[37] 0.31 0.23 0.31 -0.07 0.68 2396.79 1.00\n",
|
|
"draws_b[38] 0.30 0.23 0.30 -0.07 0.70 2798.46 1.00\n",
|
|
"draws_b[39] 0.52 0.23 0.52 0.14 0.90 2162.20 1.00\n",
|
|
"draws_b[40] -0.45 0.25 -0.44 -0.83 -0.02 1940.75 1.00\n",
|
|
"draws_b[41] 0.12 0.23 0.11 -0.24 0.50 2306.02 1.00\n",
|
|
"draws_b[42] -0.18 0.24 -0.17 -0.56 0.23 2451.13 1.00\n",
|
|
"draws_b[43] 0.03 0.24 0.03 -0.38 0.41 2145.35 1.00\n",
|
|
"draws_b[44] 0.02 0.23 0.02 -0.35 0.41 1418.17 1.00\n",
|
|
"draws_b[45] -0.70 0.26 -0.71 -1.12 -0.26 2235.37 1.00\n",
|
|
"draws_b[46] -0.16 0.24 -0.15 -0.55 0.23 2297.93 1.00\n",
|
|
"draws_b[47] -0.08 0.24 -0.08 -0.46 0.31 2919.42 1.00\n",
|
|
"draws_b[48] -0.47 0.26 -0.47 -0.85 -0.01 2107.38 1.00\n",
|
|
"draws_b[49] 0.21 0.25 0.21 -0.22 0.62 2278.84 1.00\n",
|
|
"draws_b[50] 0.42 0.24 0.42 0.03 0.82 2249.94 1.00\n",
|
|
"draws_b[51] 0.13 0.24 0.13 -0.26 0.52 2023.52 1.00\n",
|
|
"draws_b[52] -0.18 0.25 -0.18 -0.59 0.22 2162.22 1.00\n",
|
|
"draws_b[53] 0.64 0.23 0.64 0.28 1.02 2077.93 1.00\n",
|
|
"draws_b[54] -0.93 0.28 -0.93 -1.36 -0.44 1897.25 1.00\n",
|
|
"draws_b[55] 0.95 0.23 0.95 0.56 1.30 1707.46 1.00\n",
|
|
"draws_b[56] -0.27 0.26 -0.26 -0.71 0.16 2793.76 1.00\n",
|
|
"draws_b[57] 0.10 0.25 0.10 -0.33 0.47 2099.51 1.00\n",
|
|
"draws_b[58] -0.25 0.26 -0.25 -0.65 0.18 2107.67 1.00\n",
|
|
"draws_b[59] 0.14 0.23 0.13 -0.23 0.50 2137.52 1.00\n",
|
|
" draws_s[0] 0.24 0.37 0.24 -0.32 0.89 163.41 1.02\n",
|
|
" draws_s[1] -2.24 1.33 -2.03 -4.17 -0.11 1355.16 1.00\n",
|
|
" draws_s[2] 2.44 0.34 2.44 1.89 3.00 137.57 1.02\n",
|
|
" draws_s[3] 0.13 0.38 0.13 -0.54 0.72 190.29 1.01\n",
|
|
" draws_s[4] 1.26 0.35 1.27 0.72 1.87 124.68 1.02\n",
|
|
" draws_s[5] 2.50 0.34 2.49 1.95 3.04 127.95 1.02\n",
|
|
" draws_s[6] -2.14 1.23 -1.94 -4.22 -0.37 1116.20 1.00\n",
|
|
" draws_s[7] -2.17 1.26 -1.97 -3.95 -0.21 1188.19 1.00\n",
|
|
" draws_s[8] -2.15 1.24 -1.96 -3.97 -0.23 1114.07 1.00\n",
|
|
" draws_s[9] -2.21 1.31 -2.01 -4.19 -0.20 1299.27 1.00\n",
|
|
"draws_s[10] 0.69 0.36 0.69 0.07 1.25 142.47 1.02\n",
|
|
"draws_s[11] 5.05 0.96 4.91 3.60 6.44 435.17 1.01\n",
|
|
"draws_s[12] 2.90 0.36 2.90 2.26 3.44 152.54 1.02\n",
|
|
"draws_s[13] 1.21 0.34 1.21 0.61 1.76 135.16 1.02\n",
|
|
"draws_s[14] 5.02 0.90 4.90 3.68 6.42 486.24 1.01\n",
|
|
"draws_s[15] 1.95 0.34 1.95 1.38 2.46 124.29 1.02\n",
|
|
"draws_s[16] -2.19 1.28 -1.97 -4.22 -0.31 1242.63 1.00\n",
|
|
"draws_s[17] 1.58 0.33 1.58 1.02 2.10 126.45 1.02\n",
|
|
"draws_s[18] -2.13 1.21 -1.96 -3.97 -0.31 1757.34 1.00\n",
|
|
"draws_s[19] 2.58 0.34 2.57 1.96 3.11 131.96 1.02\n",
|
|
"draws_s[20] 2.04 0.34 2.04 1.43 2.55 123.61 1.02\n",
|
|
"draws_s[21] -2.21 1.35 -1.96 -4.20 -0.23 1317.17 1.00\n",
|
|
"draws_s[22] -2.15 1.26 -1.94 -4.07 -0.28 1177.37 1.00\n",
|
|
"draws_s[23] 1.17 0.34 1.16 0.58 1.71 140.12 1.02\n",
|
|
"draws_s[24] -2.16 1.27 -1.93 -4.12 -0.22 1235.15 1.00\n",
|
|
"draws_s[25] -0.66 0.53 -0.63 -1.48 0.22 386.61 1.01\n",
|
|
"draws_s[26] 2.57 0.34 2.58 1.99 3.11 139.20 1.02\n",
|
|
"draws_s[27] 0.73 0.35 0.73 0.17 1.29 157.69 1.02\n",
|
|
"draws_s[28] -2.11 1.23 -1.88 -3.97 -0.34 949.62 1.00\n",
|
|
"draws_s[29] -2.22 1.33 -2.01 -4.10 -0.09 1101.53 1.00\n",
|
|
"draws_s[30] 1.62 0.34 1.62 1.06 2.15 124.17 1.02\n",
|
|
"draws_s[31] -2.22 1.31 -1.98 -4.27 -0.28 1029.67 1.00\n",
|
|
"draws_s[32] -0.58 0.52 -0.56 -1.54 0.20 287.87 1.01\n",
|
|
"draws_s[33] 1.61 0.33 1.61 1.06 2.16 134.07 1.02\n",
|
|
"draws_s[34] 1.52 0.34 1.52 0.98 2.10 127.97 1.02\n",
|
|
"draws_s[35] -0.05 0.41 -0.04 -0.70 0.67 194.24 1.01\n",
|
|
"draws_s[36] -0.67 0.53 -0.64 -1.55 0.19 387.18 1.00\n",
|
|
"draws_s[37] 1.47 0.34 1.47 0.86 1.98 132.48 1.02\n",
|
|
"draws_s[38] -2.21 1.28 -1.98 -4.21 -0.28 877.00 1.00\n",
|
|
"draws_s[39] -2.17 1.25 -1.98 -4.01 -0.19 1869.96 1.00\n",
|
|
"draws_s[40] -0.23 0.45 -0.22 -0.94 0.50 248.15 1.01\n",
|
|
"draws_s[41] -2.17 1.26 -1.95 -4.10 -0.28 1188.23 1.00\n",
|
|
"draws_s[42] 2.04 0.33 2.04 1.48 2.57 127.20 1.02\n",
|
|
"draws_s[43] 3.32 0.38 3.32 2.67 3.91 162.59 1.02\n",
|
|
"draws_s[44] -0.72 0.54 -0.68 -1.56 0.18 324.04 1.01\n",
|
|
"draws_s[45] -2.22 1.29 -2.01 -4.37 -0.37 1211.64 1.00\n",
|
|
"draws_s[46] -2.13 1.21 -1.92 -3.91 -0.27 1580.13 1.00\n",
|
|
"draws_s[47] 1.52 0.33 1.52 0.95 2.03 122.64 1.02\n",
|
|
"draws_s[48] -0.72 0.53 -0.69 -1.64 0.06 295.95 1.00\n",
|
|
"draws_s[49] 0.22 0.38 0.23 -0.44 0.80 167.38 1.01\n",
|
|
"draws_s[50] -0.66 0.52 -0.64 -1.46 0.23 331.43 1.01\n",
|
|
"draws_s[51] 0.46 0.37 0.46 -0.10 1.10 157.09 1.01\n",
|
|
"draws_s[52] -0.09 0.42 -0.08 -0.79 0.60 213.74 1.01\n",
|
|
"draws_s[53] -2.12 1.26 -1.90 -3.98 -0.15 1200.39 1.00\n",
|
|
"draws_s[54] 1.71 0.34 1.70 1.13 2.23 132.50 1.02\n",
|
|
"draws_s[55] -2.15 1.26 -1.93 -4.13 -0.36 925.27 1.00\n",
|
|
"draws_s[56] -0.73 0.57 -0.68 -1.65 0.18 380.68 1.00\n",
|
|
"draws_s[57] 0.36 0.37 0.35 -0.27 0.95 154.13 1.02\n",
|
|
"draws_s[58] 1.90 0.34 1.90 1.36 2.46 128.04 1.02\n",
|
|
"draws_s[59] -0.72 0.55 -0.69 -1.62 0.16 374.05 1.01\n",
|
|
" mu_tot -1.82 0.30 -1.82 -2.29 -1.34 101.86 1.03\n",
|
|
" sigma_b 0.53 0.06 0.53 0.43 0.63 993.53 1.00\n",
|
|
" sigma_s 2.17 0.28 2.15 1.68 2.59 442.90 1.01\n",
|
|
"\n",
|
|
"Number of divergences: 0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"rng_key = jrandom.PRNGKey(0)\n",
|
|
"rng_key, rng_key_ = jrandom.split(rng_key)\n",
|
|
"\n",
|
|
"kernel = NUTS(model)\n",
|
|
"num_samples = 2000\n",
|
|
"mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)\n",
|
|
"mcmc.run(\n",
|
|
" rng_key_, samples=grid\n",
|
|
")\n",
|
|
"mcmc.print_summary()\n",
|
|
"samples_1 = mcmc.get_samples()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|