A dead simple Python package for creating custom JAX pytree objects.
- Strives to be minimal, the implementation is just ~100 lines of code
- Has no dependencies other than JAX
- Its compatible with both dataclasses and regular classes
- It has no intention of supporting Neural Network use cases (e.g. partitioning)