numpy>=1.15
jax>=0.3.0
tqdm