Skip to content

Momentum

This subpackage contains momentums and exponential moving averages.

Classes:

  • Averaging

    Average of past history_size updates.

  • Cautious

    Negates update for parameters where update and gradient sign is inconsistent.

  • EMA

    Maintains an exponential moving average of update.

  • HeavyBall

    Polyak's momentum (heavy-ball method).

  • IntermoduleCautious

    Negaties update on :code:main module where it's sign doesn't match with output of :code:compare module.

  • MedianAveraging

    Median of past history_size updates.

  • NAG

    Nesterov accelerated gradient method (nesterov momentum).

  • ScaleByGradCosineSimilarity

    Multiplies the update by cosine similarity with gradient.

  • ScaleModulesByCosineSimilarity

    Scales the output of :code:main module by it's cosine similarity to the output

  • UpdateGradientSignConsistency

    Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.

  • WeightedAveraging

    Weighted average of past len(weights) updates.

Averaging

Bases: torchzero.core.transform.TensorwiseTransform

Average of past history_size updates.

Parameters:

  • history_size (int) –

    Number of past updates to average

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/momentum/averaging.py
class Averaging(TensorwiseTransform):
    """Average of past ``history_size`` updates.

    Args:
        history_size (int): Number of past updates to average
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, history_size: int, target: Target = 'update'):
        defaults = dict(history_size=history_size)
        super().__init__(uses_grad=False, defaults=defaults, target=target)

    @torch.no_grad
    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        history_size = setting['history_size']
        if 'history' not in state:
            state['history'] = deque(maxlen=history_size)
            state['average'] = torch.zeros_like(tensor)

        history = state['history']; average = state['average']
        if len(history) == history_size: average -= history[0]
        history.append(tensor)
        average += tensor

        return average / len(history)

Cautious

Bases: torchzero.core.transform.Transform

Negates update for parameters where update and gradient sign is inconsistent. Optionally normalizes the update by the number of parameters that are not masked. This is meant to be used after any momentum-based modules.

Parameters:

  • normalize (bool, default: False ) –

    renormalize update after masking. only has effect when mode is 'zero'. Defaults to False.

  • eps (float, default: 1e-06 ) –

    epsilon for normalization. Defaults to 1e-6.

  • mode (str, default: 'zero' ) –

    what to do with updates with inconsistent signs. - "zero" - set them to zero (as in paper) - "grad" - set them to the gradient (same as using update magnitude and gradient sign) - "backtrack" - negate them

Examples:

Cautious Adam

opt = tz.Modular(
    bench.parameters(),
    tz.m.Adam(),
    tz.m.Cautious(),
    tz.m.LR(1e-2)
)
References

Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu

Source code in torchzero/modules/momentum/cautious.py
class Cautious(Transform):
    """Negates update for parameters where update and gradient sign is inconsistent.
    Optionally normalizes the update by the number of parameters that are not masked.
    This is meant to be used after any momentum-based modules.

    Args:
        normalize (bool, optional):
            renormalize update after masking.
            only has effect when mode is 'zero'. Defaults to False.
        eps (float, optional): epsilon for normalization. Defaults to 1e-6.
        mode (str, optional):
            what to do with updates with inconsistent signs.
            - "zero" - set them to zero (as in paper)
            - "grad" - set them to the gradient (same as using update magnitude and gradient sign)
            - "backtrack" - negate them

    ## Examples:

    Cautious Adam

    ```python
    opt = tz.Modular(
        bench.parameters(),
        tz.m.Adam(),
        tz.m.Cautious(),
        tz.m.LR(1e-2)
    )
    ```

    References:
        Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu
    """

    def __init__(
        self,
        normalize=False,
        eps=1e-6,
        mode: Literal["zero", "grad", "backtrack"] = "zero",
    ):
        defaults = dict(normalize=normalize, eps=eps, mode=mode)
        super().__init__(defaults, uses_grad=True)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(settings[0])
        return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)

EMA

Bases: torchzero.core.transform.Transform

Maintains an exponential moving average of update.

Parameters:

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • dampening (float, default: 0 ) –

    momentum dampening. Defaults to 0.

  • debiased (bool, default: False ) –

    whether to debias the EMA like in Adam. Defaults to False.

  • lerp (bool, default: True ) –

    whether to use linear interpolation. Defaults to True.

  • ema_init (str, default: 'zeros' ) –

    initial values for the EMA, "zeros" or "update".

  • target (Literal, default: 'update' ) –

    target to apply EMA to. Defaults to 'update'.

Source code in torchzero/modules/momentum/momentum.py
class EMA(Transform):
    """Maintains an exponential moving average of update.

    Args:
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        dampening (float, optional): momentum dampening. Defaults to 0.
        debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
        lerp (bool, optional): whether to use linear interpolation. Defaults to True.
        ema_init (str, optional): initial values for the EMA, "zeros" or "update".
        target (Target, optional): target to apply EMA to. Defaults to 'update'.
    """
    def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
        defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])

        exp_avg = unpack_states(states, tensors, 'exp_avg',
                                init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
        momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)

        exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)

        if debiased: return debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
        else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned

HeavyBall

Bases: torchzero.modules.momentum.momentum.EMA

Polyak's momentum (heavy-ball method).

Parameters:

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • dampening (float, default: 0 ) –

    momentum dampening. Defaults to 0.

  • debiased (bool, default: False ) –

    whether to debias the EMA like in Adam. Defaults to False.

  • lerp (bool, default: False ) –

    whether to use linear interpolation, if True, this becomes exponential moving average. Defaults to False.

  • ema_init (str, default: 'update' ) –

    initial values for the EMA, "zeros" or "update".

  • target (Literal, default: 'update' ) –

    target to apply EMA to. Defaults to 'update'.

Source code in torchzero/modules/momentum/momentum.py
class HeavyBall(EMA):
    """Polyak's momentum (heavy-ball method).

    Args:
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        dampening (float, optional): momentum dampening. Defaults to 0.
        debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
        lerp (bool, optional):
            whether to use linear interpolation, if True, this becomes exponential moving average. Defaults to False.
        ema_init (str, optional): initial values for the EMA, "zeros" or "update".
        target (Target, optional): target to apply EMA to. Defaults to 'update'.
    """
    def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update', target: Target = 'update'):
        super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init, target=target)

IntermoduleCautious

Bases: torchzero.core.module.Module

Negaties update on :code:main module where it's sign doesn't match with output of :code:compare module.

Parameters:

  • main (Chainable) –

    main module or sequence of modules whose update will be cautioned.

  • compare (Chainable) –

    modules or sequence of modules to compare the sign to.

  • normalize (bool, default: False ) –

    renormalize update after masking. Defaults to False.

  • eps (float, default: 1e-06 ) –

    epsilon for normalization. Defaults to 1e-6.

  • mode (str, default: 'zero' ) –

    what to do with updates with inconsistent signs. - "zero" - set them to zero (as in paper) - "grad" - set them to the gradient (same as using update magnitude and gradient sign) - "backtrack" - negate them

Source code in torchzero/modules/momentum/cautious.py
class IntermoduleCautious(Module):
    """Negaties update on :code:`main` module where it's sign doesn't match with output of :code:`compare` module.

    Args:
        main (Chainable): main module or sequence of modules whose update will be cautioned.
        compare (Chainable): modules or sequence of modules to compare the sign to.
        normalize (bool, optional):
            renormalize update after masking. Defaults to False.
        eps (float, optional): epsilon for normalization. Defaults to 1e-6.
        mode (str, optional):
            what to do with updates with inconsistent signs.
            - "zero" - set them to zero (as in paper)
            - "grad" - set them to the gradient (same as using update magnitude and gradient sign)
            - "backtrack" - negate them
    """
    def __init__(
        self,
        main: Chainable,
        compare: Chainable,
        normalize=False,
        eps=1e-6,
        mode: Literal["zero", "grad", "backtrack"] = "zero",
    ):

        defaults = dict(normalize=normalize, eps=eps, mode=mode)
        super().__init__(defaults)

        self.set_child('main', main)
        self.set_child('compare', compare)

    @torch.no_grad
    def step(self, var):
        main = self.children['main']
        compare = self.children['compare']

        main_var = main.step(var.clone(clone_update=True))
        var.update_attrs_from_clone_(main_var)

        compare_var = compare.step(var.clone(clone_update=True))
        var.update_attrs_from_clone_(compare_var)

        mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.defaults)
        var.update = cautious_(
            TensorList(main_var.get_update()),
            TensorList(compare_var.get_update()),
            normalize=normalize,
            mode=mode,
            eps=eps,
        )

        return var

MedianAveraging

Bases: torchzero.core.transform.TensorwiseTransform

Median of past history_size updates.

Parameters:

  • history_size (int) –

    Number of past updates to average

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/momentum/averaging.py
class MedianAveraging(TensorwiseTransform):
    """Median of past ``history_size`` updates.

    Args:
        history_size (int): Number of past updates to average
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, history_size: int, target: Target = 'update'):
        defaults = dict(history_size = history_size)
        super().__init__(uses_grad=False, defaults=defaults, target=target)

    @torch.no_grad
    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        history_size = setting['history_size']

        if 'history' not in state:
            state['history'] = deque(maxlen=history_size)

        history = state['history']
        history.append(tensor)

        stacked = torch.stack(tuple(history), 0)
        return torch.quantile(stacked, 0.5, dim = 0)

