jax[cuda]<0.8.0,>=0.4.0
chex>=0.1.0
tabulate>=0.9.0
numpy>=2.2.0
rich

[dev]
pytest>=7.0.0
