Skip to content

Adaptive

This subpackage contains adaptive methods e.g. Adam, RMSprop, SOAP, etc.

See also

Classes:

  • AEGD

    AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.

  • ASAM

    Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52

  • AdaHessian

    AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)

  • Adagrad

    Adagrad, divides by sum of past squares of gradients.

  • AdagradNorm

    Adagrad-Norm, divides by sum of past means of squares of gradients.

  • Adam

    Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.

  • Adan

    Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677

  • AdaptiveHeavyBall

    Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.

  • BacktrackOnSignChange

    Negates or undoes update for parameters where where gradient or update sign changes.

  • DualNormCorrection

    Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon).

  • ESGD

    Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)

  • FullMatrixAdagrad

    Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).

  • GGT

    GGT method from https://arxiv.org/pdf/1806.02958

  • Lion

    Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.

  • MARSCorrection

    MARS variance reduction correction.

  • MSAM

    Momentum-SAM from https://arxiv.org/pdf/2401.12033.

  • MSAMMomentum

    Momentum-SAM from https://arxiv.org/pdf/2401.12033.

  • MatrixMomentum

    Second order momentum method.

  • MuonAdjustLR

    LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master).

  • NaturalGradient

    Natural gradient approximated via empirical fisher information matrix.

  • OrthoGrad

    Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

  • Orthogonalize

    Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.

  • PSGDDenseNewton

    Dense hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

  • PSGDKronNewton

    Kron hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

  • PSGDKronWhiten

    Kron whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

  • PSGDLRANewton

    Low rank hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

  • PSGDLRAWhiten

    Low rank whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

  • RMSprop

    Divides graient by EMA of gradient squares.

  • Rprop

    Resilient propagation. The update magnitude gets multiplied by nplus if gradient didn't change the sign,

  • SAM

    Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412

  • SOAP

    SOAP (ShampoO with Adam in the Preconditioner's eigenbasis from https://arxiv.org/abs/2409.11321).

  • ScaleLRBySignChange

    learning rate gets multiplied by nplus if ascent/gradient didn't change the sign,

  • Shampoo

    Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).

  • SignConsistencyLRs

    Outputs per-weight learning rates based on consecutive sign consistency.

  • SignConsistencyMask

    Outputs a mask of sign consistency of current and previous inputs.

  • SophiaH

    SophiaH optimizer from https://arxiv.org/abs/2305.14342

Functions:

  • orthogonalize_grads_

    Computes the zeroth power / orthogonalization of gradients of an iterable of parameters.

  • orthograd_

    Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

AEGD

Bases: torchzero.core.transform.TensorTransform

AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.

Note

AEGD has a learning rate hyperparameter that can't really be removed from the update rule. To avoid compounding learning rate mofications, remove the tz.m.LR module if you had it.

Parameters:

  • lr (float, default: 0.1 ) –

    learning rate (default: 0.1)

  • c (float, default: 1 ) –

    term added to the original objective function (default: 1)

Reference

Liu, Hailiang, and Xuping Tian. "AEGD: Adaptive gradient descent with energy." arXiv preprint arXiv:2010.05109 (2020).

Source code in torchzero/modules/adaptive/aegd.py
class AEGD(TensorTransform):
    """AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.

    Note:
        AEGD has a learning rate hyperparameter that can't really be removed from the update rule.
        To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.

    Args:
        lr (float, optional): learning rate (default: 0.1)
        c (float, optional): term added to the original objective function (default: 1)

    Reference:
        [Liu, Hailiang, and Xuping Tian. "AEGD: Adaptive gradient descent with energy." arXiv preprint arXiv:2010.05109 (2020).](https://arxiv.org/pdf/2010.05109)
    """
    def __init__(
        self,
        lr: float = 0.1,
        c: float = 1,
    ):
        defaults = dict(c=c, lr=lr)
        super().__init__(defaults, uses_loss=True)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        assert loss is not None
        tensors = TensorList(tensors)

        c, lr = unpack_dicts(settings, 'c', 'lr', cls=NumberList)
        r = unpack_states(states, tensors, 'r', init=lambda t: torch.full_like(t, float(loss + c[0])**0.5), cls=TensorList)

        update = aegd_(
            f=loss,
            g=tensors,
            r_=r,
            c=c,
            eta=lr,
        )

        return update

ASAM

Bases: torchzero.modules.adaptive.sam.SAM

Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52

SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value. It performs two forward and backward passes per step.

This implementation modifies the closure to return loss and calculate gradients of the SAM objective. All modules after this will use the modified objective.

Note

This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients at two points on each step.

Parameters:

  • rho (float, default: 0.5 ) –

    Neighborhood size. Defaults to 0.05.

  • p (float, default: 2 ) –

    norm of the SAM objective. Defaults to 2.

Examples:

ASAM-SGD:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.ASAM(),
    tz.m.LR(1e-2)
)

ASAM-Adam:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.ASAM(),
    tz.m.Adam(),
    tz.m.LR(1e-2)
)
References: Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). ASAM: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR.

Source code in torchzero/modules/adaptive/sam.py
class ASAM(SAM):
    """Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52

    SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
    It performs two forward and backward passes per step.

    This implementation modifies the closure to return loss and calculate gradients
    of the SAM objective. All modules after this will use the modified objective.

    Note:
        This module requires a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients at two points on each step.

    Args:
        rho (float, optional): Neighborhood size. Defaults to 0.05.
        p (float, optional): norm of the SAM objective. Defaults to 2.

    ### Examples:

    ASAM-SGD:

    ```py
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.ASAM(),
        tz.m.LR(1e-2)
    )
    ```

    ASAM-Adam:

    ```
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.ASAM(),
        tz.m.Adam(),
        tz.m.LR(1e-2)
    )
    ```
    References:
        [Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). ASAM: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR.](https://arxiv.org/abs/2102.11600)
    """
    def __init__(self, rho: float = 0.5, p: float = 2, eps=1e-10):
        super().__init__(rho=rho, p=p, eps=eps, asam=True)

AdaHessian

Bases: torchzero.core.transform.Transform

AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)

This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.

Notes
  • In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply AdaHessian preconditioning to another module's output.

  • This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

Parameters:

  • beta1 (float, default: 0.9 ) –

    first momentum. Defaults to 0.9.

  • beta2 (float, default: 0.999 ) –

    second momentum for squared hessian diagonal estimates. Defaults to 0.999.

  • averaging (bool, default: True ) –

    whether to enable block diagonal averaging over 1st dimension on parameters that have 2+ dimensions. This can be set per-parameter in param groups.

  • block_size (int, default: None ) –

    size of block in the block-diagonal averaging.

  • update_freq (int, default: 1 ) –

    frequency of updating hessian diagonal estimate via a hessian-vector product. This value can be increased to reduce computational cost. Defaults to 1.

  • eps (float, default: 1e-08 ) –

    division stability epsilon. Defaults to 1e-8.

  • hvp_method (str, default: 'autograd' ) –

    Determines how hessian-vector products are computed.

    • "batched_autograd" - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to "autograd". Faster than "autograd" but uses more memory.
    • "autograd" - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than "batched_autograd" but uses less memory.
    • "fd_forward" - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
    • "fd_central" - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

    Defaults to "autograd".

  • h (float, default: 0.001 ) –

    The step size for finite difference if hvp_method is "fd_forward" or "fd_central". Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of hessian-vector products with random vectors to evaluate each time when updating the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.

  • seed (int | None, default: None ) –

    seed for random vectors. Defaults to None.

  • inner (Chainable | None) –

    Inner module. If this is specified, operations are performed in the following order. 1. compute hessian diagonal estimate. 2. pass inputs to inner. 3. momentum and preconditioning are applied to the ouputs of inner.

Examples:

Using AdaHessian:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.AdaHessian(),
    tz.m.LR(0.1)
)

AdaHessian preconditioner can be applied to any other module by passing it to the inner argument. Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying AdaHessian preconditioning to nesterov momentum (tz.m.NAG):

opt = tz.Optimizer(
    model.parameters(),
    tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
    tz.m.LR(0.1)
)