NAG

Bases: torchzero.core.transform.Transform

Nesterov accelerated gradient method (nesterov momentum).

Parameters:

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • dampening (float, default: 0 ) –

    momentum dampening. Defaults to 0.

  • lerp (bool, default: False ) –

    whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.

  • target (Literal, default: 'update' ) –

    target to apply EMA to. Defaults to 'update'.

Source code in torchzero/modules/momentum/momentum.py
class NAG(Transform):
    """Nesterov accelerated gradient method (nesterov momentum).

    Args:
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        dampening (float, optional): momentum dampening. Defaults to 0.
        lerp (bool, optional):
            whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
        target (Target, optional): target to apply EMA to. Defaults to 'update'.
    """
    def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False, target: Target = 'update'):
        defaults = dict(momentum=momentum,dampening=dampening, lerp=lerp)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
        lerp = self.settings[params[0]]['lerp']

        momentum,dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
        return nag_(TensorList(tensors), velocity_=velocity,momentum=momentum,dampening=dampening,lerp=lerp)

ScaleByGradCosineSimilarity

Bases: torchzero.core.transform.Transform

Multiplies the update by cosine similarity with gradient. If cosine similarity is negative, naturally the update will be negated as well.

Parameters:

  • eps (float, default: 1e-06 ) –

    epsilon for division. Defaults to 1e-6.

