memsave_torch

Lowering PyTorch’s Memory Consumption for Selective Differentiation

_images/memsave_torch_banner.svg

This package offers drop-in implementations of PyTorch torch.nn.Module s. They are as fast as their built-in equivalents, but more memory-efficient whenever you want to compute gradients for a sub-set of parameters (i.e. some have requires_grad=False).

Currently it supports the following layers:

Also, each layer has a .from_nn_<layername>(layer) function which allows to convert a single layer into its memory-saving equivalent. (e.g. MemSaveConv2d.from_nn_Conv2d)

Installation

Normal installation:

pip install git+https://github.com/plutonium-239/memsave_torch

Install (editable):

pip install -e git+https://github.com/plutonium-239/memsave_torch

Usage

Please refer to Installation / Quickstart.

Further reading

Notes on PyTorch integration

The ideal solution to this problem would be at the lower level (i.e. CPU C++ functions/GPU CUDA kernels etc.), involving a change in the function signature of torch.ops.aten.convolution_backward to handle not always having two tensors as input (i.e. the saved inputs and weights).

However, that would require a change in all the backends, which is not realistic for us to do and requires considerable design decisions from the pytorch team itself. So, we implement these layers at the higher python level, which makes it platform independent and easy(-ier) to maintain at the cost of a slight performance hit.

How to cite

If this package has benefited you at some point, consider citing

@inproceedings{
   bhatia2024lowering,
   title={Lowering PyTorch's Memory Consumption for Selective Differentiation},
   author={Samarth Bhatia and Felix Dangel},
   booktitle={2nd Workshop on Advancing Neural Network Training: Computational Efficiency, Scalability, and Resource Optimization (WANT@ICML 2024)},
   year={2024},
   url={https://openreview.net/forum?id=KsUUzxUK7N}
}

Contributors

Indices and tables