Skip to content

Weight decay

This subpackage contains weight decay modules.

Classes:

  • DirectWeightDecay

    Directly applies weight decay to parameters.

  • RelativeWeightDecay

    Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of norm_input argument.

  • WeightDecay

    Weight decay.

Functions:

DirectWeightDecay

Bases: torchzero.core.module.Module

Directly applies weight decay to parameters.

Parameters:

  • weight_decay (float) –

    weight decay scale.

  • ord (int, default: 2 ) –

    order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.

Source code in torchzero/modules/weight_decay/weight_decay.py
class DirectWeightDecay(Module):
    """Directly applies weight decay to parameters.

    Args:
        weight_decay (float): weight decay scale.
        ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
    """
    def __init__(self, weight_decay: float, ord: int = 2,):
        defaults = dict(weight_decay=weight_decay, ord=ord)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        weight_decay = self.get_settings(var.params, 'weight_decay', cls=NumberList)
        ord = self.defaults['ord']

        decay_weights_(var.params, weight_decay, ord)
        return var

RelativeWeightDecay

Bases: torchzero.core.transform.Transform

Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of norm_input argument.

Parameters:

  • weight_decay (float, default: 0.1 ) –

    relative weight decay scale.

  • ord (int, default: 2 ) –

    order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.

  • norm_input (str, default: 'update' ) –

    determines what should weight decay be relative to. "update", "grad" or "params". Defaults to "update".

  • metric (Ords, default: 'mad' ) –

    metric (norm, etc) that weight decay should be relative to. defaults to 'mad' (mean absolute deviation).

  • target (Literal, default: 'update' ) –

    what to set on var. Defaults to 'update'.

Examples:

Adam with non-decoupled relative weight decay

opt = tz.Modular(
    model.parameters(),
    tz.m.RelativeWeightDecay(1e-1),
    tz.m.Adam(),
    tz.m.LR(1e-3)
)

Adam with decoupled relative weight decay

opt = tz.Modular(
    model.parameters(),
    tz.m.Adam(),
    tz.m.RelativeWeightDecay(1e-1),
    tz.m.LR(1e-3)
)

Source code in torchzero/modules/weight_decay/weight_decay.py
class RelativeWeightDecay(Transform):
    """Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of ``norm_input`` argument.

    Args:
        weight_decay (float): relative weight decay scale.
        ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
        norm_input (str, optional):
            determines what should weight decay be relative to. "update", "grad" or "params".
            Defaults to "update".
        metric (Ords, optional):
            metric (norm, etc) that weight decay should be relative to.
            defaults to 'mad' (mean absolute deviation).
        target (Target, optional): what to set on var. Defaults to 'update'.

    ### Examples:

    Adam with non-decoupled relative weight decay
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.RelativeWeightDecay(1e-1),
        tz.m.Adam(),
        tz.m.LR(1e-3)
    )
    ```

    Adam with decoupled relative weight decay
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.Adam(),
        tz.m.RelativeWeightDecay(1e-1),
        tz.m.LR(1e-3)
    )
    ```
    """
    def __init__(
        self,
        weight_decay: float = 0.1,
        ord: int  = 2,
        norm_input: Literal["update", "grad", "params"] = "update",
        metric: Metrics = 'mad',
        target: Target = "update",
    ):
        defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input, metric=metric)
        super().__init__(defaults, uses_grad=norm_input == 'grad', target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        weight_decay = NumberList(s['weight_decay'] for s in settings)

        ord = settings[0]['ord']
        norm_input = settings[0]['norm_input']
        metric = settings[0]['metric']

        if norm_input == 'update': src = TensorList(tensors)
        elif norm_input == 'grad':
            assert grads is not None
            src = TensorList(grads)
        elif norm_input == 'params':
            src = TensorList(params)
        else:
            raise ValueError(norm_input)

        norm = src.global_metric(metric)
        return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * norm, ord)

WeightDecay

Bases: torchzero.core.transform.Transform

Weight decay.

Parameters:

  • weight_decay (float) –

    weight decay scale.

  • ord (int, default: 2 ) –

    order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.

  • target (Literal, default: 'update' ) –

    what to set on var. Defaults to 'update'.

Examples:

Adam with non-decoupled weight decay

opt = tz.Modular(
    model.parameters(),
    tz.m.WeightDecay(1e-3),
    tz.m.Adam(),
    tz.m.LR(1e-3)
)

Adam with decoupled weight decay that still scales with learning rate

opt = tz.Modular(
    model.parameters(),
    tz.m.Adam(),
    tz.m.WeightDecay(1e-3),
    tz.m.LR(1e-3)
)

Adam with fully decoupled weight decay that doesn't scale with learning rate

opt = tz.Modular(
    model.parameters(),
    tz.m.Adam(),
    tz.m.LR(1e-3),
    tz.m.WeightDecay(1e-6)
)

Source code in torchzero/modules/weight_decay/weight_decay.py
class WeightDecay(Transform):
    """Weight decay.

    Args:
        weight_decay (float): weight decay scale.
        ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
        target (Target, optional): what to set on var. Defaults to 'update'.

    ### Examples:

    Adam with non-decoupled weight decay
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.WeightDecay(1e-3),
        tz.m.Adam(),
        tz.m.LR(1e-3)
    )
    ```

    Adam with decoupled weight decay that still scales with learning rate
    ```python

    opt = tz.Modular(
        model.parameters(),
        tz.m.Adam(),
        tz.m.WeightDecay(1e-3),
        tz.m.LR(1e-3)
    )
    ```

    Adam with fully decoupled weight decay that doesn't scale with learning rate
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.Adam(),
        tz.m.LR(1e-3),
        tz.m.WeightDecay(1e-6)
    )
    ```

    """
    def __init__(self, weight_decay: float, ord: int = 2, target: Target = 'update'):

        defaults = dict(weight_decay=weight_decay, ord=ord)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        weight_decay = NumberList(s['weight_decay'] for s in settings)
        ord = settings[0]['ord']

        return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)

decay_weights_

decay_weights_(params: Iterable[Tensor], weight_decay: float | NumberList, ord: int = 2)

directly decays weights in-place

Source code in torchzero/modules/weight_decay/weight_decay.py
@torch.no_grad
def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberList, ord:int=2):
    """directly decays weights in-place"""
    params = TensorList(params)
    weight_decay_(params, params, -weight_decay, ord)