Examples:

Scaled Adam

opt = tz.Modular(
    bench.parameters(),
    tz.m.Adam(),
    tz.m.ScaleByGradCosineSimilarity(),
    tz.m.LR(1e-2)
)

Source code in torchzero/modules/momentum/cautious.py
class ScaleByGradCosineSimilarity(Transform):
    """Multiplies the update by cosine similarity with gradient.
    If cosine similarity is negative, naturally the update will be negated as well.

    Args:
        eps (float, optional): epsilon for division. Defaults to 1e-6.

    ## Examples:

    Scaled Adam
    ```python
    opt = tz.Modular(
        bench.parameters(),
        tz.m.Adam(),
        tz.m.ScaleByGradCosineSimilarity(),
        tz.m.LR(1e-2)
    )
    ```
    """
    def __init__(
        self,
        eps: float = 1e-6,
    ):
        defaults = dict(eps=eps)
        super().__init__(defaults, uses_grad=True)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        eps = settings[0]['eps']
        tensors = TensorList(tensors)
        grads = TensorList(grads)
        cos_sim = tensors.dot(grads) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)

        return tensors.mul_(cos_sim)

ScaleModulesByCosineSimilarity

Bases: torchzero.core.module.Module

Scales the output of :code:main module by it's cosine similarity to the output of :code:compare module.

Parameters:

  • main (Chainable) –

    main module or sequence of modules whose update will be scaled.

  • compare (Chainable) –

    module or sequence of modules to compare to

  • eps (float, default: 1e-06 ) –

    epsilon for division. Defaults to 1e-6.

