{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "33c24830-7601-4c52-8fa0-66330801d6a9", "metadata": {}, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "id": "985106e5-7a80-4d4d-b1c9-54a873f0b429", "metadata": {}, "outputs": [], "source": [ "import jax.numpy as jnp\n", "import jax.random as jrandom\n", "import jax.scipy.stats as jspstats" ] }, { "cell_type": "code", "execution_count": 3, "id": "8a8cc9de-9802-4761-9bba-7b77300d0bb8", "metadata": {}, "outputs": [], "source": [ "import numpyro\n", "import numpyro.distributions as dists\n", "from numpyro.infer import MCMC, NUTS" ] }, { "cell_type": "code", "execution_count": 4, "id": "9b4c0d3e-2889-4bf3-9871-6648b5de3508", "metadata": {}, "outputs": [], "source": [ "def model(samples=None, shape=(60, 60)):\n", " sigma_s = numpyro.sample('sigma_s', dists.Normal(0., 1_000.))\n", " sigma_b = numpyro.sample('sigma_b', dists.Normal(0., 1_000.))\n", " mu_tot = numpyro.sample('mu_tot', dists.Normal(0., 50.))\n", " \n", " draws_s = numpyro.sample('draws_s', dists.Normal(0., sigma_s), sample_shape=shape[:-1])\n", " draws_b = numpyro.sample('draws_b', dists.Normal(0., sigma_b), sample_shape=shape[-1:])\n", " draws_bias = mu_tot + draws_s[:, np.newaxis] + draws_b[np.newaxis, :]\n", " numpyro.sample('samples', dists.Bernoulli(jspstats.norm.cdf(draws_bias)), obs=samples)" ] }, { "cell_type": "code", "execution_count": 5, "id": "caaf0f54-205e-4d13-84b6-464c1583d432", "metadata": {}, "outputs": [], "source": [ "sigma_s = 2.5\n", "sigma_b = 0.6\n", "mu_tot = -1.\n", "\n", "grid_size = [60, 60]\n", "grid = (\n", " mu_tot\n", " + (np.random.normal(size=grid_size[0]) * sigma_s)[:, np.newaxis]\n", " + (np.random.normal(size=grid_size[1]) * sigma_b)[np.newaxis, :]\n", " + np.random.normal(size=grid_size)\n", ")\n", "thresholded_grid = 1. * (grid > 0)" ] }, { "cell_type": "code", "execution_count": 6, "id": "06b670cf-73bb-40fd-b152-b111b13b9d50", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1700508193.049216 1423603 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n", "2023-11-20 20:23:13.233156: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 12620922880\n", "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n", "sample: 100%|█| 3000/3000 [00:24<00:00, 124.55it/s, 31 steps of size 9.43e-02. a\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " draws_b[0] -0.12 0.25 -0.12 -0.55 0.26 1518.82 1.00\n", " draws_b[1] 0.01 0.23 0.01 -0.32 0.40 1139.98 1.00\n", " draws_b[2] 0.06 0.24 0.07 -0.32 0.47 1207.20 1.00\n", " draws_b[3] 0.67 0.24 0.67 0.26 1.05 1372.13 1.00\n", " draws_b[4] -0.12 0.25 -0.13 -0.55 0.27 958.14 1.00\n", " draws_b[5] -0.85 0.27 -0.84 -1.32 -0.45 1090.23 1.00\n", " draws_b[6] 0.08 0.24 0.08 -0.34 0.44 1572.24 1.00\n", " draws_b[7] 1.06 0.24 1.05 0.65 1.45 1080.15 1.00\n", " draws_b[8] -0.39 0.25 -0.39 -0.80 0.02 1451.30 1.00\n", " draws_b[9] 0.45 0.25 0.45 0.04 0.86 1228.30 1.00\n", "draws_b[10] -0.19 0.26 -0.19 -0.61 0.23 1105.58 1.00\n", "draws_b[11] -0.50 0.25 -0.50 -0.91 -0.09 1682.54 1.00\n", "draws_b[12] 0.64 0.24 0.64 0.25 1.02 1535.79 1.00\n", "draws_b[13] -1.01 0.30 -1.00 -1.49 -0.52 1616.67 1.00\n", "draws_b[14] 0.63 0.25 0.63 0.23 1.03 1491.16 1.00\n", "draws_b[15] -0.41 0.26 -0.41 -0.83 -0.00 1049.79 1.00\n", "draws_b[16] 0.21 0.23 0.20 -0.17 0.58 1766.02 1.00\n", "draws_b[17] 0.48 0.24 0.48 0.10 0.89 1252.22 1.00\n", "draws_b[18] 0.18 0.25 0.18 -0.20 0.59 987.72 1.00\n", "draws_b[19] 0.25 0.25 0.26 -0.16 0.67 1055.81 1.00\n", "draws_b[20] 1.00 0.24 0.99 0.62 1.41 1112.57 1.00\n", "draws_b[21] 0.09 0.24 0.08 -0.32 0.49 1175.33 1.00\n", "draws_b[22] -1.01 0.28 -1.01 -1.44 -0.52 1167.78 1.00\n", "draws_b[23] 0.75 0.25 0.75 0.32 1.12 1340.63 1.00\n", "draws_b[24] -0.62 0.26 -0.61 -1.04 -0.20 1473.98 1.00\n", "draws_b[25] -0.42 0.27 -0.42 -0.80 0.06 927.46 1.00\n", "draws_b[26] 0.06 0.24 0.07 -0.33 0.43 1306.08 1.00\n", "draws_b[27] -0.33 0.24 -0.32 -0.72 0.08 1242.18 1.00\n", "draws_b[28] -1.44 0.31 -1.43 -1.94 -0.92 1510.91 1.00\n", "draws_b[29] 0.55 0.24 0.55 0.15 0.94 1611.22 1.00\n", "draws_b[30] -0.21 0.25 -0.22 -0.63 0.17 1237.35 1.00\n", "draws_b[31] 0.64 0.24 0.64 0.26 1.04 1601.02 1.00\n", "draws_b[32] 0.07 0.25 0.07 -0.33 0.49 1435.76 1.00\n", "draws_b[33] 0.08 0.24 0.08 -0.35 0.44 1251.02 1.00\n", "draws_b[34] -0.22 0.26 -0.23 -0.62 0.24 1006.62 1.00\n", "draws_b[35] -0.53 0.26 -0.53 -0.93 -0.08 1504.63 1.00\n", "draws_b[36] -0.41 0.26 -0.40 -0.85 -0.01 1467.50 1.00\n", "draws_b[37] 0.17 0.25 0.16 -0.29 0.53 1575.94 1.00\n", "draws_b[38] -0.54 0.26 -0.54 -0.95 -0.08 1535.38 1.00\n", "draws_b[39] -0.53 0.27 -0.53 -0.98 -0.10 1265.94 1.00\n", "draws_b[40] 0.63 0.24 0.63 0.29 1.08 1232.50 1.00\n", "draws_b[41] 0.08 0.24 0.08 -0.32 0.45 1127.30 1.00\n", "draws_b[42] -0.42 0.26 -0.41 -0.85 0.02 1314.31 1.00\n", "draws_b[43] -0.50 0.26 -0.49 -0.95 -0.11 1455.45 1.00\n", "draws_b[44] -0.18 0.25 -0.17 -0.56 0.25 913.83 1.00\n", "draws_b[45] 0.60 0.23 0.60 0.21 0.97 1245.89 1.00\n", "draws_b[46] -0.43 0.26 -0.42 -0.88 -0.04 1174.07 1.00\n", "draws_b[47] 0.08 0.24 0.09 -0.31 0.47 1662.04 1.00\n", "draws_b[48] -0.20 0.25 -0.20 -0.61 0.19 1108.54 1.00\n", "draws_b[49] 0.10 0.25 0.10 -0.32 0.50 1110.46 1.00\n", "draws_b[50] 0.38 0.25 0.38 -0.01 0.81 1155.53 1.00\n", "draws_b[51] 0.45 0.25 0.44 0.06 0.90 1474.73 1.00\n", "draws_b[52] 0.03 0.25 0.04 -0.39 0.41 1139.01 1.00\n", "draws_b[53] 1.09 0.23 1.09 0.72 1.47 1190.93 1.00\n", "draws_b[54] 1.10 0.25 1.11 0.69 1.51 1356.29 1.00\n", "draws_b[55] 0.34 0.25 0.34 -0.07 0.75 1232.26 1.00\n", "draws_b[56] -1.07 0.28 -1.06 -1.54 -0.62 1476.18 1.00\n", "draws_b[57] -0.52 0.27 -0.52 -1.01 -0.12 1481.57 1.00\n", "draws_b[58] -0.48 0.27 -0.48 -0.91 -0.05 1592.37 1.00\n", "draws_b[59] 0.47 0.24 0.47 0.09 0.88 1099.82 1.00\n", " draws_s[0] -0.60 0.49 -0.61 -1.39 0.22 82.39 1.02\n", " draws_s[1] -3.06 1.57 -2.77 -5.38 -0.74 988.23 1.00\n", " draws_s[2] -3.00 1.48 -2.67 -5.46 -1.02 1036.62 1.00\n", " draws_s[3] -3.06 1.55 -2.81 -5.51 -0.78 835.55 1.00\n", " draws_s[4] 3.88 0.60 3.85 2.97 4.90 135.00 1.01\n", " draws_s[5] 2.24 0.43 2.20 1.56 2.97 64.78 1.02\n", " draws_s[6] -2.92 1.44 -2.65 -5.08 -0.63 954.37 1.00\n", " draws_s[7] 2.32 0.44 2.29 1.58 2.99 72.55 1.02\n", " draws_s[8] -0.20 0.46 -0.23 -1.00 0.50 79.78 1.02\n", " draws_s[9] -3.05 1.57 -2.73 -5.40 -0.71 883.40 1.00\n", "draws_s[10] 1.23 0.43 1.19 0.55 1.98 63.48 1.03\n", "draws_s[11] -3.06 1.59 -2.78 -5.34 -0.70 890.22 1.00\n", "draws_s[12] -2.98 1.55 -2.72 -5.30 -0.63 653.67 1.00\n", "draws_s[13] -2.97 1.47 -2.69 -5.28 -0.73 827.64 1.00\n", "draws_s[14] -2.95 1.47 -2.69 -5.13 -0.74 798.18 1.00\n", "draws_s[15] 0.64 0.43 0.61 -0.10 1.29 64.57 1.03\n", "draws_s[16] -3.01 1.48 -2.75 -5.40 -0.89 907.27 1.00\n", "draws_s[17] -0.33 0.47 -0.37 -1.14 0.40 77.67 1.02\n", "draws_s[18] 1.58 0.42 1.55 0.87 2.25 66.72 1.02\n", "draws_s[19] -2.95 1.49 -2.69 -5.20 -0.74 705.26 1.00\n", "draws_s[20] 3.92 0.61 3.88 2.94 4.88 154.23 1.02\n", "draws_s[21] -2.99 1.53 -2.70 -5.27 -0.70 901.85 1.00\n", "draws_s[22] 5.01 1.16 4.82 3.22 6.81 332.92 1.00\n", "draws_s[23] -2.96 1.49 -2.69 -5.35 -0.80 869.13 1.00\n", "draws_s[24] 1.82 0.43 1.80 1.08 2.49 67.65 1.02\n", "draws_s[25] 0.34 0.44 0.31 -0.43 1.03 70.26 1.02\n", "draws_s[26] 1.83 0.43 1.81 1.14 2.54 66.28 1.03\n", "draws_s[27] -3.07 1.55 -2.75 -5.45 -0.66 894.24 1.00\n", "draws_s[28] -0.44 0.46 -0.46 -1.19 0.35 79.06 1.02\n", "draws_s[29] 1.78 0.43 1.75 1.05 2.44 67.72 1.02\n", "draws_s[30] 0.35 0.44 0.31 -0.35 1.05 68.32 1.03\n", "draws_s[31] 0.65 0.43 0.61 -0.13 1.30 65.39 1.02\n", "draws_s[32] -0.34 0.47 -0.36 -1.02 0.51 80.46 1.02\n", "draws_s[33] 1.47 0.43 1.43 0.80 2.21 64.06 1.02\n", "draws_s[34] 1.56 0.43 1.51 0.84 2.24 64.75 1.02\n", "draws_s[35] 0.91 0.43 0.87 0.22 1.62 67.79 1.02\n", "draws_s[36] 3.87 0.61 3.82 2.88 4.88 134.19 1.01\n", "draws_s[37] 0.92 0.43 0.89 0.16 1.57 66.51 1.02\n", "draws_s[38] -0.02 0.45 -0.04 -0.77 0.73 73.91 1.02\n", "draws_s[39] 1.22 0.43 1.18 0.45 1.88 67.00 1.02\n", "draws_s[40] 1.23 0.43 1.19 0.47 1.89 67.51 1.02\n", "draws_s[41] 3.52 0.53 3.49 2.59 4.32 101.16 1.02\n", "draws_s[42] -0.32 0.46 -0.35 -1.07 0.45 79.62 1.02\n", "draws_s[43] 0.33 0.44 0.30 -0.38 1.05 68.23 1.02\n", "draws_s[44] -3.03 1.58 -2.75 -5.51 -0.80 857.47 1.00\n", "draws_s[45] -1.31 0.61 -1.29 -2.24 -0.27 163.59 1.01\n", "draws_s[46] -2.92 1.45 -2.65 -5.30 -0.88 733.83 1.00\n", "draws_s[47] -2.95 1.45 -2.66 -5.07 -0.76 778.21 1.00\n", "draws_s[48] 1.37 0.43 1.33 0.69 2.09 62.64 1.03\n", "draws_s[49] -0.07 0.45 -0.09 -0.85 0.66 74.99 1.02\n", "draws_s[50] -0.44 0.47 -0.47 -1.25 0.32 81.65 1.02\n", "draws_s[51] 0.93 0.43 0.89 0.28 1.71 66.19 1.03\n", "draws_s[52] 3.10 0.49 3.05 2.21 3.88 85.58 1.02\n", "draws_s[53] -1.21 0.60 -1.20 -2.25 -0.26 123.29 1.01\n", "draws_s[54] 1.67 0.43 1.64 0.98 2.35 68.76 1.02\n", "draws_s[55] -0.80 0.51 -0.83 -1.60 0.10 98.06 1.02\n", "draws_s[56] -3.03 1.55 -2.77 -5.41 -0.82 856.83 1.00\n", "draws_s[57] 1.45 0.42 1.42 0.74 2.13 64.35 1.03\n", "draws_s[58] 5.01 1.14 4.82 3.14 6.65 321.29 1.01\n", "draws_s[59] 5.06 1.24 4.84 3.24 6.99 362.34 1.00\n", " mu_tot -1.33 0.40 -1.30 -1.99 -0.65 54.43 1.03\n", " sigma_b 0.62 0.07 0.62 0.51 0.73 873.49 1.00\n", " sigma_s 2.64 0.37 2.60 2.09 3.26 315.46 1.01\n", "\n", "Number of divergences: 0\n" ] } ], "source": [ "rng_key = jrandom.PRNGKey(0)\n", "rng_key, rng_key_ = jrandom.split(rng_key)\n", "\n", "kernel = NUTS(model)\n", "num_samples = 2000\n", "mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)\n", "mcmc.run(\n", " rng_key_, samples=thresholded_grid\n", ")\n", "mcmc.print_summary()\n", "samples_1 = mcmc.get_samples()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }