Skip to content

Miscellaneous

This subpackage contains a lot of uncategorized modules, notably gradient accumulation, switching, automatic resetting, random restarts.

Classes:

  • Alternate

    Alternates between stepping with :code:modules.

  • DivByLoss

    Divides update by loss times :code:alpha

  • Dropout

    Applies dropout to the update.

  • EscapeAnnealing

    If parameters stop changing, this runs a backward annealing random search

  • ExpHomotopy
  • FillLoss

    Outputs tensors filled with loss value times :code:alpha

  • GradSign

    Copies gradient sign to update.

  • GradientAccumulation

    Uses n steps to accumulate gradients, after n gradients have been accumulated, they are passed to :code:modules and parameters are updates.

  • GraftGradToUpdate

    Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update.

  • GraftToGrad

    Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient.

  • GraftToParams

    Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:eps.

  • HpuEstimate

    returns y/||s||, where y is difference between current and previous update (gradient), s is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update.

  • LambdaHomotopy
  • LastAbsoluteRatio

    Outputs ratio between absolute values of past two updates the numerator is determined by :code:numerator argument.

  • LastDifference

    Outputs difference between past two updates.

  • LastGradDifference

    Outputs difference between past two gradients.

  • LastProduct

    Outputs difference between past two updates.

  • LastRatio

    Outputs ratio between past two updates, the numerator is determined by :code:numerator argument.

  • LogHomotopy
  • MulByLoss

    Multiplies update by loss times :code:alpha

  • Multistep

    Performs :code:steps inner steps with :code:module per each step.

  • NegateOnLossIncrease

    Uses an extra forward pass to evaluate loss at :code:parameters+update,

  • NoiseSign

    Outputs random tensors with sign copied from the update.

  • Online

    Allows certain modules to be used for mini-batch optimization.

  • PerturbWeights

    Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.

  • Previous

    Maintains an update from n steps back, for example if n=1, returns previous update

  • PrintLoss

    Prints var.get_loss().

  • PrintParams

    Prints current update.

  • PrintShape

    Prints shapes of the update.

  • PrintUpdate

    Prints current update.

  • RandomHvp

    Returns a hessian-vector product with a random vector

  • Relative

    Multiplies update by absolute parameter values to make it relative to their magnitude, :code:min_value is minimum allowed value to avoid getting stuck at 0.

  • SaveBest

    Saves best parameters found so far, ones that have lowest loss. Put this as the last module.

  • Sequential

    On each step, this sequentially steps with :code:modules :code:steps times.

  • Split

    Apply true modules to all parameters filtered by filter, apply false modules to all other parameters.

  • SqrtHomotopy
  • SquareHomotopy
  • Switch

    After :code:steps steps switches to the next module.

  • UpdateSign

    Outputs gradient with sign copied from the update.

  • WeightDropout

    Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.

Alternate

Bases: torchzero.core.module.Module

Alternates between stepping with :code:modules.

That is, first step is performed with 1st module, second step with second module, etc.

Parameters:

  • steps (int | Iterable[int], default: 1 ) –

    number of steps to perform with each module. Defaults to 1.

Examples:

Alternate between Adam, SignSGD and RMSprop

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.Alternate(
        tz.m.Adam(),
        [tz.m.SignSGD(), tz.m.Mul(0.5)],
        tz.m.RMSprop(),
    ),
    tz.m.LR(1e-3),
)
Source code in torchzero/modules/misc/switch.py
class Alternate(Module):
    """Alternates between stepping with :code:`modules`.

    That is, first step is performed with 1st module, second step with second module, etc.

    Args:
        steps (int | Iterable[int], optional): number of steps to perform with each module. Defaults to 1.

    Examples:
        Alternate between Adam, SignSGD and RMSprop

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.Alternate(
                    tz.m.Adam(),
                    [tz.m.SignSGD(), tz.m.Mul(0.5)],
                    tz.m.RMSprop(),
                ),
                tz.m.LR(1e-3),
            )
    """
    LOOP = True
    def __init__(self, *modules: Chainable, steps: int | Iterable[int] = 1):
        if isinstance(steps, Iterable):
            steps = list(steps)
            if len(steps) != len(modules):
                raise ValueError(f"steps must be the same length as modules, got {len(modules) = }, {len(steps) = }")

        defaults = dict(steps=steps)
        super().__init__(defaults)

        self.set_children_sequence(modules)
        self.global_state['current_module_idx'] = 0
        self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps

    @torch.no_grad
    def step(self, var):
        # get current module
        current_module_idx = self.global_state.setdefault('current_module_idx', 0)
        module = self.children[f'module_{current_module_idx}']

        # step
        var = module.step(var.clone(clone_update=False))

        # number of steps until next module
        steps = self.defaults['steps']
        if isinstance(steps, int): steps = [steps]*len(self.children)

        if 'steps_to_next' not in self.global_state:
            self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps

        self.global_state['steps_to_next'] -= 1

        # switch to next module
        if self.global_state['steps_to_next'] == 0:
            self.global_state['current_module_idx'] += 1

            # loop to first module (or keep using last module on Switch)
            if self.global_state['current_module_idx'] > len(self.children) - 1:
                if self.LOOP: self.global_state['current_module_idx'] = 0
                else: self.global_state['current_module_idx'] = len(self.children) - 1

            self.global_state['steps_to_next'] = steps[self.global_state['current_module_idx']]

        return var

