JMP is a Mixed Precision library for JAX.
copied from cf-staging / jmpMixed precision training [0] is a technique that mixes the use of full and half precision floating point numbers during training to reduce the memory bandwidth requirements and improve the computational efficiency of a given model.