Skip to content

Weight decay

This subpackage contains weight decay modules.

Classes:

  • DirectWeightDecay

    Directly applies weight decay to parameters.

  • RandomReinitialize

    On each step with probability p_reinit trigger reinitialization,

  • RelativeWeightDecay

    Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of norm_input argument.

  • WeightDecay

    Weight decay.

Functions:

DirectWeightDecay

Bases: torchzero.core.module.Module

Directly applies weight decay to parameters.

Parameters:

  • weight_decay (float) –

    weight decay scale.

  • ord (int, default: 2 ) –

    order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.

Source code in torchzero/modules/weight_decay/weight_decay.py
class DirectWeightDecay(Module):
    """Directly applies weight decay to parameters.

    Args:
        weight_decay (float): weight decay scale.
        ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
    """
    def __init__(self, weight_decay: float, ord: int = 2,):
        defaults = dict(weight_decay=weight_decay, ord=ord)
        super().__init__(defaults)

    @torch.no_grad
    def apply(self, objective):
        weight_decay = self.get_settings(objective.params, 'weight_decay', cls=NumberList)
        ord = self.defaults['ord']

        decay_weights_(objective.params, weight_decay, ord)
        return objective

RandomReinitialize

Bases: torchzero.core.module.Module

On each step with probability p_reinit trigger reinitialization, whereby p_weights weights are reset to their initial values.

This modifies the parameters directly. Place it as the first module.

Parameters:

  • p_reinit (float, default: 0.01 ) –

    probability to trigger reinitialization on each step. Defaults to 0.01.

  • p_weights (float, default: 0.1 ) –

    probability for each weight to be set to initial value when reinitialization is triggered. Defaults to 0.1.

  • store_every (int | None, default: None ) –

    if set, stores new initial values every this many steps. Defaults to None.

  • beta (float, default: 0 ) –

    whenever store_every is triggered, uses linear interpolation with this beta. If store_every=1, this can be set to some value close to 1 such as 0.999 to reinitialize to slow parameter EMA. Defaults to 0.

  • reset (bool, default: False ) –

    whether to reset states of other modules on reinitialization. Defaults to False.

  • seed (int | None, default: None ) –

    random seed.

Source code in torchzero/modules/weight_decay/reinit.py
class RandomReinitialize(Module):
    """On each step with probability ``p_reinit`` trigger reinitialization,
    whereby ``p_weights`` weights are reset to their initial values.

    This modifies the parameters directly. Place it as the first module.

    Args:
        p_reinit (float, optional): probability to trigger reinitialization on each step. Defaults to 0.01.
        p_weights (float, optional): probability for each weight to be set to initial value when reinitialization is triggered. Defaults to 0.1.
        store_every (int | None, optional): if set, stores new initial values every this many steps. Defaults to None.
        beta (float, optional):
            whenever ``store_every`` is triggered, uses linear interpolation with this beta.
            If ``store_every=1``, this can be set to some value close to 1 such as 0.999
            to reinitialize to slow parameter EMA. Defaults to 0.
        reset (bool, optional): whether to reset states of other modules on reinitialization. Defaults to False.
        seed (int | None, optional): random seed.
    """

    def __init__(
        self,
        p_reinit: float = 0.01,
        p_weights: float = 0.1,
        store_every: int | None = None,
        beta: float = 0,
        reset: bool = False,
        seed: int | None = None,
    ):
        defaults = dict(p_weights=p_weights, p_reinit=p_reinit, store_every=store_every, beta=beta, reset=reset, seed=seed)
        super().__init__(defaults)

    def update(self, objective):
        # this stores initial values to per-parameter states
        p_init = self.get_state(objective.params, "p_init", init="params", cls=TensorList)

        # store new params every store_every steps
        step = self.global_state.get("step", 0)
        self.global_state["step"] = step + 1

        store_every = self.defaults["store_every"]
        if (store_every is not None and step % store_every == 0):
            beta = self.get_settings(objective.params, "beta", cls=NumberList)
            p_init.lerp_(objective.params, weight=(1 - beta))

    @torch.no_grad
    def apply(self, objective):
        p_reinit = self.defaults["p_reinit"]
        device = objective.params[0].device
        generator = self.get_generator(device, self.defaults["seed"])

        # determine whether to trigger reinitialization
        reinitialize = torch.rand(1, generator=generator, device=device) < p_reinit

        # reinitialize
        if reinitialize:
            params = TensorList(objective.params)
            p_init = self.get_state(params, "p_init", init=params)


            # mask with p_weights entries being True
            p_weights = self.get_settings(params, "p_weights")
            mask = params.bernoulli_like(p_weights, generator=generator).as_bool()

            # set weights at mask to their initialization
            params.masked_set_(mask, p_init)

            # reset
            if self.defaults["reset"]:
                objective.post_step_hooks.append(partial(_reset_except_self, self=self))

        return objective

RelativeWeightDecay

Bases: torchzero.core.transform.TensorTransform

Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of norm_input argument.

