memsave_torch¶
Lowering PyTorch’s Memory Consumption for Selective Differentiation¶
This package offers drop-in implementations of PyTorch torch.nn.Module s.
They are as fast as their built-in equivalents, but more memory-efficient whenever you want to compute gradients for a sub-set of parameters (i.e. some have requires_grad=False).
Currently it supports the following layers:
Also, each layer has a .from_nn_<layername>(layer) function which allows to convert a single layer into its memory-saving equivalent. (e.g. MemSaveConv2d.from_nn_Conv2d)
Installation¶
Normal installation:
pip install git+https://github.com/plutonium-239/memsave_torch
Install (editable):
pip install -e git+https://github.com/plutonium-239/memsave_torch
Usage¶
Please refer to Installation / Quickstart.
Further reading¶
-
This explains the basic ideas around MemSave without diving into too many details.
-
It is also available on arXiv.
How to cite¶
If this package has benefited you at some point, consider citing
@inproceedings{
bhatia2024lowering,
title={Lowering PyTorch's Memory Consumption for Selective Differentiation},
author={Samarth Bhatia and Felix Dangel},
booktitle={2nd Workshop on Advancing Neural Network Training: Computational Efficiency, Scalability, and Resource Optimization (WANT@ICML 2024)},
year={2024},
url={https://openreview.net/forum?id=KsUUzxUK7N}
}