MPI support for JAX
conda install conda-forge::mpi4jax
Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python ⚡