Source code in torchzero/modules/adaptive/adahessian.py
class AdaHessian(Transform):
    """AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)

    This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.

    Notes:
        - In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply AdaHessian preconditioning to another module's output.

        - This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        beta1 (float, optional): first momentum. Defaults to 0.9.
        beta2 (float, optional): second momentum for squared hessian diagonal estimates. Defaults to 0.999.
        averaging (bool, optional):
            whether to enable block diagonal averaging over 1st dimension on parameters that have 2+ dimensions.
            This can be set per-parameter in param groups.
        block_size (int, optional):
            size of block in the block-diagonal averaging.
        update_freq (int, optional):
            frequency of updating hessian diagonal estimate via a hessian-vector product.
            This value can be increased to reduce computational cost. Defaults to 1.
        eps (float, optional):
            division stability epsilon. Defaults to 1e-8.
        hvp_method (str, optional):
            Determines how hessian-vector products are computed.

            - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
            - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
            - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

            Defaults to ``"autograd"``.
        h (float, optional):
            The step size for finite difference if ``hvp_method`` is
            ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
        n_samples (int, optional):
            number of hessian-vector products with random vectors to evaluate each time when updating
            the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
        seed (int | None, optional): seed for random vectors. Defaults to None.
        inner (Chainable | None, optional):
            Inner module. If this is specified, operations are performed in the following order.
            1. compute hessian diagonal estimate.
            2. pass inputs to ``inner``.
            3. momentum and preconditioning are applied to the ouputs of ``inner``.

    ## Examples:

    Using AdaHessian:

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.AdaHessian(),
        tz.m.LR(0.1)
    )
    ```

    AdaHessian preconditioner can be applied to any other module by passing it to the ``inner`` argument.
    Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
    AdaHessian preconditioning to nesterov momentum (``tz.m.NAG``):
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
        tz.m.LR(0.1)
    )
    ```

    """
    def __init__(
        self,
        beta1: float = 0.9,
        beta2: float = 0.999,
        averaging: bool = True,
        block_size: int | None = None,
        update_freq: int = 1,
        eps: float = 1e-8,
        hessian_power: float = 1,
        distribution: Distributions = 'rademacher',
        hvp_method: HVPMethod = 'autograd',
        h: float = 1e-3,
        n_samples = 1,
        zHz: bool = True,
        debias: bool = True,
        seed: int | None = None,

        exp_avg_tfm: Chainable | None = None,
        D_exp_avg_sq_tfm: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults["exp_avg_tfm"], defaults["D_exp_avg_sq_tfm"]
        super().__init__(defaults)

        self.set_child('exp_avg', exp_avg_tfm)
        self.set_child('D_exp_avg_sq', D_exp_avg_sq_tfm)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        params = objective.params

        beta1, beta2, averaging, block_size = unpack_dicts(settings, 'beta1', 'beta2', 'averaging', 'block_size', cls=NumberList)

        exp_avg, D_exp_avg_sq = unpack_states(states, params, 'exp_avg', 'D_exp_avg_sq', cls=TensorList)

        # ---------------------------- hutchinson hessian ---------------------------- #
        fs = settings[0]
        step = self.increment_counter("step", start=0) # 0 on 1st update
        update_freq = fs['update_freq']

        if step % update_freq == 0:
            self.increment_counter("num_Ds", start=1)

            D, _ = objective.hutchinson_hessian(
                rgrad = None,
                at_x0 = True,
                n_samples = fs['n_samples'],
                distribution = fs['distribution'],
                hvp_method = fs['hvp_method'],
                h = fs['h'],
                zHz = fs["zHz"],
                generator = self.get_generator(params[0].device, fs["seed"]),
            )

            D = TensorList(D).zipmap_args(_block_average, block_size, averaging)
            D_exp_avg_sq.mul_(beta2).addcmul_(D, D, value=1-beta2)

        # --------------------------------- momentum --------------------------------- #
        tensors = objective.get_updates() # do this after hutchinson to not disturb autograd
        exp_avg.lerp_(tensors, 1-beta1)


    @torch.no_grad
    def apply_states(self, objective, states, settings):
        params = objective.params

        beta1, beta2, eps, hessian_power = unpack_dicts(settings, 'beta1', 'beta2', 'eps', 'hessian_power', cls=NumberList)
        exp_avg, D_exp_avg_sq = unpack_states(states, params, 'exp_avg', 'D_exp_avg_sq', cls=TensorList)

        # ---------------------------------- debias ---------------------------------- #
        if settings[0]["debias"]:
            bias_correction1 = 1.0 - (beta1 ** (self.global_state["step"] + 1))
            bias_correction2 = 1.0 - (beta2 ** self.global_state["num_Ds"])
            exp_avg = exp_avg / bias_correction1
            D_exp_avg_sq = D_exp_avg_sq / bias_correction2


        # -------------------------------- transforms -------------------------------- #
        exp_avg = TensorList(self.inner_step_tensors(
            "exp_avg", tensors=exp_avg, clone=True, objective=objective, must_exist=False))

        D_exp_avg_sq = TensorList(self.inner_step_tensors(
            "D_exp_avg_sq", tensors=D_exp_avg_sq, clone=True, objective=objective, must_exist=False))

        # ------------------------------ compute update ------------------------------ #
        denom = D_exp_avg_sq.lazy_pow(hessian_power / 2) + eps
        objective.updates = exp_avg / denom
        return objective

Adagrad

Bases: torchzero.core.transform.TensorTransform

Adagrad, divides by sum of past squares of gradients.

This implementation is identical to torch.optim.Adagrad.

Parameters:

  • lr_decay (float, default: 0 ) –

    learning rate decay. Defaults to 0.

  • initial_accumulator_value (float, default: 0 ) –

    initial value of the sum of squares of gradients. Defaults to 0.

  • eps (float, default: 1e-10 ) –

    division epsilon. Defaults to 1e-10.

  • alpha (float, default: 1 ) –

    step size. Defaults to 1.

  • pow (float) –

    power for gradients and accumulator root. Defaults to 2.

  • use_sqrt (bool) –

    whether to take the root of the accumulator. Defaults to True.

  • inner (Chainable | None, default: None ) –

    Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.

Source code in torchzero/modules/adaptive/adagrad.py
class Adagrad(TensorTransform):
    """Adagrad, divides by sum of past squares of gradients.

    This implementation is identical to ``torch.optim.Adagrad``.

    Args:
        lr_decay (float, optional): learning rate decay. Defaults to 0.
        initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
        eps (float, optional): division epsilon. Defaults to 1e-10.
        alpha (float, optional): step size. Defaults to 1.
        pow (float, optional): power for gradients and accumulator root. Defaults to 2.
        use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
        inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
    """
    def __init__(
        self,

        # hyperparams
        lr_decay: float = 0,
        initial_accumulator_value: float = 0,
        eps: float = 1e-10,
        alpha: float = 1,

        # tfms
        inner: Chainable | None = None,
        accumulator_tfm: Chainable | None = None
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults["accumulator_tfm"]
        super().__init__(defaults=defaults, inner=inner)

        self.set_child('accumulator', accumulator_tfm)

    @torch.no_grad
    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        state["accumulator"] = torch.full_like(tensor, fill_value=setting["initial_accumulator_value"])

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        torch._foreach_addcmul_([state["accumulator"] for state in states], tensors, tensors)
        self.increment_counter("step", start=0)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        tensors_ = TensorList(tensors)
        step = self.global_state["step"] # 0 on first apply
        eps, alpha, lr_decay = unpack_dicts(settings, "eps", "alpha", "lr_decay", cls=NumberList)

        accumulator = [state["accumulator"] for state in states]
        accumulator = TensorList(self.inner_step_tensors(
            "accumulator", tensors=accumulator, clone=True, params=params, grads=grads, loss=loss, must_exist=False))

        denom = accumulator.sqrt().add_(eps)
        tensors_ /= denom

        clr = alpha / (1 + step * lr_decay)
        tensors_.lazy_mul_(clr)

        return tensors_

AdagradNorm

Bases: torchzero.core.transform.TensorTransform

Adagrad-Norm, divides by sum of past means of squares of gradients.

Parameters:

  • lr_decay (float, default: 0 ) –

    learning rate decay. Defaults to 0.

  • initial_accumulator_value (float, default: 0 ) –

    initial value of the sum of squares of gradients. Defaults to 0.

  • eps (float, default: 1e-10 ) –

    division epsilon. Defaults to 1e-10.

  • alpha (float, default: 1 ) –

    step size. Defaults to 1.

  • use_sqrt (bool, default: True ) –

    whether to take the root of the accumulator. Defaults to True.

  • inner (Chainable | None, default: None ) –

    Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.

Source code in torchzero/modules/adaptive/adagrad.py
class AdagradNorm(TensorTransform):
    """Adagrad-Norm, divides by sum of past means of squares of gradients.

    Args:
        lr_decay (float, optional): learning rate decay. Defaults to 0.
        initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
        eps (float, optional): division epsilon. Defaults to 1e-10.
        alpha (float, optional): step size. Defaults to 1.
        use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
        inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
    """
    def __init__(
        self,
        lr_decay: float = 0,
        initial_accumulator_value: float = 0,
        eps: float = 1e-10,
        beta:float | None = None,
        beta_debias: bool = True,
        layerwise: bool = True,
        use_sqrt: bool = True,
        alpha: float = 1,
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner']
        super().__init__(defaults=defaults, inner=inner)

    @torch.no_grad
    def multi_tensor_initialize(self, tensors, params, grads, loss, states, settings):

        # layerwise initialize in each state
        if settings[0]["layerwise"]:
            for tensor, state, setting in zip(tensors, states, settings):

                initial_accumulator_value = setting["initial_accumulator_value"]
                state["accumulator"] = torch.tensor(initial_accumulator_value, device=tensor.device, dtype=tensor.dtype)

        # global initialize in global state
        else:
            initial_accumulator_value = settings[0]["initial_accumulator_value"]
            tensor = tensors[0]
            self.global_state["accumulator"] = torch.tensor(initial_accumulator_value, device=tensor.device, dtype=tensor.dtype)

    def _get_accumulator(self, states, settings) -> torch.Tensor | TensorList:
        layerwise = settings[0]["layerwise"]
        if layerwise:
            return TensorList(s["accumulator"] for s in states)

        return self.global_state["accumulator"]

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        accumulator = self._get_accumulator(states, settings)
        self.increment_counter("step", start=0)

        # compute squared gradient norm (gg)
        if isinstance(accumulator, TensorList): gg = tensors.tensorwise_dot(tensors)
        else: gg = tensors.dot(tensors)

        # update the accumulator
        beta = settings[0]["beta"]
        if beta is None: accumulator.add_(gg) # pyright:ignore[reportArgumentType]
        else: accumulator.lerp_(gg, weight=1-beta) # pyright:ignore[reportArgumentType, reportCallIssue]

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        accumulator = self._get_accumulator(states, settings)
        eps, alpha, lr_decay = unpack_dicts(settings, "eps", "alpha", "lr_decay", cls=NumberList)
        step = self.global_state["step"] # 0 on 1st step
        fs = settings[0]
        beta = fs["beta"]

        # ------------------------ debias if beta is not None ------------------------ #
        if fs["beta_debias"] and beta is not None:
            accumulator = accumulator / (1 - beta ** (step + 1))


        # ---------------------------- compute denominator --------------------------- #
        if fs["use_sqrt"]:
            denom = accumulator.sqrt().add_(eps) # pyright:ignore[reportArgumentType]
        else:
            denom = accumulator + eps # pyright:ignore[reportOperatorIssue]


        # ---------------------------- compute the update ---------------------------- #
        tensors /= denom
        clr = alpha / (1 + step * lr_decay) # lr decay
        tensors.lazy_mul_(clr)

        return tensors

Adam

Bases: torchzero.core.transform.TensorTransform

Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.

This implementation is identical to :code:torch.optim.Adam.

Parameters:

  • beta1 (float, default: 0.9 ) –

    momentum. Defaults to 0.9.

  • beta2 (float, default: 0.999 ) –

    second momentum. Defaults to 0.999.

  • eps (float, default: 1e-08 ) –

    epsilon. Defaults to 1e-8.

  • alpha (float, default: 1.0 ) –

    learning rate. Defaults to 1.

  • amsgrad (bool, default: False ) –

    Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.

  • pow (float) –

    power used in second momentum power and root. Defaults to 2.

  • debias (bool, default: True ) –

    whether to apply debiasing to momentums based on current step. Defaults to True.

Source code in torchzero/modules/adaptive/adam.py
class Adam(TensorTransform):
    """Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.

    This implementation is identical to :code:`torch.optim.Adam`.

    Args:
        beta1 (float, optional): momentum. Defaults to 0.9.
        beta2 (float, optional): second momentum. Defaults to 0.999.
        eps (float, optional): epsilon. Defaults to 1e-8.
        alpha (float, optional): learning rate. Defaults to 1.
        amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
        pow (float, optional): power used in second momentum power and root. Defaults to 2.
        debias (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
    """
    def __init__(
        self,
        beta1: float = 0.9,
        beta2: float = 0.999,
        eps: float = 1e-8,
        amsgrad: bool = False,
        alpha: float = 1.,
        debias: bool = True,

        exp_avg_tfm: Chainable | None = None,
        exp_avg_sq_tfm: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults["exp_avg_tfm"], defaults["exp_avg_sq_tfm"]
        super().__init__(defaults)

        self.set_child('exp_avg', exp_avg_tfm)
        self.set_child('exp_avg_sq', exp_avg_sq_tfm)

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        self.increment_counter("step", start=0)
        beta1, beta2 = unpack_dicts(settings, 'beta1','beta2', cls=NumberList)

        # ----------------------------- initialize states ---------------------------- #
        if settings[0]["amsgrad"]:
            exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(
                states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        # ------------------------------ update moments ------------------------------ #
        exp_avg.lerp_(tensors, weight=1-beta1)
        exp_avg_sq.mul_(beta2).addcmul_(tensors, tensors, value=1-beta2)

        if max_exp_avg_sq is not None:
            max_exp_avg_sq.maximum_(exp_avg_sq)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state["step"] # 0 on 1st step
        fs = settings[0]

        if fs["amsgrad"]: key = "max_exp_avg_sq"
        else: key = "exp_avg_sq"
        exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', key, cls=TensorList)
        beta1, beta2, alpha, eps = unpack_dicts(settings, 'beta1', 'beta2', 'alpha', 'eps', cls=NumberList)

        # -------------------------------- transforms -------------------------------- #
        exp_avg = TensorList(self.inner_step_tensors(
            "exp_avg", tensors=exp_avg, clone=True, params=params, grads=grads, loss=loss, must_exist=False))

        exp_avg_sq = TensorList(self.inner_step_tensors(
            "exp_avg_sq", tensors=exp_avg_sq, clone=True, params=params, grads=grads, loss=loss, must_exist=False))

        # ---------------------------------- debias ---------------------------------- #
        if fs["debias"]:
            alpha = debiased_step_size((step + 1), beta1=beta1, beta2=beta2, alpha=alpha)
            exp_avg = exp_avg * alpha

        # ---------------------------------- update ---------------------------------- #
        return exp_avg / exp_avg_sq.sqrt().add_(eps)

Adan

Bases: torchzero.core.transform.TensorTransform

Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677

Parameters:

  • beta1 (float, default: 0.98 ) –

    momentum. Defaults to 0.98.

  • beta2 (float, default: 0.92 ) –

    momentum for gradient differences. Defaults to 0.92.

  • beta3 (float, default: 0.99 ) –

    thrid (squared) momentum. Defaults to 0.99.

  • eps (float, default: 1e-08 ) –

    epsilon. Defaults to 1e-8.

Example:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adan(),
    tz.m.LR(1e-3),
)
Reference: Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence.

Source code in torchzero/modules/adaptive/adan.py
class Adan(TensorTransform):
    """Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677

    Args:
        beta1 (float, optional): momentum. Defaults to 0.98.
        beta2 (float, optional): momentum for gradient differences. Defaults to 0.92.
        beta3 (float, optional): thrid (squared) momentum. Defaults to 0.99.
        eps (float, optional): epsilon. Defaults to 1e-8.

    Example:
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Adan(),
        tz.m.LR(1e-3),
    )
    ```
    Reference:
        [Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence](https://arxiv.org/abs/2208.06677).
    """
    def __init__(
        self,
        beta1: float = 0.98,
        beta2: float = 0.92,
        beta3: float = 0.99,
        eps: float = 1e-8,

        m_tfm: Chainable | None = None,
        v_tfm: Chainable | None = None,
        n_tfm: Chainable | None = None,
    ):
        defaults=dict(beta1=beta1, beta2=beta2, beta3=beta3, eps=eps)
        super().__init__(defaults, uses_grad=False)

        self.set_child("m", m_tfm)
        self.set_child("v", v_tfm)
        self.set_child("n", n_tfm)

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        step = self.increment_counter("step", start=0)

        beta1, beta2, beta3 = unpack_dicts(settings, 'beta1','beta2','beta3', cls=NumberList)
        g_prev, m, v, n = unpack_states(states, tensors, 'g_prev', 'm', 'v', 'n', cls=TensorList)

        adan_update_(g=tensors, g_prev_=g_prev, m_=m, v_=v, n_=n, beta1=beta1, beta2=beta2, beta3=beta3, step=step+1)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        step = self.global_state["step"] # 0 on 1st step

        beta1, beta2, beta3, eps = unpack_dicts(settings, 'beta1','beta2','beta3', 'eps', cls=NumberList)
        m, v, n = unpack_states(states, tensors, 'm', 'v', 'n')

        # -------------------------------- transforms -------------------------------- #
        m = TensorList(self.inner_step_tensors("m", m, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
        v = TensorList(self.inner_step_tensors("v", v, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
        n = TensorList(self.inner_step_tensors("n", n, clone=True, params=params, grads=grads, loss=loss, must_exist=False))

        # ---------------------------------- update ---------------------------------- #
        return adan_apply_(m_=m, v_=v, n_=n, beta1=beta1, beta2=beta2, beta3=beta3, eps=eps, step=step+1)

AdaptiveHeavyBall

Bases: torchzero.core.transform.TensorTransform

Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.

Suitable for quadratic objectives with known f* (loss at minimum).

note

The step size is determined by the algorithm, so learning rate modules shouldn't be used.

Parameters:

  • f_star (int, default: 0 ) –

    (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.

Source code in torchzero/modules/adaptive/adaptive_heavyball.py
class AdaptiveHeavyBall(TensorTransform):
    """Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.

    Suitable for quadratic objectives with known f* (loss at minimum).

    note:
        The step size is determined by the algorithm, so learning rate modules shouldn't be used.

    Args:
        f_star (int, optional):
            (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
    """
    def __init__(self, f_star: float = 0):
        defaults = dict(f_star=f_star)
        super().__init__(defaults, uses_loss=True)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        assert loss is not None
        tensors = TensorList(tensors)
        f_star = settings[0]['f_star']

        f_prev = self.global_state.get('f_prev', None)
        p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', init=[params,tensors], cls=TensorList)

        # -------------------------------- first step -------------------------------- #
        if f_prev is None:
            self.global_state['f_prev'] = loss
            h = 2*(loss - f_star) / tensors.dot(tensors)
            return h * tensors

        # ------------------------------- further steps ------------------------------ #
        update = adaptive_heavy_ball(
            f=loss, f_star=f_star, f_prev=f_prev, g=tensors, g_prev=g_prev, p=TensorList(params), p_prev=p_prev)

        # --------------------------- store previous values -------------------------- #
        self.global_state['f_prev'] = loss
        p_prev.copy_(params)
        g_prev.copy_(tensors)

        return update

BacktrackOnSignChange

Bases: torchzero.core.transform.TensorTransform

Negates or undoes update for parameters where where gradient or update sign changes.

This is part of RProp update rule.

Parameters:

  • use_grad (bool, default: False ) –

    if True, tracks sign change of the gradient, otherwise track sign change of the update. Defaults to True.

  • backtrack (bool, default: True ) –

    if True, undoes the update when sign changes, otherwise negates it. Defaults to True.

Source code in torchzero/modules/adaptive/rprop.py
class BacktrackOnSignChange(TensorTransform):
    """Negates or undoes update for parameters where where gradient or update sign changes.

    This is part of RProp update rule.

    Args:
        use_grad (bool, optional):
            if True, tracks sign change of the gradient,
            otherwise track sign change of the update. Defaults to True.
        backtrack (bool, optional):
            if True, undoes the update when sign changes, otherwise negates it.
            Defaults to True.

    """
    def __init__(self, use_grad = False, backtrack = True):
        defaults = dict(use_grad=use_grad, backtrack=backtrack)
        super().__init__(defaults, uses_grad=use_grad)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        tensors = TensorList(tensors)
        backtrack = settings[0]['backtrack']

        if self._uses_grad:
            assert grads is not None
            cur = TensorList(grads)
        else: cur = tensors

        tensors = backtrack_on_sign_change_(
            tensors_ = tensors,
            cur = cur,
            prev_ = unpack_states(states, tensors, 'prev', cls=TensorList),
            backtrack = backtrack,
            step = step,
        )

        return tensors

DualNormCorrection

Bases: torchzero.core.transform.TensorTransform

Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon). Orthogonalize already has this built in with the dual_norm_correction setting.

Source code in torchzero/modules/adaptive/muon.py
class DualNormCorrection(TensorTransform):
    """Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon).
    Orthogonalize already has this built in with the `dual_norm_correction` setting."""
    def __init__(self, channel_first: bool = True):
        defaults = dict(channel_first=channel_first)
        super().__init__(defaults)

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        assert grad is not None
        if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
            return _dual_norm_correction(tensor, grad, channel_first=setting["channel_first"])
        return tensor

ESGD

Bases: torchzero.core.transform.Transform

Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)

This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.

Notes:
    - In most cases ESGD should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply ESGD preconditioning to another module's output.

    - This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

Args:
    damping (float, optional): added to denominator for stability. Defaults to 1e-4.
    update_freq (int, optional):
        frequency of updating hessian diagonal estimate via a hessian-vector product.
        This value can be increased to reduce computational cost. Defaults to 20.
    hvp_method (str, optional):
        Determines how hessian-vector products are computed.

        - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
        - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
        - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
        - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

        Defaults to ``"autograd"``.
    h (float, optional):
        The step size for finite difference if ``hvp_method`` is
        ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
    n_samples (int, optional):
        number of hessian-vector products with random vectors to evaluate each time when updating
        the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
    seed (int | None, optional): seed for random vectors. Defaults to None.
    inner (Chainable | None, optional):
        Inner module. If this is specified, operations are performed in the following order.
        1. compute hessian diagonal estimate.
        2. pass inputs to :code:`inner`.
        3. momentum and preconditioning are applied to the ouputs of :code:`inner`.

### Examples:

Using ESGD:

```python

opt = tz.Optimizer(
    model.parameters(),
    tz.m.ESGD(),
    tz.m.LR(0.1)
)
```

ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):

```python
opt = tz.Optimizer(
    model.parameters(),
    tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
    tz.m.LR(0.1)
)
```
Source code in torchzero/modules/adaptive/esgd.py
class ESGD(Transform):
    """Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)

    This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.

    Notes:
        - In most cases ESGD should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply ESGD preconditioning to another module's output.

        - This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        damping (float, optional): added to denominator for stability. Defaults to 1e-4.
        update_freq (int, optional):
            frequency of updating hessian diagonal estimate via a hessian-vector product.
            This value can be increased to reduce computational cost. Defaults to 20.
        hvp_method (str, optional):
            Determines how hessian-vector products are computed.

            - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
            - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
            - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

            Defaults to ``"autograd"``.
        h (float, optional):
            The step size for finite difference if ``hvp_method`` is
            ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
        n_samples (int, optional):
            number of hessian-vector products with random vectors to evaluate each time when updating
            the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
        seed (int | None, optional): seed for random vectors. Defaults to None.
        inner (Chainable | None, optional):
            Inner module. If this is specified, operations are performed in the following order.
            1. compute hessian diagonal estimate.
            2. pass inputs to :code:`inner`.
            3. momentum and preconditioning are applied to the ouputs of :code:`inner`.

    ### Examples:

    Using ESGD:
```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.ESGD(),
        tz.m.LR(0.1)
    )
    ```

    ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
    ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
        tz.m.LR(0.1)
    )
    ```

    """
    def __init__(
        self,
        damping: float = 1e-4,
        update_freq: int = 20,
        distribution: Distributions = 'gaussian',
        hvp_method: HVPMethod = 'autograd',
        h: float = 1e-3,
        n_samples = 1,
        zHz: bool = False,
        seed: int | None = None,
        beta: float | None = None,
        beta_debias: bool = True,

        inner: Chainable | None = None,
        Hz_sq_acc_tfm: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults["Hz_sq_acc_tfm"]
        super().__init__(defaults, inner=inner)

        self.set_child("Hz_sq_acc", Hz_sq_acc_tfm)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        params = objective.params

        fs = settings[0]
        update_freq = fs['update_freq']

        # ------------------------------- accumulate Hz ------------------------------ #
        step = self.increment_counter("step", start=0)

        if step % update_freq == 0:
            self.increment_counter("num_Hzs", start=1)

            Hz, _ = objective.hutchinson_hessian(
                rgrad = None,
                at_x0 = True,
                n_samples = fs['n_samples'],
                distribution = fs['distribution'],
                hvp_method = fs['hvp_method'],
                h = fs['h'],
                zHz = fs["zHz"], # default is False, so it returns Hz, not z⊙Hz
                generator = self.get_generator(params[0].device, fs["seed"]),
            )

            Hz = TensorList(Hz)
            Hz_sq_acc = unpack_states(states, params, 'Hz_sq_acc', cls=TensorList)

            beta = fs["beta"]
            if beta is None:
                Hz_sq_acc.addcmul_(Hz, Hz)

            else:
                Hz_sq_acc.mul_(beta).addcmul_(Hz, Hz, value=1-beta)

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        tensors = TensorList(objective.get_updates())
        Hz_sq_acc = unpack_states(states, tensors, 'Hz_sq_acc', cls=TensorList)
        num_Hzs = self.global_state["num_Hzs"]
        fs = settings[0]

        # ---------------------------------- debias ---------------------------------- #
        beta = fs["beta"]
        beta_debias = fs["beta_debias"]

        if beta_debias and beta is not None:
            bias_correction = 1.0 - beta ** num_Hzs
            Hz_sq_acc = Hz_sq_acc / bias_correction

        else:
            Hz_sq_acc = Hz_sq_acc / num_Hzs

        # ---------------------------------- update ---------------------------------- #
        damping = [s["damping"] for s in settings]

        denom = (Hz_sq_acc / num_Hzs).sqrt_().add_(damping)

        objective.updates = tensors.div_(denom)
        return objective

FullMatrixAdagrad

Bases: torchzero.core.transform.TensorTransform

Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).

Note

A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in tz.m.GGT.

Parameters:

  • reg (float, default: 1e-12 ) –

    regularization, scale of identity matrix added to accumulator. Defaults to 1e-12.

  • precond_freq (int, default: 1 ) –

    frequency of updating the inverse square root of the accumulator. Defaults to 1.

  • beta (float | None, default: None ) –

    momentum for gradient outer product accumulators. if None, uses sum. Defaults to None.

  • beta_debias (bool, default: True ) –

    whether to use debiasing, only has effect when beta is not None. Defaults to True.

  • init (Literal[str], default: 'identity' ) –

    how to initialize the accumulator. - "identity" - with identity matrix (default). - "zeros" - with zero matrix. - "ones" - with matrix of ones. -"GGT" - with the first outer product

  • matrix_power (float, default: -0.5 ) –

    accumulator matrix power. Defaults to -1/2.

  • concat_params (bool, default: True ) –

    if False, each parameter will have it's own accumulator. Defaults to True.

  • inner (Chainable | None, default: None ) –

    inner modules to apply preconditioning to. Defaults to None.

Examples:

Plain full-matrix adagrad

opt = tz.Optimizer(
    model.parameters(),
    tz.m.FullMatrixAdagrd(),
    tz.m.LR(1e-2),
)

Full-matrix RMSprop

opt = tz.Optimizer(
    model.parameters(),
    tz.m.FullMatrixAdagrad(beta=0.99),
    tz.m.LR(1e-2),
)

Full-matrix Adam

opt = tz.Optimizer(
    model.parameters(),
    tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9)),
    tz.m.Debias(0.9, 0.999),
    tz.m.LR(1e-2),
)

Source code in torchzero/modules/adaptive/adagrad.py
class FullMatrixAdagrad(TensorTransform):
    """Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).

    Note:
        A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in ``tz.m.GGT``.

    Args:
        reg (float, optional): regularization, scale of identity matrix added to accumulator. Defaults to 1e-12.
        precond_freq (int, optional): frequency of updating the inverse square root of the accumulator. Defaults to 1.
        beta (float | None, optional): momentum for gradient outer product accumulators. if None, uses sum. Defaults to None.
        beta_debias (bool, optional): whether to use debiasing, only has effect when ``beta`` is not ``None``. Defaults to True.
        init (Literal[str], optional):
            how to initialize the accumulator.
            - "identity" - with identity matrix (default).
            - "zeros" - with zero matrix.
            - "ones" - with matrix of ones.
             -"GGT" - with the first outer product
        matrix_power (float, optional): accumulator matrix power. Defaults to -1/2.
        concat_params (bool, optional): if False, each parameter will have it's own accumulator. Defaults to True.
        inner (Chainable | None, optional): inner modules to apply preconditioning to. Defaults to None.

    ## Examples:

    Plain full-matrix adagrad
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.FullMatrixAdagrd(),
        tz.m.LR(1e-2),
    )
    ```

    Full-matrix RMSprop
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.FullMatrixAdagrad(beta=0.99),
        tz.m.LR(1e-2),
    )
    ```

    Full-matrix Adam
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9)),
        tz.m.Debias(0.9, 0.999),
        tz.m.LR(1e-2),
    )
    ```
    """
    def __init__(
        self,
        reg: float = 1e-12,
        precond_freq: int = 1,
        beta: float | None = None,
        beta_debias: bool=True,
        init: Literal["identity", "zeros", "GGT"] = "identity",
        matrix_power: float = -1/2,
        matrix_power_method: MatrixPowerMethod = "eigh_abs",
        concat_params=True,

        inner: Chainable | None = None,
        accumulator_tfm: Chainable | None = None
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults["concat_params"], defaults["accumulator_tfm"]
        super().__init__(defaults=defaults, inner=inner, concat_params=concat_params)

        self.set_child("accumulator", accumulator_tfm)

    @torch.no_grad
    def single_tensor_update(self, tensor, param, grad, loss, state, setting):

        G = tensor.ravel()
        GGT = torch.outer(G, G)

        # initialize
        if "accumulator" not in state:
            init = setting['init']
            if init == 'identity': state['accumulator'] = torch.eye(GGT.size(0), device=GGT.device, dtype=GGT.dtype)
            elif init == 'zeros': state['accumulator'] =  torch.zeros_like(GGT)
            elif init == 'GGT': state['accumulator'] = GGT.clone()
            else: raise ValueError(init)

        # update
        beta = setting['beta']
        accumulator: torch.Tensor = state["accumulator"]

        if beta is None: accumulator.add_(GGT)
        else: accumulator.lerp_(GGT, 1-beta)

        # update number of GGᵀ in accumulator for divide
        state['num_GGTs'] = state.get('num_GGTs', 0) + 1

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        step = state.get('step', 0)
        state['step'] = step + 1

        accumulator: torch.Tensor = state['accumulator']
        accumulator = self.inner_step_tensors("accumulator", [accumulator], clone=True, must_exist=False)[0]

        precond_freq = setting['precond_freq']
        reg = setting['reg']
        beta = setting["beta"]

        # add regularizer
        if reg != 0:
            device = accumulator.device; dtype = accumulator.dtype
            accumulator = accumulator + torch.eye(accumulator.size(0), device=device, dtype=dtype).mul_(reg)

        # for single value use sqrt
        if tensor.numel() == 1:
            dir = tensor.mul_(accumulator.squeeze() ** setting["matrix_power"])

        # otherwise use matrix inverse square root
        else:

            # compute inverse square root and store to state
            try:
                if "B" not in state or step % precond_freq == 0:
                    B = state["B"] = _matrix_power(accumulator, setting["matrix_power"], method=setting["matrix_power_method"])
                else:
                    B = state["B"]

                dir = (B @ tensor.ravel()).view_as(tensor)

            # fallback to diagonal Adagrad on fail
            except torch.linalg.LinAlgError:
                dir = tensor.mul_(accumulator.diagonal() ** setting["matrix_power"])

        # debias
        if setting["beta_debias"] and beta is not None:
            num_GGTs = state.get('num_GGTs', 1)
            bias_correction = 1 - beta ** num_GGTs
            dir *= bias_correction ** 0.5

        return dir

GGT

Bases: torchzero.core.transform.TensorTransform

GGT method from https://arxiv.org/pdf/1806.02958

The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg. But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.

This is equivalent to full-matrix Adagrad on recent gradients.

Parameters:

  • history_size (int, default: 100 ) –

    number of past gradients to store. Defaults to 10.

  • beta (float) –

    beta for momentum maintained in whitened space. Defaults to 0.0.

  • update_freq (int, default: 1 ) –

    frequency of updating the preconditioner (U and S). Defaults to 1.

  • eig_tol (float, default: 1e-07 ) –

    removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.

  • truncate (int, default: None ) –

    number of larges eigenvalues to keep. None to disable. Defaults to None.

  • damping (float, default: 0.0001 ) –

    damping value. Defaults to 1e-4.

  • rdamping (float, default: 0 ) –

    value of damping relative to largest eigenvalue. Defaults to 0.

  • concat_params (bool, default: True ) –

    if True, treats all parameters as a single vector. Defaults to True.

  • inner (Chainable | None, default: None ) –

    preconditioner will be applied to output of this module. Defaults to None.

Examples:

Limited-memory Adagrad

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.GGT(),
    tz.m.LR(0.1)
)
Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.GGT(inner=tz.m.EMA()),
    tz.m.Debias(0.9, 0.999),
    tz.m.LR(0.01)
)

Stable Adam with L-Adagrad preconditioner (this is what I would recommend)

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.GGT(inner=tz.m.EMA()),
    tz.m.Debias(0.9, 0.999),
    tz.m.ClipNormByEMA(max_ema_growth=1.2),
    tz.m.LR(0.01)
)
Reference: Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.

Source code in torchzero/modules/adaptive/ggt.py
class GGT(TensorTransform):
    """
    GGT method from https://arxiv.org/pdf/1806.02958

    The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
    But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.

    This is equivalent to full-matrix Adagrad on recent gradients.

    Args:
        history_size (int, optional): number of past gradients to store. Defaults to 10.
        beta (float, optional): beta for momentum maintained in whitened space. Defaults to 0.0.
        update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
        eig_tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
        truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
        damping (float, optional): damping value. Defaults to 1e-4.
        rdamping (float, optional): value of damping relative to largest eigenvalue. Defaults to 0.
        concat_params (bool, optional): if True, treats all parameters as a single vector. Defaults to True.
        inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.

    ## Examples:

    Limited-memory Adagrad

    ```python
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.GGT(),
        tz.m.LR(0.1)
    )
    ```
    Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)

    ```python
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.GGT(inner=tz.m.EMA()),
        tz.m.Debias(0.9, 0.999),
        tz.m.LR(0.01)
    )
    ```

    Stable Adam with L-Adagrad preconditioner (this is what I would recommend)

    ```python
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.GGT(inner=tz.m.EMA()),
        tz.m.Debias(0.9, 0.999),
        tz.m.ClipNormByEMA(max_ema_growth=1.2),
        tz.m.LR(0.01)
    )
    ```
    Reference:
        Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
    """

    def __init__(
        self,
        history_size: int = 100,
        update_freq: int = 1,
        eig_tol: float = 1e-7,
        truncate: int | None = None,
        damping: float = 1e-4,
        rdamping: float = 0,
        eigenbasis_optimizer: LREOptimizerBase | None = None,
        concat_params: bool = True,

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults['concat_params']

        super().__init__(defaults, concat_params=concat_params, inner=inner)

    @torch.no_grad
    def single_tensor_update(self, tensor, param, grad, loss, state, setting):
        history_size = setting['history_size']
        update_freq = setting['update_freq']

        if 'history' not in state: state['history'] = deque(maxlen=history_size)
        history = state['history']

        t = tensor.clone().view(-1)
        history.append(t)

        step = state.get('step', 0)
        state['step'] = step + 1

        if step % update_freq == 0 :

            # compute new factors
            L = state.get("L", None)
            U = state.get("U", None)

            L_new, U_new = ggt_update(
                history,
                damping=setting["damping"],
                rdamping=setting["rdamping"],
                truncate=setting["truncate"],
                eig_tol=setting["eig_tol"],
            )

            # reproject eigenbasis optimizer
            eigenbasis_optimizer: LREOptimizerBase | None = setting["eigenbasis_optimizer"]
            if eigenbasis_optimizer is not None:
                if (L is not None) and (U is not None) and (L_new is not None) and (U_new is not None):
                    eigenbasis_state = state["eigenbasis_state"]
                    eigenbasis_optimizer.reproject(L_old=L, Q_old=U, L_new=L_new, Q_new=U_new, state=eigenbasis_state)


            # store new factors
            if L_new is not None: state["L"] = L_new
            if U_new is not None: state["U"] = U_new


    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        g = tensor.view(-1)
        U = state.get('U', None)

        if U is None:
            # fallback to element-wise preconditioning
            history = torch.stack(tuple(state["history"]), 0)
            g /= history.square().mean(0).sqrt().add(1e-8)
            return g.view_as(tensor)

        L = state['L']

        # step with eigenbasis optimizer
        eigenbasis_optimizer: LREOptimizerBase | None = setting["eigenbasis_optimizer"]
        if eigenbasis_optimizer is not None:

            if "eigenbasis_state" not in state: state["eigenbasis_state"] = {}
            eigenbasis_state = state["eigenbasis_state"]

            update = eigenbasis_optimizer.step(g, L=L, Q=U, state=eigenbasis_state)
            return update.view_as(tensor)

        # or just whiten
        z = U.T @ g
        update = (U * L.rsqrt()) @ z
        return update.view_as(tensor)

Lion

Bases: torchzero.core.transform.TensorTransform

Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.

Parameters:

  • beta1 (float, default: 0.9 ) –

    dampening for momentum. Defaults to 0.9.

  • beta2 (float, default: 0.99 ) –

    momentum factor. Defaults to 0.99.

Source code in torchzero/modules/adaptive/lion.py
class Lion(TensorTransform):
    """Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.

    Args:
        beta1 (float, optional): dampening for momentum. Defaults to 0.9.
        beta2 (float, optional): momentum factor. Defaults to 0.99.
    """

    def __init__(self, beta1: float = 0.9, beta2: float = 0.99):
        defaults = dict(beta1=beta1, beta2=beta2)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
        exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
        return lion_(TensorList(tensors), exp_avg, beta1, beta2)

MARSCorrection

Bases: torchzero.core.transform.TensorTransform

MARS variance reduction correction.

Place any other momentum-based optimizer after this, make sure beta parameter matches with momentum in the optimizer.

Parameters:

  • beta (float, default: 0.9 ) –

    use the same beta as you use in the momentum module. Defaults to 0.9.

  • scaling (float, default: 0.025 ) –

    controls the scale of gradient correction in variance reduction. Defaults to 0.025.

  • max_norm (float, default: 1 ) –

    clips norm of corrected gradients, None to disable. Defaults to 1.

Examples:

Mars-AdamW

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.MARSCorrection(beta=0.95),
    tz.m.Adam(beta1=0.95, beta2=0.99),
    tz.m.WeightDecay(1e-3),
    tz.m.LR(0.1)
)

Mars-Lion

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.MARSCorrection(beta=0.9),
    tz.m.Lion(beta1=0.9),
    tz.m.LR(0.1)
)

Source code in torchzero/modules/adaptive/mars.py
class MARSCorrection(TensorTransform):
    """MARS variance reduction correction.

    Place any other momentum-based optimizer after this,
    make sure ``beta`` parameter matches with momentum in the optimizer.

    Args:
        beta (float, optional): use the same beta as you use in the momentum module. Defaults to 0.9.
        scaling (float, optional): controls the scale of gradient correction in variance reduction. Defaults to 0.025.
        max_norm (float, optional): clips norm of corrected gradients, None to disable. Defaults to 1.

    ## Examples:

    Mars-AdamW
    ```python
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.MARSCorrection(beta=0.95),
        tz.m.Adam(beta1=0.95, beta2=0.99),
        tz.m.WeightDecay(1e-3),
        tz.m.LR(0.1)
    )
    ```

    Mars-Lion
    ```python
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.MARSCorrection(beta=0.9),
        tz.m.Lion(beta1=0.9),
        tz.m.LR(0.1)
    )
    ```

    """
    def __init__(
        self,
        beta: float = 0.9,
        scaling: float = 0.025,
        max_norm: float | None = 1,
    ):
        defaults = dict(beta=beta, scaling=scaling, max_norm=max_norm)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
        beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
        max_norm = settings[0]['max_norm']

        return mars_correction_(
            tensors_=TensorList(tensors),
            prev_=prev,
            beta=beta,
            scaling=scaling,
            max_norm=max_norm,
        )

MSAM

Bases: torchzero.core.transform.Transform

Momentum-SAM from https://arxiv.org/pdf/2401.12033.

Note

Please make sure to place tz.m.LR inside the modules argument. For example, tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)]). Putting LR after MSAM will lead to an incorrect update rule.

