TensorDict is a pytorch dedicated tensor container.
copied from cf-staging / tensordictTensorDict
is a dictionary-like class that inherits properties from tensors,
such as indexing, shape operations, casting to device or point-to-point communication
in distributed settings.
The main purpose of TensorDict is to make code-bases more readable and modular by abstracting away tailored operations:
for i, tensordict in enumerate(dataset):
tensordict = model(tensordict)
loss = loss_module(tensordict)
loss.backward()
optimizer.step()
optimizer.zero_grad()
With this level of abstraction, one can recycle a training loop for highly heterogeneous task. Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.) can be tailored to the use case at hand without impacting the others. For instance, the above example can be easily used across classification and segmentation tasks, among many others.
PyPI: https://pypi.org/project/tensordict/
:fire: The conda-forge recipe was generated with Conda-Forger App.