estimate.py

This script observes the peak memory after a forward pass/total time taken till backward pass for a model.

Estimate possible speed-up when randomizing the weight VJP of convolutions.

We take a CNN and answer the following questions:

Q1) What is the relative run time consumed by the weight VJP for convolutions?

Q2) Assume we achieve a speed-up x by randomizing the weight VJP, what would

be the speed-up for one optimization step (forward+backward)?

Q3) The same as Q1) and Q2) but in terms of memory consumption.

experiments.util.estimate.estimate_mem_savings(model_fn: Callable[[], Module], loss_fn: Callable[[], Module], x: Tensor, y: Tensor, targets: List[Dict[str, Tensor]] | None, architecture: str, dev: device, case: List[str], results_dir: str, return_val: bool = False)

Print an estimate of the memory savings caused by weight VJP memory savings.

Parameters:
  • model_fn – Function that sets up the neural network.

  • loss_fn – Function that sets up the loss function.

  • x – Input to the model.

  • y – Labels of the input.

  • targets – Targets in case of detection model

  • architecture – linear or conv

  • dev – Device to run the computation on.

  • case – str indicating which grads to take

  • results_dir – See args.results_dir

  • return_val – Whether to return the value or save it (Default: Save)

Returns:

The required estimate (only returned if return_val is True)

Return type:

result

experiments.util.estimate.estimate_speedup(model_fn: Callable[[], Module], loss_fn: Callable[[], Module], x: Tensor, y: Tensor, targets: List[Dict[str, Tensor]] | None, architecture: str, dev: device, case: List[str], results_dir: str, return_val: bool = False)

Save an estimate of total training speed-up caused by a weight VJP speed-up.

Parameters:
  • model_fn – Function that sets up the neural network.

  • loss_fn – Function that sets up the loss function.

  • x – Input to the model.

  • y – Labels of the input.

  • targets – Targets in case of detection model

  • architecture – linear or conv

  • dev – Device to run the computation on.

  • case – str indicating which grads to take

  • results_dir – See args.results_dir

  • return_val – Whether to return the value or save it (Default: Save)

Returns:

The required estimate (only returned if return_val is True)

Return type:

result

experiments.util.estimate.parse_case(case: List[str] | None) Dict[str, bool]

Small helper function to convert cases into kw-arguments for measurements

Parameters:

case (Optional[List[str]]) – List of all cases

Returns:

dictionary with keys as allowed_cases present in the input (which dont start with no_)

Return type:

Dict[str, bool]

experiments.util.estimate.skip_case_check(args: Namespace) bool

Decide whether to skip the case:

  1. when case has grad_norm_* but model does not have any normalization layers

  2. when case has no_grad_embed_weights but no grad_input: there is a backward error (no input requires_grad)

Parameters:

args (argparse.Namespace) – args

Returns:

Whether to skip or not

Return type:

bool