Parameters:

  • modules (Chainable) –

    modules that will optimize the MSAM objective. Make sure tz.m.LR is one of them.

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • rho (float, default: 0.3 ) –

    perturbation strength. Defaults to 0.3.

  • nesterov (bool, default: False ) –

    whether to use nesterov momentum formula. Defaults to False.

  • lerp (bool, default: False ) –

    whether to use linear interpolation, if True, MSAM momentum becomes similar to exponential moving average. Defaults to False.

Examples: AdamW-MSAM

opt = tz.Optimizer(
    bench.parameters(),
    tz.m.MSAMObjective(
        [tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
        rho=1.
    )
)
Source code in torchzero/modules/adaptive/msam.py
class MSAM(Transform):
    """Momentum-SAM from https://arxiv.org/pdf/2401.12033.

    Note:
        Please make sure to place ``tz.m.LR`` inside the ``modules`` argument. For example,
        ``tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)])``. Putting LR after MSAM will lead
        to an incorrect update rule.

    Args:
        modules (Chainable): modules that will optimize the MSAM objective. Make sure ``tz.m.LR`` is one of them.
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        rho (float, optional): perturbation strength. Defaults to 0.3.
        nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
        lerp (bool, optional):
            whether to use linear interpolation, if True, MSAM momentum becomes similar to exponential moving average.
            Defaults to False.

    Examples:
    AdamW-MSAM

    ```py
    opt = tz.Optimizer(
        bench.parameters(),
        tz.m.MSAMObjective(
            [tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
            rho=1.
        )
    )
    ```
    """
    def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
        defaults = dict(momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
        super().__init__(defaults)

        self.set_child('modules', modules)


    @torch.no_grad
    def apply_states(self, objective, states, settings):
        velocity = unpack_states(states, objective.params, 'velocity', cls=TensorList)
        fs = settings[0]

        momentum, rho, weight_decay = unpack_dicts(settings, 'momentum', 'rho', 'weight_decay', cls=NumberList)

        return msam_(
            TensorList(objective.get_updates()),
            params=TensorList(objective.params),
            velocity_=velocity,
            momentum=momentum,
            lr=None,
            rho=rho,
            weight_decay=weight_decay,
            nesterov=fs['nesterov'],
            lerp=fs['lerp'],

            # inner args
            inner=self.children["modules"],
            objective=objective,
        )

MSAMMomentum

Bases: torchzero.core.transform.TensorTransform

Momentum-SAM from https://arxiv.org/pdf/2401.12033.

This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in replacement for momentum strategies in other optimizers.

To combine MSAM with other optimizers in the way done in the official implementation, e.g. to make Adam_MSAM, use tz.m.MSAMObjective module.

Note MSAM has a learning rate hyperparameter that can't really be removed from the update rule. To avoid compounding learning rate mofications, remove the tz.m.LR module if you had it.

Parameters:

  • lr (float) –

    learning rate. Adding this module adds support for learning rate schedulers.

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • rho (float, default: 0.3 ) –

    perturbation strength. Defaults to 0.3.

  • weight_decay (float, default: 0 ) –

    weight decay. It is applied to perturbed parameters, so it is differnet from applying :code:tz.m.WeightDecay after MSAM. Defaults to 0.

  • nesterov (bool, default: False ) –

    whether to use nesterov momentum formula. Defaults to False.

  • lerp (bool, default: False ) –

    whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.

Examples:

MSAM

opt = tz.Optimizer(
    model.parameters(),
    tz.m.MSAM(1e-3)
)

Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM. To make Adam_MSAM and such, use the tz.m.MSAMObjective module.

opt = tz.Optimizer(
    model.parameters(),
    tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
    tz.m.Debias(0.9, 0.999),
)
Source code in torchzero/modules/adaptive/msam.py
class MSAMMomentum(TensorTransform):
    """Momentum-SAM from https://arxiv.org/pdf/2401.12033.

    This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in
    replacement for momentum strategies in other optimizers.

    To combine MSAM with other optimizers in the way done in the official implementation,
    e.g. to make Adam_MSAM, use ``tz.m.MSAMObjective`` module.

    Note
        MSAM has a learning rate hyperparameter that can't really be removed from the update rule.
        To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.

    Args:
        lr (float): learning rate. Adding this module adds support for learning rate schedulers.
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        rho (float, optional): perturbation strength. Defaults to 0.3.
        weight_decay (float, optional):
            weight decay. It is applied to perturbed parameters, so it is differnet
            from applying :code:`tz.m.WeightDecay` after MSAM. Defaults to 0.
        nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
        lerp (bool, optional):
            whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.

    ### Examples:

    MSAM

    ```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.MSAM(1e-3)
    )
    ```

    Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
    To make Adam_MSAM and such, use the ``tz.m.MSAMObjective`` module.

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
        tz.m.Debias(0.9, 0.999),
    )
    ```
    """

    def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3,  weight_decay:float=0, nesterov=False, lerp=False,):
        defaults = dict(lr = lr, momentum=momentum, rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
        fs = settings[0]

        lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)

        return msam_(
            TensorList(tensors),
            params=TensorList(params),
            velocity_=velocity,
            momentum=momentum,
            lr=lr,
            rho=rho,
            weight_decay=weight_decay,
            nesterov=fs['nesterov'],
            lerp=fs['lerp'],

            # inner args
            inner=None,
            objective=None,
        )

MatrixMomentum

Bases: torchzero.core.transform.Transform

Second order momentum method.

Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.

Notes
  • mu needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable. I have devised an adaptive version of this - tz.m.AdaptiveMatrixMomentum, and it works well without having to tune mu, however the adaptive version doesn't work on stochastic objectives.

  • In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.

  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument.

Parameters:

  • mu (float, default: 0.1 ) –

    this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.

  • hvp_method (str, default: 'autograd' ) –

    Determines how hessian-vector products are computed.

    • "batched_autograd" - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to "autograd". Faster than "autograd" but uses more memory.
    • "autograd" - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than "batched_autograd" but uses less memory.
    • "fd_forward" - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
    • "fd_central" - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

    Defaults to "autograd".

  • h (float, default: 0.001 ) –

    The step size for finite difference if hvp_method is "fd_forward" or "fd_central". Defaults to 1e-3.

  • hvp_tfm (Chainable | None) –

    optional module applied to hessian-vector products. Defaults to None.

Reference

Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).

