Neural Networks in JAX.
conda install conda-forge::equinox
Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.