jax==0.4.16 jaxlib==0.4.16 numpy==1.25.2 numpyro==0.11.0