Source code in torchzero/modules/adaptive/matrix_momentum.py
class MatrixMomentum(Transform):
    """Second order momentum method.

    Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.

    Notes:
        - ``mu`` needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable. I have devised an adaptive version of this - ``tz.m.AdaptiveMatrixMomentum``, and it works well without having to tune ``mu``, however the adaptive version doesn't work on stochastic objectives.

        - In most cases ``MatrixMomentum`` should be the first module in the chain because it relies on autograd.

        - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument.

    Args:
        mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
        hvp_method (str, optional):
            Determines how hessian-vector products are computed.

            - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
            - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
            - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

            Defaults to ``"autograd"``.
        h (float, optional):
            The step size for finite difference if ``hvp_method`` is
            ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
        hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.

    Reference:
        Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
    """

    def __init__(
        self,
        lr:float,
        mu=0.1,
        hvp_method: HVPMethod = "autograd",
        h: float = 1e-3,
        adaptive:bool = False,
        adapt_freq: int | None = None,

        inner: Chainable | None = None,
    ):
        defaults = dict(lr=lr, mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
        super().__init__(defaults, inner=inner)

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('p_prev')

    @torch.no_grad
    def update_states(self, objective, states, settings):
        step = self.increment_counter("step", 0)
        p = TensorList(objective.params)
        p_prev = unpack_states(states, p, 'p_prev', init=p)

        fs = settings[0]
        hvp_method = fs['hvp_method']
        h = fs['h']

        if step > 0:
            s = p - p_prev

            Hs, _ = objective.hessian_vector_product(s, at_x0=True, rgrad=None, hvp_method=hvp_method, h=h, retain_graph=False)
            Hs = [t.detach() for t in Hs]

            self.store(p, ("Hs", "s"), (Hs, s))

            # -------------------------------- adaptive mu ------------------------------- #
            if fs["adaptive"]:
                g = TensorList(objective.get_grads())

                if fs["adapt_freq"] is None:
                    # ---------------------------- deterministic case ---------------------------- #
                    g_prev = unpack_states(states, p, "g_prev", cls=TensorList)
                    y = g - g_prev
                    g_prev.copy_(g)
                    denom = y.global_vector_norm()
                    denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
                    self.global_state["mu_mul"] = s.global_vector_norm() / denom

                else:
                    # -------------------------------- stochastic -------------------------------- #
                    adapt_freq = self.defaults["adapt_freq"]

                    # we start on 1nd step, and want to adapt when we start, so use (step - 1)
                    if (step - 1) % adapt_freq == 0:
                        assert objective.closure is not None
                        params = TensorList(objective.params)
                        p_cur = params.clone()

                        # move to previous params and evaluate p_prev with current mini-batch
                        params.copy_(unpack_states(states, p, 'p_prev'))
                        with torch.enable_grad():
                            objective.closure()
                        g_prev = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
                        y = g - g_prev

                        # move back to current params
                        params.copy_(p_cur)

                        denom = y.global_vector_norm()
                        denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
                        self.global_state["mu_mul"] = s.global_vector_norm() / denom

        torch._foreach_copy_(p_prev, objective.params)

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        update = TensorList(objective.get_updates())
        lr, mu = unpack_dicts(settings, "lr", 'mu', cls=NumberList)

        if "mu_mul" in self.global_state:
            mu = mu * self.global_state["mu_mul"]

        # --------------------------------- 1st step --------------------------------- #
        # p_prev is not available so make a small step
        step = self.global_state["step"]
        if step == 1:
            if self.defaults["adaptive"]:
                # initialize
                unpack_states(states, objective.params, "g_prev", init=objective.get_grads())

            update.mul_(lr) # separate so that initial_step_size can clip correctly
            update.mul_(initial_step_size(update, 1e-7))
            return objective

        # -------------------------- matrix momentum update -------------------------- #
        s, Hs = unpack_states(states, objective.params, 's', 'Hs', cls=TensorList)

        update.mul_(lr).sub_(s).add_(Hs*mu)
        objective.updates = update
        return objective

MuonAdjustLR

Bases: torchzero.core.transform.Transform

LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master). Orthogonalize already has this built in with the adjust_lr setting, however you might want to move this to be later in the chain.

Source code in torchzero/modules/adaptive/muon.py
class MuonAdjustLR(Transform):
    """LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master).
    Orthogonalize already has this built in with the ``adjust_lr`` setting, however you might want to move this to be later in the chain."""
    def __init__(self, channel_first: bool = True, alpha: float = 1):
        defaults = dict(channel_first=channel_first, alpha=alpha)
        super().__init__(defaults=defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        alphas = [s['alpha'] for s in settings]
        channel_first = [s["channel_first=channel_first"] for s in settings]
        tensors_alphas = [
            (t, adjust_lr_for_muon(a, t.shape, cf)) for t, a, cf in zip(tensors, alphas, channel_first) if _is_at_least_2d(t, channel_first=cf)
        ]
        tensors = [i[0] for i in tensors_alphas]
        a = [i[1] for i in alphas]
        torch._foreach_mul_(tensors, a)
        return tensors

NaturalGradient

Bases: torchzero.core.transform.Transform

Natural gradient approximated via empirical fisher information matrix.

To use this, either pass vector of per-sample losses to the step method, or make sure the closure returns it. Gradients will be calculated via batched autograd within this module, you don't need to implement the backward pass. When using closure, please add the backward argument, it will always be False but it is required. See below for an example.

Note

Empirical fisher information matrix may give a really bad approximation in some cases. If that is the case, set sqrt to True to perform whitening instead, which is way more robust.

Parameters:

  • reg (float, default: 1e-08 ) –

    regularization parameter. Defaults to 1e-8.

  • sqrt (bool, default: False ) –

    if True, uses square root of empirical fisher information matrix. Both EFIM and it's square root can be calculated and stored efficiently without ndim^2 memory. Square root whitens the gradient and often performs much better, especially when you try to use NGD with a vector that isn't strictly per-sample gradients, but rather for example different losses.

  • gn_grad (bool, default: False ) –

    if True, uses Gauss-Newton G^T @ f as the gradient, which is effectively sum weighted by value and is equivalent to squaring the values. That makes the kernel trick solver incorrect, but for some reason it still works. If False, uses sum of per-sample gradients. This has an effect when sqrt=False, and affects the grad attribute. Defaults to False.

  • batched (bool, default: True ) –

    whether to use vmapping. Defaults to True.

Examples:

training a neural network:

X = torch.randn(64, 20)
y = torch.randn(64, 10)

model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
opt = tz.Optimizer(
    model.parameters(),
    tz.m.NaturalGradient(),
    tz.m.LR(3e-2)
)

for i in range(100):
    y_hat = model(X) # (64, 10)
    losses = (y_hat - y).pow(2).mean(0) # (10, )
    opt.step(loss=losses)
    if i % 10 == 0:
        print(f'{losses.mean() = }')

training a neural network - closure version

X = torch.randn(64, 20)
y = torch.randn(64, 10)

model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
opt = tz.Optimizer(
    model.parameters(),
    tz.m.NaturalGradient(),
    tz.m.LR(3e-2)
)

def closure(backward=True):
    y_hat = model(X) # (64, 10)
    return (y_hat - y).pow(2).mean(0) # (10, )

for i in range(100):
    losses = opt.step(closure)
    if i % 10 == 0:
    print(f'{losses.mean() = }')

minimizing the rosenbrock function with a mix of natural gradient, whitening and gauss-newton:

def rosenbrock(X):
    x1, x2 = X
    return torch.stack([(1 - x1).abs(), (10 * (x2 - x1**2).abs())])

X = torch.tensor([-1.1, 2.5], requires_grad=True)
opt = tz.Optimizer([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))

for iter in range(200):
    losses = rosenbrock(X)
    opt.step(loss=losses)
    if iter % 20 == 0:
        print(f'{losses.mean() = }')

Source code in torchzero/modules/adaptive/natural_gradient.py
class NaturalGradient(Transform):
    """Natural gradient approximated via empirical fisher information matrix.

    To use this, either pass vector of per-sample losses to the step method, or make sure
    the closure returns it. Gradients will be calculated via batched autograd within this module,
    you don't need to implement the backward pass. When using closure, please add the ``backward`` argument,
    it will always be False but it is required. See below for an example.

    Note:
        Empirical fisher information matrix may give a really bad approximation in some cases.
        If that is the case, set ``sqrt`` to True to perform whitening instead, which is way more robust.

    Args:
        reg (float, optional): regularization parameter. Defaults to 1e-8.
        sqrt (bool, optional):
            if True, uses square root of empirical fisher information matrix. Both EFIM and it's square
            root can be calculated and stored efficiently without ndim^2 memory. Square root
            whitens the gradient and often performs much better, especially when you try to use NGD
            with a vector that isn't strictly per-sample gradients, but rather for example different losses.
        gn_grad (bool, optional):
            if True, uses Gauss-Newton G^T @ f as the gradient, which is effectively sum weighted by value
            and is equivalent to squaring the values. That makes the kernel trick solver incorrect, but for
            some reason it still works. If False, uses sum of per-sample gradients.
            This has an effect when ``sqrt=False``, and affects the ``grad`` attribute.
            Defaults to False.
        batched (bool, optional): whether to use vmapping. Defaults to True.

    Examples:

    training a neural network:
    ```python
    X = torch.randn(64, 20)
    y = torch.randn(64, 10)

    model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NaturalGradient(),
        tz.m.LR(3e-2)
    )

    for i in range(100):
        y_hat = model(X) # (64, 10)
        losses = (y_hat - y).pow(2).mean(0) # (10, )
        opt.step(loss=losses)
        if i % 10 == 0:
            print(f'{losses.mean() = }')
    ```

    training a neural network - closure version
    ```python
    X = torch.randn(64, 20)
    y = torch.randn(64, 10)

    model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NaturalGradient(),
        tz.m.LR(3e-2)
    )

    def closure(backward=True):
        y_hat = model(X) # (64, 10)
        return (y_hat - y).pow(2).mean(0) # (10, )

    for i in range(100):
        losses = opt.step(closure)
        if i % 10 == 0:
        print(f'{losses.mean() = }')
    ```

    minimizing the rosenbrock function with a mix of natural gradient, whitening and gauss-newton:
    ```python
    def rosenbrock(X):
        x1, x2 = X
        return torch.stack([(1 - x1).abs(), (10 * (x2 - x1**2).abs())])

    X = torch.tensor([-1.1, 2.5], requires_grad=True)
    opt = tz.Optimizer([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))

    for iter in range(200):
        losses = rosenbrock(X)
        opt.step(loss=losses)
        if iter % 20 == 0:
            print(f'{losses.mean() = }')
    ```
    """
    def __init__(self, reg:float = 1e-8, sqrt:bool=False, gn_grad:bool=False, batched:bool=True, ):
        super().__init__(defaults=dict(batched=batched, reg=reg, sqrt=sqrt, gn_grad=gn_grad))

    @torch.no_grad
    def update_states(self, objective, states, settings):
        params = objective.params
        closure = objective.closure
        fs = settings[0]
        batched = fs['batched']
        gn_grad = fs['gn_grad']

        # compute per-sample losses
        f = objective.loss
        if f is None:
            assert closure is not None
            with torch.enable_grad():
                f = objective.get_loss(backward=False) # n_out
                assert isinstance(f, torch.Tensor)

        # compute per-sample gradients
        with torch.enable_grad():
            G_list = jacobian_wrt([f.ravel()], params, batched=batched)

        # set scalar loss and it's grad to objective
        objective.loss = f.sum()
        G = self.global_state["G"] = flatten_jacobian(G_list) # (n_samples, ndim)

        if gn_grad:
            g = self.global_state["g"] = G.H @ f.detach()

        else:
            g = self.global_state["g"] = G.sum(0)

        objective.grads = vec_to_tensors(g, params)

        # set closure to calculate scalar value for line searches etc
        if closure is not None:

            def ngd_closure(backward=True):

                if backward:
                    objective.zero_grad()
                    with torch.enable_grad():
                        loss = closure(False)
                        if gn_grad: loss = loss.pow(2)
                        loss = loss.sum()
                        loss.backward()
                    return loss

                loss = closure(False)
                if gn_grad: loss = loss.pow(2)
                return loss.sum()

            objective.closure = ngd_closure

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        params = objective.params
        fs = settings[0]
        reg = fs['reg']
        sqrt = fs['sqrt']

        G: torch.Tensor = self.global_state['G'] # (n_samples, n_dim)

        if sqrt:
            # this computes U, S <- SVD(M), then calculate update as U S^-1 Uᵀg,
            # but it computes it through eigendecompotision
            L, U = ggt_update(G.H, damping=reg, rdamping=1e-16, truncate=0, eig_tol=1e-12)

            if U is None or L is None:

                # fallback to element-wise
                g = self.global_state["g"]
                g /= G.square().mean(0).sqrt().add(reg)
                objective.updates = vec_to_tensors(g, params)
                return objective

            # whiten
            z = U.T @ self.global_state["g"]
            v = (U * L.rsqrt()) @ z
            objective.updates = vec_to_tensors(v, params)
            return objective

        # we need (G^T G)v = g
        # where g = G^T
        # so we need to solve (G^T G)v = G^T
        GGt = G @ G.H # (n_samples, n_samples)

        if reg != 0:
            GGt.add_(torch.eye(GGt.size(0), device=GGt.device, dtype=GGt.dtype).mul_(reg))

        z, _ = torch.linalg.solve_ex(GGt, torch.ones_like(GGt[0])) # pylint:disable=not-callable
        v = G.H @ z

        objective.updates = vec_to_tensors(v, params)
        return objective


    def get_H(self, objective=...):
        if "G" not in self.global_state: return linear_operator.ScaledIdentity()
        G = self.global_state['G']
        return linear_operator.AtA(G)

OrthoGrad

Bases: torchzero.core.transform.TensorTransform

Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

Parameters:

  • eps (float, default: 1e-08 ) –

    epsilon added to the denominator for numerical stability (default: 1e-30)

  • renormalize (bool, default: True ) –

    whether to graft projected gradient to original gradient norm. Defaults to True.

Source code in torchzero/modules/adaptive/orthograd.py
class OrthoGrad(TensorTransform):
    """Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

    Args:
        eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
        renormalize (bool, optional): whether to graft projected gradient to original gradient norm. Defaults to True.
    """
    def __init__(self, eps: float = 1e-8, renormalize=True):
        defaults = dict(eps=eps, renormalize=renormalize)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        eps = settings[0]['eps']
        renormalize = settings[0]['renormalize']

        params = TensorList(params)
        target = TensorList(tensors)

        scale = params.dot(target)/(params.dot(params) + eps)
        if renormalize:
            norm = target.global_vector_norm()
            target -= params * scale
            target *= (norm / target.global_vector_norm())
            return target

        target -= params * scale
        return target

Orthogonalize

Bases: torchzero.core.transform.TensorTransform

Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.

To disable orthogonalization for a parameter, put it into a parameter group with "orthogonalize" = False. The Muon page says that embeddings and classifier heads should not be orthogonalized. Usually only matrix parameters that are directly used in matmuls should be orthogonalized.

To make Muon, use Split with Adam on 1d params

Parameters:

  • adjust_lr (bool, default: False ) –

    Enables LR adjustment based on parameter size from "Muon is Scalable for LLM Training". Defaults to False.

  • dual_norm_correction (bool, default: False ) –

    enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.

  • method (str, default: 'newtonschulz' ) –

    Newton-Schulz is very fast, SVD is slow but can be more precise.

  • channel_first (bool, default: True ) –

    if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions are considered batch dimensions.

Examples:

standard Muon with Adam fallback

opt = tz.Optimizer(
    model.head.parameters(),
    tz.m.Split(
        # apply muon only to 2D+ parameters
        filter = lambda t: t.ndim >= 2,
        true = [
            tz.m.HeavyBall(),
            tz.m.Orthogonalize(),
            tz.m.LR(1e-2),
        ],
        false = tz.m.Adam()
    ),
    tz.m.LR(1e-2)
)

Reference

Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon

Source code in torchzero/modules/adaptive/muon.py
class Orthogonalize(TensorTransform):
    """Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.

    To disable orthogonalization for a parameter, put it into a parameter group with "orthogonalize" = False.
    The Muon page says that embeddings and classifier heads should not be orthogonalized.
    Usually only matrix parameters that are directly used in matmuls should be orthogonalized.

    To make Muon, use Split with Adam on 1d params

    Args:
        adjust_lr (bool, optional):
            Enables LR adjustment based on parameter size from "Muon is Scalable for LLM Training". Defaults to False.
        dual_norm_correction (bool, optional):
            enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
        method (str, optional):
            Newton-Schulz is very fast, SVD is slow but can be more precise.
        channel_first (bool, optional):
            if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions
            are considered batch dimensions.

    ## Examples:

    standard Muon with Adam fallback
    ```py
    opt = tz.Optimizer(
        model.head.parameters(),
        tz.m.Split(
            # apply muon only to 2D+ parameters
            filter = lambda t: t.ndim >= 2,
            true = [
                tz.m.HeavyBall(),
                tz.m.Orthogonalize(),
                tz.m.LR(1e-2),
            ],
            false = tz.m.Adam()
        ),
        tz.m.LR(1e-2)
    )
    ```

    Reference:
        Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
    """
    def __init__(self, adjust_lr=False, dual_norm_correction=False,
                 method: OrthogonalizeMethod = 'newtonschulz', channel_first:bool=True):
        defaults = dict(orthogonalize=True, dual_norm_correction=dual_norm_correction, adjust_lr=adjust_lr, method=method.lower(), channel_first=channel_first)
        super().__init__(defaults=defaults)

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        orthogonalize, dual_norm_correction, adjust_lr, method, channel_first = itemgetter(
            'orthogonalize', 'dual_norm_correction', 'adjust_lr', 'method', 'channel_first')(setting)

        if not orthogonalize: return tensor

        if _is_at_least_2d(tensor, channel_first=channel_first):

            X = _orthogonalize_format(tensor, method, channel_first=channel_first)

            if dual_norm_correction:
                X = _dual_norm_correction(X, tensor, channel_first=channel_first)

            if adjust_lr:
                X.mul_(adjust_lr_for_muon(1, param.shape, channel_first=channel_first))

            return X.view_as(param)

        return tensor

PSGDDenseNewton

Bases: torchzero.core.transform.Transform

Dense hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

Parameters:

  • init_scale (float | None, default: None ) –

    initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.

  • lr_preconditioner (float, default: 0.1 ) –

    learning rate of the preconditioner. Defaults to 0.1.

  • betaL (float, default: 0.9 ) –

    EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.

  • damping (float, default: 1e-09 ) –

    adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.

  • grad_clip_max_norm (float, default: inf ) –

    clips norm of the update. Defaults to float("inf").

  • update_probability (float, default: 1.0 ) –

    probability of updating preconditioner on each step. Defaults to 1.0.

  • dQ (str, default: 'Q0.5EQ1.5' ) –

    geometry for preconditioner update. Defaults to "Q0.5EQ1.5".

  • hvp_method (Literal, default: 'autograd' ) –

    how to compute hessian-vector products. Defaults to 'autograd'.

  • h (float, default: 0.001 ) –

    if hvp_method is "fd_central" or "fd_forward", controls finite difference step size. Defaults to 1e-3.

  • distribution (Literal, default: 'normal' ) –

    distribution for random vectors for hessian-vector products. Defaults to 'normal'.

  • inner (Chainable | None, default: None ) –

    preconditioning will be applied to output of this module. Defaults to None.

Examples:

Pure Dense Newton PSGD:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.DenseNewton(),
    tz.m.LR(1e-3),
)

