Installation / Quickstart¶
pip install git+https://github.com/plutonium-239/memsave_torch
Replace all (valid) layers with MemSave layers¶
The convert_to_memory_saving function from the memsave_torch.nn module is a handy tool to replace all layers of a model that is passed to it with their memory saving counterparts.
import torch
from torchvision.models import resnet18
from memsave_torch.nn import convert_to_memory_saving
x = torch.rand(2, 3, 224, 224)
rn18 = resnet18()
rn18 = convert_to_memory_saving(rn18)
# Set input to be differentiable and model weights to be non-differentiable
x.requires_grad = True
rn18.requires_grad_(False)
y = rn18(x)
loss = torch.nn.MSELoss()(y, torch.rand_like(y))
loss.backward()
Attention
You can’t use the old model in the same python run after calling convert_to_memory_saving on it, because by default weights are not copied to not cause extra memory consumption.
However, if you need to use both models together, pass the clone_params = True argument to convert_to_memory_saving, this will cause model weights to be copied and not just referenced.