LOOP class-attribute

LOOP = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

DivByLoss

Bases: torchzero.core.module.Module

Divides update by loss times :code:alpha

Source code in torchzero/modules/misc/misc.py
class DivByLoss(Module):
    """Divides update by loss times :code:`alpha`"""
    def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
        defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
        loss = var.get_loss(backward=self.defaults['backward'])
        mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
        torch._foreach_div_(var.update, mul)
        return var

Dropout

Bases: torchzero.core.transform.Transform

Applies dropout to the update.

For each weight the update to that weight has :code:p probability to be set to 0. This can be used to implement gradient dropout or update dropout depending on placement.

Parameters:

  • p (float, default: 0.5 ) –

    probability that update for a weight is replaced with 0. Defaults to 0.5.

  • graft (bool, default: False ) –

    if True, update after dropout is rescaled to have the same norm as before dropout. Defaults to False.

  • target (Literal, default: 'update' ) –

    what to set on var, refer to documentation. Defaults to 'update'.

Examples:

Gradient dropout.

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.Dropout(0.5),
    tz.m.Adam(),
    tz.m.LR(1e-3)
)

Update dropout.

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.Adam(),
    tz.m.Dropout(0.5),
    tz.m.LR(1e-3)
)
Source code in torchzero/modules/misc/regularization.py
class Dropout(Transform):
    """Applies dropout to the update.

    For each weight the update to that weight has :code:`p` probability to be set to 0.
    This can be used to implement gradient dropout or update dropout depending on placement.

    Args:
        p (float, optional): probability that update for a weight is replaced with 0. Defaults to 0.5.
        graft (bool, optional):
            if True, update after dropout is rescaled to have the same norm as before dropout. Defaults to False.
        target (Target, optional): what to set on var, refer to documentation. Defaults to 'update'.


    Examples:
        Gradient dropout.

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.Dropout(0.5),
                tz.m.Adam(),
                tz.m.LR(1e-3)
            )

        Update dropout.

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.Adam(),
                tz.m.Dropout(0.5),
                tz.m.LR(1e-3)
            )

    """
    def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
        defaults = dict(p=p, graft=graft)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        p = NumberList(s['p'] for s in settings)
        graft = settings[0]['graft']

        if graft:
            target_norm = tensors.global_vector_norm()
            tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
            return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft

        return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))

EscapeAnnealing

Bases: torchzero.core.module.Module

If parameters stop changing, this runs a backward annealing random search

Source code in torchzero/modules/misc/escape.py
class EscapeAnnealing(Module):
    """If parameters stop changing, this runs a backward annealing random search"""
    def __init__(self, max_region:float = 1, max_iter:int = 1000, tol=1e-6, n_tol: int = 10):
        defaults = dict(max_region=max_region, max_iter=max_iter, tol=tol, n_tol=n_tol)
        super().__init__(defaults)


    @torch.no_grad
    def step(self, var):
        closure = var.closure
        if closure is None: raise RuntimeError("Escape requries closure")

        params = TensorList(var.params)
        settings = self.settings[params[0]]
        max_region = self.get_settings(params, 'max_region', cls=NumberList)
        max_iter = settings['max_iter']
        tol = settings['tol']
        n_tol = settings['n_tol']

        n_bad = self.global_state.get('n_bad', 0)

        prev_params = self.get_state(params, 'prev_params', cls=TensorList)
        diff = params-prev_params
        prev_params.copy_(params)

        if diff.abs().global_max() <= tol:
            n_bad += 1

        else:
            n_bad = 0

        self.global_state['n_bad'] = n_bad

        # no progress
        f_0 = var.get_loss(False)
        if n_bad >= n_tol:
            for i in range(1, max_iter+1):
                alpha = max_region * (i / max_iter)
                pert = params.sphere_like(radius=alpha)

                params.add_(pert)
                f_star = closure(False)

                if math.isfinite(f_star) and f_star < f_0-1e-12:
                    var.update = None
                    var.stop = True
                    var.skip_update = True
                    return var

                params.sub_(pert)

            self.global_state['n_bad'] = 0
        return var

ExpHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class ExpHomotopy(HomotopyBase):
    def __init__(self): super().__init__()
    def loss_transform(self, loss): return loss.exp()

FillLoss

Bases: torchzero.core.module.Module

Outputs tensors filled with loss value times :code:alpha

Source code in torchzero/modules/misc/misc.py
class FillLoss(Module):
    """Outputs tensors filled with loss value times :code:`alpha`"""
    def __init__(self, alpha: float = 1, backward: bool = True):
        defaults = dict(alpha=alpha, backward=backward)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        alpha = self.get_settings(var.params, 'alpha')
        loss = var.get_loss(backward=self.defaults['backward'])
        var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
        return var

GradSign

Bases: torchzero.core.transform.Transform

Copies gradient sign to update.

Source code in torchzero/modules/misc/misc.py
class GradSign(Transform):
    """Copies gradient sign to update."""
    def __init__(self, target: Target = 'update'):
        super().__init__({}, uses_grad=True, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        return [t.copysign_(g) for t,g in zip(tensors, grads)]

GradientAccumulation

Bases: torchzero.core.module.Module

Uses n steps to accumulate gradients, after n gradients have been accumulated, they are passed to :code:modules and parameters are updates.

Accumulating gradients for n steps is equivalent to increasing batch size by n. Increasing the batch size is more computationally efficient, but sometimes it is not feasible due to memory constraints.

Note

Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.

Parameters:

  • n (int) –

    number of gradients to accumulate.

  • mean (bool, default: True ) –

    if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.

  • stop (bool, default: True ) –

    this module prevents next modules from stepping unless n gradients have been accumulate. Setting this argument to False disables that. Defaults to True.

Examples:

Adam with gradients accumulated for 16 batches.

opt = tz.Modular(
    model.parameters(),
    tz.m.GradientAccumulation(),
    tz.m.Adam(),
    tz.m.LR(1e-2),
)
Source code in torchzero/modules/misc/gradient_accumulation.py
class GradientAccumulation(Module):
    """Uses ``n`` steps to accumulate gradients, after ``n`` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.

    Accumulating gradients for ``n`` steps is equivalent to increasing batch size by ``n``. Increasing the batch size
    is more computationally efficient, but sometimes it is not feasible due to memory constraints.

    Note:
        Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.

    Args:
        n (int): number of gradients to accumulate.
        mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
        stop (bool, optional):
            this module prevents next modules from stepping unless ``n`` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.

    ## Examples:

    Adam with gradients accumulated for 16 batches.

    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.GradientAccumulation(),
        tz.m.Adam(),
        tz.m.LR(1e-2),
    )
    ```
    """
    def __init__(self, n: int, mean=True, stop=True):
        defaults = dict(n=n, mean=mean, stop=stop)
        super().__init__(defaults)


    @torch.no_grad
    def step(self, var):
        accumulator = self.get_state(var.params, 'accumulator')
        settings = self.defaults
        n = settings['n']; mean = settings['mean']; stop = settings['stop']
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        # add update to accumulator
        torch._foreach_add_(accumulator, var.get_update())

        # step with accumulated updates
        if step % n == 0:
            if mean:
                torch._foreach_div_(accumulator, n)

            var.update = accumulator

            # zero accumulator
            self.clear_state_keys('accumulator')

        else:
            # prevent update
            if stop:
                var.update = None
                var.stop=True
                var.skip_update=True

        return var

GraftGradToUpdate

Bases: torchzero.core.transform.Transform

Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update.

Source code in torchzero/modules/misc/misc.py
class GraftGradToUpdate(Transform):
    """Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update."""
    def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6, target: Target = 'update'):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, uses_grad=True, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
        return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)

GraftToGrad

Bases: torchzero.core.transform.Transform

Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient.

Source code in torchzero/modules/misc/misc.py
class GraftToGrad(Transform):
    """Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient."""
    def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6, target: Target = 'update'):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, uses_grad=True, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
        return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)

GraftToParams

Bases: torchzero.core.transform.Transform

Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:eps.

Source code in torchzero/modules/misc/misc.py
class GraftToParams(Transform):
    """Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:`eps`."""
    def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-4, target: Target = 'update'):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
        return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)

HpuEstimate

Bases: torchzero.core.transform.Transform

returns y/||s||, where y is difference between current and previous update (gradient), s is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update.

Source code in torchzero/modules/misc/misc.py
class HpuEstimate(Transform):
    """returns ``y/||s||``, where ``y`` is difference between current and previous update (gradient), ``s`` is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update."""
    def __init__(self):
        defaults = dict()
        super().__init__(defaults, uses_grad=False)

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('prev_params', 'prev_update')

    @torch.no_grad
    def update_tensors(self, tensors, params, grads, loss, states, settings):
        prev_params, prev_update = self.get_state(params, 'prev_params', 'prev_update') # initialized to 0
        s = torch._foreach_sub(params, prev_params)
        y = torch._foreach_sub(tensors, prev_update)
        for p, c in zip(prev_params, params): p.copy_(c)
        for p, c in zip(prev_update, tensors): p.copy_(c)
        torch._foreach_div_(y, torch.linalg.norm(torch.cat([t.ravel() for t in s])).clip(min=1e-8)) # pylint:disable=not-callable
        self.store(params, 'y', y)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        return [self.state[p]['y'] for p in params]

LambdaHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class LambdaHomotopy(HomotopyBase):
    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        defaults = dict(fn=fn)
        super().__init__(defaults)

    def loss_transform(self, loss): return self.defaults['fn'](loss)

LastAbsoluteRatio