Applying preconditioner to momentum:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.DenseNewton(inner=tz.m.EMA(0.9)),
    tz.m.LR(1e-3),
)

Source code in torchzero/modules/adaptive/psgd/psgd_dense_newton.py
class PSGDDenseNewton(Transform):
    """Dense hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

    Args:
        init_scale (float | None, optional):
            initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
        lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
        betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
        damping (float, optional):
            adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.
        grad_clip_max_norm (float, optional): clips norm of the update. Defaults to float("inf").
        update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
        dQ (str, optional): geometry for preconditioner update. Defaults to "Q0.5EQ1.5".
        hvp_method (HVPMethod, optional): how to compute hessian-vector products. Defaults to 'autograd'.
        h (float, optional):
            if ``hvp_method`` is ``"fd_central"`` or ``"fd_forward"``, controls finite difference step size.
            Defaults to 1e-3.
        distribution (Distributions, optional):
            distribution for random vectors for hessian-vector products. Defaults to 'normal'.

        inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.

    ###Examples:

    Pure Dense Newton PSGD:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.DenseNewton(),
        tz.m.LR(1e-3),
    )
    ```

    Applying preconditioner to momentum:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.DenseNewton(inner=tz.m.EMA(0.9)),
        tz.m.LR(1e-3),
    )
    ```
    """
    def __init__(
        self,
        init_scale: float | None = None,
        lr_preconditioner=0.1,
        betaL=0.9,
        damping=1e-9,
        grad_clip_max_norm=float("inf"),
        update_probability=1.0,
        dQ: Literal["QUAD4P", "QUAD", "QEP", "EQ", "QEQ", "Q0p5EQ1p5", "Q0.5EQ1.5"] = "Q0.5EQ1.5",

        hvp_method: HVPMethod = 'autograd',
        h: float = 1e-3,
        distribution: Distributions = 'normal',

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults["inner"], defaults["self"]
        super().__init__(defaults, inner=inner)


    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]

        # -------------------------------- initialize -------------------------------- #
        if "Q" not in self.global_state:

            p = objective.params[0]
            dQ = fs["dQ"]
            init_scale = fs["init_scale"]

            if init_scale is None:
                self.global_state["Q"] = None

            else:
                n = sum(p.numel() for p in objective.params)
                if dQ == "QUAD4P":
                    init_scale *= init_scale
                self.global_state["Q"] = torch.eye(n, dtype=p.dtype, device=p.device) * init_scale

            self.global_state["L"] = lift2single(torch.zeros([], dtype=p.dtype, device=p.device)) # Lipschitz smoothness constant estimation for the psgd criterion

            if dQ == "QUAD4P":
                self.global_state["update_precond"] = update_precond_dense_quad4p
                self.global_state["precond_grad"] = lambda Q, g: Q @ g
                assert torch.finfo(p.dtype).eps < 1e-6, "Directly fitting P needs at least single precision"

            elif dQ == "QUAD":
                self.global_state["update_precond"] = update_precond_dense_quad
                self.global_state["precond_grad"] = lambda Q, g: Q @ (Q @ g) # Q is symmetric; just save one transpose

            else:
                self.global_state["precond_grad"] = lambda Q, g: Q.T @ (Q @ g)
                if dQ == "QEP":
                    self.global_state["update_precond"] = update_precond_dense_qep
                elif dQ == "EQ":
                    self.global_state["update_precond"] = update_precond_dense_eq
                elif dQ == "QEQ":
                    self.global_state["update_precond"] = update_precond_dense_qeq
                else:
                    assert (dQ == "Q0p5EQ1p5") or (dQ == "Q0.5EQ1.5"), f"Invalid choice for dQ: '{dQ}'"
                    self.global_state["update_precond"] = update_precond_dense_q0p5eq1p5

        # ---------------------------------- update ---------------------------------- #
        Q = self.global_state["Q"]
        if (torch.rand([]) < fs["update_probability"]) or Q is None:

            # hessian-vector product
            vs = TensorList(objective.params).sample_like(distribution=fs["distribution"])
            Hvs, _ = objective.hessian_vector_product(z=vs, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])

            v = torch.cat([t.ravel() for t in vs]).unsqueeze(1)
            h = torch.cat([t.ravel() for t in Hvs]).unsqueeze(1)

            # initialize on the fly
            if Q is None:
                scale = (torch.mean(v*v))**(1/4) * (torch.mean(h**4) + fs["damping"]**4)**(-1/8)
                if fs["dQ"] == "QUAD4P": # Q actually is P in this case
                    scale *= scale
                Q = self.global_state["Q"] = torch.eye(len(v), dtype=v.dtype, device=v.device) * scale

            # update preconditioner
            self.global_state["update_precond"](
                Q=Q,
                L=self.global_state["L"],
                v=v,
                h=h,
                lr=fs["lr_preconditioner"],
                betaL=fs["betaL"],
                damping=fs["damping"],
            )

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        updates = objective.get_updates()

        # cat grads
        g = torch.cat([t.ravel() for t in updates]).unsqueeze(1) # column vec
        pre_grad = self.global_state["precond_grad"](self.global_state["Q"], g)

        # norm clipping
        grad_clip_max_norm = settings[0]["grad_clip_max_norm"]
        if grad_clip_max_norm < float("inf"): # clip preconditioned gradient
            grad_norm = torch.linalg.vector_norm(pre_grad)
            if grad_norm > grad_clip_max_norm:
                pre_grad *= grad_clip_max_norm / grad_norm

        vec_to_tensors_(pre_grad, updates)
        return objective

PSGDKronNewton

Bases: torchzero.core.transform.Transform

Kron hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

Parameters:

  • max_dim (int, default: 10000 ) –

    dimensions with size larger than this use diagonal preconditioner. Defaults to 10_000.

  • max_skew (float, default: 1.0 ) –

    if memory used by full preconditioner (dim^2) is larger than total number of elements in a parameter times max_skew, it uses a diagonal preconditioner. Defaults to 1.0.

  • init_scale (float | None, default: None ) –

    initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.

  • lr_preconditioner (float, default: 0.1 ) –

    learning rate of the preconditioner. Defaults to 0.1.

  • betaL (float, default: 0.9 ) –

    EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.

  • damping (float, default: 1e-09 ) –

    adds small noise to gradient when updating the preconditioner. Defaults to 1e-9.

  • grad_clip_max_amp (float, default: inf ) –

    clips amplitude of the update. Defaults to float("inf").

  • update_probability (float, default: 1.0 ) –

    probability of updating preconditioner on each step. Defaults to 1.0.

  • dQ (str, default: 'Q0.5EQ1.5' ) –

    geometry for preconditioner update. Defaults to "Q0.5EQ1.5".

  • balance_probability (float, default: 0.01 ) –

    probablility of balancing the dynamic ranges of the factors of Q to avoid over/under-flow on each step. Defaults to 0.01.

  • hvp_method (Literal, default: 'autograd' ) –

    how to compute hessian-vector products. Defaults to 'autograd'.

  • h (float, default: 0.001 ) –

    if hvp_method is "fd_central" or "fd_forward", controls finite difference step size. Defaults to 1e-3.

  • distribution (Literal, default: 'normal' ) –

    distribution for random vectors for hessian-vector products. Defaults to 'normal'.

  • inner (Chainable | None, default: None ) –

    preconditioning will be applied to output of this module. Defaults to None.

Examples:

Pure PSGD Kron Newton:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.KronNewton(),
    tz.m.LR(1e-3),
)

Applying preconditioner to momentum:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.KronNewton(inner=tz.m.EMA(0.9)),
    tz.m.LR(1e-3),
)

Source code in torchzero/modules/adaptive/psgd/psgd_kron_newton.py
class PSGDKronNewton(Transform):
    """Kron hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

    Args:
        max_dim (int, optional): dimensions with size larger than this use diagonal preconditioner. Defaults to 10_000.
        max_skew (float, optional):
            if memory used by full preconditioner (dim^2) is larger than total number of elements in a parameter times ``max_skew``, it uses a diagonal preconditioner. Defaults to 1.0.
        init_scale (float | None, optional):
            initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
        lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
        betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
        damping (float, optional): adds small noise to gradient when updating the preconditioner. Defaults to 1e-9.
        grad_clip_max_amp (float, optional): clips amplitude of the update. Defaults to float("inf").
        update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
        dQ (str, optional): geometry for preconditioner update. Defaults to "Q0.5EQ1.5".
        balance_probability (float, optional):
            probablility of balancing the dynamic ranges of the factors of Q to avoid over/under-flow on each step. Defaults to 0.01.
        hvp_method (HVPMethod, optional): how to compute hessian-vector products. Defaults to 'autograd'.
        h (float, optional):
            if ``hvp_method`` is ``"fd_central"`` or ``"fd_forward"``, controls finite difference step size.
            Defaults to 1e-3.
        distribution (Distributions, optional):
            distribution for random vectors for hessian-vector products. Defaults to 'normal'.
        inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.


    ###Examples:

    Pure PSGD Kron Newton:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.KronNewton(),
        tz.m.LR(1e-3),
    )
    ```

    Applying preconditioner to momentum:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.KronNewton(inner=tz.m.EMA(0.9)),
        tz.m.LR(1e-3),
    )
    ```
    """
    def __init__(
        self,
        max_dim: int = 10_000,
        max_skew: float = 1.0,
        init_scale: float | None = None,
        lr_preconditioner: float = 0.1,
        betaL: float = 0.9,
        damping: float = 1e-9,
        grad_clip_max_amp: float = float("inf"),
        update_probability: float= 1.0,
        dQ: Literal["QEP", "EQ", "QEQ", "QUAD",  "Q0.5EQ1.5", "Q0p5EQ1p5", "QUAD4P"] = "Q0.5EQ1.5",
        balance_probability: float = 0.01,

        hvp_method: HVPMethod = 'autograd',
        h: float = 1e-3,
        distribution: Distributions = 'normal',

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults["inner"], defaults["self"]
        super().__init__(defaults, inner=inner)


    def _initialize_state(self, param, state, setting):
        assert "initialized" not in state
        state["initialized"] = True

        # initialize preconditioners
        if setting["init_scale"] is None:
            warnings.warn("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
            state["QLs_exprs"] = None
        else:
            state["QLs_exprs"] = init_kron(
                param.squeeze(),
                Scale=setting["init_scale"],
                max_size=setting["max_dim"],
                max_skew=setting["max_skew"],
                dQ=setting["dQ"],
            )

        dQ = setting["dQ"]
        if dQ == "QUAD4P":
            assert torch.finfo(param.dtype).eps < 1e-6, "Directly fitting P needs at least single precision"
            state["update_precond"] = update_precond_kron_newton_quad4p
            state["precond_grad"] = lambda QL, exprs, G: exprs[0](*QL[0], G) # it's exprA(*Q, G)

        else:
            state["precond_grad"] = precond_grad_kron
            if dQ == "QEP":
                state["update_precond"] = update_precond_kron_newton_quad
            elif dQ == "EQ":
                state["update_precond"] = update_precond_kron_newton_qep
            elif dQ == "QEQ":
                state["update_precond"] = update_precond_kron_newton_eq
            elif dQ == "QUAD":
                state["update_precond"] = update_precond_kron_newton_qeq
            else:
                assert (dQ == "Q0.5EQ1.5") or (dQ == "Q0p5EQ1p5"), f"Invalid choice for dQ: '{dQ}'"
                state["update_precond"] = update_precond_kron_newton_q0p5eq1p5

    @torch.no_grad
    def update_states(self, objective, states, settings):

        # initialize states
        for param, state, setting in zip(objective.params, states, settings):
            if "initialized" not in state:
                self._initialize_state(param, state, setting)

        fs = settings[0]

        uninitialized = any(state["QLs_exprs"] is None for state in states)
        if (torch.rand([]) < fs["update_probability"]) or uninitialized:

            # hessian-vector product
            vs = TensorList(objective.params).sample_like(distribution=fs["distribution"])
            Hvs, _ = objective.hessian_vector_product(z=vs, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])

            # initialize on the fly (why does it use vs?)
            if uninitialized:

                scale = (sum([torch.sum(torch.abs(v)**2) for v in vs])/sum([v.numel() for v in vs])) ** (1/4) # (mean(|v|^2))^(1/4)

                scale = scale * (max([torch.mean((torch.abs(h))**4) for h in Hvs]) + fs["damping"]**4) ** (-1/8) # (mean(|v|^2))^(1/4) * (mean(|h|^4))^(-1/8)

                for h, state, setting in zip(Hvs, states, settings):
                    if state["QLs_exprs"] is None:
                        state["QLs_exprs"] = init_kron(
                            h.squeeze(),
                            Scale=scale,
                            max_size=setting["max_dim"],
                            max_skew=setting["max_skew"],
                            dQ=setting["dQ"],
                        )

            # update preconditioner
            for v, h, state, setting in zip(vs, Hvs, states, settings):
                state["update_precond"](
                    *state["QLs_exprs"],
                    v.squeeze(),
                    h.squeeze(),
                    lr=setting["lr_preconditioner"],
                    betaL=setting["betaL"],
                    damping=setting["damping"],
                    balance_prob=setting["balance_probability"]
                )

    @torch.no_grad
    def apply_states(self, objective, states, settings):

        params = objective.params
        tensors = objective.get_updates()
        pre_tensors = []

        # precondition
        for param, tensor, state in zip(params, tensors, states):
            t = state["precond_grad"](
                *state["QLs_exprs"],
                tensor.squeeze(),
            )
            pre_tensors.append(t.view_as(param))

        # norm clipping
        grad_clip_max_amp = settings[0]["grad_clip_max_amp"]
        if grad_clip_max_amp < math.inf:
            pre_tensors = TensorList(pre_tensors)
            num_params = sum(t.numel() for t in pre_tensors)

            avg_amp = pre_tensors.dot(pre_tensors.conj()).div(num_params).sqrt()

            if avg_amp > grad_clip_max_amp:
                torch._foreach_mul_(pre_tensors, grad_clip_max_amp / avg_amp)

        objective.updates = pre_tensors
        return objective

PSGDKronWhiten

Bases: torchzero.core.transform.TensorTransform

Kron whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

Parameters:

  • max_dim (int, default: 10000 ) –

    dimensions with size larger than this use diagonal preconditioner. Defaults to 10_000.

  • max_skew (float, default: 1.0 ) –

    if memory used by full preconditioner (dim^2) is larger than total number of elements in a parameter times max_skew, it uses a diagonal preconditioner. Defaults to 1.0.

  • init_scale (float | None, default: None ) –

    initial scale of the preconditioner. If None, determined from magnitude of the first gradient. Defaults to None.

  • lr_preconditioner (float, default: 0.1 ) –

    learning rate of the preconditioner. Defaults to 0.1.

  • betaL (float, default: 0.9 ) –

    EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.

  • damping (float, default: 1e-09 ) –

    adds small noise to gradient when updating the preconditioner. Defaults to 1e-9.

  • grad_clip_max_amp (float, default: inf ) –

    clips amplitude of the update. Defaults to float("inf").

  • update_probability (float, default: 1.0 ) –

    probability of updating preconditioner on each step. Defaults to 1.0.

  • dQ (str, default: 'Q0.5EQ1.5' ) –

    geometry for preconditioner update. Defaults to "Q0.5EQ1.5".

  • balance_probability (float, default: 0.01 ) –

    probablility of balancing the dynamic ranges of the factors of Q to avoid over/under-flow on each step. Defaults to 0.01.

  • inner (Chainable | None, default: None ) –

    preconditioning will be applied to output of this module. Defaults to None.

Examples:

Pure PSGD Kron:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.KronWhiten(),
    tz.m.LR(1e-3),
)

Momentum into preconditioner (whitens momentum):

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.EMA(0.9),
    tz.m.KronWhiten(),
    tz.m.LR(1e-3),
)

Updating the preconditioner from gradients and applying it to momentum:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.KronWhiten(inner=tz.m.EMA(0.9)),
    tz.m.LR(1e-3),
)