Examples:

Adam scaled by similarity to RMSprop

opt = tz.Modular(
    bench.parameters(),
    tz.m.ScaleModulesByCosineSimilarity(
        main = tz.m.Adam(),
        compare = tz.m.RMSprop(0.999, debiased=True),
    ),
    tz.m.LR(1e-2)
)

Source code in torchzero/modules/momentum/cautious.py
class ScaleModulesByCosineSimilarity(Module):
    """Scales the output of :code:`main` module by it's cosine similarity to the output
    of :code:`compare` module.

    Args:
        main (Chainable): main module or sequence of modules whose update will be scaled.
        compare (Chainable): module or sequence of modules to compare to
        eps (float, optional): epsilon for division. Defaults to 1e-6.

    ## Examples:

    Adam scaled by similarity to RMSprop
    ```python
    opt = tz.Modular(
        bench.parameters(),
        tz.m.ScaleModulesByCosineSimilarity(
            main = tz.m.Adam(),
            compare = tz.m.RMSprop(0.999, debiased=True),
        ),
        tz.m.LR(1e-2)
    )
    ```
    """
    def __init__(
        self,
        main: Chainable,
        compare: Chainable,
        eps=1e-6,
    ):
        defaults = dict(eps=eps)
        super().__init__(defaults)

        self.set_child('main', main)
        self.set_child('compare', compare)

    @torch.no_grad
    def step(self, var):
        main = self.children['main']
        compare = self.children['compare']

        main_var = main.step(var.clone(clone_update=True))
        var.update_attrs_from_clone_(main_var)

        compare_var = compare.step(var.clone(clone_update=True))
        var.update_attrs_from_clone_(compare_var)

        m = TensorList(main_var.get_update())
        c = TensorList(compare_var.get_update())
        eps = self.defaults['eps']

        cos_sim = m.dot(c) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)

        var.update = m.mul_(cos_sim)
        return var

UpdateGradientSignConsistency

Bases: torchzero.core.transform.Transform

Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.

Parameters:

  • normalize (bool, default: False ) –

    renormalize update after masking. Defaults to False.

  • eps (float, default: 1e-06 ) –

    epsilon for normalization. Defaults to 1e-6.

Source code in torchzero/modules/momentum/cautious.py
class UpdateGradientSignConsistency(Transform):
    """Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.

    Args:
        normalize (bool, optional):
            renormalize update after masking. Defaults to False.
        eps (float, optional): epsilon for normalization. Defaults to 1e-6.
    """
    def __init__(self, normalize = False, eps=1e-6):

        defaults = dict(normalize=normalize, eps=eps)
        super().__init__(defaults, uses_grad=True)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        normalize, eps = itemgetter('normalize', 'eps')(settings[0])

        mask = (TensorList(tensors).mul_(grads)).gt_(0)
        if normalize: mask = mask / mask.global_mean().clip(min = eps) # pyright: ignore[reportOperatorIssue]

        return mask

WeightedAveraging

Bases: torchzero.core.transform.TensorwiseTransform

Weighted average of past len(weights) updates.

Parameters:

  • weights (Sequence[float]) –

    a sequence of weights from oldest to newest.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/momentum/averaging.py
class WeightedAveraging(TensorwiseTransform):
    """Weighted average of past ``len(weights)`` updates.

    Args:
        weights (Sequence[float]): a sequence of weights from oldest to newest.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, weights: Sequence[float] | torch.Tensor | Any, target: Target = 'update'):
        defaults = dict(weights = tolist(weights))
        super().__init__(uses_grad=False, defaults=defaults, target=target)

    @torch.no_grad
    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        weights = setting['weights']

        if 'history' not in state:
            state['history'] = deque(maxlen=len(weights))

        history = state['history']
        history.append(tensor)
        if len(history) != len(weights):
            weights = weights[-len(history):]

        average = None
        for i, (h, w) in enumerate(zip(history, weights)):
            if average is None: average = h * (w / len(history))
            else:
                if w == 0: continue
                average += h * (w / len(history))

        assert average is not None
        return average