jax>=0.2.13
jaxlib>=0.1.62
tensorflow
tqdm>=4.48.2
tensorflow-probability[jax]
numpy
scipy
matplotlib