Source code in torchzero/modules/adaptive/psgd/psgd_kron_whiten.py
class PSGDKronWhiten(TensorTransform):
    """Kron whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

    Args:
        max_dim (int, optional): dimensions with size larger than this use diagonal preconditioner. Defaults to 10_000.
        max_skew (float, optional):
            if memory used by full preconditioner (dim^2) is larger than total number of elements in a parameter times ``max_skew``, it uses a diagonal preconditioner. Defaults to 1.0.
        init_scale (float | None, optional):
            initial scale of the preconditioner. If None, determined from magnitude of the first gradient. Defaults to None.
        lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
        betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
        damping (float, optional): adds small noise to gradient when updating the preconditioner. Defaults to 1e-9.
        grad_clip_max_amp (float, optional): clips amplitude of the update. Defaults to float("inf").
        update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
        dQ (str, optional): geometry for preconditioner update. Defaults to "Q0.5EQ1.5".
        balance_probability (float, optional):
            probablility of balancing the dynamic ranges of the factors of Q to avoid over/under-flow on each step. Defaults to 0.01.

        inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.

    ###Examples:

    Pure PSGD Kron:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.KronWhiten(),
        tz.m.LR(1e-3),
    )
    ```

    Momentum into preconditioner (whitens momentum):
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.EMA(0.9),
        tz.m.KronWhiten(),
        tz.m.LR(1e-3),
    )
    ```

    Updating the preconditioner from gradients and applying it to momentum:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.KronWhiten(inner=tz.m.EMA(0.9)),
        tz.m.LR(1e-3),
    )
    ```

    """
    def __init__(
        self,
        max_dim: int = 10_000,
        max_skew: float = 1.0,
        init_scale: float | None = None,
        lr_preconditioner: float = 0.1,
        betaL: float = 0.9,
        damping: float = 1e-9,
        grad_clip_max_amp: float = float("inf"),
        update_probability: float= 1.0,
        dQ: Literal["QEP", "EQ", "QEQ", "QUAD",  "Q0.5EQ1.5", "Q0p5EQ1p5", "QUAD4P"] = "Q0.5EQ1.5",
        balance_probability: float = 0.01,

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults["inner"], defaults["self"]
        super().__init__(defaults, inner=inner)

    @torch.no_grad
    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        # initialize preconditioners
        if setting["init_scale"] is None:
            # warnings.warn("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
            state["QLs_exprs"] = None
        else:
            state["QLs_exprs"] = init_kron(
                param.squeeze(),
                Scale=setting["init_scale"],
                max_size=setting["max_dim"],
                max_skew=setting["max_skew"],
                dQ=setting["dQ"],
            )

        dQ = setting["dQ"]
        if dQ == "QUAD4P":
            assert torch.finfo(param.dtype).eps < 1e-6, "Directly fitting P needs at least single precision"
            state["update_precond"] = update_precond_kron_whiten_quad4p
            state["precond_grad"] = lambda QL, exprs, G: exprs[0](*QL[0], G) # it's exprA(*Q, G)

        else:
            state["precond_grad"] = precond_grad_kron
            if dQ == "QEP":
                state["update_precond"] = update_precond_kron_whiten_qep
            elif dQ == "EQ":
                state["update_precond"] = update_precond_kron_whiten_eq
            elif dQ == "QEQ":
                state["update_precond"] = update_precond_kron_whiten_qeq
            elif dQ == "QUAD":
                state["update_precond"] = update_precond_kron_whiten_quad
            else:
                assert (dQ == "Q0.5EQ1.5") or (dQ == "Q0p5EQ1p5"), f"Invalid choice for dQ: '{dQ}'"
                state["update_precond"] = update_precond_kron_whiten_q0p5eq1p5

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):

        # initialize on the fly if not initialized
        if any(state["QLs_exprs"] is None for state in states):

            scale = max([torch.mean((torch.abs(g))**4) for g in tensors])
            scale = (scale + settings[0]["damping"]**4)**(-1/8)

            for param, state, setting in zip(params, states, settings):
                if state["QLs_exprs"] is None:
                    state["QLs_exprs"] = init_kron(
                        param.squeeze(),
                        Scale=scale,
                        max_size=setting["max_dim"],
                        max_skew=setting["max_skew"],
                        dQ=setting["dQ"],
                    )


        # update preconditioners
        # (could also try per-parameter probability)
        if torch.rand([]) < settings[0]["update_probability"]: # update Q
            for tensor, state, setting in zip(tensors, states, settings):
                state["update_precond"](
                    *state["QLs_exprs"],
                    tensor.squeeze(),
                    lr=setting["lr_preconditioner"],
                    betaL=setting["betaL"],
                    damping=setting["damping"],
                    balance_prob=setting["balance_probability"]
                )

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):

        pre_tensors = []

        # precondition
        for param, tensor, state in zip(params, tensors, states):
            t = state["precond_grad"](
                *state["QLs_exprs"],
                tensor.squeeze(),
            )
            pre_tensors.append(t.view_as(param))

        # norm clipping
        grad_clip_max_amp = settings[0]["grad_clip_max_amp"]
        if grad_clip_max_amp < math.inf:
            pre_tensors = TensorList(pre_tensors)
            num_params = sum(t.numel() for t in pre_tensors)

            avg_amp = pre_tensors.dot(pre_tensors.conj()).div(num_params).sqrt()

            if avg_amp > grad_clip_max_amp:
                torch._foreach_mul_(pre_tensors, grad_clip_max_amp / avg_amp)

        return pre_tensors

PSGDLRANewton

Bases: torchzero.core.transform.Transform

Low rank hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

Parameters:

  • rank (int, default: 10 ) –

    Preconditioner has a diagonal part and a low rank part, whose rank is decided by this setting. Defaults to 10.

  • init_scale (float | None, default: None ) –

    initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.

  • lr_preconditioner (float, default: 0.1 ) –

    learning rate of the preconditioner. Defaults to 0.1.

  • betaL (float, default: 0.9 ) –

    EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.

  • damping (float, default: 1e-09 ) –

    adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.

  • grad_clip_max_norm (float, default: inf ) –

    clips norm of the update. Defaults to float("inf").

  • update_probability (float, default: 1.0 ) –

    probability of updating preconditioner on each step. Defaults to 1.0.

  • hvp_method (Literal, default: 'autograd' ) –

    how to compute hessian-vector products. Defaults to 'autograd'.

  • h (float, default: 0.001 ) –

    if hvp_method is "fd_central" or "fd_forward", controls finite difference step size. Defaults to 1e-3.

  • distribution (Literal, default: 'normal' ) –

    distribution for random vectors for hessian-vector products. Defaults to 'normal'.

  • inner (Chainable | None, default: None ) –

    preconditioning will be applied to output of this module. Defaults to None.

Examples:

Pure LRA Newton PSGD:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.LRANewton(),
    tz.m.LR(1e-3),
)

Applying preconditioner to momentum:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.LRANewton(inner=tz.m.EMA(0.9)),
    tz.m.LR(1e-3),
)

Source code in torchzero/modules/adaptive/psgd/psgd_lra_newton.py
class PSGDLRANewton(Transform):
    """Low rank hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

    Args:
        rank (int, optional):
            Preconditioner has a diagonal part and a low rank part, whose rank is decided by this setting. Defaults to 10.
        init_scale (float | None, optional):
            initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
        lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
        betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
        damping (float, optional):
            adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.
        grad_clip_max_norm (float, optional): clips norm of the update. Defaults to float("inf").
        update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
        hvp_method (HVPMethod, optional): how to compute hessian-vector products. Defaults to 'autograd'.
        h (float, optional):
            if ``hvp_method`` is ``"fd_central"`` or ``"fd_forward"``, controls finite difference step size.
            Defaults to 1e-3.
        distribution (Distributions, optional):
            distribution for random vectors for hessian-vector products. Defaults to 'normal'.

        inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.

    ###Examples:

    Pure LRA Newton PSGD:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.LRANewton(),
        tz.m.LR(1e-3),
    )
    ```

    Applying preconditioner to momentum:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.LRANewton(inner=tz.m.EMA(0.9)),
        tz.m.LR(1e-3),
    )
    ```
    """
    def __init__(
        self,
        rank: int = 10,
        init_scale: float | None = None,
        lr_preconditioner=0.1,
        betaL=0.9,
        damping=1e-9,
        grad_clip_max_norm=float("inf"),
        update_probability=1.0,

        hvp_method: HVPMethod = 'autograd',
        h: float = 1e-3,
        distribution: Distributions = 'normal',

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults["inner"], defaults["self"]
        super().__init__(defaults, inner=inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]

        # initialize
        if "UVd" not in self.global_state:
            p = torch.cat([t.ravel() for t in objective.params])
            _initialize_lra_state_(p, self.global_state, fs)

        UVd = self.global_state["UVd"]
        if (torch.rand([]) < fs["update_probability"]) or (UVd[2] is None):

            # hessian-vector product
            vs = TensorList(objective.params).sample_like(distribution=fs["distribution"])
            Hvs, _ = objective.hessian_vector_product(z=vs, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])

            v = torch.cat([t.ravel() for t in vs]).unsqueeze(1)
            h = torch.cat([t.ravel() for t in Hvs]).unsqueeze(1)

            if UVd[2] is None:
                UVd[2] = (torch.mean(v*v))**(1/4) * (torch.mean(h**4) + fs["damping"]**4)**(-1/8) * torch.ones_like(v)

            # update preconditioner
            update_precond_lra_newton(UVd=UVd, Luvd=self.global_state["Luvd"], v=v, h=h, lr=fs["lr_preconditioner"], betaL=fs["betaL"], damping=fs["damping"])


    @torch.no_grad
    def apply_states(self, objective, states, settings):
        updates = objective.get_updates()

        g = torch.cat([t.ravel() for t in updates]).unsqueeze(1) # column vec
        pre_grad = precond_grad_lra(UVd=self.global_state["UVd"], g=g)

        # norm clipping
        grad_clip_max_norm = settings[0]["grad_clip_max_norm"]
        if grad_clip_max_norm < float("inf"): # clip preconditioned gradient
            grad_norm = torch.linalg.vector_norm(pre_grad)
            if grad_norm > grad_clip_max_norm:
                pre_grad *= grad_clip_max_norm / grad_norm

        vec_to_tensors_(pre_grad, updates)
        return objective

PSGDLRAWhiten

Bases: torchzero.core.transform.TensorTransform

Low rank whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

Parameters:

  • rank (int, default: 10 ) –

    Preconditioner has a diagonal part and a low rank part, whose rank is decided by this setting. Defaults to 10.

  • init_scale (float | None, default: None ) –

    initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.

  • lr_preconditioner (float, default: 0.1 ) –

    learning rate of the preconditioner. Defaults to 0.1.

  • betaL (float, default: 0.9 ) –

    EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.

  • damping (float, default: 1e-09 ) –

    adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.

  • grad_clip_max_norm (float) –

    clips norm of the update. Defaults to float("inf").

  • update_probability (float, default: 1.0 ) –

    probability of updating preconditioner on each step. Defaults to 1.0.

  • concat_params (bool, default: True ) –

    if True, treats all parameters as concatenated to a single vector. If False, each parameter is preconditioned separately. Defaults to True.

  • inner (Chainable | None, default: None ) –

    preconditioning will be applied to output of this module. Defaults to None.

Examples:

Pure PSGD LRA:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.LRAWhiten(),
    tz.m.LR(1e-3),
)

Momentum into preconditioner (whitens momentum):

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.EMA(0.9),
    tz.m.LRAWhiten(),
    tz.m.LR(1e-3),
)

Updating the preconditioner from gradients and applying it to momentum:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.LRAWhiten(inner=tz.m.EMA(0.9)),
    tz.m.LR(1e-3),
)

Source code in torchzero/modules/adaptive/psgd/psgd_lra_whiten.py
class PSGDLRAWhiten(TensorTransform):
    """Low rank whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

    Args:
        rank (int, optional):
            Preconditioner has a diagonal part and a low rank part, whose rank is decided by this setting. Defaults to 10.
        init_scale (float | None, optional):
            initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
        lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
        betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
        damping (float, optional):
            adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.
        grad_clip_max_norm (float, optional): clips norm of the update. Defaults to float("inf").
        update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
        concat_params (bool, optional):
            if True, treats all parameters as concatenated to a single vector.
            If False, each parameter is preconditioned separately. Defaults to True.
        inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.

    ###Examples:

    Pure PSGD LRA:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.LRAWhiten(),
        tz.m.LR(1e-3),
    )
    ```

    Momentum into preconditioner (whitens momentum):
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.EMA(0.9),
        tz.m.LRAWhiten(),
        tz.m.LR(1e-3),
    )
    ```

    Updating the preconditioner from gradients and applying it to momentum:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.LRAWhiten(inner=tz.m.EMA(0.9)),
        tz.m.LR(1e-3),
    )
    ```

    """
    def __init__(
        self,
        rank: int = 10,
        init_scale: float | None = None,
        lr_preconditioner=0.1,
        betaL=0.9,
        damping=1e-9,
        grad_clip_max_amp=float("inf"),
        update_probability=1.0,

        concat_params: bool = True,
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults["inner"], defaults["self"]
        super().__init__(defaults, concat_params=concat_params, inner=inner)

    @torch.no_grad
    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        _initialize_lra_state_(tensor, state, setting)

    @torch.no_grad
    def single_tensor_update(self, tensor, param, grad, loss, state, setting):

        g = tensor.ravel().unsqueeze(1) # column vector

        UVd = state["UVd"]
        if UVd[2] is None: # initialize d on the fly
            UVd[2] = (torch.mean(g**4) + setting["damping"]**4)**(-1/8) * torch.ones_like(g)

        if torch.rand([]) < setting["update_probability"]:  # update preconditioner
            update_precond_lra_whiten(
                UVd=UVd,
                Luvd=state["Luvd"],
                g=g,
                lr=setting["lr_preconditioner"],
                betaL=setting["betaL"],
                damping=setting["damping"],
            )

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):

        g = tensor.ravel().unsqueeze(1)
        pre_grad = precond_grad_lra(UVd=state["UVd"], g=g)

        # norm clipping
        grad_clip_max_amp = setting["grad_clip_max_amp"]
        if grad_clip_max_amp < float("inf"): # clip preconditioned gradient
            amp = torch.sqrt(torch.mean(pre_grad * pre_grad))
            if amp > grad_clip_max_amp:
                pre_grad *= grad_clip_max_amp/amp

        return pre_grad.view_as(tensor)

RMSprop

Bases: torchzero.core.transform.TensorTransform

Divides graient by EMA of gradient squares.

This implementation is identical to :code:torch.optim.RMSprop.

Parameters:

  • smoothing (float, default: 0.99 ) –

    beta for exponential moving average of gradient squares. Defaults to 0.99.

  • eps (float, default: 1e-08 ) –

    epsilon for division. Defaults to 1e-8.

  • centered (bool, default: False ) –

    whether to center EMA of gradient squares using an additional EMA. Defaults to False.

  • debias (bool, default: False ) –

    applies Adam debiasing. Defaults to False.

  • amsgrad (bool, default: False ) –

    Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.

  • pow (float) –

    power used in second momentum power and root. Defaults to 2.

  • init (str, default: 'zeros' ) –

    how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".

  • inner (Chainable | None, default: None ) –

    Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.