Bases: torchzero.core.transform.Transform

Outputs ratio between absolute values of past two updates the numerator is determined by :code:numerator argument.

Source code in torchzero/modules/misc/misc.py
class LastAbsoluteRatio(Transform):
    """Outputs ratio between absolute values of past two updates the numerator is determined by :code:`numerator` argument."""
    def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
        defaults = dict(numerator=numerator, eps=eps)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
        numerator = settings[0]['numerator']
        eps = NumberList(s['eps'] for s in settings)

        torch._foreach_abs_(tensors)
        torch._foreach_clamp_min_(prev, eps)

        if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
        else: ratio = torch._foreach_div(prev, tensors)
        for p, c in zip(prev, tensors): p.set_(c)
        return ratio

LastDifference

Bases: torchzero.core.transform.Transform

Outputs difference between past two updates.

Source code in torchzero/modules/misc/misc.py
class LastDifference(Transform):
    """Outputs difference between past two updates."""
    def __init__(self,target: Target = 'update'):
        super().__init__({}, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        prev_tensors = unpack_states(states, tensors, 'prev_tensors') # initialized to 0
        difference = torch._foreach_sub(tensors, prev_tensors)
        for p, c in zip(prev_tensors, tensors): p.set_(c)
        return difference

LastGradDifference

Bases: torchzero.core.module.Module

Outputs difference between past two gradients.

Source code in torchzero/modules/misc/misc.py
class LastGradDifference(Module):
    """Outputs difference between past two gradients."""
    def __init__(self):
        super().__init__({})

    @torch.no_grad
    def step(self, var):
        grad = var.get_grad()
        prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
        difference = torch._foreach_sub(grad, prev_grad)
        for p, c in zip(prev_grad, grad): p.copy_(c)
        var.update = list(difference)
        return var

LastProduct

Bases: torchzero.core.transform.Transform

Outputs difference between past two updates.

Source code in torchzero/modules/misc/misc.py
class LastProduct(Transform):
    """Outputs difference between past two updates."""
    def __init__(self,target: Target = 'update'):
        super().__init__({}, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
        prod = torch._foreach_mul(tensors, prev)
        for p, c in zip(prev, tensors): p.set_(c)
        return prod

LastRatio

Bases: torchzero.core.transform.Transform

Outputs ratio between past two updates, the numerator is determined by :code:numerator argument.

Source code in torchzero/modules/misc/misc.py
class LastRatio(Transform):
    """Outputs ratio between past two updates, the numerator is determined by :code:`numerator` argument."""
    def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
        defaults = dict(numerator=numerator)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
        numerator = settings[0]['numerator']
        if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
        else: ratio = torch._foreach_div(prev, tensors)
        for p, c in zip(prev, tensors): p.set_(c)
        return ratio

LogHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class LogHomotopy(HomotopyBase):
    def __init__(self): super().__init__()
    def loss_transform(self, loss): return (loss+1e-12).log()

MulByLoss

Bases: torchzero.core.module.Module

Multiplies update by loss times :code:alpha

Source code in torchzero/modules/misc/misc.py
class MulByLoss(Module):
    """Multiplies update by loss times :code:`alpha`"""
    def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
        defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
        loss = var.get_loss(backward=self.defaults['backward'])
        mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
        torch._foreach_mul_(var.update, mul)
        return var

Multistep

Bases: torchzero.core.module.Module

Performs :code:steps inner steps with :code:module per each step.

The update is taken to be the parameter difference between parameters before and after the inner loop.

Source code in torchzero/modules/misc/multistep.py
class Multistep(Module):
    """Performs :code:`steps` inner steps with :code:`module` per each step.

    The update is taken to be the parameter difference between parameters before and after the inner loop."""
    def __init__(self, module: Chainable, steps: int):
        defaults = dict(steps=steps)
        super().__init__(defaults)
        self.set_child('module', module)

    @torch.no_grad
    def step(self, var):
        return _sequential_step(self, var, sequential=False)

NegateOnLossIncrease

Bases: torchzero.core.module.Module

Uses an extra forward pass to evaluate loss at :code:parameters+update, if loss is larger than at :code:parameters, the update is set to 0 if :code:backtrack=False and to :code:-update otherwise

Source code in torchzero/modules/misc/multistep.py
class NegateOnLossIncrease(Module):
    """Uses an extra forward pass to evaluate loss at :code:`parameters+update`,
    if loss is larger than at :code:`parameters`,
    the update is set to 0 if :code:`backtrack=False` and to :code:`-update` otherwise"""
    def __init__(self, backtrack=False):
        defaults = dict(backtrack=backtrack)
        super().__init__(defaults=defaults)

    @torch.no_grad
    def step(self, var):
        closure = var.closure
        if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
        backtrack = self.defaults['backtrack']

        update = var.get_update()
        f_0 = var.get_loss(backward=False)

        torch._foreach_sub_(var.params, update)
        f_1 = closure(False)

        if f_1 <= f_0:
            if var.is_last and var.last_module_lrs is None:
                var.stop = True
                var.skip_update = True
                return var

            torch._foreach_add_(var.params, update)
            return var

        torch._foreach_add_(var.params, update)
        if backtrack:
            torch._foreach_neg_(var.update)
        else:
            torch._foreach_zero_(var.update)
        return var

NoiseSign

Bases: torchzero.core.transform.Transform

Outputs random tensors with sign copied from the update.

Source code in torchzero/modules/misc/misc.py
class NoiseSign(Transform):
    """Outputs random tensors with sign copied from the update."""
    def __init__(self, distribution:Distributions = 'normal', variance:float | None = None):
        defaults = dict(distribution=distribution, variance=variance)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        variance = unpack_dicts(settings, 'variance')
        return TensorList(tensors).sample_like(settings[0]['distribution'], variance=variance).copysign_(tensors)

Online

Bases: torchzero.core.module.Module

Allows certain modules to be used for mini-batch optimization.

Examples:

Online L-BFGS with Backtracking line search

opt = tz.Modular(
    model.parameters(),
    tz.m.Online(tz.m.LBFGS()),
    tz.m.Backtracking()
)

Online L-BFGS trust region

opt = tz.Modular(
    model.parameters(),
    tz.m.TrustCG(tz.m.Online(tz.m.LBFGS()))
)

Source code in torchzero/modules/misc/multistep.py
class Online(Module):
    """Allows certain modules to be used for mini-batch optimization.

    Examples:

    Online L-BFGS with Backtracking line search
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.Online(tz.m.LBFGS()),
        tz.m.Backtracking()
    )
    ```

    Online L-BFGS trust region
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.TrustCG(tz.m.Online(tz.m.LBFGS()))
    )
    ```

    """
    def __init__(self, *modules: Module,):
        super().__init__()

        self.set_child('module', modules)

    @torch.no_grad
    def update(self, var):
        closure = var.closure
        if closure is None: raise ValueError("Closure must be passed for Online")

        step = self.global_state.get('step', 0) + 1
        self.global_state['step'] = step

        params = TensorList(var.params)
        p_cur = params.clone()
        p_prev = self.get_state(params, 'p_prev', cls=TensorList)

        module = self.children['module']
        var_c = var.clone(clone_update=False)

        # on 1st step just step and store previous params
        if step == 1:
            p_prev.copy_(params)

            module.update(var_c)
            var.update_attrs_from_clone_(var_c)
            return

        # restore previous params and update
        var_prev = Var(params=params, closure=closure, model=var.model, current_step=var.current_step)
        params.set_(p_prev)
        module.reset_for_online()
        module.update(var_prev)

        # restore current params and update
        params.set_(p_cur)
        p_prev.copy_(params)
        module.update(var_c)
        var.update_attrs_from_clone_(var_c)

    @torch.no_grad
    def apply(self, var):
        module = self.children['module']
        return module.apply(var.clone(clone_update=False))

    def get_H(self, var):
        return self.children['module'].get_H(var)

PerturbWeights

Bases: torchzero.core.module.Module

Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.

Can be disabled for a parameter by setting :code:perturb=False in corresponding parameter group.

Parameters:

  • alpha (float, default: 0.1 ) –

    multiplier for perturbation magnitude. Defaults to 0.1.

  • relative (bool, default: True ) –

    whether to multiply perturbation by mean absolute value of the parameter. Defaults to True.

  • distribution (bool, default: 'normal' ) –

    distribution of the random perturbation. Defaults to False.

Source code in torchzero/modules/misc/regularization.py
class PerturbWeights(Module):
    """
    Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.

    Can be disabled for a parameter by setting :code:`perturb=False` in corresponding parameter group.

    Args:
        alpha (float, optional): multiplier for perturbation magnitude. Defaults to 0.1.
        relative (bool, optional): whether to multiply perturbation by mean absolute value of the parameter. Defaults to True.
        distribution (bool, optional):
            distribution of the random perturbation. Defaults to False.
    """
    def __init__(self, alpha: float = 0.1, relative:bool=True, distribution:Distributions = 'normal'):
        defaults = dict(alpha=alpha, relative=relative, distribution=distribution, perturb=True)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        closure = var.closure
        if closure is None: raise RuntimeError('WeightDropout requires closure')
        params = TensorList(var.params)

        # create perturbations
        perts = []
        for p in params:
            settings = self.settings[p]
            if not settings['perturb']:
                perts.append(torch.zeros_like(p))
                continue

            alpha = settings['alpha']
            if settings['relative']:
                alpha *= p.abs().mean()

            distribution = self.settings[p]['distribution'].lower()
            if distribution in ('normal', 'gaussian'):
                perts.append(torch.randn_like(p).mul_(alpha))
            elif distribution == 'uniform':
                perts.append(torch.empty_like(p).uniform_(-alpha,alpha))
            elif distribution == 'sphere':
                r = torch.randn_like(p)
                perts.append((r * alpha) / torch.linalg.vector_norm(r)) # pylint:disable=not-callable
            else:
                raise ValueError(distribution)

        @torch.no_grad
        def perturbed_closure(backward=True):
            params.add_(perts)
            if backward:
                with torch.enable_grad(): loss = closure()
            else:
                loss = closure(False)
            params.sub_(perts)
            return loss

        var.closure = perturbed_closure
        return var

Previous

Bases: torchzero.core.transform.TensorwiseTransform

Maintains an update from n steps back, for example if n=1, returns previous update

Source code in torchzero/modules/misc/misc.py
class Previous(TensorwiseTransform):
    """Maintains an update from n steps back, for example if n=1, returns previous update"""
    def __init__(self, n=1, target: Target = 'update'):
        defaults = dict(n=n)
        super().__init__(uses_grad=False, defaults=defaults, target=target)


    @torch.no_grad
    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        n = setting['n']

        if 'history' not in state:
            state['history'] = deque(maxlen=n+1)

        state['history'].append(tensor)

        return state['history'][0]

PrintLoss

Bases: torchzero.core.module.Module

Prints var.get_loss().

Source code in torchzero/modules/misc/debug.py
class PrintLoss(Module):
    """Prints var.get_loss()."""
    def __init__(self, text = 'loss = ', print_fn = print):
        defaults = dict(text=text, print_fn=print_fn)
        super().__init__(defaults)

    def step(self, var):
        self.defaults["print_fn"](f'{self.defaults["text"]}{var.get_loss(False)}')
        return var

PrintParams

Bases: torchzero.core.module.Module

Prints current update.

Source code in torchzero/modules/misc/debug.py
class PrintParams(Module):
    """Prints current update."""
    def __init__(self, text = 'params = ', print_fn = print):
        defaults = dict(text=text, print_fn=print_fn)
        super().__init__(defaults)

    def step(self, var):
        self.defaults["print_fn"](f'{self.defaults["text"]}{var.params}')
        return var

PrintShape

Bases: torchzero.core.module.Module

Prints shapes of the update.

Source code in torchzero/modules/misc/debug.py
class PrintShape(Module):
    """Prints shapes of the update."""
    def __init__(self, text = 'shapes = ', print_fn = print):
        defaults = dict(text=text, print_fn=print_fn)
        super().__init__(defaults)

    def step(self, var):
        shapes = [u.shape for u in var.update] if var.update is not None else None
        self.defaults["print_fn"](f'{self.defaults["text"]}{shapes}')
        return var

PrintUpdate

Bases: torchzero.core.module.Module

Prints current update.

Source code in torchzero/modules/misc/debug.py
class PrintUpdate(Module):
    """Prints current update."""
    def __init__(self, text = 'update = ', print_fn = print):
        defaults = dict(text=text, print_fn=print_fn)
        super().__init__(defaults)

    def step(self, var):
        self.defaults["print_fn"](f'{self.defaults["text"]}{var.update}')
        return var

RandomHvp

Bases: torchzero.core.module.Module

Returns a hessian-vector product with a random vector

Source code in torchzero/modules/misc/misc.py
class RandomHvp(Module):
    """Returns a hessian-vector product with a random vector"""

    def __init__(
        self,
        n_samples: int = 1,
        distribution: Distributions = "normal",
        update_freq: int = 1,
        hvp_method: Literal["autograd", "forward", "central"] = "autograd",
        h=1e-3,
    ):
        defaults = dict(n_samples=n_samples, distribution=distribution, hvp_method=hvp_method, h=h, update_freq=update_freq)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        params = TensorList(var.params)
        settings = self.settings[params[0]]
        n_samples = settings['n_samples']
        distribution = settings['distribution']
        hvp_method = settings['hvp_method']
        h = settings['h']
        update_freq = settings['update_freq']

        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        D = None
        if step % update_freq == 0:

            rgrad = None
            for i in range(n_samples):
                u = params.sample_like(distribution=distribution, variance=1)

                Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
                                    h=h, normalize=True, retain_grad=i < n_samples-1)

                if D is None: D = Hvp
                else: torch._foreach_add_(D, Hvp)

            if n_samples > 1: torch._foreach_div_(D, n_samples)
            if update_freq != 1:
                assert D is not None
                D_buf = self.get_state(params, "D", cls=TensorList)
                D_buf.set_(D)

        if D is None:
            D = self.get_state(params, "D", cls=TensorList)

        var.update = list(D)
        return var

Relative

Bases: torchzero.core.transform.Transform

Multiplies update by absolute parameter values to make it relative to their magnitude, :code:min_value is minimum allowed value to avoid getting stuck at 0.

Source code in torchzero/modules/misc/misc.py
class Relative(Transform):
    """Multiplies update by absolute parameter values to make it relative to their magnitude, :code:`min_value` is minimum allowed value to avoid getting stuck at 0."""
    def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
        defaults = dict(min_value=min_value)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
        torch._foreach_mul_(tensors, mul)
        return tensors

SaveBest

Bases: torchzero.core.module.Module

Saves best parameters found so far, ones that have lowest loss. Put this as the last module.

Adds the following attrs:

  • best_params - a list of tensors with best parameters.
  • best_loss - loss value with best_params.
  • load_best_parameters - a function that sets parameters to the best parameters./

Examples

```python def rosenbrock(x, y): return (1 - x)2 + (100 * (y - x2))**2

xy = torch.tensor((-1.1, 2.5), requires_grad=True) opt = tz.Modular( [xy], tz.m.NAG(0.999), tz.m.LR(1e-6), tz.m.SaveBest() )

optimize for 1000 steps

for i in range(1000): loss = rosenbrock(*xy) opt.zero_grad() loss.backward() opt.step(loss=loss) # SaveBest needs closure or loss

NAG overshot, but we saved the best params

print(f'{rosenbrock(*xy) = }') # >> 3.6583 print(f"{opt.attrs['best_loss'] = }") # >> 0.000627

load best parameters

opt.attrs'load_best_params' print(f'{rosenbrock(*xy) = }') # >> 0.000627

Source code in torchzero/modules/misc/misc.py
class SaveBest(Module):
    """Saves best parameters found so far, ones that have lowest loss. Put this as the last module.

    Adds the following attrs:

    - ``best_params`` - a list of tensors with best parameters.
    - ``best_loss`` - loss value with ``best_params``.
    - ``load_best_parameters`` - a function that sets parameters to the best parameters./

    ## Examples
    ```python
    def rosenbrock(x, y):
        return (1 - x)**2 + (100 * (y - x**2))**2

    xy = torch.tensor((-1.1, 2.5), requires_grad=True)
    opt = tz.Modular(
        [xy],
        tz.m.NAG(0.999),
        tz.m.LR(1e-6),
        tz.m.SaveBest()
    )

    # optimize for 1000 steps
    for i in range(1000):
        loss = rosenbrock(*xy)
        opt.zero_grad()
        loss.backward()
        opt.step(loss=loss) # SaveBest needs closure or loss

    # NAG overshot, but we saved the best params
    print(f'{rosenbrock(*xy) = }') # >> 3.6583
    print(f"{opt.attrs['best_loss'] = }") # >> 0.000627

    # load best parameters
    opt.attrs['load_best_params']()
    print(f'{rosenbrock(*xy) = }') # >> 0.000627
    """
    def __init__(self):
        super().__init__()

    @torch.no_grad
    def step(self, var):
        loss = tofloat(var.get_loss(False))
        lowest_loss = self.global_state.get('lowest_loss', float("inf"))

        if loss < lowest_loss:
            self.global_state['lowest_loss'] = loss
            best_params = var.attrs['best_params'] = [p.clone() for p in var.params]
            var.attrs['best_loss'] = loss
            var.attrs['load_best_params'] = partial(_load_best_parameters, params=var.params, best_params=best_params)

        return var

Sequential

Bases: torchzero.core.module.Module

On each step, this sequentially steps with :code:modules :code:steps times.

The update is taken to be the parameter difference between parameters before and after the inner loop.

Source code in torchzero/modules/misc/multistep.py
class Sequential(Module):
    """On each step, this sequentially steps with :code:`modules` :code:`steps` times.

    The update is taken to be the parameter difference between parameters before and after the inner loop."""
    def __init__(self, modules: Iterable[Chainable], steps: int=1):
        defaults = dict(steps=steps)
        super().__init__(defaults)
        self.set_children_sequence(modules)

    @torch.no_grad
    def step(self, var):
        return _sequential_step(self, var, sequential=True)

Split

Bases: torchzero.core.module.Module

Apply true modules to all parameters filtered by filter, apply false modules to all other parameters.

Parameters:

  • filter (Filter, bool]) –

    a filter that selects tensors to be optimized by true. - tensor or iterable of tensors (e.g. encoder.parameters()). - function that takes in tensor and outputs a bool (e.g. lambda x: x.ndim >= 2). - a sequence of above (acts as "or", so returns true if any of them is true).

  • true (Chainable | None) –

    modules that are applied to tensors where filter is True.

  • false (Chainable | None) –

    modules that are applied to tensors where filter is False.

Examples:

Muon with Adam fallback using same hyperparams as https://github.com/KellerJordan/Muon

opt = tz.Modular(
    model.parameters(),
    tz.m.NAG(0.95),
    tz.m.Split(
        lambda p: p.ndim >= 2,
        true = tz.m.Orthogonalize(),
        false = [tz.m.Adam(0.9, 0.95), tz.m.Mul(1/66)],
    ),
    tz.m.LR(1e-2),
)
Source code in torchzero/modules/misc/split.py
class Split(Module):
    """Apply ``true`` modules to all parameters filtered by ``filter``, apply ``false`` modules to all other parameters.

    Args:
        filter (Filter, bool]):
            a filter that selects tensors to be optimized by ``true``.
            - tensor or iterable of tensors (e.g. ``encoder.parameters()``).
            - function that takes in tensor and outputs a bool (e.g. ``lambda x: x.ndim >= 2``).
            - a sequence of above (acts as "or", so returns true if any of them is true).

        true (Chainable | None): modules that are applied to tensors where ``filter`` is ``True``.
        false (Chainable | None): modules that are applied to tensors where ``filter`` is ``False``.

    ### Examples:

    Muon with Adam fallback using same hyperparams as https://github.com/KellerJordan/Muon

    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.NAG(0.95),
        tz.m.Split(
            lambda p: p.ndim >= 2,
            true = tz.m.Orthogonalize(),
            false = [tz.m.Adam(0.9, 0.95), tz.m.Mul(1/66)],
        ),
        tz.m.LR(1e-2),
    )
    ```
    """
    def __init__(self, filter: Filter, true: Chainable | None, false: Chainable | None):
        defaults = dict(filter=filter)
        super().__init__(defaults)

        if true is not None: self.set_child('true', true)
        if false is not None: self.set_child('false', false)

    def step(self, var):

        params = var.params
        filter = _make_filter(self.settings[params[0]]['filter'])

        true_idxs = []
        false_idxs = []
        for i,p in enumerate(params):
            if filter(p): true_idxs.append(i)
            else: false_idxs.append(i)

        if 'true' in self.children and len(true_idxs) > 0:
            true = self.children['true']
            var = _split(true, idxs=true_idxs, params=params, var=var)

        if 'false' in self.children and len(false_idxs) > 0:
            false = self.children['false']
            var = _split(false, idxs=false_idxs, params=params, var=var)

        return var

SqrtHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class SqrtHomotopy(HomotopyBase):
    def __init__(self): super().__init__()
    def loss_transform(self, loss): return (loss+1e-12).sqrt()

SquareHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class SquareHomotopy(HomotopyBase):
    def __init__(self): super().__init__()
    def loss_transform(self, loss): return loss.square().copysign(loss)

Switch

Bases: torchzero.modules.misc.switch.Alternate

After :code:steps steps switches to the next module.

Parameters:

  • steps (int | Iterable[int]) –

    Number of steps to perform with each module.

Examples:

Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.Switch(
        [tz.m.Adam(), tz.m.LR(1e-3)],
        [tz.m.LBFGS(), tz.m.Backtracking()],
        [tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
        steps = (1000, 2000)
    )
)
Source code in torchzero/modules/misc/switch.py
class Switch(Alternate):
    """After :code:`steps` steps switches to the next module.

    Args:
        steps (int | Iterable[int]): Number of steps to perform with each module.

    Examples:
        Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.Switch(
                    [tz.m.Adam(), tz.m.LR(1e-3)],
                    [tz.m.LBFGS(), tz.m.Backtracking()],
                    [tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
                    steps = (1000, 2000)
                )
            )
    """

    LOOP = False
    def __init__(self, *modules: Chainable, steps: int | Iterable[int]):

        if isinstance(steps, Iterable):
            steps = list(steps)
            if len(steps) != len(modules) - 1:
                raise ValueError(f"steps must be the same length as modules minus 1, got {len(modules) = }, {len(steps) = }")

            steps.append(1)

        super().__init__(*modules, steps=steps)

LOOP class-attribute

LOOP = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

UpdateSign

Bases: torchzero.core.transform.Transform

Outputs gradient with sign copied from the update.

Source code in torchzero/modules/misc/misc.py
class UpdateSign(Transform):
    """Outputs gradient with sign copied from the update."""
    def __init__(self, target: Target = 'update'):
        super().__init__({}, uses_grad=True, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place

WeightDropout

Bases: torchzero.core.module.Module

Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.

Dropout can be disabled for a parameter by setting :code:use_dropout=False in corresponding parameter group.

Parameters:

  • p (float, default: 0.5 ) –

    probability that any weight is replaced with 0. Defaults to 0.5.

  • graft (bool, default: True ) –

    if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.

Source code in torchzero/modules/misc/regularization.py
class WeightDropout(Module):
    """
    Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.

    Dropout can be disabled for a parameter by setting :code:`use_dropout=False` in corresponding parameter group.

    Args:
        p (float, optional): probability that any weight is replaced with 0. Defaults to 0.5.
        graft (bool, optional):
            if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
    """
    def __init__(self, p: float = 0.5, graft: bool = True):
        defaults = dict(p=p, graft=graft, use_dropout=True)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        closure = var.closure
        if closure is None: raise RuntimeError('WeightDropout requires closure')
        params = TensorList(var.params)
        p = NumberList(self.settings[p]['p'] for p in params)

        # create masks
        mask = []
        for p, m in zip(params, mask):
            prob = self.settings[p]['p']
            use_dropout = self.settings[p]['use_dropout']
            if use_dropout: mask.append(_bernoulli_like(p, prob))
            else: mask.append(torch.ones_like(p))

        @torch.no_grad
        def dropout_closure(backward=True):
            orig_params = params.clone()
            params.mul_(mask)
            if backward:
                with torch.enable_grad(): loss = closure()
            else:
                loss = closure(False)
            params.copy_(orig_params)
            return loss

        var.closure = dropout_closure
        return var