MemSaveLinear

class memsave_torch.nn.MemSaveLinear(in_features, out_features, bias=True, device=None, dtype=None)

MemSaveLinear.

Inits a MemSaveLinear layer with the given params.

Parameters:
  • in_features – in_features

  • out_features – out_features

  • bias – bias

  • device – device

  • dtype – dtype

forward(x)

Forward pass.

Parameters:

x – Input to the network [B, F_in]

Returns:

Output [B, F_out]

Return type:

torch.Tensor

classmethod from_nn_Linear(linear: Linear)

Converts a nn.Linear layer to MemSaveLinear.

Parameters:

linear – The nn.Linear layer

Returns:

The MemSaveLinear object

Return type:

obj

Hint

The usage is the same as torch.nn.Linear

For usage examples, please refer to the linked torch documentation