Source code in torchzero/modules/adaptive/rmsprop.py
class RMSprop(TensorTransform):
    """Divides graient by EMA of gradient squares.

    This implementation is identical to :code:`torch.optim.RMSprop`.

    Args:
        smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
        eps (float, optional): epsilon for division. Defaults to 1e-8.
        centered (bool, optional): whether to center EMA of gradient squares using an additional EMA. Defaults to False.
        debias (bool, optional): applies Adam debiasing. Defaults to False.
        amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
        pow (float, optional): power used in second momentum power and root. Defaults to 2.
        init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
        inner (Chainable | None, optional):
            Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.
    """
    def __init__(
        self,
        smoothing: float = 0.99,
        eps: float = 1e-8,
        centered: bool = False,
        debias: bool = False,
        amsgrad: bool = False,
        init: Literal["zeros", "update"] = "zeros",

        inner: Chainable | None = None,
        exp_avg_sq_tfm: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults["inner"], defaults["exp_avg_sq_tfm"]
        super().__init__(defaults, inner=inner)

        self.set_child('exp_avg_sq', exp_avg_sq_tfm)

    @torch.no_grad
    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        if setting["init"] == "zeros":
            state["exp_avg_sq"] = torch.zeros_like(tensor)
            if setting["centered"]: state["exp_avg"] = torch.zeros_like(tensor)
            if setting["amsgrad"]: state["amsgrad"] = torch.zeros_like(tensor)

        else:
            state["exp_avg_sq"] = tensor ** 2
            if setting["centered"]: state["exp_avg"] = tensor.clone()
            if setting["amsgrad"]: state["amsgrad"] = tensor ** 2

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        self.increment_counter("step", start = 0)
        fs = settings[0]

        exp_avg_sq = unpack_states(states, tensors, "exp_avg_sq", cls=TensorList)

        # update exponential average
        smoothing = NumberList(s["smoothing"] for s in settings)
        exp_avg_sq.mul_(smoothing).addcmul_(tensors, tensors, value=1-smoothing)

        # update mean estimate if centered
        if fs["centered"]:
            exp_avg = unpack_states(states, tensors, "exp_avg", cls=TensorList)
            exp_avg.lerp_(tensors, 1-smoothing)

        # amsgrad
        if fs["amsgrad"]:
            exp_avg_sq_max = unpack_states(states, tensors, "exp_avg_sq_max", cls=TensorList)
            exp_avg_sq_max.maximum_(exp_avg_sq)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        step = self.global_state["step"] # 0 on 1st step
        eps = NumberList(s["eps"] for s in settings)
        fs = settings[0]

        if fs["amsgrad"]: key = "max_exp_avg_sq"
        else: key = "exp_avg_sq"
        exp_avg_sq = TensorList(s[key] for s in states)

        # load mean estimate if centered
        exp_avg = None
        if fs['centered']:
            exp_avg = TensorList(s["exp_avg"] for s in states)

        # debias exp_avg_sq and exp_avg
        if fs["debias"]:
            smoothing = NumberList(s["smoothing"] for s in settings)
            bias_correction = 1 - (smoothing ** (step + 1))
            exp_avg_sq = exp_avg_sq / bias_correction

            if fs['centered']:
                assert exp_avg is not None
                exp_avg = exp_avg / bias_correction

        # apply transform to potentially debiased exp_avg_sq
        exp_avg_sq = TensorList(self.inner_step_tensors(
            "exp_avg_sq", exp_avg_sq, params=params, grads=grads, loss=loss, clone=True, must_exist=False
        ))

        # center
        if fs["centered"]:
            assert exp_avg is not None
            exp_avg_sq = exp_avg_sq.addcmul(exp_avg, exp_avg, value=-1)

        return tensors.div_(exp_avg_sq.sqrt().add_(eps))

Rprop

Bases: torchzero.core.transform.TensorTransform

Resilient propagation. The update magnitude gets multiplied by nplus if gradient didn't change the sign, or nminus if it did. Then the update is applied with the sign of the current gradient.

Additionally, if gradient changes sign, the update for that weight is reverted. Next step, magnitude for that weight won't change.

Compared to pytorch this also implements backtracking update when sign changes.

This implementation is identical to :code:torch.optim.Rprop if :code:backtrack is set to False.

Parameters:

  • nplus (float, default: 1.2 ) –

    multiplicative increase factor for when ascent didn't change sign (default: 1.2).

  • nminus (float, default: 0.5 ) –

    multiplicative decrease factor for when ascent changed sign (default: 0.5).

  • lb (float, default: 1e-06 ) –

    minimum step size, can be None (default: 1e-6)

  • ub (float, default: 50 ) –

    maximum step size, can be None (default: 50)

  • backtrack (float, default: True ) –

    if True, when ascent sign changes, undoes last weight update, otherwise sets update to 0. When this is False, this exactly matches pytorch Rprop. (default: True)

  • alpha (float, default: 1 ) –

    initial per-parameter learning rate (default: 1).

reference Riedmiller, M., & Braun, H. (1993, March). A direct adaptive method for faster backpropagation learning: The RPROP algorithm. In IEEE international conference on neural networks (pp. 586-591). IEEE.

Source code in torchzero/modules/adaptive/rprop.py
class Rprop(TensorTransform):
    """
    Resilient propagation. The update magnitude gets multiplied by `nplus` if gradient didn't change the sign,
    or `nminus` if it did. Then the update is applied with the sign of the current gradient.

    Additionally, if gradient changes sign, the update for that weight is reverted.
    Next step, magnitude for that weight won't change.

    Compared to pytorch this also implements backtracking update when sign changes.

    This implementation is identical to :code:`torch.optim.Rprop` if :code:`backtrack` is set to False.

    Args:
        nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
        nminus (float): multiplicative decrease factor for when ascent changed sign (default: 0.5).
        lb (float): minimum step size, can be None (default: 1e-6)
        ub (float): maximum step size, can be None (default: 50)
        backtrack (float):
            if True, when ascent sign changes, undoes last weight update, otherwise sets update to 0.
            When this is False, this exactly matches pytorch Rprop. (default: True)
        alpha (float): initial per-parameter learning rate (default: 1).

    reference
        *Riedmiller, M., & Braun, H. (1993, March). A direct adaptive method for faster backpropagation learning:
        The RPROP algorithm. In IEEE international conference on neural networks (pp. 586-591). IEEE.*
    """
    def __init__(
        self,
        nplus: float = 1.2,
        nminus: float = 0.5,
        lb: float = 1e-6,
        ub: float = 50,
        backtrack=True,
        alpha: float = 1,
    ):
        defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub, backtrack=backtrack)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        nplus, nminus, lb, ub, alpha = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', 'alpha', cls=NumberList)
        prev, allowed, magnitudes = unpack_states(
            states, tensors,
            'prev','allowed','magnitudes',
            init=[torch.zeros_like, _bool_ones_like, torch.zeros_like],
            cls = TensorList,
        )

        tensors = rprop_(
            tensors_ = TensorList(tensors),
            prev_ = prev,
            allowed_ = allowed,
            magnitudes_ = magnitudes,
            nplus = nplus,
            nminus = nminus,
            lb = lb,
            ub = ub,
            alpha = alpha,
            backtrack=settings[0]['backtrack'],
            step=step,
        )

        return tensors

SAM

Bases: torchzero.core.transform.Transform

Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412

SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value. It performs two forward and backward passes per step.

This implementation modifies the closure to return loss and calculate gradients of the SAM objective. All modules after this will use the modified objective.

.. note:: This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients at two points on each step.

Parameters:

  • rho (float, default: 0.05 ) –

    Neighborhood size. Defaults to 0.05.

  • p (float, default: 2 ) –

    norm of the SAM objective. Defaults to 2.

  • asam (bool, default: False ) –

    enables ASAM variant which makes perturbation relative to weight magnitudes. ASAM requires a much larger rho, like 0.5 or 1. The tz.m.ASAM class is idential to setting this argument to True, but it has larger rho by default.

Examples:

SAM-SGD:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SAM(),
    tz.m.LR(1e-2)
)

SAM-Adam:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SAM(),
    tz.m.Adam(),
    tz.m.LR(1e-2)
)
References

Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412.

Source code in torchzero/modules/adaptive/sam.py
class SAM(Transform):
    """Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412

    SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
    It performs two forward and backward passes per step.

    This implementation modifies the closure to return loss and calculate gradients
    of the SAM objective. All modules after this will use the modified objective.

    .. note::
        This module requires a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients at two points on each step.

    Args:
        rho (float, optional): Neighborhood size. Defaults to 0.05.
        p (float, optional): norm of the SAM objective. Defaults to 2.
        asam (bool, optional):
            enables ASAM variant which makes perturbation relative to weight magnitudes.
            ASAM requires a much larger ``rho``, like 0.5 or 1.
            The ``tz.m.ASAM`` class is idential to setting this argument to True, but
            it has larger ``rho`` by default.

    ### Examples:

    SAM-SGD:

    ```py
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SAM(),
        tz.m.LR(1e-2)
    )
    ```

    SAM-Adam:

    ```
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SAM(),
        tz.m.Adam(),
        tz.m.LR(1e-2)
    )
    ```

    References:
        [Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412.](https://arxiv.org/abs/2010.01412#page=3.16)
    """
    def __init__(self, rho: float = 0.05, p: float = 2, eps=1e-10, asam=False):
        defaults = dict(rho=rho, p=p, eps=eps, asam=asam)
        super().__init__(defaults)

    @torch.no_grad
    def update_states(self, objective, states, settings):

        params = objective.params
        closure = objective.closure
        zero_grad = objective.zero_grad
        if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
        p, rho = unpack_dicts(settings, 'p', 'rho', cls=NumberList)
        fs = settings[0]
        eps = fs['eps']
        asam = fs['asam']

        # 1/p + 1/q = 1
        # okay, authors of SAM paper, I will manually solve your equation
        # so q = -p/(1-p)
        q = -p / (1-p)
        # as a validation for 2 it is -2 / -1 = 2

        @torch.no_grad
        def sam_closure(backward=True):
            orig_grads = None
            if not backward:
                # if backward is False, make sure this doesn't modify gradients
                # to avoid issues
                orig_grads = [p.grad for p in params]

            # gradient at initial parameters
            zero_grad()
            with torch.enable_grad():
                closure()

            grad = TensorList(p.grad if p.grad is not None else torch.zeros_like(p) for p in params)
            grad_abs = grad.abs()

            # compute e
            term1 = grad.sign().mul_(rho)
            term2 = grad_abs.pow(q-1)

            if asam:
                grad_abs.mul_(torch._foreach_abs(params))

            denom = grad_abs.pow_(q).sum().pow(1/p)

            e = term1.mul_(term2).div_(denom.clip(min=eps))

            if asam:
                e.mul_(torch._foreach_pow(params, 2))

            # calculate loss and gradient approximation of inner problem
            torch._foreach_add_(params, e)
            if backward:
                zero_grad()
                with torch.enable_grad():
                    # this sets .grad attributes
                    sam_loss = closure()

            else:
                sam_loss = closure(False)

            # and restore initial parameters
            torch._foreach_sub_(params, e)

            if orig_grads is not None:
                for param,orig_grad in zip(params, orig_grads):
                    param.grad = orig_grad

            return sam_loss

        objective.closure = sam_closure

SOAP

Bases: torchzero.core.transform.TensorTransform

SOAP (ShampoO with Adam in the Preconditioner's eigenbasis from https://arxiv.org/abs/2409.11321).

Parameters:

  • beta1 (float, default: 0.95 ) –

    beta for first momentum. Defaults to 0.95.

  • beta2 (float, default: 0.95 ) –

    beta for second momentum. Defaults to 0.95.

  • shampoo_beta (float | None, default: 0.95 ) –

    beta for covariance matrices accumulators. Can be None, then it just sums them like Adagrad (which works worse). Defaults to 0.95.

  • precond_freq (int, default: 10 ) –

    How often to update the preconditioner. Defaults to 10.

  • merge_small (bool, default: True ) –

    Whether to merge small dims. Defaults to True.

  • max_dim (int, default: 4096 ) –

    Won't precondition dims larger than this. Defaults to 10_000.

  • precondition_1d (bool, default: True ) –

    Whether to precondition 1d params (SOAP paper sets this to False). Defaults to True.

  • eps (float, default: 1e-08 ) –

    epsilon for dividing first momentum by second. Defaults to 1e-8.

  • debias (bool, default: True ) –

    enables adam bias correction. Defaults to True.

  • proj_exp_avg (bool, default: True ) –

    if True, maintains exponential average of gradients (momentum) in projected space. If False - in original space Defaults to True.

  • alpha (float, default: 1 ) –

    learning rate. Defaults to 1.

  • inner (Chainable | None, default: None ) –

    output of this module is projected and Adam will run on it, but preconditioners are updated from original gradients.

Examples:

SOAP:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SOAP(),
    tz.m.LR(1e-3)
)
Stabilized SOAP:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SOAP(),
    tz.m.NormalizeByEMA(max_ema_growth=1.2),
    tz.m.LR(1e-2)
)
Source code in torchzero/modules/adaptive/soap.py
class SOAP(TensorTransform):
    """SOAP (ShampoO with Adam in the Preconditioner's eigenbasis from https://arxiv.org/abs/2409.11321).

    Args:
        beta1 (float, optional): beta for first momentum. Defaults to 0.95.
        beta2 (float, optional): beta for second momentum. Defaults to 0.95.
        shampoo_beta (float | None, optional):
            beta for covariance matrices accumulators. Can be None, then it just sums them like Adagrad (which works worse). Defaults to 0.95.
        precond_freq (int, optional): How often to update the preconditioner. Defaults to 10.
        merge_small (bool, optional): Whether to merge small dims. Defaults to True.
        max_dim (int, optional): Won't precondition dims larger than this. Defaults to 10_000.
        precondition_1d (bool, optional):
            Whether to precondition 1d params (SOAP paper sets this to False). Defaults to True.
        eps (float, optional):
            epsilon for dividing first momentum by second. Defaults to 1e-8.
        debias (bool, optional):
            enables adam bias correction. Defaults to True.
        proj_exp_avg (bool, optional):
            if True, maintains exponential average of gradients (momentum) in projected space.
            If False - in original space Defaults to True.
        alpha (float, optional):
            learning rate. Defaults to 1.
        inner (Chainable | None, optional):
            output of this module is projected and Adam will run on it, but preconditioners are updated
            from original gradients.

    ### Examples:
    SOAP:

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SOAP(),
        tz.m.LR(1e-3)
    )
    ```
    Stabilized SOAP:

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SOAP(),
        tz.m.NormalizeByEMA(max_ema_growth=1.2),
        tz.m.LR(1e-2)
    )
    ```
    """
    def __init__(
        self,
        beta1: float = 0.95,
        beta2: float = 0.95,
        shampoo_beta: float | None = 0.95,
        precond_freq: int = 10,
        merge_small: bool = True,
        max_dim: int = 4096,
        precondition_1d: bool = True,
        eps: float = 1e-8,
        debias: bool = True,
        proj_exp_avg: bool = True,
        alpha: float = 1,

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults["inner"]

        super().__init__(defaults)
        self.set_child("inner", inner)

    @torch.no_grad
    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        if setting["merge_small"]:
            tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])

        state["exp_avg_proj"] = torch.zeros_like(tensor)
        state["exp_avg_sq_proj"] = torch.zeros_like(tensor)

        if tensor.ndim <= 1 and not setting["precondition_1d"]:
            state['GG'] = []

        else:
            max_dim = setting["max_dim"]
            state['GG'] = [
                torch.zeros(s, s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
            ]

        # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
        if len([i is not None for i in state['GG']]) == 0:
            state['GG'] = None

        # first covariance accumulation
        if state['GG'] is not None:
            update_soap_covariances_(tensor, GGs_=state['GG'], beta=setting["shampoo_beta"])

            # get projection matrix with first gradients with eigh
            try: state['Q'] = get_orthogonal_matrix(state['GG'])
            except torch.linalg.LinAlgError as e:
                warnings.warn(f"torch.linalg.eigh raised an error when initializing SOAP Q matrices on 1st step, diagonal preconditioning will be used for this parameter. The error was:\n{e}")
                state["GG"] = None

        state['step'] = 0


    # no update to avoid running merge_dims twice

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        # note
        # do not modify tensors in-place
        # because they are used to update preconditioner at the end

        steps = [s["step"] for s in states]
        if any(s == 0 for s in steps):
            # skip 1st update so to avoid using current gradient in the projection
            # I scale it instead to avoid issues with further modules
            for s in states: s["step"] += 1
            return TensorList(tensors).clamp(-0.1, 0.1)
            # return TensorList(tensors).zero_()


        fs = settings[0]
        merged = []
        projected = []
        # ---------------------------------- project --------------------------------- #

        for tensor, state, setting in zip(tensors, states, settings):
            if setting["merge_small"]:
                tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])

            merged.append(tensor)

            if state['GG'] is not None:
                tensor = project(tensor, state['Q'])

            projected.append(tensor)

        # ------------------------ run adam in projected space ----------------------- #
        exp_avg_proj, exp_avg_sq_proj = unpack_states(states, tensors, "exp_avg_proj", "exp_avg_sq_proj", must_exist=True, cls=TensorList)
        alpha, beta1, beta2, eps = unpack_dicts(settings, "alpha", "beta1", "beta2", "eps", cls=NumberList)

        # lerp exp_avg in projected space
        if fs["proj_exp_avg"]:
            exp_avg_proj.lerp_(projected, weight=1-beta1)

        # or lerp in original space and project
        else:
            exp_avg = exp_avg_proj
            exp_avg.lerp_(merged, weight=1-beta1)
            exp_avg_proj = []
            for t, state, setting in zip(exp_avg, states, settings):
                if state['GG'] is not None:
                    t = project(t, state["Q"])
                exp_avg_proj.append(t)

        exp_avg_sq_proj.mul_(beta2).addcmul_(projected, projected, value=1-beta2)

        denom = exp_avg_sq_proj.sqrt().add_(eps)
        dirs_proj = exp_avg_proj / denom

        # ------------------------------- project back ------------------------------- #
        dirs: list[torch.Tensor] = []
        for dir, state, setting in zip(dirs_proj, states, settings):
            if state['GG'] is not None:
                dir = project_back(dir, state['Q'])

            if setting["merge_small"]:
                dir = _unmerge_small_dims(dir, state['flat_sizes'], state['sort_idxs'])

            dirs.append(dir)


        # -------------------------------- inner step -------------------------------- #
        if "inner" in self.children:
            tensors = self.inner_step_tensors("inner", tensors, clone=False,
                                              params=params, grads=grads,loss=loss)

            # we now have to re-merge small dims on updated tensors
            merged = []
            for tensor, state, setting in zip(tensors, states, settings):
                if setting["merge_small"]:
                    tensor, _, _ = _merge_small_dims(tensor, setting["max_dim"])
                    merged.append(tensor)

        # -------------------------- update preconditioners -------------------------- #
        # Update is done after the gradient step to avoid using current gradients in the projection.

        for tensor, state, setting in zip(merged, states, settings):
            if state['GG'] is not None:

                # lerp covariances
                update_soap_covariances_(tensor, state['GG'], beta=setting["shampoo_beta"])

                # (state['step'] - 1) since we start updating on 2nd step
                if (state['step'] - 1) % setting['precond_freq'] == 0:

                    # unproject exp_avg before updating if it is maintained projected
                    exp_avg = None
                    if fs["proj_exp_avg"]:
                        exp_avg = project_back(state["exp_avg_proj"], state["Q"])

                    # update projection matrix and exp_avg_sq_proj
                    try:
                        state['Q'], state['exp_avg_sq_proj'] = get_orthogonal_matrix_QR(
                            state["exp_avg_sq_proj"], state['GG'], state['Q'])

                        # re-project exp_avg if it is maintained projected
                        if fs["proj_exp_avg"]:
                            assert exp_avg is not None
                            state["exp_avg_proj"] = project(exp_avg, state["Q"])

                    except torch.linalg.LinAlgError:
                        pass

            state["step"] += 1


        # ------------------------- bias-corrected step size ------------------------- #
        if fs["debias"]:
            steps1 = [s+1 for s in steps]
            bias_correction1 = 1.0 - beta1 ** steps1
            bias_correction2 = 1.0 - beta2 ** steps1
            alpha = alpha * (bias_correction2 ** .5) / bias_correction1

        torch._foreach_mul_(dirs, alpha)
        return dirs

ScaleLRBySignChange

Bases: torchzero.core.transform.TensorTransform

learning rate gets multiplied by nplus if ascent/gradient didn't change the sign, or nminus if it did.

This is part of RProp update rule.

Parameters:

  • nplus (float, default: 1.2 ) –

    learning rate gets multiplied by nplus if ascent/gradient didn't change the sign

  • nminus (float, default: 0.5 ) –

    learning rate gets multiplied by nminus if ascent/gradient changed the sign

  • lb (float, default: 1e-06 ) –

    lower bound for lr.

  • ub (float, default: 50.0 ) –

    upper bound for lr.

  • alpha (float, default: 1.0 ) –

    initial learning rate.

