MemSaveLayerNorm

class memsave_torch.nn.MemSaveLayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None)

MemSaveLayerNorm.

Inits a LayerNorm layer with the given params

Parameters:
  • normalized_shape – normalized_shape

  • eps – eps

  • elementwise_affine – elementwise_affine

  • bias – bias (introduced in torch v2.1)

  • device – device

  • dtype – dtype

forward(x)

Forward pass.

Parameters:

x – Input to the network [B, C, H, W]

Returns:

Output [B, C, H, W]

Return type:

torch.Tensor

classmethod from_nn_LayerNorm(ln: LayerNorm)

Converts a nn.LayerNorm layer to MemSaveLayerNorm.

Parameters:

ln – The nn.LayerNorm layer

Returns:

The MemSaveLayerNorm object

Return type:

obj

Hint

The usage is the same as torch.nn.LayerNorm

For usage examples, please refer to the linked torch documentation