jax
jaxlib>=0.3
equinox
tqdm
optax
numpy<=1.22.4

[dev]
pytest
