CMD + K

simple-pytree

Community

A dead simple Python package for creating custom JAX pytree objects

Installation

To install this package, run one of the following:

Conda
$conda install ryanvolz::simple-pytree

Usage Tracking

0.2.2
0.1.7
2 / 8 versions selected
Total downloads: 0

Description

A dead simple Python package for creating custom JAX pytree objects.

  • Strives to be minimal, the implementation is just ~100 lines of code
  • Has no dependencies other than JAX
  • Its compatible with both dataclasses and regular classes
  • It has no intention of supporting Neural Network use cases (e.g. partitioning)

About

Summary

A dead simple Python package for creating custom JAX pytree objects

Information Last Updated

Mar 25, 2025 at 16:25

License

MIT

Total Downloads

42

Platforms

noarch Version: 0.2.2