Source code in torchzero/modules/adaptive/rprop.py
class ScaleLRBySignChange(TensorTransform):
    """
    learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign,
    or `nminus` if it did.

    This is part of RProp update rule.

    Args:
        nplus (float): learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign
        nminus (float): learning rate gets multiplied by `nminus` if ascent/gradient changed the sign
        lb (float): lower bound for lr.
        ub (float): upper bound for lr.
        alpha (float): initial learning rate.

    """

    def __init__(
        self,
        nplus: float = 1.2,
        nminus: float = 0.5,
        lb=1e-6,
        ub=50.0,
        alpha=1.0,
        use_grad=False,
    ):
        defaults = dict(nplus=nplus, nminus=nminus, alpha=alpha, lb=lb, ub=ub, use_grad=use_grad)
        super().__init__(defaults, uses_grad=use_grad)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        tensors = TensorList(tensors)
        if self._uses_grad:
            assert grads is not None
            cur = TensorList(grads)
        else: cur = tensors

        nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
        prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)

        if step == 0:
            lrs.set_(tensors.full_like([s['alpha'] for s in settings]))

        tensors = scale_by_sign_change_(
            tensors_ = tensors,
            cur = cur,
            prev_ = prev,
            lrs_ = lrs,
            nplus = nplus,
            nminus = nminus,
            lb = lb,
            ub = ub,
            step = step,
        )
        return tensors

Shampoo

Bases: torchzero.core.transform.TensorTransform

Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).

Notes

Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.

Shampoo is a very computationally expensive optimizer, increase update_freq if it is too slow.

SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as tz.m.SOAP.

Parameters:

  • update_freq (int) –

    preconditioner update frequency. Defaults to 10.

  • matrix_power (float | None, default: None ) –

    overrides matrix exponent. By default uses -1/grad.ndim. Defaults to None.

  • merge_small (bool, default: True ) –

    whether to merge small dims on tensors. Defaults to True.

  • max_dim (int, default: 10000 ) –

    maximum dimension size for preconditioning. Defaults to 10_000.

  • precondition_1d (bool, default: True ) –

    whether to precondition 1d tensors. Defaults to True.

  • adagrad_eps (float, default: 1e-08 ) –

    epsilon for adagrad division for tensors where shampoo can't be applied. Defaults to 1e-8.

  • matrix_power_method (Literal, default: 'eigh_abs' ) –

    how to compute matrix power.

  • beta (float | None, default: None ) –

    if None calculates sum as in standard Shampoo, otherwise uses EMA of preconditioners. Defaults to None.

  • inner (Chainable | None, default: None ) –

    module applied after updating preconditioners and before applying preconditioning. For example if beta≈0.999 and inner=tz.m.EMA(0.9), this becomes Adam with shampoo preconditioner (ignoring debiasing). Defaults to None.

Examples: Shampoo grafted to Adam

opt = tz.Optimizer(
    model.parameters(),
    tz.m.GraftModules(
        direction = tz.m.Shampoo(),
        magnitude = tz.m.Adam(),
    ),
    tz.m.LR(1e-3)
)

Adam with Shampoo preconditioner

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
    tz.m.Debias(0.9, 0.999),
    tz.m.LR(1e-3)
)
Source code in torchzero/modules/adaptive/shampoo.py
class Shampoo(TensorTransform):
    """Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).

    Notes:
        Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.

        Shampoo is a very computationally expensive optimizer, increase ``update_freq`` if it is too slow.

        SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as ``tz.m.SOAP``.

    Args:
        update_freq (int, optional): preconditioner update frequency. Defaults to 10.
        matrix_power (float | None, optional): overrides matrix exponent. By default uses ``-1/grad.ndim``. Defaults to None.
        merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
        max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 10_000.
        precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
        adagrad_eps (float, optional): epsilon for adagrad division for tensors where shampoo can't be applied. Defaults to 1e-8.
        matrix_power_method (MatrixPowerMethod, optional): how to compute matrix power.
        beta (float | None, optional):
            if None calculates sum as in standard Shampoo, otherwise uses EMA of preconditioners. Defaults to None.
        inner (Chainable | None, optional):
            module applied after updating preconditioners and before applying preconditioning.
            For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
            Defaults to None.

    Examples:
    Shampoo grafted to Adam

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.GraftModules(
            direction = tz.m.Shampoo(),
            magnitude = tz.m.Adam(),
        ),
        tz.m.LR(1e-3)
    )
    ```

    Adam with Shampoo preconditioner

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
        tz.m.Debias(0.9, 0.999),
        tz.m.LR(1e-3)
    )
    ```
    """
    def __init__(
        self,
        reg: float = 1e-12,
        precond_freq: int = 10,
        matrix_power: float | None = None,
        merge_small: bool = True,
        max_dim: int = 10_000,
        precondition_1d: bool = True,
        adagrad_eps: float = 1e-8,
        matrix_power_method: MatrixPowerMethod = "eigh_abs",
        beta: float | None = None,
        beta_debias: bool = True,

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults["inner"]

        super().__init__(defaults, inner=inner)

    @torch.no_grad
    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        if setting["merge_small"]:
            tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])

        if tensor.ndim <= 1 and not setting["precondition_1d"]:
            state["accumulators"] = []

        else:
            max_dim = setting["max_dim"]
            state['accumulators'] = [
                torch.eye(s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
            ]
            state['preconditioners'] = [
                torch.eye(s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
            ]

        # either scalar parameter, 1d with precondition_1d=False, or too big, then diagonal preconditioner is used.
        if len([i is not None for i in state['accumulators']]) == 0:
            state['diagonal_accumulator'] = torch.zeros_like(tensor)

        state['step'] = 0
        state["num_GTG"] = 0

    @torch.no_grad
    def single_tensor_update(self, tensor, param, grad, loss, state, setting):
        if setting["merge_small"]:
            tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])

        if 'diagonal_accumulator' in state:
            update_diagonal_(tensor, state['diagonal_accumulator'], beta=setting["beta"])
        else:
            update_shampoo_preconditioner_(
                tensor,
                accumulators_=state['accumulators'],
                preconditioners_=state['preconditioners'],
                step=state['step'],
                precond_freq=setting["precond_freq"],
                matrix_power=setting["matrix_power"],
                beta=setting["beta"],
                reg=setting["reg"],
                matrix_power_method=setting["matrix_power_method"],
            )

        if state["step"] % setting["precond_freq"] == 0:
            state["num_GTG"] += 1

        state["step"] += 1

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        if setting["merge_small"]:
            tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])

        if 'diagonal_accumulator' in state:
            dir = apply_diagonal_(tensor, state['diagonal_accumulator'], eps=setting["adagrad_eps"])
        else:
            dir = apply_shampoo_preconditioner(tensor, preconditioners_=state['preconditioners'])

        if setting["merge_small"]:
            dir = _unmerge_small_dims(dir, state['flat_sizes'], state['sort_idxs'])

        if setting['beta_debias'] and setting["beta"] is not None:
            bias_correction = 1 - (setting["beta"] ** state["num_GTG"])
            dir *= bias_correction ** 0.5

        return dir

SignConsistencyLRs

Bases: torchzero.core.transform.TensorTransform

Outputs per-weight learning rates based on consecutive sign consistency.

The learning rate for a weight is multiplied by nplus when two consecutive update signs are the same, otherwise it is multiplied by nplus. The learning rates are bounded to be in (lb, ub) range.

Examples:

GD scaled by consecutive gradient sign consistency

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Mul(tz.m.SignConsistencyLRs()),
    tz.m.LR(1e-2)
)
Source code in torchzero/modules/adaptive/rprop.py
class SignConsistencyLRs(TensorTransform):
    """Outputs per-weight learning rates based on consecutive sign consistency.

    The learning rate for a weight is multiplied by ``nplus`` when two consecutive update signs are the same, otherwise it is multiplied by ``nplus``. The learning rates are bounded to be in ``(lb, ub)`` range.

    ### Examples:

    GD scaled by consecutive gradient sign consistency

    ```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Mul(tz.m.SignConsistencyLRs()),
        tz.m.LR(1e-2)
    )
    ```

"""
    def __init__(
        self,
        nplus: float = 1.2,
        nminus: float = 0.5,
        lb: float | None = 1e-6,
        ub: float | None = 50,
        alpha: float = 1,
    ):
        defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        target = TensorList(tensors)
        nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
        prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)

        if step == 0:
            lrs.set_(target.full_like([s['alpha'] for s in settings]))

        target = sign_consistency_lrs_(
            tensors = target,
            prev_ = prev,
            lrs_ = lrs,
            nplus = nplus,
            nminus = nminus,
            lb = lb,
            ub = ub,
            step = step,
        )
        return target.clone()

SignConsistencyMask

Bases: torchzero.core.transform.TensorTransform

Outputs a mask of sign consistency of current and previous inputs.

The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.

Examples:

GD that skips update for weights where gradient sign changed compared to previous gradient.

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Mul(tz.m.SignConsistencyMask()),
    tz.m.LR(1e-2)
)
Source code in torchzero/modules/adaptive/rprop.py
class SignConsistencyMask(TensorTransform):
    """
    Outputs a mask of sign consistency of current and previous inputs.

    The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.

    ### Examples:

    GD that skips update for weights where gradient sign changed compared to previous gradient.

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Mul(tz.m.SignConsistencyMask()),
        tz.m.LR(1e-2)
    )
    ```

    """
    def __init__(self):
        super().__init__()

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        prev = unpack_states(states, tensors, 'prev', cls=TensorList)
        mask = prev.mul_(tensors).gt_(0)
        prev.copy_(tensors)
        return mask

SophiaH

Bases: torchzero.core.transform.Transform

SophiaH optimizer from https://arxiv.org/abs/2305.14342

This is similar to Adam, but the second momentum is replaced by an exponential moving average of randomized hessian diagonal estimates, and the update is agressively clipped.

Notes
  • In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply SophiaH preconditioning to another module's output.

  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

Parameters:

  • beta1 (float, default: 0.96 ) –

    first momentum. Defaults to 0.96.

  • beta2 (float, default: 0.99 ) –

    momentum for hessian diagonal estimate. Defaults to 0.99.

  • update_freq (int, default: 10 ) –

    frequency of updating hessian diagonal estimate via a hessian-vector product. Defaults to 10.

  • precond_scale (float, default: 1 ) –

    scale of the preconditioner. Defaults to 1.

  • clip (float, default: 1 ) –

    clips update to (-clip, clip). Defaults to 1.

  • eps (float, default: 1e-12 ) –

    clips hessian diagonal esimate to be no less than this value. Defaults to 1e-12.

  • hvp_method (str, default: 'autograd' ) –

    Determines how Hessian-vector products are computed.

    • "batched_autograd" - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to "autograd". Faster than "autograd" but uses more memory.
    • "autograd" - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than "batched_autograd" but uses less memory.
    • "fd_forward" - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
    • "fd_central" - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

    Defaults to "autograd".

  • h (float, default: 0.001 ) –

    The step size for finite difference if hvp_method is "fd_forward" or "fd_central". Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of hessian-vector products with random vectors to evaluate each time when updating the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.

  • seed (int | None, default: None ) –

    seed for random vectors. Defaults to None.

  • inner (Chainable | None) –

    preconditioning is applied to the output of this module. Defaults to None.

Examples:

Using SophiaH:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SophiaH(),
    tz.m.LR(0.1)
)

SophiaH preconditioner can be applied to any other module by passing it to the inner argument. Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying SophiaH preconditioning to nesterov momentum (tz.m.NAG):

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
    tz.m.LR(0.1)
)
Source code in torchzero/modules/adaptive/sophia_h.py
class SophiaH(Transform):
    """SophiaH optimizer from https://arxiv.org/abs/2305.14342

    This is similar to Adam, but the second momentum is replaced by an exponential moving average of randomized hessian diagonal estimates, and the update is agressively clipped.

    Notes:
        - In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply SophiaH preconditioning to another module's output.

        - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        beta1 (float, optional): first momentum. Defaults to 0.96.
        beta2 (float, optional): momentum for hessian diagonal estimate. Defaults to 0.99.
        update_freq (int, optional):
            frequency of updating hessian diagonal estimate via a hessian-vector product. Defaults to 10.
        precond_scale (float, optional):
            scale of the preconditioner. Defaults to 1.
        clip (float, optional):
            clips update to (-clip, clip). Defaults to 1.
        eps (float, optional):
            clips hessian diagonal esimate to be no less than this value. Defaults to 1e-12.
        hvp_method (str, optional):
            Determines how Hessian-vector products are computed.

            - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
            - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
            - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

            Defaults to ``"autograd"``.
        h (float, optional):
            The step size for finite difference if ``hvp_method`` is
            ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
        n_samples (int, optional):
            number of hessian-vector products with random vectors to evaluate each time when updating
            the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
        seed (int | None, optional): seed for random vectors. Defaults to None.
        inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.

    ### Examples:

    Using SophiaH:

    ```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SophiaH(),
        tz.m.LR(0.1)
    )
    ```

    SophiaH preconditioner can be applied to any other module by passing it to the ``inner`` argument.
    Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying
    SophiaH preconditioning to nesterov momentum (``tz.m.NAG``):

    ```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
        tz.m.LR(0.1)
    )
    ```
    """
    def __init__(
        self,
        beta1: float = 0.96,
        beta2: float = 0.99,
        update_freq: int = 10,
        precond_scale: float = 1,
        clip: float = 1,
        eps: float = 1e-12,
        hvp_method: HVPMethod = 'autograd',
        distribution: Distributions = 'gaussian',
        h: float = 1e-3,
        n_samples = 1,
        zHz: bool = True,
        debias: bool = False,
        seed: int | None = None,

        exp_avg_tfm: Chainable | None = None,
        D_exp_avg_tfm: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['exp_avg_tfm'], defaults["D_exp_avg_tfm"]
        super().__init__(defaults)

        self.set_child('exp_avg', exp_avg_tfm)
        self.set_child('D_exp_avg', D_exp_avg_tfm)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        params = objective.params

        beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)

        exp_avg, D_exp_avg = unpack_states(states, params, 'exp_avg', 'D_exp_avg', cls=TensorList)

        step = self.increment_counter("step", start=0) # 0 on 1st update

        # ---------------------------- hutchinson hessian ---------------------------- #
        fs = settings[0]
        update_freq = fs['update_freq']

        if step % update_freq == 0:
            self.increment_counter("num_Ds", start=1)

            D, _ = objective.hutchinson_hessian(
                rgrad = None,
                at_x0 = True,
                n_samples = fs['n_samples'],
                distribution = fs['distribution'],
                hvp_method = fs['hvp_method'],
                h = fs['h'],
                zHz = fs["zHz"],
                generator = self.get_generator(params[0].device, fs["seed"]),
            )

            D_exp_avg.lerp_(D, weight=1-beta2)

        # --------------------------------- momentum --------------------------------- #
        tensors = objective.get_updates() # do this after hutchinson to not disturb autograd
        exp_avg.lerp_(tensors, 1-beta1)


    @torch.no_grad
    def apply_states(self, objective, states, settings):
        params = objective.params

        beta1, beta2, eps, precond_scale, clip = unpack_dicts(
            settings, 'beta1', 'beta2', 'eps', 'precond_scale', 'clip', cls=NumberList)

        exp_avg, D_exp_avg = unpack_states(states, params, 'exp_avg', 'D_exp_avg')

        # ---------------------------------- debias ---------------------------------- #
        if settings[0]["debias"]:
            bias_correction1 = 1.0 - (beta1 ** (self.global_state["step"] + 1))
            bias_correction2 = 1.0 - (beta2 ** self.global_state["num_Ds"])

            exp_avg = exp_avg / bias_correction1
            D_exp_avg = D_exp_avg / bias_correction2

        # -------------------------------- transforms -------------------------------- #
        exp_avg = TensorList(self.inner_step_tensors(
            "exp_avg", tensors=exp_avg, clone=True, objective=objective, must_exist=False))

        D_exp_avg = TensorList(self.inner_step_tensors(
            "D_exp_avg", tensors=D_exp_avg, clone=True, objective=objective, must_exist=False))

        # ------------------------------ compute update ------------------------------ #
        denom = D_exp_avg.lazy_mul(precond_scale).clip(min=eps)
        objective.updates = (exp_avg / denom).clip_(-clip, clip)
        return objective

orthogonalize_grads_

orthogonalize_grads_(params: Iterable[Tensor], dual_norm_correction=False, method: Literal['newtonschulz', 'svd', 'qr', 'eigh'] = 'newtonschulz', channel_first: bool = True)

Computes the zeroth power / orthogonalization of gradients of an iterable of parameters.

This sets gradients in-place. Applies along first 2 dims (expected to be out_channels, in_channels).

Note that the Muon page says that embeddings and classifier heads should not be orthogonalized. Args: params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize. dual_norm_correction (bool, optional): enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False. method (str, optional): Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise. channel_first (bool, optional): if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions are considered batch dimensions.

Source code in torchzero/modules/adaptive/muon.py
def orthogonalize_grads_(
    params: Iterable[torch.Tensor],
    dual_norm_correction=False,
    method: OrthogonalizeMethod = "newtonschulz",
    channel_first:bool=True,
):
    """Computes the zeroth power / orthogonalization of gradients of an iterable of parameters.

    This sets gradients in-place. Applies along first 2 dims (expected to be `out_channels, in_channels`).

    Note that the Muon page says that embeddings and classifier heads should not be orthogonalized.
    Args:
        params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
        dual_norm_correction (bool, optional):
            enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
        method (str, optional):
            Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
        channel_first (bool, optional):
            if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions
            are considered batch dimensions.
    """
    for p in params:
        if (p.grad is not None) and _is_at_least_2d(p.grad, channel_first=channel_first):
            X = _orthogonalize_format(p.grad, method=method, channel_first=channel_first)
            if dual_norm_correction: X = _dual_norm_correction(X, p.grad, channel_first=False)
            p.grad.set_(X.view_as(p)) # pyright:ignore[reportArgumentType]

orthograd_

orthograd_(params: Iterable[Tensor], eps: float = 1e-30)

Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

Parameters:

  • params (Iterable[Tensor]) –

    parameters that hold gradients to apply ⟂Grad to.

  • eps (float, default: 1e-30 ) –

    epsilon added to the denominator for numerical stability (default: 1e-30)

reference https://arxiv.org/abs/2501.04697

Source code in torchzero/modules/adaptive/orthograd.py
def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
    """Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

    Args:
        params (abc.Iterable[torch.Tensor]): parameters that hold gradients to apply ⟂Grad to.
        eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)

    reference
        https://arxiv.org/abs/2501.04697
    """
    params = TensorList(params).with_grad()
    grad = params.grad
    grad -= (params.dot(grad)/(params.dot(params) + eps)) * params