From cee076e4695c6782f5fd6d3b5c6156dd102626cf Mon Sep 17 00:00:00 2001 From: imperator Date: Mon, 20 Nov 2023 20:21:11 +0100 Subject: [PATCH] Initial commit --- mcmc.ipynb | 256 +++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 4 + 2 files changed, 260 insertions(+) create mode 100644 mcmc.ipynb create mode 100644 requirements.txt diff --git a/mcmc.ipynb b/mcmc.ipynb new file mode 100644 index 0000000..2440776 --- /dev/null +++ b/mcmc.ipynb @@ -0,0 +1,256 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "33c24830-7601-4c52-8fa0-66330801d6a9", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "985106e5-7a80-4d4d-b1c9-54a873f0b429", + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "import jax.random as jrandom\n", + "import jax.scipy.stats as jspstats" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8a8cc9de-9802-4761-9bba-7b77300d0bb8", + "metadata": {}, + "outputs": [], + "source": [ + "import numpyro\n", + "import numpyro.distributions as dists\n", + "from numpyro.infer import MCMC, NUTS" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9b4c0d3e-2889-4bf3-9871-6648b5de3508", + "metadata": {}, + "outputs": [], + "source": [ + "def model(samples=None, shape=(60, 60)):\n", + " sigma_s = numpyro.sample('sigma_s', dists.Normal(0., 1_000.))\n", + " sigma_b = numpyro.sample('sigma_b', dists.Normal(0., 1_000.))\n", + " mu_tot = numpyro.sample('mu_tot', dists.Normal(0., 50.))\n", + " \n", + " draws_s = numpyro.sample('draws_s', dists.Normal(0., sigma_s), sample_shape=shape[:-1])\n", + " draws_b = numpyro.sample('draws_b', dists.Normal(0., sigma_b), sample_shape=shape[-1:])\n", + " draws_bias = mu_tot + draws_s[:, np.newaxis] + draws_b[np.newaxis, :]\n", + " numpyro.sample('samples', dists.Bernoulli(jspstats.norm.cdf(draws_bias)), obs=samples)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "caaf0f54-205e-4d13-84b6-464c1583d432", + "metadata": {}, + "outputs": [], + "source": [ + "s_s = 2.5\n", + "s_b = 0.6\n", + "m_t = -1.\n", + "\n", + "grid_size = [60, 60]\n", + "grid = (np.random.normal(size=grid_size[0]) * s_s)[:, np.newaxis] + (np.random.normal(size=grid_size[1]) * s_b)[np.newaxis, :] + np.random.normal(size=grid_size)\n", + "thresholded_grid = 1. * (grid + m_t > 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "06b670cf-73bb-40fd-b152-b111b13b9d50", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "I0000 00:00:1700507813.984381 1419943 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n", + "sample: 100%|█| 3000/3000 [00:42<00:00, 70.76it/s, 31 steps of size 1.05e-01. ac\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " mean std median 5.0% 95.0% n_eff r_hat\n", + " draws_b[0] 0.54 0.27 0.54 0.07 0.95 1726.16 1.00\n", + " draws_b[1] 0.32 0.25 0.31 -0.09 0.71 1369.21 1.00\n", + " draws_b[2] 0.68 0.26 0.69 0.25 1.09 1518.55 1.00\n", + " draws_b[3] -0.22 0.27 -0.22 -0.65 0.22 1533.48 1.00\n", + " draws_b[4] 0.23 0.26 0.23 -0.20 0.65 1078.14 1.00\n", + " draws_b[5] 0.10 0.27 0.10 -0.35 0.52 1411.39 1.00\n", + " draws_b[6] -1.09 0.31 -1.09 -1.59 -0.58 1922.28 1.00\n", + " draws_b[7] 0.58 0.25 0.58 0.19 1.00 1254.93 1.00\n", + " draws_b[8] -0.22 0.27 -0.22 -0.67 0.19 1412.75 1.00\n", + " draws_b[9] 0.55 0.25 0.55 0.13 0.96 1191.63 1.00\n", + "draws_b[10] 0.26 0.27 0.26 -0.18 0.68 1384.75 1.00\n", + "draws_b[11] 0.34 0.25 0.34 -0.06 0.75 1880.85 1.00\n", + "draws_b[12] 0.02 0.26 0.02 -0.42 0.43 1484.88 1.00\n", + "draws_b[13] 0.10 0.28 0.10 -0.34 0.55 1793.79 1.00\n", + "draws_b[14] 0.00 0.27 0.00 -0.43 0.46 1575.12 1.00\n", + "draws_b[15] 0.10 0.25 0.11 -0.28 0.54 1463.69 1.00\n", + "draws_b[16] 0.37 0.25 0.37 -0.03 0.78 1831.50 1.00\n", + "draws_b[17] 0.79 0.25 0.79 0.40 1.22 1244.41 1.00\n", + "draws_b[18] -0.19 0.27 -0.18 -0.62 0.27 1036.48 1.00\n", + "draws_b[19] 0.49 0.25 0.49 0.11 0.95 1760.80 1.00\n", + "draws_b[20] -0.79 0.29 -0.78 -1.26 -0.33 1345.88 1.00\n", + "draws_b[21] 0.01 0.26 0.01 -0.41 0.45 1503.33 1.00\n", + "draws_b[22] 0.79 0.25 0.79 0.40 1.20 1519.65 1.00\n", + "draws_b[23] -1.59 0.35 -1.58 -2.20 -1.05 1248.30 1.00\n", + "draws_b[24] 1.19 0.24 1.20 0.78 1.58 1363.44 1.00\n", + "draws_b[25] 0.06 0.26 0.05 -0.34 0.51 1344.98 1.00\n", + "draws_b[26] 0.11 0.26 0.11 -0.31 0.53 1329.94 1.00\n", + "draws_b[27] 0.13 0.25 0.13 -0.30 0.52 1586.63 1.00\n", + "draws_b[28] -0.70 0.29 -0.69 -1.15 -0.21 1783.14 1.00\n", + "draws_b[29] 1.32 0.23 1.33 0.94 1.71 1445.95 1.00\n", + "draws_b[30] 0.04 0.26 0.04 -0.38 0.46 1281.60 1.00\n", + "draws_b[31] -0.46 0.28 -0.46 -0.91 -0.01 1579.55 1.00\n", + "draws_b[32] -1.09 0.31 -1.08 -1.57 -0.57 1433.98 1.00\n", + "draws_b[33] -1.19 0.29 -1.18 -1.67 -0.74 1349.72 1.00\n", + "draws_b[34] 0.57 0.25 0.57 0.19 1.01 1176.79 1.00\n", + "draws_b[35] 0.45 0.25 0.45 0.06 0.89 1484.68 1.00\n", + "draws_b[36] -0.20 0.26 -0.20 -0.65 0.20 1629.53 1.00\n", + "draws_b[37] 0.34 0.26 0.35 -0.13 0.74 1505.92 1.00\n", + "draws_b[38] 0.15 0.25 0.15 -0.23 0.59 1443.92 1.00\n", + "draws_b[39] 0.34 0.26 0.34 -0.07 0.79 1775.66 1.00\n", + "draws_b[40] 0.23 0.26 0.24 -0.17 0.68 1457.14 1.00\n", + "draws_b[41] -0.58 0.27 -0.57 -1.01 -0.15 1313.66 1.00\n", + "draws_b[42] -0.49 0.27 -0.48 -0.94 -0.04 1490.91 1.00\n", + "draws_b[43] -0.57 0.27 -0.56 -1.01 -0.11 1483.20 1.00\n", + "draws_b[44] 0.23 0.25 0.23 -0.15 0.66 981.31 1.00\n", + "draws_b[45] -0.31 0.26 -0.31 -0.70 0.15 1435.06 1.00\n", + "draws_b[46] -0.33 0.26 -0.33 -0.77 0.10 1224.35 1.00\n", + "draws_b[47] -0.62 0.27 -0.61 -1.05 -0.19 1631.06 1.00\n", + "draws_b[48] -0.48 0.27 -0.48 -0.92 -0.05 1106.08 1.00\n", + "draws_b[49] 1.16 0.26 1.16 0.74 1.58 1318.62 1.00\n", + "draws_b[50] -0.33 0.27 -0.33 -0.77 0.10 1336.15 1.00\n", + "draws_b[51] 0.35 0.25 0.34 -0.07 0.78 1383.73 1.00\n", + "draws_b[52] -0.32 0.26 -0.32 -0.76 0.09 1547.82 1.00\n", + "draws_b[53] 0.43 0.25 0.43 0.02 0.81 1491.33 1.00\n", + "draws_b[54] 0.10 0.26 0.10 -0.34 0.52 1530.15 1.00\n", + "draws_b[55] -0.58 0.28 -0.58 -1.07 -0.16 1241.15 1.00\n", + "draws_b[56] -0.48 0.28 -0.47 -0.95 -0.02 2004.81 1.00\n", + "draws_b[57] -0.45 0.28 -0.45 -0.91 -0.00 1592.65 1.00\n", + "draws_b[58] -0.22 0.29 -0.23 -0.68 0.24 1622.34 1.00\n", + "draws_b[59] -0.37 0.26 -0.37 -0.82 0.03 1411.27 1.00\n", + " draws_s[0] 1.59 0.41 1.58 0.99 2.31 91.66 1.01\n", + " draws_s[1] 2.00 0.42 1.97 1.32 2.67 92.51 1.01\n", + " draws_s[2] -0.81 0.47 -0.82 -1.60 -0.08 117.14 1.01\n", + " draws_s[3] 1.44 0.41 1.43 0.79 2.09 93.06 1.01\n", + " draws_s[4] 1.30 0.42 1.28 0.69 2.03 86.22 1.01\n", + " draws_s[5] -3.25 1.59 -2.94 -5.69 -1.06 801.05 1.00\n", + " draws_s[6] -0.52 0.45 -0.54 -1.30 0.16 108.95 1.01\n", + " draws_s[7] 1.96 0.42 1.93 1.29 2.62 94.58 1.01\n", + " draws_s[8] 2.37 0.43 2.35 1.69 3.07 98.58 1.01\n", + " draws_s[9] -3.24 1.55 -2.93 -5.45 -0.80 899.89 1.00\n", + "draws_s[10] 5.06 1.33 4.78 3.02 7.12 436.17 1.00\n", + "draws_s[11] -0.62 0.46 -0.66 -1.36 0.10 111.43 1.01\n", + "draws_s[12] -0.62 0.47 -0.65 -1.35 0.17 120.13 1.01\n", + "draws_s[13] -3.21 1.47 -2.95 -5.44 -0.99 1044.95 1.00\n", + "draws_s[14] -3.16 1.49 -2.89 -5.45 -0.95 954.54 1.00\n", + "draws_s[15] 5.03 1.29 4.83 3.12 7.02 603.50 1.00\n", + "draws_s[16] -0.83 0.48 -0.83 -1.62 -0.08 126.28 1.01\n", + "draws_s[17] -0.27 0.44 -0.28 -1.04 0.37 101.52 1.01\n", + "draws_s[18] -0.75 0.48 -0.77 -1.50 0.03 117.77 1.01\n", + "draws_s[19] -3.20 1.49 -2.97 -5.40 -0.87 866.56 1.00\n", + "draws_s[20] 1.54 0.41 1.52 0.90 2.21 90.24 1.01\n", + "draws_s[21] 0.01 0.43 -0.01 -0.73 0.65 96.74 1.01\n", + "draws_s[22] -1.05 0.51 -1.05 -1.86 -0.19 140.02 1.00\n", + "draws_s[23] -3.26 1.60 -2.94 -5.74 -0.85 956.12 1.00\n", + "draws_s[24] 0.65 0.41 0.63 0.02 1.36 92.65 1.01\n", + "draws_s[25] 5.03 1.29 4.79 3.24 7.05 699.49 1.00\n", + "draws_s[26] -0.10 0.43 -0.12 -0.75 0.62 97.31 1.01\n", + "draws_s[27] -0.64 0.46 -0.65 -1.39 0.10 116.88 1.01\n", + "draws_s[28] 0.40 0.42 0.38 -0.31 1.05 89.63 1.01\n", + "draws_s[29] 3.80 0.62 3.76 2.83 4.79 201.10 1.00\n", + "draws_s[30] 1.73 0.41 1.72 1.04 2.37 90.90 1.01\n", + "draws_s[31] -3.22 1.52 -2.91 -5.73 -1.11 1008.01 1.00\n", + "draws_s[32] -0.80 0.49 -0.80 -1.61 -0.05 125.12 1.01\n", + "draws_s[33] 1.91 0.42 1.89 1.23 2.57 97.31 1.01\n", + "draws_s[34] 1.09 0.41 1.06 0.41 1.74 90.62 1.01\n", + "draws_s[35] 1.04 0.41 1.01 0.35 1.69 90.74 1.01\n", + "draws_s[36] -1.53 0.61 -1.52 -2.50 -0.47 220.52 1.00\n", + "draws_s[37] -0.86 0.50 -0.87 -1.64 -0.06 124.72 1.01\n", + "draws_s[38] -1.41 0.58 -1.40 -2.40 -0.50 165.96 1.01\n", + "draws_s[39] -1.00 0.50 -1.00 -1.86 -0.26 134.26 1.01\n", + "draws_s[40] 0.43 0.42 0.40 -0.32 1.03 91.86 1.01\n", + "draws_s[41] -3.17 1.51 -2.88 -5.51 -0.94 802.41 1.00\n", + "draws_s[42] -3.17 1.51 -2.90 -5.58 -1.03 703.27 1.01\n", + "draws_s[43] -3.23 1.51 -2.94 -5.49 -1.06 1047.16 1.00\n", + "draws_s[44] -0.69 0.47 -0.70 -1.49 0.04 111.69 1.01\n", + "draws_s[45] -3.24 1.51 -2.96 -5.43 -0.87 809.34 1.00\n", + "draws_s[46] -1.14 0.52 -1.14 -1.96 -0.27 133.51 1.01\n", + "draws_s[47] 1.50 0.41 1.48 0.86 2.15 88.51 1.01\n", + "draws_s[48] -3.26 1.54 -2.99 -5.62 -0.90 776.73 1.00\n", + "draws_s[49] 2.81 0.45 2.79 2.06 3.48 110.23 1.01\n", + "draws_s[50] -1.50 0.60 -1.49 -2.48 -0.50 198.81 1.00\n", + "draws_s[51] 3.75 0.62 3.70 2.78 4.75 191.96 1.00\n", + "draws_s[52] 1.32 0.41 1.30 0.63 1.96 89.34 1.01\n", + "draws_s[53] -0.64 0.47 -0.64 -1.42 0.06 117.17 1.01\n", + "draws_s[54] -3.25 1.57 -2.90 -5.81 -1.11 894.61 1.00\n", + "draws_s[55] 5.01 1.26 4.76 3.33 6.99 449.27 1.00\n", + "draws_s[56] 5.05 1.32 4.82 3.05 6.98 642.23 1.00\n", + "draws_s[57] 4.99 1.27 4.72 3.12 7.02 663.59 1.00\n", + "draws_s[58] -3.14 1.44 -2.88 -5.37 -1.10 808.01 1.00\n", + "draws_s[59] 1.50 0.42 1.48 0.84 2.18 89.42 1.01\n", + " mu_tot -1.15 0.38 -1.13 -1.74 -0.53 76.53 1.01\n", + " sigma_b 0.65 0.07 0.64 0.53 0.77 935.24 1.00\n", + " sigma_s 2.75 0.37 2.70 2.18 3.36 425.85 1.00\n", + "\n", + "Number of divergences: 0\n" + ] + } + ], + "source": [ + "rng_key = jrandom.PRNGKey(0)\n", + "rng_key, rng_key_ = jrandom.split(rng_key)\n", + "\n", + "kernel = NUTS(model)\n", + "num_samples = 2000\n", + "mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)\n", + "mcmc.run(\n", + " rng_key_, samples=thresholded_grid\n", + ")\n", + "mcmc.print_summary()\n", + "samples_1 = mcmc.get_samples()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..230ecde --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +jax==0.4.16 +numpy==1.25.2 +numpy==1.21.5 +numpyro==0.11.0