A Jax package for approximate curvature estimation and optimization using KFAC.
conda install conda-forge::kfac-jax