functorch is JAX-like composable function transforms for PyTorch
conda install conda-forge::functorch