Installation / Quickstart¶
pip install git+https://github.com/plutonium-239/memsave_torch
Replace all (valid) layers with MemSave layers¶
The convert_to_memory_saving() function from the memsave_torch.nn module is a handy tool to replace all layers of a model that is passed to it with their memory saving counterparts.
import torch
my_torch_model: torch.nn.Module
from memsave_torch.nn import convert_to_memory_saving
memsave_torch_model = convert_to_memory_saving(my_torch_model)