Parameters:

  • weight_decay (float, default: 0.1 ) –

    relative weight decay scale.

  • ord (int, default: 2 ) –

    order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.

  • norm_input (str, default: 'update' ) –

    determines what should weight decay be relative to. "update", "grad" or "params". Defaults to "update".

  • metric (Ords, default: 'mad' ) –

    metric (norm, etc) that weight decay should be relative to. defaults to 'mad' (mean absolute deviation).

  • target (Target) –

    what to set on var. Defaults to 'update'.

Examples:

Adam with non-decoupled relative weight decay

opt = tz.Optimizer(
    model.parameters(),
    tz.m.RelativeWeightDecay(1e-1),
    tz.m.Adam(),
    tz.m.LR(1e-3)
)

Adam with decoupled relative weight decay

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adam(),
    tz.m.RelativeWeightDecay(1e-1),
    tz.m.LR(1e-3)
)

Source code in torchzero/modules/weight_decay/weight_decay.py
class RelativeWeightDecay(TensorTransform):
    """Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of ``norm_input`` argument.

    Args:
        weight_decay (float): relative weight decay scale.
        ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
        norm_input (str, optional):
            determines what should weight decay be relative to. "update", "grad" or "params".
            Defaults to "update".
        metric (Ords, optional):
            metric (norm, etc) that weight decay should be relative to.
            defaults to 'mad' (mean absolute deviation).
        target (Target, optional): what to set on var. Defaults to 'update'.

    ### Examples:

    Adam with non-decoupled relative weight decay
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.RelativeWeightDecay(1e-1),
        tz.m.Adam(),
        tz.m.LR(1e-3)
    )
    ```

    Adam with decoupled relative weight decay
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Adam(),
        tz.m.RelativeWeightDecay(1e-1),
        tz.m.LR(1e-3)
    )
    ```
    """
    def __init__(
        self,
        weight_decay: float = 0.1,
        ord: int  = 2,
        norm_input: Literal["update", "grad", "params"] = "update",
        metric: Metrics = 'mad',
    ):
        defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input, metric=metric)
        super().__init__(defaults, uses_grad=norm_input == 'grad')

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        weight_decay = NumberList(s['weight_decay'] for s in settings)

        ord = settings[0]['ord']
        norm_input = settings[0]['norm_input']
        metric = settings[0]['metric']

        if norm_input == 'update': src = TensorList(tensors)
        elif norm_input == 'grad':
            assert grads is not None
            src = TensorList(grads)
        elif norm_input == 'params':
            src = TensorList(params)
        else:
            raise ValueError(norm_input)

        norm = src.global_metric(metric)
        return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * norm, ord)

WeightDecay

Bases: torchzero.core.transform.TensorTransform

Weight decay.

Parameters:

  • weight_decay (float) –

    weight decay scale.

  • ord (int, default: 2 ) –

    order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.

  • target (Target) –

    what to set on var. Defaults to 'update'.

Examples:

Adam with non-decoupled weight decay

opt = tz.Optimizer(
    model.parameters(),
    tz.m.WeightDecay(1e-3),
    tz.m.Adam(),
    tz.m.LR(1e-3)
)

Adam with decoupled weight decay that still scales with learning rate

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adam(),
    tz.m.WeightDecay(1e-3),
    tz.m.LR(1e-3)
)

Adam with fully decoupled weight decay that doesn't scale with learning rate

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adam(),
    tz.m.LR(1e-3),
    tz.m.WeightDecay(1e-6)
)

Source code in torchzero/modules/weight_decay/weight_decay.py
class WeightDecay(TensorTransform):
    """Weight decay.

    Args:
        weight_decay (float): weight decay scale.
        ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
        target (Target, optional): what to set on var. Defaults to 'update'.

    ### Examples:

    Adam with non-decoupled weight decay
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.WeightDecay(1e-3),
        tz.m.Adam(),
        tz.m.LR(1e-3)
    )
    ```

    Adam with decoupled weight decay that still scales with learning rate
    ```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Adam(),
        tz.m.WeightDecay(1e-3),
        tz.m.LR(1e-3)
    )
    ```

    Adam with fully decoupled weight decay that doesn't scale with learning rate
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Adam(),
        tz.m.LR(1e-3),
        tz.m.WeightDecay(1e-6)
    )
    ```

    """
    def __init__(self, weight_decay: float, ord: int = 2):

        defaults = dict(weight_decay=weight_decay, ord=ord)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        weight_decay = NumberList(s['weight_decay'] for s in settings)
        ord = settings[0]['ord']

        return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)

decay_weights_

decay_weights_(params: Iterable[Tensor], weight_decay: float | NumberList, ord: int = 2)

directly decays weights in-place

Source code in torchzero/modules/weight_decay/weight_decay.py
@torch.no_grad
def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberList, ord:int=2):
    """directly decays weights in-place"""
    params = TensorList(params)
    weight_decay_(params, params, -weight_decay, ord)