Skip to content

List of all modules

A somewhat categorized list of modules is also available in Modules

Classes:

  • AEGD

    AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.

  • ASAM

    Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52

  • Abs

    Returns :code:abs(input)

  • AccumulateMaximum

    Accumulates maximum of all past updates.

  • AccumulateMean

    Accumulates mean of all past updates.

  • AccumulateMinimum

    Accumulates minimum of all past updates.

  • AccumulateProduct

    Accumulates product of all past updates.

  • AccumulateSum

    Accumulates sum of all past updates.

  • AdGD

    AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)

  • AdaHessian

    AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)

  • Adagrad

    Adagrad, divides by sum of past squares of gradients.

  • AdagradNorm

    Adagrad-Norm, divides by sum of past means of squares of gradients.

  • Adam

    Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.

  • Adan

    Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677

  • AdaptiveBacktracking

    Adaptive backtracking line search. After each line search procedure, a new initial step size is set

  • AdaptiveHeavyBall

    Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.

  • AdaptiveTracking

    A line search that evaluates previous step size, if value increased, backtracks until the value stops decreasing,

  • Add

    Add :code:other to tensors. :code:other can be a number or a module.

  • Alternate

    Alternates between stepping with :code:modules.

  • Averaging

    Average of past history_size updates.

  • BBStab

    Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).

  • BFGS

    Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.

  • BacktrackOnSignChange

    Negates or undoes update for parameters where where gradient or update sign changes.

  • Backtracking

    Backtracking line search.

  • BarzilaiBorwein

    Barzilai-Borwein step size method.

  • BinaryOperationBase

    Base class for operations that use update as the first operand. This is an abstract class, subclass it and override transform method to use it.

  • BirginMartinezRestart

    the restart criterion for conjugate gradient methods designed by Birgin and Martinez.

  • BroydenBad

    Broyden's "bad" Quasi-Newton method.

  • BroydenGood

    Broyden's "good" Quasi-Newton method.

  • CCD

    Cumulative coordinate descent. This updates one gradient coordinate at a time and accumulates it

  • CCDLS

    CCD with line search instead of adaptive step size.

  • CD

    Coordinate descent. Proposes a descent direction along a single coordinate.

  • Cautious

    Negates update for parameters where update and gradient sign is inconsistent.

  • CenteredEMASquared

    Maintains a centered exponential moving average of squared updates. This also maintains an additional

  • CenteredSqrtEMASquared

    Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.

  • Centralize

    Centralizes the update.

  • Clip

    clip tensors to be in :code:(min, max) range. :code:min and :code:`max: can be None, numbers or modules.

  • ClipModules

    Calculates :code:input(tensors).clip(min, max). :code:min and :code:max can be numbers or modules.

  • ClipNorm

    Clips update norm to be no larger than value.

  • ClipNormByEMA

    Clips norm to be no larger than the norm of an exponential moving average of past updates.

  • ClipNormGrowth

    Clips update norm growth.

  • ClipValue

    Clips update magnitude to be within (-value, value) range.

  • ClipValueByEMA

    Clips magnitude of update to be no larger than magnitude of exponential moving average of past (unclipped) updates.

  • ClipValueGrowth

    Clips update value magnitude growth.

  • Clone

    Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations

  • ConjugateDescent

    Conjugate Descent (CD).

  • CopyMagnitude

    Returns :code:other(tensors) with sign copied from tensors.

  • CopySign

    Returns tensors with sign copied from :code:other(tensors).

  • CubicRegularization

    Cubic regularization.

  • CustomUnaryOperation

    Applies :code:getattr(tensor, name) to each tensor

  • DFP

    Davidon–Fletcher–Powell Quasi-Newton method.

  • DNRTR

    Diagonal quasi-newton method.

  • DYHS

    Dai-Yuan - Hestenes–Stiefel hybrid conjugate gradient method.

  • DaiYuan

    Dai–Yuan nonlinear conjugate gradient method.

  • Debias

    Multiplies the update by an Adam debiasing term based first and/or second momentum.

  • Debias2

    Multiplies the update by an Adam debiasing term based on the second momentum.

  • DiagonalBFGS

    Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful.

  • DiagonalQuasiCauchi

    Diagonal quasi-cauchi method.

  • DiagonalSR1

    Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful.

  • DiagonalWeightedQuasiCauchi

    Diagonal quasi-cauchi method.

  • DirectWeightDecay

    Directly applies weight decay to parameters.

  • Div

    Divide tensors by :code:other. :code:other can be a number or a module.

  • DivByLoss

    Divides update by loss times :code:alpha

  • DivModules

    Calculates :code:input / other. :code:input and :code:other can be numbers or modules.

  • Dogleg

    Dogleg trust region algorithm.

  • Dropout

    Applies dropout to the update.

  • DualNormCorrection

    Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon).

  • EMA

    Maintains an exponential moving average of update.

  • EMASquared

    Maintains an exponential moving average of squared updates.

  • ESGD

    Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)

  • EscapeAnnealing

    If parameters stop changing, this runs a backward annealing random search

  • Exp

    Returns :code:exp(input)

  • ExpHomotopy
  • FDM

    Approximate gradients via finite difference method.

  • Fill

    Outputs tensors filled with :code:value

  • FillLoss

    Outputs tensors filled with loss value times :code:alpha

  • FletcherReeves

    Fletcher–Reeves nonlinear conjugate gradient method.

  • FletcherVMM

    Fletcher's variable metric Quasi-Newton method.

  • ForwardGradient

    Forward gradient method.

  • FullMatrixAdagrad

    Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).

  • GaussNewton

    Gauss-newton method.

  • GaussianSmoothing

    Gradient approximation via Gaussian smoothing method.

  • Grad

    Outputs the gradient

  • GradApproximator

    Base class for gradient approximations.

  • GradSign

    Copies gradient sign to update.

  • GradToNone

    Sets :code:grad attribute to None on :code:var.

  • GradientAccumulation

    Uses n steps to accumulate gradients, after n gradients have been accumulated, they are passed to :code:modules and parameters are updates.

  • GradientCorrection

    Estimates gradient at minima along search direction assuming function is quadratic.

  • GradientSampling

    Samples and aggregates gradients and values at perturbed points.

  • Graft

    Outputs tensors rescaled to have the same norm as :code:magnitude(tensors).

  • GraftGradToUpdate

    Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update.

  • GraftModules

    Outputs :code:direction output rescaled to have the same norm as :code:magnitude output.

  • GraftToGrad

    Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient.

  • GraftToParams

    Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:eps.

  • GraftToUpdate

    Outputs :code:magnitude(tensors) rescaled to have the same norm as tensors

  • GramSchimdt

    outputs tensors made orthogonal to other(tensors) via Gram-Schmidt.

  • Greenstadt1

    Greenstadt's first Quasi-Newton method.

  • Greenstadt2

    Greenstadt's second Quasi-Newton method.

  • HagerZhang

    Hager-Zhang nonlinear conjugate gradient method,

  • HeavyBall

    Polyak's momentum (heavy-ball method).

  • HestenesStiefel

    Hestenes–Stiefel nonlinear conjugate gradient method.

  • HigherOrderNewton

    A basic arbitrary order newton's method with optional trust region and proximal penalty.

  • Horisho

    Horisho's variable metric Quasi-Newton method.

  • HpuEstimate

    returns y/||s||, where y is difference between current and previous update (gradient), s is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update.

  • ICUM

    Inverse Column-updating Quasi-Newton method. This is computationally cheaper than other Quasi-Newton methods

  • Identity

    Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

  • IntermoduleCautious

    Negaties update on :code:main module where it's sign doesn't match with output of :code:compare module.

  • InverseFreeNewton

    Inverse-free newton's method

  • LBFGS

    Limited-memory BFGS algorithm. A line search or trust region is recommended.

  • LMAdagrad

    Limited-memory full matrix Adagrad.

  • LR

    Learning rate. Adding this module also adds support for LR schedulers.

  • LSR1

    Limited-memory SR1 algorithm. A line search or trust region is recommended.

  • LambdaHomotopy
  • LaplacianSmoothing

    Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.

  • LastAbsoluteRatio

    Outputs ratio between absolute values of past two updates the numerator is determined by :code:numerator argument.

  • LastDifference

    Outputs difference between past two updates.

  • LastGradDifference

    Outputs difference between past two gradients.

  • LastProduct

    Outputs difference between past two updates.

  • LastRatio

    Outputs ratio between past two updates, the numerator is determined by :code:numerator argument.

  • LerpModules

    Does a linear interpolation of :code:input(tensors) and :code:end(tensors) based on a scalar :code:weight.

  • LevenbergMarquardt

    Levenberg-Marquardt trust region algorithm.

  • LineSearchBase

    Base class for line searches.

  • Lion

    Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.

  • LiuStorey

    Liu-Storey nonlinear conjugate gradient method.

  • LogHomotopy
  • MARSCorrection

    MARS variance reduction correction.

  • MSAM

    Momentum-SAM from https://arxiv.org/pdf/2401.12033.

  • MSAMObjective

    Momentum-SAM from https://arxiv.org/pdf/2401.12033.

  • MatrixMomentum

    Second order momentum method.

  • Maximum

    Outputs :code:maximum(tensors, other(tensors))

  • MaximumModules

    Outputs elementwise maximum of :code:inputs that can be modules or numbers.

  • McCormick

    McCormicks's Quasi-Newton method.

  • MeZO

    Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.

  • Mean

    Outputs a mean of :code:inputs that can be modules or numbers.

  • MedianAveraging

    Median of past history_size updates.

  • Minimum

    Outputs :code:minimum(tensors, other(tensors))

  • MinimumModules

    Outputs elementwise minimum of :code:inputs that can be modules or numbers.

  • Mul

    Multiply tensors by :code:other. :code:other can be a number or a module.

  • MulByLoss

    Multiplies update by loss times :code:alpha

  • MultiOperationBase

    Base class for operations that use operands. This is an abstract class, subclass it and override transform method to use it.

  • Multistep

    Performs :code:steps inner steps with :code:module per each step.

  • MuonAdjustLR

    LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master).

  • NAG

    Nesterov accelerated gradient method (nesterov momentum).

  • NanToNum

    Convert nan, inf and -inf to numbers.

  • NaturalGradient

    Natural gradient approximated via empirical fisher information matrix.

  • Negate

    Returns :code:- input

  • NegateOnLossIncrease

    Uses an extra forward pass to evaluate loss at :code:parameters+update,

  • NewDQN

    Diagonal quasi-newton method.

  • NewSSM

    Self-scaling Quasi-Newton method.

  • Newton

    Exact newton's method via autograd.

  • NewtonCG

    Newton's method with a matrix-free conjugate gradient or minimial-residual solver.

  • NewtonCGSteihaug

    Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.

  • NoiseSign

    Outputs random tensors with sign copied from the update.

  • Noop

    Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

  • Normalize

    Normalizes the update.

  • NormalizeByEMA

    Sets norm of the update to be the same as the norm of an exponential moving average of past updates.

  • NystromPCG

    Newton's method with a Nyström-preconditioned conjugate gradient solver.

  • NystromSketchAndSolve

    Newton's method with a Nyström sketch-and-solve solver.

  • Ones

    Outputs ones

  • Online

    Allows certain modules to be used for mini-batch optimization.

  • OrthoGrad

    Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

  • Orthogonalize

    Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.

  • PSB

    Powell's Symmetric Broyden Quasi-Newton method.

  • Params

    Outputs parameters

  • Pearson

    Pearson's Quasi-Newton method.

  • PerturbWeights

    Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.

  • PolakRibiere

    Polak-Ribière-Polyak nonlinear conjugate gradient method.

  • PolyakStepSize

    Polyak's subgradient method with known or unknown f*.

  • Pow

    Take tensors to the power of :code:exponent. :code:exponent can be a number or a module.

  • PowModules

    Calculates :code:input ** exponent. :code:input and :code:other can be numbers or modules.

  • PowellRestart

    Powell's two restarting criterions for conjugate gradient methods.

  • Previous

    Maintains an update from n steps back, for example if n=1, returns previous update

  • PrintLoss

    Prints var.get_loss().

  • PrintParams

    Prints current update.

  • PrintShape

    Prints shapes of the update.

  • PrintUpdate

    Prints current update.

  • Prod

    Outputs product of :code:inputs that can be modules or numbers.

  • ProjectedGradientMethod

    Projected gradient method. Directly projects the gradient onto subspace conjugate to past directions.

  • ProjectedNewtonRaphson

    Projected Newton Raphson method.

  • ProjectionBase

    Base class for projections.

  • RCopySign

    Returns :code:other(tensors) with sign copied from tensors.

  • RDSA

    Gradient approximation via Random-direction stochastic approximation (RDSA) method.

  • RDiv

    Divide :code:other by tensors. :code:other can be a number or a module.

  • RGraft

    Outputs :code:magnitude(tensors) rescaled to have the same norm as tensors

  • RMSprop

    Divides graient by EMA of gradient squares.

  • RPow

    Take :code:other to the power of tensors. :code:other can be a number or a module.

  • RSub

    Subtract tensors from :code:other. :code:other can be a number or a module.

  • Randn

    Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1.

  • RandomHvp

    Returns a hessian-vector product with a random vector

  • RandomSample

    Outputs tensors filled with random numbers from distribution depending on value of :code:distribution.

  • RandomStepSize

    Uses random global or layer-wise step size from low to high.

  • RandomizedFDM

    Gradient approximation via a randomized finite-difference method.

  • Reciprocal

    Returns :code:1 / input

  • ReduceOperationBase

    Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override transform method to use it.

  • Relative

    Multiplies update by absolute parameter values to make it relative to their magnitude, :code:min_value is minimum allowed value to avoid getting stuck at 0.

  • RelativeWeightDecay

    Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of norm_input argument.

  • RestartEvery

    Resets the state every n steps

  • RestartOnStuck

    Resets the state when update (difference in parameters) is zero for multiple steps in a row.

  • RestartStrategyBase

    Base class for restart strategies.

  • Rprop

    Resilient propagation. The update magnitude gets multiplied by nplus if gradient didn't change the sign,

  • SAM

    Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412

  • SOAP

    SOAP (ShampoO with Adam in the Preconditioner's eigenbasis from https://arxiv.org/abs/2409.11321).

  • SPSA

    Gradient approximation via Simultaneous perturbation stochastic approximation (SPSA) method.

  • SR1

    Symmetric Rank 1. This works best with a trust region:

  • SSVM

    Self-scaling variable metric Quasi-Newton method.

  • SVRG

    Stochastic variance reduced gradient method (SVRG).

  • SaveBest

    Saves best parameters found so far, ones that have lowest loss. Put this as the last module.

  • ScalarProjection

    projetion that splits all parameters into individual scalars

  • ScaleByGradCosineSimilarity

    Multiplies the update by cosine similarity with gradient.

  • ScaleLRBySignChange

    learning rate gets multiplied by nplus if ascent/gradient didn't change the sign,

  • ScaleModulesByCosineSimilarity

    Scales the output of :code:main module by it's cosine similarity to the output

  • ScipyMinimizeScalar

    Line search via :code:scipy.optimize.minimize_scalar which implements brent, golden search and bounded brent methods.

  • Sequential

    On each step, this sequentially steps with :code:modules :code:steps times.

  • Shampoo

    Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).

  • ShorR

    Shor’s r-algorithm.

  • Sign

    Returns :code:sign(input)

  • SignConsistencyLRs

    Outputs per-weight learning rates based on consecutive sign consistency.

  • SignConsistencyMask

    Outputs a mask of sign consistency of current and previous inputs.

  • SixthOrder3P

    Sixth-order iterative method.

  • SixthOrder3PM2

    Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45.

  • SixthOrder5P

    Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139.

  • SophiaH

    SophiaH optimizer from https://arxiv.org/abs/2305.14342

  • Split

    Apply true modules to all parameters filtered by filter, apply false modules to all other parameters.

  • Sqrt

    Returns :code:sqrt(input)

  • SqrtEMASquared

    Maintains an exponential moving average of squared updates, outputs optionally debiased square root.

  • SqrtHomotopy
  • SquareHomotopy
  • StepSize

    this is exactly the same as LR, except the lr parameter can be renamed to any other name to avoid clashes

  • StrongWolfe

    Interpolation line search satisfying Strong Wolfe condition.

  • Sub

    Subtract :code:other from tensors. :code:other can be a number or a module.

  • SubModules

    Calculates :code:input - other. :code:input and :code:other can be numbers or modules.

  • Sum

    Outputs sum of :code:inputs that can be modules or numbers.

  • SumOfSquares

    Sets loss to be the sum of squares of values returned by the closure.

  • Switch

    After :code:steps steps switches to the next module.

  • TerminateAfterNEvaluations
  • TerminateAfterNSeconds
  • TerminateAfterNSteps
  • TerminateAll
  • TerminateAny
  • TerminateByGradientNorm
  • TerminateByUpdateNorm

    update is calculated as parameter difference

  • TerminateNever
  • TerminateOnLossReached
  • TerminateOnNoImprovement
  • TerminationCriteriaBase
  • ThomasOptimalMethod

    Thomas's "optimal" Quasi-Newton method.

  • Threshold

    Outputs tensors thresholded such that values above :code:threshold are set to :code:value.

  • To

    Cast modules to specified device and dtype

  • TrustCG

    Trust region via Steihaug-Toint Conjugate Gradient method.

  • TrustRegionBase
  • TwoPointNewton

    two-point Newton method with frozen derivative with third order convergence.

  • UnaryLambda

    Applies :code:fn to input tensors.

  • UnaryParameterwiseLambda

    Applies :code:fn to each input tensor.

  • Uniform

    Outputs tensors filled with random numbers from uniform distribution between :code:low and :code:high.

  • UpdateGradientSignConsistency

    Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.

  • UpdateSign

    Outputs gradient with sign copied from the update.

  • UpdateToNone

    Sets :code:update attribute to None on :code:var.

  • VectorProjection

    projection that concatenates all parameters into a vector

  • ViewAsReal

    View complex tensors as real tensors. Doesn't affect tensors that are already.

  • Warmup

    Learning rate warmup, linearly increases learning rate multiplier from :code:start_lr to :code:end_lr over :code:steps steps.

  • WarmupNormClip

    Warmup via clipping of the update norm.

  • WeightDecay

    Weight decay.

  • WeightDropout

    Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.

  • WeightedAveraging

    Weighted average of past len(weights) updates.

  • WeightedMean

    Outputs weighted mean of :code:inputs that can be modules or numbers.

  • WeightedSum
  • Wrap

    Wraps a pytorch optimizer to use it as a module.

  • Zeros

    Outputs zeros

Functions:

  • clip_grad_norm_

    Clips gradient of an iterable of parameters to specified norm value.

  • clip_grad_value_

    Clips gradient of an iterable of parameters at specified value.

  • decay_weights_

    directly decays weights in-place

  • normalize_grads_

    Normalizes gradient of an iterable of parameters to specified norm value.

  • orthogonalize_grads_

    Uses newton-Schulz iteration to compute the zeroth power / orthogonalization of gradients of an iterable of parameters.

  • orthograd_

    Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

AEGD

Bases: torchzero.core.transform.Transform

AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.

Note

AEGD has a learning rate hyperparameter that can't really be removed from the update rule. To avoid compounding learning rate mofications, remove the tz.m.LR module if you had it.

Parameters:

  • eta (float) –

    step size. Defaults to 0.1.

  • c (float, default: 1 ) –

    c. Defaults to 1.

  • beta3 (float) –

    thrid (squared) momentum. Defaults to 0.1.

  • eps (float) –

    epsilon. Defaults to 1e-8.

  • use_n_prev (bool) –

    whether to use previous gradient differences momentum.

Source code in torchzero/modules/adaptive/aegd.py
class AEGD(Transform):
    """AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.

    Note:
        AEGD has a learning rate hyperparameter that can't really be removed from the update rule.
        To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.

    Args:
        eta (float, optional): step size. Defaults to 0.1.
        c (float, optional): c. Defaults to 1.
        beta3 (float, optional): thrid (squared) momentum. Defaults to 0.1.
        eps (float, optional): epsilon. Defaults to 1e-8.
        use_n_prev (bool, optional):
            whether to use previous gradient differences momentum.
    """
    def __init__(
        self,
        lr: float = 0.1,
        c: float = 1,
    ):
        defaults=dict(c=c,lr=lr)
        super().__init__(defaults, uses_loss=True)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert loss is not None
        tensors = TensorList(tensors)

        c,lr=unpack_dicts(settings, 'c','lr', cls=NumberList)
        r = unpack_states(states, tensors, 'r', init=lambda t: torch.full_like(t, float(loss+c[0])**0.5), cls=TensorList)

        update = aegd_(
            f=loss,
            g=tensors,
            r_=r,
            c=c,
            eta=lr,
        )

        return update

ASAM

Bases: torchzero.modules.adaptive.sam.SAM

Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52

SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value. It performs two forward and backward passes per step.

This implementation modifies the closure to return loss and calculate gradients of the SAM objective. All modules after this will use the modified objective.

.. note:: This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients at two points on each step.

Parameters:

  • rho (float, default: 0.5 ) –

    Neighborhood size. Defaults to 0.05.

  • p (float, default: 2 ) –

    norm of the SAM objective. Defaults to 2.

Examples:

ASAM-Adam:

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.ASAM(),
    tz.m.Adam(),
    tz.m.LR(1e-2)
)
References

Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR. https://arxiv.org/abs/2102.11600

Source code in torchzero/modules/adaptive/sam.py
class ASAM(SAM):
    """Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52

    SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
    It performs two forward and backward passes per step.

    This implementation modifies the closure to return loss and calculate gradients
    of the SAM objective. All modules after this will use the modified objective.

    .. note::
        This module requires a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients at two points on each step.

    Args:
        rho (float, optional): Neighborhood size. Defaults to 0.05.
        p (float, optional): norm of the SAM objective. Defaults to 2.

    Examples:
        ASAM-Adam:

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.ASAM(),
                tz.m.Adam(),
                tz.m.LR(1e-2)
            )

    References:
        Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR. https://arxiv.org/abs/2102.11600
    """
    def __init__(self, rho: float = 0.5, p: float = 2, eps=1e-10):
        super().__init__(rho=rho, p=p, eps=eps, asam=True)

Abs

Bases: torchzero.core.transform.Transform

Returns :code:abs(input)

Source code in torchzero/modules/ops/unary.py
class Abs(Transform):
    """Returns :code:`abs(input)`"""
    def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        torch._foreach_abs_(tensors)
        return tensors

AccumulateMaximum

Bases: torchzero.core.transform.Transform

Accumulates maximum of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateMaximum(Transform):
    """Accumulates maximum of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0, target: Target = 'update',):
        defaults = dict(decay=decay)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return maximum.maximum_(tensors).lazy_mul(decay, clone=True)

AccumulateMean

Bases: torchzero.core.transform.Transform

Accumulates mean of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateMean(Transform):
    """Accumulates mean of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0, target: Target = 'update',):
        defaults = dict(decay=decay)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1
        mean = unpack_states(states, tensors, 'mean', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return mean.add_(tensors).lazy_mul(decay, clone=True).div_(step)

AccumulateMinimum

Bases: torchzero.core.transform.Transform

Accumulates minimum of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateMinimum(Transform):
    """Accumulates minimum of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0, target: Target = 'update',):
        defaults = dict(decay=decay)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return minimum.minimum_(tensors).lazy_mul(decay, clone=True)

AccumulateProduct

Bases: torchzero.core.transform.Transform

Accumulates product of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateProduct(Transform):
    """Accumulates product of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0, target: Target = 'update',):
        defaults = dict(decay=decay)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        prod = unpack_states(states, tensors, 'prod', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return prod.mul_(tensors).lazy_mul(decay, clone=True)

AccumulateSum

Bases: torchzero.core.transform.Transform

Accumulates sum of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateSum(Transform):
    """Accumulates sum of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0, target: Target = 'update',):
        defaults = dict(decay=decay)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        sum = unpack_states(states, tensors, 'sum', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return sum.add_(tensors).lazy_mul(decay, clone=True)

AdGD

Bases: torchzero.core.transform.Transform

AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)

Source code in torchzero/modules/step_size/adaptive.py
class AdGD(Transform):
    """AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)"""
    def __init__(self, variant:Literal[1,2]=2, alpha_0:float = 1e-7, sqrt:bool=True, use_grad=True, inner: Chainable | None = None,):
        defaults = dict(variant=variant, alpha_0=alpha_0, sqrt=sqrt)
        super().__init__(defaults, uses_grad=use_grad, inner=inner,)

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('prev_g')
        self.global_state['reset'] = True

    @torch.no_grad
    def update_tensors(self, tensors, params, grads, loss, states, settings):
        variant = settings[0]['variant']
        theta_0 = 0 if variant == 1 else 1/3
        theta = self.global_state.get('theta', theta_0)

        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        p = TensorList(params)
        g = grads if self._uses_grad else tensors
        assert g is not None
        g = TensorList(g)

        prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)

        # online
        if self.global_state.get('reset', False):
            del self.global_state['reset']
            prev_p.copy_(p)
            prev_g.copy_(g)
            return

        if step == 0:
            alpha_0 = settings[0]['alpha_0']
            if alpha_0 is None: alpha_0 = epsilon_step_size(g)
            self.global_state['alpha']  = alpha_0
            prev_p.copy_(p)
            prev_g.copy_(g)
            return

        sqrt = settings[0]['sqrt']
        alpha = self.global_state.get('alpha', math.inf)
        L = (g - prev_g).global_vector_norm() / (p - prev_p).global_vector_norm()
        eps = torch.finfo(L.dtype).tiny * 2

        if variant == 1:
            a1 = math.sqrt(1 + theta)*alpha
            val = math.sqrt(2) if sqrt else 2
            if L > eps: a2 = 1 / (val*L)
            else: a2 = math.inf

        elif variant == 2:
            a1 = math.sqrt(2/3 + theta)*alpha
            a2 = alpha / math.sqrt(max(eps, 2 * alpha**2 * L**2 - 1))

        else:
            raise ValueError(variant)

        alpha_new = min(a1, a2)
        if alpha_new < 0: alpha_new = max(a1, a2)
        if alpha_new > eps:
            self.global_state['theta'] = alpha_new/alpha
            self.global_state['alpha'] = alpha_new

        prev_p.copy_(p)
        prev_g.copy_(g)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        alpha = self.global_state.get('alpha', None)

        if not _acceptable_alpha(alpha, tensors[0]):
            # alpha isn't None on 1st step
            self.state.clear()
            self.global_state.clear()
            alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])

        torch._foreach_mul_(tensors, alpha)
        return tensors

    def get_H(self, var):
        return _get_H(self, var)

AdaHessian

Bases: torchzero.core.module.Module

AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)

This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.

Notes
  • In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply AdaHessian preconditioning to another module's output.

  • If you are using gradient estimators or reformulations, set hvp_method to "forward" or "central".

  • This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

Parameters:

  • beta1 (float, default: 0.9 ) –

    first momentum. Defaults to 0.9.

  • beta2 (float, default: 0.999 ) –

    second momentum for squared hessian diagonal estimates. Defaults to 0.999.

  • averaging (bool, default: True ) –

    whether to enable block diagonal averaging over 1st dimension on parameters that have 2+ dimensions. This can be set per-parameter in param groups.

  • block_size (int, default: None ) –

    size of block in the block-diagonal averaging.

  • update_freq (int, default: 1 ) –

    frequency of updating hessian diagonal estimate via a hessian-vector product. This value can be increased to reduce computational cost. Defaults to 1.

  • eps (float, default: 1e-08 ) –

    division stability epsilon. Defaults to 1e-8.

  • hvp_method (str, default: 'autograd' ) –

    Determines how Hessian-vector products are evaluated.

    • "autograd": Use PyTorch's autograd to calculate exact HVPs. This requires creating a graph for the gradient.
    • "forward": Use a forward finite difference formula to approximate the HVP. This requires one extra gradient evaluation.
    • "central": Use a central finite difference formula for a more accurate HVP approximation. This requires two extra gradient evaluations. Defaults to "autograd".
  • fd_h (float, default: 0.001 ) –

    finite difference step size if hvp_method is "forward" or "central". Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of hessian-vector products with random vectors to evaluate each time when updating the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.

  • seed (int | None, default: None ) –

    seed for random vectors. Defaults to None.

  • inner (Chainable | None, default: None ) –

    Inner module. If this is specified, operations are performed in the following order. 1. compute hessian diagonal estimate. 2. pass inputs to inner. 3. momentum and preconditioning are applied to the ouputs of inner.

Examples:

Using AdaHessian:

opt = tz.Modular(
    model.parameters(),
    tz.m.AdaHessian(),
    tz.m.LR(0.1)
)

AdaHessian preconditioner can be applied to any other module by passing it to the inner argument. Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying AdaHessian preconditioning to nesterov momentum (tz.m.NAG):

opt = tz.Modular(
    model.parameters(),
    tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
    tz.m.LR(0.1)
)

Source code in torchzero/modules/adaptive/adahessian.py
class AdaHessian(Module):
    """AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)

    This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.

    Notes:
        - In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply AdaHessian preconditioning to another module's output.

        - If you are using gradient estimators or reformulations, set ``hvp_method`` to "forward" or "central".

        - This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        beta1 (float, optional): first momentum. Defaults to 0.9.
        beta2 (float, optional): second momentum for squared hessian diagonal estimates. Defaults to 0.999.
        averaging (bool, optional):
            whether to enable block diagonal averaging over 1st dimension on parameters that have 2+ dimensions.
            This can be set per-parameter in param groups.
        block_size (int, optional):
            size of block in the block-diagonal averaging.
        update_freq (int, optional):
            frequency of updating hessian diagonal estimate via a hessian-vector product.
            This value can be increased to reduce computational cost. Defaults to 1.
        eps (float, optional):
            division stability epsilon. Defaults to 1e-8.
        hvp_method (str, optional):
            Determines how Hessian-vector products are evaluated.

            - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
              This requires creating a graph for the gradient.
            - ``"forward"``: Use a forward finite difference formula to
              approximate the HVP. This requires one extra gradient evaluation.
            - ``"central"``: Use a central finite difference formula for a
              more accurate HVP approximation. This requires two extra
              gradient evaluations.
            Defaults to "autograd".
        fd_h (float, optional): finite difference step size if ``hvp_method`` is "forward" or "central". Defaults to 1e-3.
        n_samples (int, optional):
            number of hessian-vector products with random vectors to evaluate each time when updating
            the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
        seed (int | None, optional): seed for random vectors. Defaults to None.
        inner (Chainable | None, optional):
            Inner module. If this is specified, operations are performed in the following order.
            1. compute hessian diagonal estimate.
            2. pass inputs to ``inner``.
            3. momentum and preconditioning are applied to the ouputs of ``inner``.

    ## Examples:

    Using AdaHessian:

    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.AdaHessian(),
        tz.m.LR(0.1)
    )
    ```

    AdaHessian preconditioner can be applied to any other module by passing it to the ``inner`` argument.
    Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
    AdaHessian preconditioning to nesterov momentum (``tz.m.NAG``):
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
        tz.m.LR(0.1)
    )
    ```

    """
    def __init__(
        self,
        beta1: float = 0.9,
        beta2: float = 0.999,
        averaging: bool = True,
        block_size: int | None = None,
        update_freq: int = 1,
        eps: float = 1e-8,
        hessian_power: float = 1,
        hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
        fd_h: float = 1e-3,
        n_samples = 1,
        seed: int | None = None,
        inner: Chainable | None = None
    ):
        defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, averaging=averaging, block_size=block_size, eps=eps, hessian_power=hessian_power, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
        super().__init__(defaults)

        if inner is not None:
            self.set_child('inner', inner)

    @torch.no_grad
    def step(self, var):
        params = var.params
        settings = self.settings[params[0]]
        hvp_method = settings['hvp_method']
        fd_h = settings['fd_h']
        update_freq = settings['update_freq']
        n_samples = settings['n_samples']

        seed = settings['seed']
        generator = self.get_generator(params[0].device, seed)

        beta1, beta2, eps, averaging, block_size, hessian_power = self.get_settings(params,
            'beta1', 'beta2', 'eps', 'averaging', 'block_size', "hessian_power", cls=NumberList)

        exp_avg, D_exp_avg_sq = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)

        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        closure = var.closure
        assert closure is not None

        D = None
        if step % update_freq == 0:

            rgrad=None
            for i in range(n_samples):
                u = [_rademacher_like(p, generator=generator) for p in params]

                Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
                                     h=fd_h, normalize=True, retain_grad=i < n_samples-1)
                Hvp = tuple(Hvp)

                if D is None: D = Hvp
                else: torch._foreach_add_(D, Hvp)

            assert D is not None
            if n_samples > 1: torch._foreach_div_(D, n_samples)

            D = TensorList(D).zipmap_args(_block_average, block_size, averaging)

        update = var.get_update()
        if 'inner' in self.children:
            update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)

        var.update = adahessian(
            tensors=TensorList(update),
            D=TensorList(D) if D is not None else None,
            exp_avg_=exp_avg,
            D_exp_avg_sq_=D_exp_avg_sq,
            beta1=beta1,
            beta2=beta2,
            update_freq=update_freq,
            eps=eps,
            hessian_power=hessian_power,
            step=step,
        )
        return var

Adagrad

Bases: torchzero.core.transform.Transform

Adagrad, divides by sum of past squares of gradients.

This implementation is identical to torch.optim.Adagrad.

Parameters:

  • lr_decay (float, default: 0 ) –

    learning rate decay. Defaults to 0.

  • initial_accumulator_value (float, default: 0 ) –

    initial value of the sum of squares of gradients. Defaults to 0.

  • eps (float, default: 1e-10 ) –

    division epsilon. Defaults to 1e-10.

  • alpha (float, default: 1 ) –

    step size. Defaults to 1.

  • pow (float, default: 2 ) –

    power for gradients and accumulator root. Defaults to 2.

  • use_sqrt (bool, default: True ) –

    whether to take the root of the accumulator. Defaults to True.

  • inner (Chainable | None, default: None ) –

    Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.

Source code in torchzero/modules/adaptive/adagrad.py
class Adagrad(Transform):
    """Adagrad, divides by sum of past squares of gradients.

    This implementation is identical to ``torch.optim.Adagrad``.

    Args:
        lr_decay (float, optional): learning rate decay. Defaults to 0.
        initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
        eps (float, optional): division epsilon. Defaults to 1e-10.
        alpha (float, optional): step size. Defaults to 1.
        pow (float, optional): power for gradients and accumulator root. Defaults to 2.
        use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
        inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
    """
    def __init__(
        self,
        lr_decay: float = 0,
        initial_accumulator_value: float = 0,
        eps: float = 1e-10,
        alpha: float = 1,
        pow: float = 2,
        use_sqrt: bool = True,
        divide: bool=False,
        beta:float | None = None,
        decay: float | None = None,
        inner: Chainable | None = None,
    ):
        defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
                        eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide, beta=beta, decay=decay)
        super().__init__(defaults=defaults, uses_grad=False)

        if inner is not None:
            self.set_child('inner', inner)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)

        pow, use_sqrt, divide = itemgetter('pow', 'use_sqrt', 'divide')(settings[0])

        sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)

        # initialize accumulator on 1st step
        if step == 1:
            sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))

        return adagrad_(
            tensors,
            sq_sum_=sq_sum,
            alpha=alpha,
            lr_decay=lr_decay,
            eps=eps,
            step=step,
            pow=pow,
            use_sqrt=use_sqrt,
            divide=divide,

            beta = self.defaults["beta"],
            decay = self.defaults["decay"],
            # inner args
            inner=self.children.get("inner", None),
            params=params,
            grads=grads,
        )

AdagradNorm

Bases: torchzero.core.transform.Transform

Adagrad-Norm, divides by sum of past means of squares of gradients.

Parameters:

  • lr_decay (float, default: 0 ) –

    learning rate decay. Defaults to 0.

  • initial_accumulator_value (float, default: 0 ) –

    initial value of the sum of squares of gradients. Defaults to 0.

  • eps (float, default: 1e-10 ) –

    division epsilon. Defaults to 1e-10.

  • alpha (float, default: 1 ) –

    step size. Defaults to 1.

  • pow (float, default: 2 ) –

    power for gradients and accumulator root. Defaults to 2.

  • use_sqrt (bool, default: True ) –

    whether to take the root of the accumulator. Defaults to True.

  • inner (Chainable | None, default: None ) –

    Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.

Source code in torchzero/modules/adaptive/adagrad.py
class AdagradNorm(Transform):
    """Adagrad-Norm, divides by sum of past means of squares of gradients.

    Args:
        lr_decay (float, optional): learning rate decay. Defaults to 0.
        initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
        eps (float, optional): division epsilon. Defaults to 1e-10.
        alpha (float, optional): step size. Defaults to 1.
        pow (float, optional): power for gradients and accumulator root. Defaults to 2.
        use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
        inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
    """
    def __init__(
        self,
        lr_decay: float = 0,
        initial_accumulator_value: float = 0,
        eps: float = 1e-10,
        alpha: float = 1,
        pow: float = 2,
        use_sqrt: bool = True,
        divide: bool=False,
        beta:float | None = None,
        decay: float | None = None,
        inner: Chainable | None = None,
    ):
        defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
                        eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide, beta=beta, decay=decay)
        super().__init__(defaults=defaults, uses_grad=False)

        if inner is not None:
            self.set_child('inner', inner)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1
        lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)

        use_sqrt, divide, initial_accumulator_value = itemgetter('use_sqrt', 'divide', "initial_accumulator_value")(settings[0])

        accumulator = self.global_state.get("accumulator", initial_accumulator_value)

        d, self.global_state["accumulator"] = adagrad_norm_(
            tensors,
            accumulator=accumulator,
            alpha=alpha,
            lr_decay=lr_decay,
            eps=eps,
            step=step,
            use_sqrt=use_sqrt,
            divide=divide,

            beta = self.defaults["beta"],
            decay = self.defaults["decay"],
            # inner args
            inner=self.children.get("inner", None),
            params=params,
            grads=grads,
        )

        return d

Adam

Bases: torchzero.core.transform.Transform

Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.

This implementation is identical to :code:torch.optim.Adam.

Parameters:

  • beta1 (float, default: 0.9 ) –

    momentum. Defaults to 0.9.

  • beta2 (float, default: 0.999 ) –

    second momentum. Defaults to 0.999.

  • eps (float, default: 1e-08 ) –

    epsilon. Defaults to 1e-8.

  • alpha (float, default: 1.0 ) –

    learning rate. Defaults to 1.

  • amsgrad (bool, default: False ) –

    Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.

  • pow (float, default: 2 ) –

    power used in second momentum power and root. Defaults to 2.

  • debiased (bool, default: True ) –

    whether to apply debiasing to momentums based on current step. Defaults to True.

Source code in torchzero/modules/adaptive/adam.py
class Adam(Transform):
    """Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.

    This implementation is identical to :code:`torch.optim.Adam`.

    Args:
        beta1 (float, optional): momentum. Defaults to 0.9.
        beta2 (float, optional): second momentum. Defaults to 0.999.
        eps (float, optional): epsilon. Defaults to 1e-8.
        alpha (float, optional): learning rate. Defaults to 1.
        amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
        pow (float, optional): power used in second momentum power and root. Defaults to 2.
        debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
    """
    def __init__(
        self,
        beta1: float = 0.9,
        beta2: float = 0.999,
        eps: float = 1e-8,
        amsgrad: bool = False,
        alpha: float = 1.,
        pow: float = 2,
        debiased: bool = True,
        inner: Chainable | None = None
    ):
        defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
        super().__init__(defaults, uses_grad=False)

        if inner is not None: self.set_child('inner', inner)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
        amsgrad,pow,debiased = itemgetter('amsgrad','pow','debiased')(settings[0])

        if amsgrad:
            exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None


        return adam_(
            tensors=TensorList(tensors),
            exp_avg_=exp_avg,
            exp_avg_sq_=exp_avg_sq,
            alpha=alpha,
            beta1=beta1,
            beta2=beta2,
            eps=eps,
            step=step,
            pow=pow,
            debiased=debiased,
            max_exp_avg_sq_=max_exp_avg_sq,

            # inner args
            inner=self.children.get("inner", None),
            params=params,
            grads=grads,

        )

Adan

Bases: torchzero.core.transform.Transform

Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677

Parameters:

  • beta1 (float, default: 0.98 ) –

    momentum. Defaults to 0.98.

  • beta2 (float, default: 0.92 ) –

    momentum for gradient differences. Defaults to 0.92.

  • beta3 (float, default: 0.99 ) –

    thrid (squared) momentum. Defaults to 0.99.

  • eps (float, default: 1e-08 ) –

    epsilon. Defaults to 1e-8.

  • use_n_prev (bool) –

    whether to use previous gradient differences momentum.

Example: ```python opt = tz.Modular( model.parameters(), tz.m.Adan(), tz.m.LR(1e-3), ) Reference: Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence. https://arxiv.org/abs/2208.06677

Source code in torchzero/modules/adaptive/adan.py
class Adan(Transform):
    """Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677

    Args:
        beta1 (float, optional): momentum. Defaults to 0.98.
        beta2 (float, optional): momentum for gradient differences. Defaults to 0.92.
        beta3 (float, optional): thrid (squared) momentum. Defaults to 0.99.
        eps (float, optional): epsilon. Defaults to 1e-8.
        use_n_prev (bool, optional):
            whether to use previous gradient differences momentum.

    Example:
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.Adan(),
        tz.m.LR(1e-3),
    )
    Reference:
        Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence. https://arxiv.org/abs/2208.06677
    """
    def __init__(
        self,
        beta1: float = 0.98,
        beta2: float = 0.92,
        beta3: float = 0.99,
        eps: float = 1e-8,
    ):
        defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        beta1,beta2,beta3,eps=unpack_dicts(settings, 'beta1','beta2','beta3','eps', cls=NumberList)
        g_prev, m, v, n = unpack_states(states, tensors, 'g_prev','m','v','n', cls=TensorList)

        update = adan_(
            g=tensors,
            g_prev_=g_prev,
            m_=m,
            v_=v,
            n_=n,
            beta1=beta1,
            beta2=beta2,
            beta3=beta3,
            eps=eps,
            step=step,
        )

        return update

AdaptiveBacktracking

Bases: torchzero.modules.line_search.line_search.LineSearchBase

Adaptive backtracking line search. After each line search procedure, a new initial step size is set such that optimal step size in the procedure would be found on the second line search iteration.

Parameters:

  • init (float, default: 1.0 ) –

    initial step size. Defaults to 1.0.

  • beta (float, default: 0.5 ) –

    multiplies each consecutive step size by this value. Defaults to 0.5.

  • c (float, default: 0.0001 ) –

    sufficient decrease condition. Defaults to 1e-4.

  • condition (Literal, default: 'armijo' ) –

    termination condition, only ones that do not use gradient at f(x+a*d) can be specified. - "armijo" - sufficient decrease condition. - "decrease" - any decrease in objective function value satisfies the condition.

    "goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage. Defaults to 'armijo'.

  • maxiter (int, default: 20 ) –

    maximum number of function evaluations per step. Defaults to 10.

  • target_iters (int, default: 1 ) –

    sets next step size such that this number of iterations are expected to be performed until optimal step size is found. Defaults to 1.

  • nplus (float, default: 2.0 ) –

    if initial step size is optimal, it is multiplied by this value. Defaults to 2.0.

  • scale_beta (float, default: 0.0 ) –

    momentum for initial step size, at 0 disables momentum. Defaults to 0.0.

Source code in torchzero/modules/line_search/backtracking.py
class AdaptiveBacktracking(LineSearchBase):
    """Adaptive backtracking line search. After each line search procedure, a new initial step size is set
    such that optimal step size in the procedure would be found on the second line search iteration.

    Args:
        init (float, optional): initial step size. Defaults to 1.0.
        beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
        c (float, optional): sufficient decrease condition. Defaults to 1e-4.
        condition (TerminationCondition, optional):
            termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
            - "armijo" - sufficient decrease condition.
            - "decrease" - any decrease in objective function value satisfies the condition.

            "goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
            Defaults to 'armijo'.
        maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
        target_iters (int, optional):
            sets next step size such that this number of iterations are expected
            to be performed until optimal step size is found. Defaults to 1.
        nplus (float, optional):
            if initial step size is optimal, it is multiplied by this value. Defaults to 2.0.
        scale_beta (float, optional):
            momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
    """
    def __init__(
        self,
        init: float = 1.0,
        beta: float = 0.5,
        c: float = 1e-4,
        condition: TerminationCondition = 'armijo',
        maxiter: int = 20,
        target_iters = 1,
        nplus = 2.0,
        scale_beta = 0.0,
    ):
        defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta)
        super().__init__(defaults=defaults)

        self.global_state['beta_scale'] = 1.0
        self.global_state['initial_scale'] = 1.0

    def reset(self):
        super().reset()
        self.global_state['beta_scale'] = 1.0
        self.global_state['initial_scale'] = 1.0

    @torch.no_grad
    def search(self, update, var):
        init, beta, c,condition, maxiter, target_iters, nplus, scale_beta=itemgetter(
            'init','beta','c','condition', 'maxiter','target_iters','nplus','scale_beta')(self.defaults)

        objective = self.make_objective(var=var)

        # directional derivative (0 if c = 0 because it is not needed)
        if c == 0: d = 0
        else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))

        # scale beta
        beta = beta * self.global_state['beta_scale']

        # scale step size so that decrease is expected at target_iters
        init = init * self.global_state['initial_scale']

        step_size = backtracking_line_search(objective, d, init=init, beta=beta, c=c, condition=condition, maxiter=maxiter)

        # found an alpha that reduces loss
        if step_size is not None:

            # update initial_scale
            # initial step size satisfied conditions, increase initial_scale by nplus
            if step_size == init and target_iters > 0:
                self.global_state['initial_scale'] *= nplus ** target_iters

                # clip by maximum possibel value to avoid overflow exception
                self.global_state['initial_scale'] = min(
                    self.global_state['initial_scale'],
                    torch.finfo(var.params[0].dtype).max / 2,
                )

            else:
                # otherwise make initial_scale such that target_iters iterations will satisfy armijo
                init_target = step_size
                for _ in range(target_iters):
                    init_target = step_size / beta

                self.global_state['initial_scale'] = _lerp(
                    self.global_state['initial_scale'], init_target / init, 1-scale_beta
                )

            # revert beta_scale
            self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))

            return step_size

        # on fail reduce beta scale value
        self.global_state['beta_scale'] /= 1.5
        return 0

AdaptiveHeavyBall

Bases: torchzero.core.transform.Transform

Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.

This is related to conjugate gradient methods, it may be very good for non-stochastic convex objectives, but won't work on stochastic ones.

note

The step size is determined by the algorithm, so learning rate modules shouldn't be used.

Parameters:

  • f_star (int, default: 0 ) –

    (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.

Source code in torchzero/modules/adaptive/adaptive_heavyball.py
class AdaptiveHeavyBall(Transform):
    """Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.

    This is related to conjugate gradient methods, it may be very good for non-stochastic convex objectives, but won't work on stochastic ones.

    note:
        The step size is determined by the algorithm, so learning rate modules shouldn't be used.

    Args:
        f_star (int, optional):
            (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
    """
    def __init__(self, f_star: float = 0):
        defaults = dict(f_star=f_star)
        super().__init__(defaults, uses_grad=False, uses_loss=True)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert loss is not None
        tensors = TensorList(tensors)
        f_star = self.defaults['f_star']

        f_prev = self.global_state.get('f_prev', None)
        p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', init=[params,tensors], cls=TensorList)

        if f_prev is None:
            self.global_state['f_prev'] = loss
            h = 2*(loss - f_star) / tensors.dot(tensors)
            return h * tensors

        update = adaptive_heavy_ball(f=loss, f_star=f_star, f_prev=f_prev, g=tensors, g_prev=g_prev, p=TensorList(params), p_prev=p_prev)

        self.global_state['f_prev'] = loss
        p_prev.copy_(params)
        g_prev.copy_(tensors)
        return update

AdaptiveTracking

Bases: torchzero.modules.line_search.line_search.LineSearchBase

A line search that evaluates previous step size, if value increased, backtracks until the value stops decreasing, otherwise forward-tracks until value stops decreasing.

Parameters:

  • init (float, default: 1.0 ) –

    initial step size. Defaults to 1.0.

  • nplus (float, default: 2 ) –

    multiplier to step size if initial step size is optimal. Defaults to 2.

  • nminus (float, default: 0.5 ) –

    multiplier to step size if initial step size is too big. Defaults to 0.5.

  • maxiter (int, default: 10 ) –

    maximum number of function evaluations per step. Defaults to 10.

  • adaptive (bool, default: True ) –

    when enabled, if line search failed, step size will continue decreasing on the next step. Otherwise it will restart the line search from init step size. Defaults to True.

Source code in torchzero/modules/line_search/adaptive.py
class AdaptiveTracking(LineSearchBase):
    """A line search that evaluates previous step size, if value increased, backtracks until the value stops decreasing,
    otherwise forward-tracks until value stops decreasing.

    Args:
        init (float, optional): initial step size. Defaults to 1.0.
        nplus (float, optional): multiplier to step size if initial step size is optimal. Defaults to 2.
        nminus (float, optional): multiplier to step size if initial step size is too big. Defaults to 0.5.
        maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
        adaptive (bool, optional):
            when enabled, if line search failed, step size will continue decreasing on the next step.
            Otherwise it will restart the line search from ``init`` step size. Defaults to True.
    """
    def __init__(
        self,
        init: float = 1.0,
        nplus: float = 2,
        nminus: float = 0.5,
        maxiter: int = 10,
        adaptive=True,
    ):
        defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive)
        super().__init__(defaults=defaults)

    def reset(self):
        super().reset()

    @torch.no_grad
    def search(self, update, var):
        init, nplus, nminus, maxiter, adaptive = itemgetter(
            'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.defaults)

        objective = self.make_objective(var=var)

        # scale a_prev
        a_prev = self.global_state.get('a_prev', init)
        if adaptive: a_prev = a_prev * self.global_state.get('init_scale', 1)

        a_init = a_prev
        if a_init < torch.finfo(var.params[0].dtype).tiny * 2:
            a_init = torch.finfo(var.params[0].dtype).max / 2

        step_size, f, niter = adaptive_tracking(
            objective,
            a_init=a_init,
            maxiter=maxiter,
            nplus=nplus,
            nminus=nminus,
        )

        # found an alpha that reduces loss
        if step_size != 0:
            assert (var.loss is None) or (math.isfinite(f) and f < var.loss)
            self.global_state['init_scale'] = 1

            # if niter == 1, forward tracking failed to decrease function value compared to f_a_prev
            if niter == 1 and step_size >= a_init: step_size *= nminus

            self.global_state['a_prev'] = step_size
            return step_size

        # on fail reduce beta scale value
        self.global_state['init_scale'] = self.global_state.get('init_scale', 1) * nminus**maxiter
        self.global_state['a_prev'] = init
        return 0

Add

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Add :code:other to tensors. :code:other can be a number or a module.

If :code:other is a module, this calculates :code:tensors + other(tensors)

Source code in torchzero/modules/ops/binary.py
class Add(BinaryOperationBase):
    """Add :code:`other` to tensors. :code:`other` can be a number or a module.

    If :code:`other` is a module, this calculates :code:`tensors + other(tensors)`
    """
    def __init__(self, other: Chainable | float, alpha: float = 1):
        defaults = dict(alpha=alpha)
        super().__init__(defaults, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.defaults['alpha'])
        else: torch._foreach_add_(update, other, alpha=self.defaults['alpha'])
        return update

Alternate

Bases: torchzero.core.module.Module

Alternates between stepping with :code:modules.

That is, first step is performed with 1st module, second step with second module, etc.

Parameters:

  • steps (int | Iterable[int], default: 1 ) –

    number of steps to perform with each module. Defaults to 1.

Examples:

Alternate between Adam, SignSGD and RMSprop

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.Alternate(
        tz.m.Adam(),
        [tz.m.SignSGD(), tz.m.Mul(0.5)],
        tz.m.RMSprop(),
    ),
    tz.m.LR(1e-3),
)
Source code in torchzero/modules/misc/switch.py
class Alternate(Module):
    """Alternates between stepping with :code:`modules`.

    That is, first step is performed with 1st module, second step with second module, etc.

    Args:
        steps (int | Iterable[int], optional): number of steps to perform with each module. Defaults to 1.

    Examples:
        Alternate between Adam, SignSGD and RMSprop

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.Alternate(
                    tz.m.Adam(),
                    [tz.m.SignSGD(), tz.m.Mul(0.5)],
                    tz.m.RMSprop(),
                ),
                tz.m.LR(1e-3),
            )
    """
    LOOP = True
    def __init__(self, *modules: Chainable, steps: int | Iterable[int] = 1):
        if isinstance(steps, Iterable):
            steps = list(steps)
            if len(steps) != len(modules):
                raise ValueError(f"steps must be the same length as modules, got {len(modules) = }, {len(steps) = }")

        defaults = dict(steps=steps)
        super().__init__(defaults)

        self.set_children_sequence(modules)
        self.global_state['current_module_idx'] = 0
        self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps

    @torch.no_grad
    def step(self, var):
        # get current module
        current_module_idx = self.global_state.setdefault('current_module_idx', 0)
        module = self.children[f'module_{current_module_idx}']

        # step
        var = module.step(var.clone(clone_update=False))

        # number of steps until next module
        steps = self.defaults['steps']
        if isinstance(steps, int): steps = [steps]*len(self.children)

        if 'steps_to_next' not in self.global_state:
            self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps

        self.global_state['steps_to_next'] -= 1

        # switch to next module
        if self.global_state['steps_to_next'] == 0:
            self.global_state['current_module_idx'] += 1

            # loop to first module (or keep using last module on Switch)
            if self.global_state['current_module_idx'] > len(self.children) - 1:
                if self.LOOP: self.global_state['current_module_idx'] = 0
                else: self.global_state['current_module_idx'] = len(self.children) - 1

            self.global_state['steps_to_next'] = steps[self.global_state['current_module_idx']]

        return var

LOOP class-attribute

LOOP = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

Averaging

Bases: torchzero.core.transform.TensorwiseTransform

Average of past history_size updates.

Parameters:

  • history_size (int) –

    Number of past updates to average

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/momentum/averaging.py
class Averaging(TensorwiseTransform):
    """Average of past ``history_size`` updates.

    Args:
        history_size (int): Number of past updates to average
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, history_size: int, target: Target = 'update'):
        defaults = dict(history_size=history_size)
        super().__init__(uses_grad=False, defaults=defaults, target=target)

    @torch.no_grad
    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        history_size = setting['history_size']
        if 'history' not in state:
            state['history'] = deque(maxlen=history_size)
            state['average'] = torch.zeros_like(tensor)

        history = state['history']; average = state['average']
        if len(history) == history_size: average -= history[0]
        history.append(tensor)
        average += tensor

        return average / len(history)

BBStab

Bases: torchzero.core.transform.Transform

Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).

This clips the norm of the Barzilai-Borwein update by delta, where delta can be adaptive if c is specified.

Parameters:

  • c (float, default: 0.2 ) –

    adaptive delta parameter. If delta is set to None, first inf_iters updates are performed with non-stabilized Barzilai-Borwein step size. Then delta is set to norm of the update that had the smallest norm, and multiplied by c. Defaults to 0.2.

  • delta (float | None, default: None ) –

    Barzilai-Borwein update is clipped to this value. Set to None to use an adaptive choice. Defaults to None.

  • type (str, default: 'geom' ) –

    one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long. Defaults to "geom". Note that "long" corresponds to BB1stab and "short" to BB2stab, however I found that "geom" works really well.

  • inner (Chainable | None, default: None ) –

    step size will be applied to outputs of this module. Defaults to None.

Source code in torchzero/modules/step_size/adaptive.py
class BBStab(Transform):
    """Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).

    This clips the norm of the Barzilai-Borwein update by ``delta``, where ``delta`` can be adaptive if ``c`` is specified.

    Args:
        c (float, optional):
            adaptive delta parameter. If ``delta`` is set to None, first ``inf_iters`` updates are performed
            with non-stabilized Barzilai-Borwein step size. Then delta is set to norm of
            the update that had the smallest norm, and multiplied by ``c``. Defaults to 0.2.
        delta (float | None, optional):
            Barzilai-Borwein update is clipped to this value. Set to ``None`` to use an adaptive choice. Defaults to None.
        type (str, optional):
            one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
            Defaults to "geom". Note that "long" corresponds to BB1stab and "short" to BB2stab,
            however I found that "geom" works really well.
        inner (Chainable | None, optional):
            step size will be applied to outputs of this module. Defaults to None.

    """
    def __init__(
        self,
        c=0.2,
        delta:float | None = None,
        type: Literal["long", "short", "geom", "geom-fallback"] = "geom",
        alpha_0: float = 1e-7,
        use_grad=True,
        inf_iters: int = 3,
        inner: Chainable | None = None,
    ):
        defaults = dict(type=type,alpha_0=alpha_0, c=c, delta=delta, inf_iters=inf_iters)
        super().__init__(defaults, uses_grad=use_grad, inner=inner)

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('prev_g')
        self.global_state['reset'] = True

    @torch.no_grad
    def update_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
        type = self.defaults['type']
        c = self.defaults['c']
        delta = self.defaults['delta']
        inf_iters = self.defaults['inf_iters']

        g = grads if self._uses_grad else tensors
        assert g is not None
        g = TensorList(g)

        reset = self.global_state.get('reset', False)
        self.global_state.pop('reset', None)

        if step != 0 and not reset:
            s = params-prev_p
            y = g-prev_g
            sy = s.dot(y)
            eps = torch.finfo(sy.dtype).tiny

            if type == 'short': alpha = _bb_short(s, y, sy, eps)
            elif type == 'long': alpha = _bb_long(s, y, sy, eps)
            elif type == 'geom': alpha = _bb_geom(s, y, sy, eps, fallback=False)
            elif type == 'geom-fallback': alpha = _bb_geom(s, y, sy, eps, fallback=True)
            else: raise ValueError(type)

            if alpha is not None:

                # adaptive delta
                if delta is None:
                    niters = self.global_state.get('niters', 0) # this accounts for skipped negative curvature steps
                    self.global_state['niters'] = niters + 1


                    if niters == 0: pass # 1st iteration is scaled GD step, shouldn't be used to find s_norm_min
                    elif niters <= inf_iters:
                        s_norm_min = self.global_state.get('s_norm_min', None)
                        if s_norm_min is None: s_norm_min = s.global_vector_norm()
                        else: s_norm_min = min(s_norm_min, s.global_vector_norm())
                        self.global_state['s_norm_min'] = s_norm_min
                        # first few steps use delta=inf, so delta remains None

                    else:
                        delta = c * self.global_state['s_norm_min']

                if delta is None: # delta is inf for first few steps
                    self.global_state['alpha'] = alpha

                # BBStab step size
                else:
                    a_stab = delta / g.global_vector_norm()
                    self.global_state['alpha'] = min(alpha, a_stab)

        prev_p.copy_(params)
        prev_g.copy_(g)

    def get_H(self, var):
        return _get_H(self, var)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        alpha = self.global_state.get('alpha', None)

        if not _acceptable_alpha(alpha, tensors[0]):
            alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])

        torch._foreach_mul_(tensors, alpha)
        return tensors

BFGS

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.

Note

a line search or a trust region is recommended

Warning

this uses at least O(N^2) memory.

Parameters:

  • init_scale (float | Literal['auto'], default: 'auto' ) –

    initial hessian matrix is set to identity times this.

    "auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.

    Defaults to "auto".

  • tol (float, default: 1e-32 ) –

    tolerance on curvature condition. Defaults to 1e-32.

  • ptol (float | None, default: 1e-32 ) –

    skips update if maximum difference between current and previous gradients is less than this, to avoid instability. Defaults to 1e-32.

  • ptol_restart (bool, default: False ) –

    whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.

  • restart_interval (int | None | Literal['auto'], default: None ) –

    interval between resetting the hessian approximation.

    "auto" corresponds to number of decision variables + 1.

    None - no resets.

    Defaults to None.

  • beta (float | None, default: None ) –

    momentum on H or B. Defaults to None.

  • update_freq (int, default: 1 ) –

    frequency of updating H or B. Defaults to 1.

  • scale_first (bool, default: False ) –

    whether to downscale first step before hessian approximation becomes available. Defaults to True.

  • scale_second (bool) –

    whether to downscale second step. Defaults to False.

  • concat_params (bool, default: True ) –

    If true, all parameters are treated as a single vector. If False, the update rule is applied to each parameter separately. Defaults to True.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to the output of this module. Defaults to None.

Examples:

BFGS with backtracking line search:

opt = tz.Modular(
    model.parameters(),
    tz.m.BFGS(),
    tz.m.Backtracking()
)

BFGS with trust region

opt = tz.Modular(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.BFGS(inverse=False)),
)

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class BFGS(_InverseHessianUpdateStrategyDefaults):
    """Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.

    Note:
        a line search or a trust region is recommended

    Warning:
        this uses at least O(N^2) memory.

    Args:
        init_scale (float | Literal["auto"], optional):
            initial hessian matrix is set to identity times this.

            "auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.

            Defaults to "auto".
        tol (float, optional):
            tolerance on curvature condition. Defaults to 1e-32.
        ptol (float | None, optional):
            skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
            Defaults to 1e-32.
        ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
        restart_interval (int | None | Literal["auto"], optional):
            interval between resetting the hessian approximation.

            "auto" corresponds to number of decision variables + 1.

            None - no resets.

            Defaults to None.
        beta (float | None, optional): momentum on H or B. Defaults to None.
        update_freq (int, optional): frequency of updating H or B. Defaults to 1.
        scale_first (bool, optional):
            whether to downscale first step before hessian approximation becomes available. Defaults to True.
        scale_second (bool, optional): whether to downscale second step. Defaults to False.
        concat_params (bool, optional):
            If true, all parameters are treated as a single vector.
            If False, the update rule is applied to each parameter separately. Defaults to True.
        inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.

    ## Examples:

    BFGS with backtracking line search:

    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.BFGS(),
        tz.m.Backtracking()
    )
    ```

    BFGS with trust region
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.LevenbergMarquardt(tz.m.BFGS(inverse=False)),
    )
    ```
    """

    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return bfgs_B_(B=B, s=s, y=y, tol=setting['tol'])

BacktrackOnSignChange

Bases: torchzero.core.transform.Transform

Negates or undoes update for parameters where where gradient or update sign changes.

This is part of RProp update rule.

Parameters:

  • use_grad (bool, default: False ) –

    if True, tracks sign change of the gradient, otherwise track sign change of the update. Defaults to True.

  • backtrack (bool, default: True ) –

    if True, undoes the update when sign changes, otherwise negates it. Defaults to True.

Source code in torchzero/modules/adaptive/rprop.py
class BacktrackOnSignChange(Transform):
    """Negates or undoes update for parameters where where gradient or update sign changes.

    This is part of RProp update rule.

    Args:
        use_grad (bool, optional):
            if True, tracks sign change of the gradient,
            otherwise track sign change of the update. Defaults to True.
        backtrack (bool, optional):
            if True, undoes the update when sign changes, otherwise negates it.
            Defaults to True.

    """
    def __init__(self, use_grad = False, backtrack = True, target: Target = 'update'):
        defaults = dict(use_grad=use_grad, backtrack=backtrack, target=target)
        super().__init__(defaults, uses_grad=use_grad)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        tensors = as_tensorlist(tensors)
        use_grad = settings[0]['use_grad']
        backtrack = settings[0]['backtrack']

        if use_grad: cur = as_tensorlist(grads)
        else: cur = tensors

        tensors = backtrack_on_sign_change_(
            tensors_ = tensors,
            cur = cur,
            prev_ = unpack_states(states, tensors, 'prev', cls=TensorList),
            backtrack = backtrack,
            step = step,
        )

        return tensors

Backtracking

Bases: torchzero.modules.line_search.line_search.LineSearchBase

Backtracking line search.

Parameters:

  • init (float, default: 1.0 ) –

    initial step size. Defaults to 1.0.

  • beta (float, default: 0.5 ) –

    multiplies each consecutive step size by this value. Defaults to 0.5.

  • c (float, default: 0.0001 ) –

    sufficient decrease condition. Defaults to 1e-4.

  • condition (Literal, default: 'armijo' ) –

    termination condition, only ones that do not use gradient at f(x+a*d) can be specified. - "armijo" - sufficient decrease condition. - "decrease" - any decrease in objective function value satisfies the condition.

    "goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage. Defaults to 'armijo'.

  • maxiter (int, default: 10 ) –

    maximum number of function evaluations per step. Defaults to 10.

  • adaptive (bool, default: True ) –

    when enabled, if line search failed, step size will continue decreasing on the next step. Otherwise it will restart the line search from init step size. Defaults to True.

Examples: Gradient descent with backtracking line search:

opt = tz.Modular(
    model.parameters(),
    tz.m.Backtracking()
)

L-BFGS with backtracking line search:

opt = tz.Modular(
    model.parameters(),
    tz.m.LBFGS(),
    tz.m.Backtracking()
)

Source code in torchzero/modules/line_search/backtracking.py
class Backtracking(LineSearchBase):
    """Backtracking line search.

    Args:
        init (float, optional): initial step size. Defaults to 1.0.
        beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
        c (float, optional): sufficient decrease condition. Defaults to 1e-4.
        condition (TerminationCondition, optional):
            termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
            - "armijo" - sufficient decrease condition.
            - "decrease" - any decrease in objective function value satisfies the condition.

            "goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
            Defaults to 'armijo'.
        maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
        adaptive (bool, optional):
            when enabled, if line search failed, step size will continue decreasing on the next step.
            Otherwise it will restart the line search from ``init`` step size. Defaults to True.

    Examples:
    Gradient descent with backtracking line search:

    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.Backtracking()
    )
    ```

    L-BFGS with backtracking line search:
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.LBFGS(),
        tz.m.Backtracking()
    )
    ```

    """
    def __init__(
        self,
        init: float = 1.0,
        beta: float = 0.5,
        c: float = 1e-4,
        condition: TerminationCondition = 'armijo',
        maxiter: int = 10,
        adaptive=True,
    ):
        defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,adaptive=adaptive)
        super().__init__(defaults=defaults)

    def reset(self):
        super().reset()

    @torch.no_grad
    def search(self, update, var):
        init, beta, c, condition, maxiter, adaptive = itemgetter(
            'init', 'beta', 'c', 'condition', 'maxiter', 'adaptive')(self.defaults)

        objective = self.make_objective(var=var)

        # # directional derivative
        if c == 0: d = 0
        else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))

        # scale init
        init_scale = self.global_state.get('init_scale', 1)
        if adaptive: init = init * init_scale

        step_size = backtracking_line_search(objective, d, init=init, beta=beta,c=c, condition=condition, maxiter=maxiter)

        # found an alpha that reduces loss
        if step_size is not None:
            #self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
            self.global_state['init_scale'] = 1
            return step_size

        # on fail set init_scale to continue decreasing the step size
        # or set to large step size when it becomes too small
        if adaptive:
            finfo = torch.finfo(var.params[0].dtype)
            if init_scale <= finfo.tiny * 2:
                self.global_state["init_scale"] = finfo.max / 2
            else:
                self.global_state['init_scale'] = init_scale * beta**maxiter
        return 0

BarzilaiBorwein

Bases: torchzero.core.transform.Transform

Barzilai-Borwein step size method.

Parameters:

  • type (str, default: 'geom' ) –

    one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long. Defaults to "geom".

  • fallback (float) –

    step size when denominator is less than 0 (will happen on negative curvature). Defaults to 1e-3.

  • inner (Chainable | None, default: None ) –

    step size will be applied to outputs of this module. Defaults to None.

Source code in torchzero/modules/step_size/adaptive.py
class BarzilaiBorwein(Transform):
    """Barzilai-Borwein step size method.

    Args:
        type (str, optional):
            one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
            Defaults to "geom".
        fallback (float, optional): step size when denominator is less than 0 (will happen on negative curvature). Defaults to 1e-3.
        inner (Chainable | None, optional):
            step size will be applied to outputs of this module. Defaults to None.
    """

    def __init__(
        self,
        type: Literal["long", "short", "geom", "geom-fallback"] = "geom",
        alpha_0: float = 1e-7,
        use_grad=True,
        inner: Chainable | None = None,
    ):
        defaults = dict(type=type, alpha_0=alpha_0)
        super().__init__(defaults, uses_grad=use_grad, inner=inner)

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('prev_g')
        self.global_state['reset'] = True

    @torch.no_grad
    def update_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
        type = self.defaults['type']

        g = grads if self._uses_grad else tensors
        assert g is not None

        reset = self.global_state.get('reset', False)
        self.global_state.pop('reset', None)

        if step != 0 and not reset:
            s = params-prev_p
            y = g-prev_g
            sy = s.dot(y)
            eps = torch.finfo(sy.dtype).tiny * 2

            if type == 'short': alpha = _bb_short(s, y, sy, eps)
            elif type == 'long': alpha = _bb_long(s, y, sy, eps)
            elif type == 'geom': alpha = _bb_geom(s, y, sy, eps, fallback=False)
            elif type == 'geom-fallback': alpha = _bb_geom(s, y, sy, eps, fallback=True)
            else: raise ValueError(type)

            # if alpha is not None:
            self.global_state['alpha'] = alpha

        prev_p.copy_(params)
        prev_g.copy_(g)

    def get_H(self, var):
        return _get_H(self, var)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        alpha = self.global_state.get('alpha', None)

        if not _acceptable_alpha(alpha, tensors[0]):
            alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])

        torch._foreach_mul_(tensors, alpha)
        return tensors

BinaryOperationBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for operations that use update as the first operand. This is an abstract class, subclass it and override transform method to use it.

Methods:

  • transform

    applies the operation to operands

Source code in torchzero/modules/ops/binary.py
class BinaryOperationBase(Module, ABC):
    """Base class for operations that use update as the first operand. This is an abstract class, subclass it and override `transform` method to use it."""
    def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
        super().__init__(defaults=defaults)

        self.operands = {}
        for k,v in operands.items():

            if isinstance(v, (Module, Sequence)):
                self.set_child(k, v)
                self.operands[k] = self.children[k]
            else:
                self.operands[k] = v

    @abstractmethod
    def transform(self, var: Var, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
        """applies the operation to operands"""
        raise NotImplementedError

    @torch.no_grad
    def step(self, var: Var) -> Var:
        # pass cloned update to all module operands
        processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()

        for k,v in self.operands.items():
            if k in self.children:
                v: Module
                updated_var = v.step(var.clone(clone_update=True))
                processed_operands[k] = updated_var.get_update()
                var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them

        transformed = self.transform(var, update=var.get_update(), **processed_operands)
        var.update = list(transformed)
        return var

transform

transform(var: Var, update: list[Tensor], **operands: Any | list[Tensor]) -> Iterable[Tensor]

applies the operation to operands

Source code in torchzero/modules/ops/binary.py
@abstractmethod
def transform(self, var: Var, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
    """applies the operation to operands"""
    raise NotImplementedError

BirginMartinezRestart

Bases: torchzero.core.module.Module

the restart criterion for conjugate gradient methods designed by Birgin and Martinez.

This criterion restarts when when the angle between dk+1 and −gk+1 is not acute enough.

The restart clears all states of module.

Parameters:

  • module (Module) –

    module to restart, should be a conjugate gradient or possibly a quasi-newton method.

  • cond (float, default: 0.001 ) –

    Restart is performed whenevr d^Tg > -cond||d||||g||. The default condition value of 1e-3 is suggested by Birgin and Martinez.

Reference

Birgin, Ernesto G., and José Mario Martínez. "A spectral conjugate gradient method for unconstrained optimization." Applied Mathematics & Optimization 43.2 (2001): 117-128.

Source code in torchzero/modules/restarts/restars.py
class BirginMartinezRestart(Module):
    """the restart criterion for conjugate gradient methods designed by Birgin and Martinez.

    This criterion restarts when when the angle between dk+1 and −gk+1 is not acute enough.

    The restart clears all states of ``module``.

    Args:
        module (Module):
            module to restart, should be a conjugate gradient or possibly a quasi-newton method.
        cond (float, optional):
            Restart is performed whenevr d^Tg > -cond*||d||*||g||.
            The default condition value of 1e-3 is suggested by Birgin and Martinez.

    Reference:
        Birgin, Ernesto G., and José Mario Martínez. "A spectral conjugate gradient method for unconstrained optimization." Applied Mathematics & Optimization 43.2 (2001): 117-128.
    """
    def __init__(self, module: Module, cond:float = 1e-3):
        defaults=dict(cond=cond)
        super().__init__(defaults)

        self.set_child("module", module)

    def update(self, var):
        module = self.children['module']
        module.update(var)

    def apply(self, var):
        module = self.children['module']
        var = module.apply(var.clone(clone_update=False))

        cond = self.defaults['cond']
        g = TensorList(var.get_grad())
        d = TensorList(var.get_update())
        d_g = d.dot(g)
        d_norm = d.global_vector_norm()
        g_norm = g.global_vector_norm()

        # d in our case is same direction as g so it has a minus sign
        if -d_g > -cond * d_norm * g_norm:
            module.reset()
            var.update = g.clone()
            return var

        return var

BroydenBad

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Broyden's "bad" Quasi-Newton method.

Note

a trust region or an accurate line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class BroydenBad(_InverseHessianUpdateStrategyDefaults):
    """Broyden's "bad" Quasi-Newton method.

    Note:
        a trust region or an accurate line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return broyden_bad_H_(H=H, s=s, y=y)
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return broyden_bad_B_(B=B, s=s, y=y)

BroydenGood

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Broyden's "good" Quasi-Newton method.

Note

a trust region or an accurate line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class BroydenGood(_InverseHessianUpdateStrategyDefaults):
    """Broyden's "good" Quasi-Newton method.

    Note:
        a trust region or an accurate line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return broyden_good_H_(H=H, s=s, y=y)
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return broyden_good_B_(B=B, s=s, y=y)

CCD

Bases: torchzero.core.module.Module

Cumulative coordinate descent. This updates one gradient coordinate at a time and accumulates it to the update direction. The coordinate updated is random weighted by magnitudes of current update direction. As update direction ceases to be a descent direction due to old accumulated coordinates, it is decayed.

Parameters:

  • pmin (float, default: 0.1 ) –

    multiplier to probability of picking the lowest magnitude gradient. Defaults to 0.1.

  • pmax (float, default: 1.0 ) –

    multiplier to probability of picking the largest magnitude gradient. Defaults to 1.0.

  • pow (int, default: 2 ) –

    power transform to probabilities. Defaults to 2.

  • decay (float, default: 0.8 ) –

    accumulated gradient decay on failed step. Defaults to 0.5.

  • decay2 (float, default: 0.2 ) –

    decay multiplier decay on failed step. Defaults to 0.25.

  • nplus (float, default: 1.5 ) –

    step size increase on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.75 ) –

    step size increase on unsuccessful steps. Defaults to 0.75.

Source code in torchzero/modules/zeroth_order/cd.py
class CCD(Module):
    """Cumulative coordinate descent. This updates one gradient coordinate at a time and accumulates it
    to the update direction. The coordinate updated is random weighted by magnitudes of current update direction.
    As update direction ceases to be a descent direction due to old accumulated coordinates, it is decayed.

    Args:
        pmin (float, optional): multiplier to probability of picking the lowest magnitude gradient. Defaults to 0.1.
        pmax (float, optional): multiplier to probability of picking the largest magnitude gradient. Defaults to 1.0.
        pow (int, optional): power transform to probabilities. Defaults to 2.
        decay (float, optional): accumulated gradient decay on failed step. Defaults to 0.5.
        decay2 (float, optional): decay multiplier decay on failed step. Defaults to 0.25.
        nplus (float, optional): step size increase on successful steps. Defaults to 1.5.
        nminus (float, optional): step size increase on unsuccessful steps. Defaults to 0.75.
    """
    def __init__(self, pmin=0.1, pmax=1.0, pow=2, decay:float=0.8, decay2:float=0.2, nplus=1.5, nminus=0.75):

        defaults = dict(pmin=pmin, pmax=pmax, pow=pow, decay=decay, decay2=decay2, nplus=nplus, nminus=nminus)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        closure = var.closure
        if closure is None:
            raise RuntimeError("CD requires closure")

        params = TensorList(var.params)
        p_prev = self.get_state(params, "p_prev", init=params, cls=TensorList)

        f_0 = var.get_loss(False)
        step_size = self.global_state.get('step_size', 1)

        # ------------------------ hard reset on infinite loss ----------------------- #
        if not math.isfinite(f_0):
            del self.global_state['f_prev']
            var.update = params - p_prev
            self.global_state.clear()
            self.state.clear()
            self.global_state["step_size"] = step_size / 10
            return var

        # ---------------------------- soft reset if stuck --------------------------- #
        if "igrad" in self.state[params[0]]:
            n_bad = self.global_state.get('n_bad', 0)

            f_prev = self.global_state.get("f_prev", None)
            if f_prev is not None:

                decay2 = self.defaults["decay2"]
                decay = self.global_state.get("decay", self.defaults["decay"])

                if f_0 >= f_prev:

                    igrad = self.get_state(params, "igrad", cls=TensorList)
                    del self.global_state['f_prev']

                    # undo previous update
                    var.update = params - p_prev

                    # increment n_bad
                    self.global_state['n_bad'] = n_bad + 1

                    # decay step size
                    self.global_state['step_size'] = step_size * self.defaults["nminus"]

                    # soft reset
                    if n_bad > 0:
                        igrad *= decay
                        self.global_state["decay"] = decay*decay2
                        self.global_state['n_bad'] = 0

                    return var

                else:
                    # increase step size and reset n_bad
                    self.global_state['step_size'] = step_size * self.defaults["nplus"]
                    self.global_state['n_bad'] = 0
                    self.global_state["decay"] = self.defaults["decay"]

            self.global_state['f_prev'] = float(f_0)

        # ------------------------------ determine index ----------------------------- #
        idx, igrad = _icd_get_idx(self, params)

        # -------------------------- find descent direction -------------------------- #
        h_vec = self.get_state(params, 'h_vec', init=lambda x: torch.full_like(x, 1e-3), cls=TensorList)
        h = float(h_vec.flat_get(idx))

        params.flat_set_lambda_(idx, lambda x: x + h)
        f_p = closure(False)

        params.flat_set_lambda_(idx, lambda x: x - 2*h)
        f_n = closure(False)
        params.flat_set_lambda_(idx, lambda x: x + h)

        # ---------------------------------- adapt h --------------------------------- #
        if f_0 <= f_p and f_0 <= f_n:
            h_vec.flat_set_lambda_(idx, lambda x: max(x/2, 1e-10))
        else:
            if abs(f_0 - f_n) < 1e-12 or abs((f_p - f_0) / (f_0 - f_n) - 1) < 1e-2:
                h_vec.flat_set_lambda_(idx, lambda x: min(x*2, 1e10))

        # ------------------------------- update igrad ------------------------------- #
        if f_0 < f_p and f_0 < f_n: alpha = 0
        else: alpha = (f_p - f_n) / (2*h)

        igrad.flat_set_(idx, alpha)

        # ----------------------------- create the update ---------------------------- #
        var.update = igrad * step_size
        p_prev.copy_(params)
        return var

CCDLS

Bases: torchzero.core.module.Module

CCD with line search instead of adaptive step size.

Parameters:

  • pmin (float, default: 0.1 ) –

    multiplier to probability of picking the lowest magnitude gradient. Defaults to 0.1.

  • pmax (float, default: 1.0 ) –

    multiplier to probability of picking the largest magnitude gradient. Defaults to 1.0.

  • pow (int, default: 2 ) –

    power transform to probabilities. Defaults to 2.

  • decay (float, default: 0.8 ) –

    accumulated gradient decay on failed step. Defaults to 0.5.

  • decay2 (float, default: 0.2 ) –

    decay multiplier decay on failed step. Defaults to 0.25.

  • maxiter (int, default: 10 ) –

    max number of line search iterations.

Source code in torchzero/modules/zeroth_order/cd.py
class CCDLS(Module):
    """CCD with line search instead of adaptive step size.

    Args:
        pmin (float, optional): multiplier to probability of picking the lowest magnitude gradient. Defaults to 0.1.
        pmax (float, optional): multiplier to probability of picking the largest magnitude gradient. Defaults to 1.0.
        pow (int, optional): power transform to probabilities. Defaults to 2.
        decay (float, optional): accumulated gradient decay on failed step. Defaults to 0.5.
        decay2 (float, optional): decay multiplier decay on failed step. Defaults to 0.25.
        maxiter (int, optional): max number of line search iterations.
    """
    def __init__(self, pmin=0.1, pmax=1.0, pow=2, decay=0.8, decay2=0.2, maxiter=10, ):
        defaults = dict(pmin=pmin, pmax=pmax, pow=pow, maxiter=maxiter, decay=decay, decay2=decay2)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        closure = var.closure
        if closure is None:
            raise RuntimeError("CD requires closure")

        params = TensorList(var.params)
        finfo = torch.finfo(params[0].dtype)
        f_0 = var.get_loss(False)

        # ------------------------------ determine index ----------------------------- #
        idx, igrad = _icd_get_idx(self, params)

        # -------------------------- find descent direction -------------------------- #
        h_vec = self.get_state(params, 'h_vec', init=lambda x: torch.full_like(x, 1e-3), cls=TensorList)
        h = float(h_vec.flat_get(idx))

        params.flat_set_lambda_(idx, lambda x: x + h)
        f_p = closure(False)

        params.flat_set_lambda_(idx, lambda x: x - 2*h)
        f_n = closure(False)
        params.flat_set_lambda_(idx, lambda x: x + h)

        # ---------------------------------- adapt h --------------------------------- #
        if f_0 <= f_p and f_0 <= f_n:
            h_vec.flat_set_lambda_(idx, lambda x: max(x/2, finfo.tiny * 2))
        else:
            # here eps, not tiny
            if abs(f_0 - f_n) < finfo.eps or abs((f_p - f_0) / (f_0 - f_n) - 1) < 1e-2:
                h_vec.flat_set_lambda_(idx, lambda x: min(x*2, finfo.max / 2))

        # ------------------------------- update igrad ------------------------------- #
        if f_0 < f_p and f_0 < f_n: alpha = 0
        else: alpha = (f_p - f_n) / (2*h)

        igrad.flat_set_(idx, alpha)

        # -------------------------------- line search ------------------------------- #
        x0 = params.clone()
        def f(a):
            params.sub_(igrad, alpha=a)
            loss = closure(False)
            params.copy_(x0)
            return loss

        a_prev = self.global_state.get('a_prev', 1)
        a, f_a, niter = adaptive_tracking(f, a_prev, maxiter=self.defaults['maxiter'], f_0=f_0)
        if (a is None) or (not math.isfinite(a)) or (not math.isfinite(f_a)):
            a = 0

        # -------------------------------- set a_prev -------------------------------- #
        decay2 = self.defaults["decay2"]
        decay = self.global_state.get("decay", self.defaults["decay"])

        if abs(a) > finfo.tiny * 2:
            assert f_a < f_0
            self.global_state['a_prev'] = max(min(a, finfo.max / 2), finfo.tiny * 2)
            self.global_state["decay"] = self.defaults["decay"]

        # ---------------------------- soft reset on fail ---------------------------- #
        else:
            igrad *= decay
            self.global_state["decay"] = decay*decay2
            self.global_state['a_prev'] = a_prev / 2

        # -------------------------------- set update -------------------------------- #
        var.update = igrad * a
        return var

CD

Bases: torchzero.core.module.Module

Coordinate descent. Proposes a descent direction along a single coordinate. You can then put a line search such as tz.m.ScipyMinimizeScalar, or a fixed step size.

Parameters:

  • h (float, default: 0.001 ) –

    finite difference step size. Defaults to 1e-3.

  • grad (bool, default: True ) –

    if True, scales direction by gradient estimate. If False, the scale is fixed to 1. Defaults to True.

  • adaptive (bool, default: True ) –

    whether to adapt finite difference step size, this requires an additional buffer. Defaults to True.

  • index (str, default: 'cyclic2' ) –

    index selection strategy. - "cyclic" - repeatedly cycles through each coordinate, e.g. 1,2,3,1,2,3,.... - "cyclic2" - cycles forward and then backward, e.g 1,2,3,3,2,1,1,2,3,... (default). - "random" - picks coordinate randomly.

  • threepoint (bool, default: True ) –

    whether to use three points (three function evaluatins) to determine descent direction. if False, uses two points, but then adaptive can't be used. Defaults to True.

Source code in torchzero/modules/zeroth_order/cd.py
class CD(Module):
    """Coordinate descent. Proposes a descent direction along a single coordinate.
    You can then put a line search such as ``tz.m.ScipyMinimizeScalar``, or a fixed step size.

    Args:
        h (float, optional): finite difference step size. Defaults to 1e-3.
        grad (bool, optional):
            if True, scales direction by gradient estimate. If False, the scale is fixed to 1. Defaults to True.
        adaptive (bool, optional):
            whether to adapt finite difference step size, this requires an additional buffer. Defaults to True.
        index (str, optional):
            index selection strategy.
            - "cyclic" - repeatedly cycles through each coordinate, e.g. ``1,2,3,1,2,3,...``.
            - "cyclic2" - cycles forward and then backward, e.g ``1,2,3,3,2,1,1,2,3,...`` (default).
            - "random" - picks coordinate randomly.
        threepoint (bool, optional):
            whether to use three points (three function evaluatins) to determine descent direction.
            if False, uses two points, but then ``adaptive`` can't be used. Defaults to True.
    """
    def __init__(self, h:float=1e-3, grad:bool=True, adaptive:bool=True, index:Literal['cyclic', 'cyclic2', 'random']="cyclic2", threepoint:bool=True,):
        defaults = dict(h=h, grad=grad, adaptive=adaptive, index=index, threepoint=threepoint)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        closure = var.closure
        if closure is None:
            raise RuntimeError("CD requires closure")

        params = TensorList(var.params)
        ndim = params.global_numel()

        grad_step_size = self.defaults['grad']
        adaptive = self.defaults['adaptive']
        index_strategy = self.defaults['index']
        h = self.defaults['h']
        threepoint = self.defaults['threepoint']

        # ------------------------------ determine index ----------------------------- #
        if index_strategy == 'cyclic':
            idx = self.global_state.get('idx', 0) % ndim
            self.global_state['idx'] = idx + 1

        elif index_strategy == 'cyclic2':
            idx = self.global_state.get('idx', 0)
            self.global_state['idx'] = idx + 1
            if idx >= ndim * 2:
                idx = self.global_state['idx'] = 0
            if idx >= ndim:
                idx  = (2*ndim - idx) - 1

        elif index_strategy == 'random':
            if 'generator' not in self.global_state:
                self.global_state['generator'] = random.Random(0)
            generator = self.global_state['generator']
            idx = generator.randrange(0, ndim)

        else:
            raise ValueError(index_strategy)

        # -------------------------- find descent direction -------------------------- #
        h_vec = None
        if adaptive:
            if threepoint:
                h_vec = self.get_state(params, 'h_vec', init=lambda x: torch.full_like(x, h), cls=TensorList)
                h = float(h_vec.flat_get(idx))
            else:
                warnings.warn("CD adaptive=True only works with threepoint=True")

        f_0 = var.get_loss(False)
        params.flat_set_lambda_(idx, lambda x: x + h)
        f_p = closure(False)

        # -------------------------------- threepoint -------------------------------- #
        if threepoint:
            params.flat_set_lambda_(idx, lambda x: x - 2*h)
            f_n = closure(False)
            params.flat_set_lambda_(idx, lambda x: x + h)

            if adaptive:
                assert h_vec is not None
                if f_0 <= f_p and f_0 <= f_n:
                    h_vec.flat_set_lambda_(idx, lambda x: max(x/2, 1e-10))
                else:
                    if abs(f_0 - f_n) < 1e-12 or abs((f_p - f_0) / (f_0 - f_n) - 1) < 1e-2:
                        h_vec.flat_set_lambda_(idx, lambda x: min(x*2, 1e10))

            if grad_step_size:
                alpha = (f_p - f_n) / (2*h)

            else:
                if f_0 < f_p and f_0 < f_n: alpha = 0
                elif f_p < f_n: alpha = -1
                else: alpha = 1

        # --------------------------------- twopoint --------------------------------- #
        else:
            params.flat_set_lambda_(idx, lambda x: x - h)
            if grad_step_size:
                alpha = (f_p - f_0) / h
            else:
                if f_p < f_0: alpha = -1
                else: alpha = 1

        # ----------------------------- create the update ---------------------------- #
        update = params.zeros_like()
        update.flat_set_(idx, alpha)
        var.update = update
        return var

Cautious

Bases: torchzero.core.transform.Transform

Negates update for parameters where update and gradient sign is inconsistent. Optionally normalizes the update by the number of parameters that are not masked. This is meant to be used after any momentum-based modules.

Parameters:

  • normalize (bool, default: False ) –

    renormalize update after masking. only has effect when mode is 'zero'. Defaults to False.

  • eps (float, default: 1e-06 ) –

    epsilon for normalization. Defaults to 1e-6.

  • mode (str, default: 'zero' ) –

    what to do with updates with inconsistent signs. - "zero" - set them to zero (as in paper) - "grad" - set them to the gradient (same as using update magnitude and gradient sign) - "backtrack" - negate them

Examples:

Cautious Adam

opt = tz.Modular(
    bench.parameters(),
    tz.m.Adam(),
    tz.m.Cautious(),
    tz.m.LR(1e-2)
)
References

Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu

Source code in torchzero/modules/momentum/cautious.py
class Cautious(Transform):
    """Negates update for parameters where update and gradient sign is inconsistent.
    Optionally normalizes the update by the number of parameters that are not masked.
    This is meant to be used after any momentum-based modules.

    Args:
        normalize (bool, optional):
            renormalize update after masking.
            only has effect when mode is 'zero'. Defaults to False.
        eps (float, optional): epsilon for normalization. Defaults to 1e-6.
        mode (str, optional):
            what to do with updates with inconsistent signs.
            - "zero" - set them to zero (as in paper)
            - "grad" - set them to the gradient (same as using update magnitude and gradient sign)
            - "backtrack" - negate them

    ## Examples:

    Cautious Adam

    ```python
    opt = tz.Modular(
        bench.parameters(),
        tz.m.Adam(),
        tz.m.Cautious(),
        tz.m.LR(1e-2)
    )
    ```

    References:
        Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu
    """

    def __init__(
        self,
        normalize=False,
        eps=1e-6,
        mode: Literal["zero", "grad", "backtrack"] = "zero",
    ):
        defaults = dict(normalize=normalize, eps=eps, mode=mode)
        super().__init__(defaults, uses_grad=True)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(settings[0])
        return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)

CenteredEMASquared

Bases: torchzero.core.transform.Transform

Maintains a centered exponential moving average of squared updates. This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.

Parameters:

  • beta (float, default: 0.99 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Source code in torchzero/modules/ops/higher_level.py
class CenteredEMASquared(Transform):
    """Maintains a centered exponential moving average of squared updates. This also maintains an additional
    exponential moving average of un-squared updates, square of which is subtracted from the EMA.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2):
        defaults = dict(beta=beta, amsgrad=amsgrad, pow=pow)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        amsgrad, pow = itemgetter('amsgrad', 'pow')(settings[0])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return centered_ema_sq_(
            TensorList(tensors),
            exp_avg_=exp_avg,
            exp_avg_sq_=exp_avg_sq,
            beta=beta,
            max_exp_avg_sq_=max_exp_avg_sq,
            pow=pow,
        ).clone()

CenteredSqrtEMASquared

Bases: torchzero.core.transform.Transform

Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root. This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.

Parameters:

  • beta (float, default: 0.99 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • debiased (bool, default: False ) –

    whether to multiply the output by a debiasing term from the Adam method. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Source code in torchzero/modules/ops/higher_level.py
class CenteredSqrtEMASquared(Transform):
    """Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.
    This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2):
        defaults = dict(beta=beta, amsgrad=amsgrad, debiased=debiased, pow=pow)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return sqrt_centered_ema_sq_(
            TensorList(tensors),
            exp_avg_=exp_avg,
            exp_avg_sq_=exp_avg_sq,
            beta=beta,
            debiased=debiased,
            step=step,
            max_exp_avg_sq_=max_exp_avg_sq,
            pow=pow,
        )

Centralize

Bases: torchzero.core.transform.Transform

Centralizes the update.

Parameters:

  • dim (int | Sequence[int] | str | None, default: None ) –

    calculates norm along those dimensions. If list/tuple, tensors are centralized along all dimensios in dim that they have. Can be set to "global" to centralize by global mean of all gradients concatenated to a vector. Defaults to None.

  • inverse_dims (bool, default: False ) –

    if True, the dims argument is inverted, and all other dimensions are centralized.

  • min_size (int, default: 2 ) –

    minimal size of a dimension to normalize along it. Defaults to 1.

Examples:

Standard gradient centralization:

opt = tz.Modular(
    model.parameters(),
    tz.m.Centralize(dim=0),
    tz.m.LR(1e-2),
)

References: - Yong, H., Huang, J., Hua, X., & Zhang, L. (2020). Gradient centralization: A new optimization technique for deep neural networks. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing. https://arxiv.org/abs/2004.01461

Source code in torchzero/modules/clipping/clipping.py
class Centralize(Transform):
    """Centralizes the update.

    Args:
        dim (int | Sequence[int] | str | None, optional):
            calculates norm along those dimensions.
            If list/tuple, tensors are centralized along all dimensios in `dim` that they have.
            Can be set to "global" to centralize by global mean of all gradients concatenated to a vector.
            Defaults to None.
        inverse_dims (bool, optional):
            if True, the `dims` argument is inverted, and all other dimensions are centralized.
        min_size (int, optional):
            minimal size of a dimension to normalize along it. Defaults to 1.

    Examples:

    Standard gradient centralization:
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.Centralize(dim=0),
        tz.m.LR(1e-2),
    )
    ```

    References:
    - Yong, H., Huang, J., Hua, X., & Zhang, L. (2020). Gradient centralization: A new optimization technique for deep neural networks. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing. https://arxiv.org/abs/2004.01461
    """
    def __init__(
        self,
        dim: int | Sequence[int] | Literal["global"] | None = None,
        inverse_dims: bool = False,
        min_size: int = 2,
        target: Target = "update",
    ):
        defaults = dict(dim=dim,min_size=min_size,inverse_dims=inverse_dims)
        super().__init__(defaults, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(settings[0])

        _centralize_(tensors_ = TensorList(tensors), dim=dim, inverse_dims=inverse_dims, min_size=min_size)

        return tensors

Clip

Bases: torchzero.modules.ops.binary.BinaryOperationBase

clip tensors to be in :code:(min, max) range. :code:min and :code:`max: can be None, numbers or modules.

If code:min and :code:max: are modules, this calculates :code:tensors.clip(min(tensors), max(tensors)).

Source code in torchzero/modules/ops/binary.py
class Clip(BinaryOperationBase):
    """clip tensors to be in  :code:`(min, max)` range. :code:`min` and :code:`max: can be None, numbers or modules.

    If code:`min` and :code:`max`:  are modules, this calculates :code:`tensors.clip(min(tensors), max(tensors))`.
    """
    def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
        super().__init__({}, min=min, max=max)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
        return TensorList(update).clamp_(min=min,  max=max)

ClipModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates :code:input(tensors).clip(min, max). :code:min and :code:max can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class ClipModules(MultiOperationBase):
    """Calculates :code:`input(tensors).clip(min, max)`. :code:`min` and :code:`max` can be numbers or modules."""
    def __init__(self, input: Chainable, min: float | Chainable | None = None, max: float | Chainable | None = None):
        defaults = {}
        super().__init__(defaults, input=input, min=min, max=max)

    @torch.no_grad
    def transform(self, var: Var, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
        return TensorList(input).clamp_(min=min, max=max)

ClipNorm

Bases: torchzero.core.transform.Transform

Clips update norm to be no larger than value.

Parameters:

  • max_norm (float) –

    value to clip norm to.

  • ord (float, default: 2 ) –

    norm order. Defaults to 2.

  • dim (int | Sequence[int] | str | None, default: None ) –

    calculates norm along those dimensions. If list/tuple, tensors are normalized along all dimensios in dim that they have. Can be set to "global" to normalize by global norm of all gradients concatenated to a vector. Defaults to None.

  • inverse_dims (bool, default: False ) –

    if True, the dims argument is inverted, and all other dimensions are normalized.

  • min_size (int, default: 1 ) –

    minimal numer of elements in a parameter or slice to clip norm. Defaults to 1.

  • target (str, default: 'update' ) –

    what this affects.

Examples:

Gradient norm clipping:

opt = tz.Modular(
    model.parameters(),
    tz.m.ClipNorm(1),
    tz.m.Adam(),
    tz.m.LR(1e-2),
)

Update norm clipping:

opt = tz.Modular(
    model.parameters(),
    tz.m.Adam(),
    tz.m.ClipNorm(1),
    tz.m.LR(1e-2),
)

Source code in torchzero/modules/clipping/clipping.py
class ClipNorm(Transform):
    """Clips update norm to be no larger than `value`.

    Args:
        max_norm (float): value to clip norm to.
        ord (float, optional): norm order. Defaults to 2.
        dim (int | Sequence[int] | str | None, optional):
            calculates norm along those dimensions.
            If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
            Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
            Defaults to None.
        inverse_dims (bool, optional):
            if True, the `dims` argument is inverted, and all other dimensions are normalized.
        min_size (int, optional):
            minimal numer of elements in a parameter or slice to clip norm. Defaults to 1.
        target (str, optional):
            what this affects.

    Examples:

    Gradient norm clipping:
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.ClipNorm(1),
        tz.m.Adam(),
        tz.m.LR(1e-2),
    )
    ```

    Update norm clipping:
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.Adam(),
        tz.m.ClipNorm(1),
        tz.m.LR(1e-2),
    )
    ```
    """
    def __init__(
        self,
        max_norm: float,
        ord: Metrics = 2,
        dim: int | Sequence[int] | Literal["global"] | None = None,
        inverse_dims: bool = False,
        min_size: int = 1,
        target: Target = "update",
    ):
        defaults = dict(max_norm=max_norm,ord=ord,dim=dim,min_size=min_size,inverse_dims=inverse_dims)
        super().__init__(defaults, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        max_norm = NumberList(s['max_norm'] for s in settings)
        ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
        _clip_norm_(
            tensors_ = TensorList(tensors),
            min = 0,
            max = max_norm,
            norm_value = None,
            ord = ord,
            dim = dim,
            inverse_dims=inverse_dims,
            min_size = min_size,
        )
        return tensors

ClipNormByEMA

Bases: torchzero.core.transform.Transform

Clips norm to be no larger than the norm of an exponential moving average of past updates.

Parameters:

  • beta (float, default: 0.99 ) –

    beta for the exponential moving average. Defaults to 0.99.

  • ord (float, default: 2 ) –

    order of the norm. Defaults to 2.

  • eps (float, default: 1e-06 ) –

    epsilon for division. Defaults to 1e-6.

  • tensorwise (bool, default: True ) –

    if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.

  • max_ema_growth (float | None, default: 1.5 ) –

    if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.

  • ema_init (str, default: 'zeros' ) –

    How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.

Source code in torchzero/modules/clipping/ema_clipping.py
class ClipNormByEMA(Transform):
    """Clips norm to be no larger than the norm of an exponential moving average of past updates.

    Args:
        beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
        ord (float, optional): order of the norm. Defaults to 2.
        eps (float, optional): epsilon for division. Defaults to 1e-6.
        tensorwise (bool, optional):
            if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
        max_ema_growth (float | None, optional):
            if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
        ema_init (str, optional):
            How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
    """
    NORMALIZE = False
    def __init__(
        self,
        beta=0.99,
        ord: Metrics = 2,
        eps=1e-6,
        tensorwise:bool=True,
        max_ema_growth: float | None = 1.5,
        ema_init: Literal['zeros', 'update'] = 'zeros',
        inner: Chainable | None = None,
    ):
        defaults = dict(beta=beta, ord=ord, tensorwise=tensorwise, ema_init=ema_init, eps=eps, max_ema_growth=max_ema_growth)
        super().__init__(defaults, inner=inner)

    @torch.no_grad
    def update_tensors(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(settings[0])

        beta, eps = unpack_dicts(settings, 'beta', 'eps', cls=NumberList)

        ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else tensors), cls=TensorList)

        ema.lerp_(tensors, 1-beta)

        if tensorwise:
            ema_norm = ema.metric(ord)

            # clip ema norm growth
            if max_ema_growth is not None:
                prev_ema_norm = unpack_states(states, tensors, 'prev_ema_norm', init=ema_norm, cls=TensorList)
                allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=1e-6)
                ema_denom = (ema_norm / allowed_norm).clip(min=1)
                ema.div_(ema_denom)
                ema_norm.div_(ema_denom)
                prev_ema_norm.set_(ema_norm)

            tensors_norm = tensors.norm(ord)
            denom = tensors_norm / ema_norm.clip(min=eps)
            if self.NORMALIZE: denom.clip_(min=eps)
            else: denom.clip_(min=1)

        else:
            ema_norm = ema.global_metric(ord)

            # clip ema norm growth
            if max_ema_growth is not None:
                prev_ema_norm = self.global_state.setdefault('prev_ema_norm', ema_norm)
                allowed_norm = prev_ema_norm * max_ema_growth
                if ema_norm > allowed_norm:
                    ema.div_(ema_norm / allowed_norm)
                    ema_norm = allowed_norm
                prev_ema_norm.set_(ema_norm)

            tensors_norm = tensors.global_metric(ord)
            denom = tensors_norm / ema_norm.clip(min=eps[0])
            if self.NORMALIZE: denom.clip_(min=eps[0])
            else: denom.clip_(min=1)

        self.global_state['denom'] = denom

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        denom = self.global_state.pop('denom')
        torch._foreach_div_(tensors, denom)
        return tensors

NORMALIZE class-attribute

NORMALIZE = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

ClipNormGrowth

Bases: torchzero.core.transform.Transform

Clips update norm growth.

Parameters:

  • add (float | None, default: None ) –

    additive clipping, next update norm is at most previous norm + add. Defaults to None.

  • mul (float | None, default: 1.5 ) –

    multiplicative clipping, next update norm is at most previous norm * mul. Defaults to 1.5.

  • min_value (float | None, default: 0.0001 ) –

    minimum value for multiplicative clipping to prevent collapse to 0. Next norm is at most :code:max(prev_norm, min_value) * mul. Defaults to 1e-4.

  • max_decay (float | None, default: 2 ) –

    bounds the tracked multiplicative clipping decay to prevent collapse to 0. Next norm is at most :code:max(previous norm * mul, max_decay). Defaults to 2.

  • ord (float, default: 2 ) –

    norm order. Defaults to 2.

  • parameterwise (bool, default: True ) –

    if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.

  • target (Literal, default: 'update' ) –

    what to set on var. Defaults to "update".

Source code in torchzero/modules/clipping/growth_clipping.py
class ClipNormGrowth(Transform):
    """Clips update norm growth.

    Args:
        add (float | None, optional): additive clipping, next update norm is at most `previous norm + add`. Defaults to None.
        mul (float | None, optional):
            multiplicative clipping, next update norm is at most `previous norm * mul`. Defaults to 1.5.
        min_value (float | None, optional):
            minimum value for multiplicative clipping to prevent collapse to 0.
            Next norm is at most :code:`max(prev_norm, min_value) * mul`. Defaults to 1e-4.
        max_decay (float | None, optional):
            bounds the tracked multiplicative clipping decay to prevent collapse to 0.
            Next norm is at most :code:`max(previous norm * mul, max_decay)`.
            Defaults to 2.
        ord (float, optional): norm order. Defaults to 2.
        parameterwise (bool, optional):
            if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
        target (Target, optional): what to set on var. Defaults to "update".
    """
    def __init__(
        self,
        add: float | None = None,
        mul: float | None = 1.5,
        min_value: float | None = 1e-4,
        max_decay: float | None = 2,
        ord: float = 2,
        parameterwise=True,
        target: Target = "update",
    ):
        defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay, ord=ord, parameterwise=parameterwise)
        super().__init__(defaults, target=target)



    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        parameterwise = settings[0]['parameterwise']
        tensors = TensorList(tensors)

        if parameterwise:
            ts = tensors
            stts = states
            stns = settings

        else:
            ts = [tensors.to_vec()]
            stts = [self.global_state]
            stns = [settings[0]]


        for t, state, setting in zip(ts, stts, stns):
            if 'prev_norm' not in state:
                state['prev_norm'] = torch.linalg.vector_norm(t, ord=setting['ord']) # pylint:disable=not-callable
                state['prev_denom'] = 1
                continue

            _,  state['prev_norm'], state['prev_denom'] = norm_growth_clip_(
                tensor_ = t,
                prev_norm = state['prev_norm'],
                add = setting['add'],
                mul = setting['mul'],
                min_value = setting['min_value'],
                max_decay = setting['max_decay'],
                ord = setting['ord'],
            )

        if not parameterwise:
            tensors.from_vec_(ts[0])

        return tensors

ClipValue

Bases: torchzero.core.transform.Transform

Clips update magnitude to be within (-value, value) range.

Parameters:

  • value (float) –

    value to clip to.

  • target (str, default: 'update' ) –

    refer to target argument in documentation.

Examples:

Gradient clipping:

opt = tz.Modular(
    model.parameters(),
    tz.m.ClipValue(1),
    tz.m.Adam(),
    tz.m.LR(1e-2),
)

Update clipping:

opt = tz.Modular(
    model.parameters(),
    tz.m.Adam(),
    tz.m.ClipValue(1),
    tz.m.LR(1e-2),
)

Source code in torchzero/modules/clipping/clipping.py
class ClipValue(Transform):
    """Clips update magnitude to be within ``(-value, value)`` range.

    Args:
        value (float): value to clip to.
        target (str): refer to ``target argument`` in documentation.

    Examples:

    Gradient clipping:
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.ClipValue(1),
        tz.m.Adam(),
        tz.m.LR(1e-2),
    )
    ```

    Update clipping:
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.Adam(),
        tz.m.ClipValue(1),
        tz.m.LR(1e-2),
    )
    ```

    """
    def __init__(self, value: float, target: Target = 'update'):
        defaults = dict(value=value)
        super().__init__(defaults, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        value = [s['value'] for s in settings]
        return TensorList(tensors).clip_([-v for v in value], value)

ClipValueByEMA

Bases: torchzero.core.transform.Transform

Clips magnitude of update to be no larger than magnitude of exponential moving average of past (unclipped) updates.

Parameters:

  • beta (float, default: 0.99 ) –

    beta for the exponential moving average. Defaults to 0.99.

  • ema_init (str, default: 'zeros' ) –

    How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.

  • ema_tfm (Chainable | None, default: None ) –

    optional modules applied to exponential moving average before clipping by it. Defaults to None.

Source code in torchzero/modules/clipping/ema_clipping.py
class ClipValueByEMA(Transform):
    """Clips magnitude of update to be no larger than magnitude of exponential moving average of past (unclipped) updates.

    Args:
        beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
        ema_init (str, optional):
            How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
        ema_tfm (Chainable | None, optional):
            optional modules applied to exponential moving average before clipping by it. Defaults to None.
    """
    def __init__(
        self,
        beta=0.99,
        ema_init: Literal['zeros', 'update'] = 'zeros',
        ema_tfm:Chainable | None=None,
        inner: Chainable | None = None,
    ):
        defaults = dict(beta=beta, ema_init=ema_init)
        super().__init__(defaults, inner=inner)

        if ema_tfm is not None:
            self.set_child('ema_tfm', ema_tfm)

    @torch.no_grad
    def update_tensors(self, tensors, params, grads, loss, states, settings):
        ema_init = itemgetter('ema_init')(settings[0])

        beta = unpack_dicts(settings, 'beta', cls=NumberList)
        tensors = TensorList(tensors)

        ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else lambda t: t.abs()), cls=TensorList)
        ema.lerp_(tensors.abs(), 1-beta)

    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        ema = unpack_states(states, tensors, 'ema', cls=TensorList)

        if 'ema_tfm' in self.children:
            ema = TensorList(apply_transform(self.children['ema_tfm'], ema.clone(), params, grads, loss))

        tensors.clip_(-ema, ema)
        return tensors

ClipValueGrowth

Bases: torchzero.core.transform.TensorwiseTransform

Clips update value magnitude growth.

Parameters:

  • add (float | None, default: None ) –

    additive clipping, next update is at most previous update + add. Defaults to None.

  • mul (float | None, default: 1.5 ) –

    multiplicative clipping, next update is at most previous update * mul. Defaults to 1.5.

  • min_value (float | None, default: 0.0001 ) –

    minimum value for multiplicative clipping to prevent collapse to 0. Next update is at most :code:max(prev_update, min_value) * mul. Defaults to 1e-4.

  • max_decay (float | None, default: 2 ) –

    bounds the tracked multiplicative clipping decay to prevent collapse to 0. Next update is at most :code:max(previous update * mul, max_decay). Defaults to 2.

  • target (Literal, default: 'update' ) –

    what to set on var. Defaults to "update".

Source code in torchzero/modules/clipping/growth_clipping.py
class ClipValueGrowth(TensorwiseTransform):
    """Clips update value magnitude growth.

    Args:
        add (float | None, optional): additive clipping, next update is at most `previous update + add`. Defaults to None.
        mul (float | None, optional): multiplicative clipping, next update is at most `previous update * mul`. Defaults to 1.5.
        min_value (float | None, optional):
            minimum value for multiplicative clipping to prevent collapse to 0.
            Next update is at most :code:`max(prev_update, min_value) * mul`. Defaults to 1e-4.
        max_decay (float | None, optional):
            bounds the tracked multiplicative clipping decay to prevent collapse to 0.
            Next update is at most :code:`max(previous update * mul, max_decay)`.
            Defaults to 2.
        target (Target, optional): what to set on var. Defaults to "update".
    """
    def __init__(
        self,
        add: float | None = None,
        mul: float | None = 1.5,
        min_value: float | None = 1e-4,
        max_decay: float | None = 2,
        target: Target = "update",
    ):
        defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay)
        super().__init__(defaults, target=target)


    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(setting)
        add: float | None

        if add is None and mul is None:
            return tensor

        if 'prev' not in state:
            state['prev'] = tensor.clone()
            return tensor

        prev: torch.Tensor = state['prev']

        # additive bound
        if add is not None:
            growth = (tensor.abs() - prev.abs()).clip(min=0)
            tensor.sub_(torch.where(growth > add, (growth-add).copysign_(tensor), 0))

        # multiplicative bound
        growth = None
        if mul is not None:
            prev_magn = prev.abs()
            if min_value is not None: prev_magn.clip_(min=min_value)
            growth = (tensor.abs() / prev_magn).clamp_(min=1e-8)

            denom = torch.where(growth > mul, growth/mul, 1)

            tensor.div_(denom)

        # limit max growth decay
        if max_decay is not None:
            if growth is None:
                prev_magn = prev.abs()
                if min_value is not None: prev_magn.clip_(min=min_value)
                growth = (tensor.abs() / prev_magn).clamp_(min=1e-8)

            new_prev = torch.where(growth < (1/max_decay), prev/max_decay, tensor)
        else:
            new_prev = tensor.clone()

        state['prev'] = new_prev
        return tensor

Clone

Bases: torchzero.core.module.Module

Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations

Source code in torchzero/modules/ops/utility.py
class Clone(Module):
    """Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def step(self, var):
        var.update = [u.clone() for u in var.get_update()]
        return var

ConjugateDescent

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Conjugate Descent (CD).

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class ConjugateDescent(ConguateGradientBase):
    """Conjugate Descent (CD).

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
    """
    def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def get_beta(self, p, g, prev_g, prev_d):
        return conjugate_descent_beta(g, prev_d, prev_g)

CopyMagnitude

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Returns :code:other(tensors) with sign copied from tensors.

Source code in torchzero/modules/ops/binary.py
class RCopySign(BinaryOperationBase):
    """Returns :code:`other(tensors)` with sign copied from tensors."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
        return [o.copysign_(u) for u, o in zip(update, other)]

CopySign

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Returns tensors with sign copied from :code:other(tensors).

Source code in torchzero/modules/ops/binary.py
class CopySign(BinaryOperationBase):
    """Returns tensors with sign copied from :code:`other(tensors)`."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
        return [u.copysign_(o) for u, o in zip(update, other)]

CubicRegularization

Bases: torchzero.modules.trust_region.trust_region.TrustRegionBase

Cubic regularization.

Parameters:

  • hess_module (Module | None) –

    A module that maintains a hessian approximation (not hessian inverse!). This includes all full-matrix quasi-newton methods, tz.m.Newton and tz.m.GaussNewton. When using quasi-newton methods, set inverse=False when constructing them.

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. When :code:hess_module is GaussNewton, this can be set to 0. Defaults to 0.15.

  • nplus (float, default: 3.5 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.99 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.0001 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • maxiter (float, default: 100 ) –

    maximum iterations when solving cubic subproblem, defaults to 1e-7.

  • eps (float, default: 1e-08 ) –

    epsilon for the solver, defaults to 1e-8.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • max_attempts (max_attempts, default: 10 ) –

    maximum number of trust region size size reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • fallback (bool) –

    if True, when hess_module maintains hessian inverse which can't be inverted efficiently, it will be inverted anyway. When False (default), a RuntimeError will be raised instead.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to output of thise module. Defaults to None.

Examples:

Cubic regularized newton

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.CubicRegularization(tz.m.Newton()),
)
Source code in torchzero/modules/trust_region/cubic_regularization.py
class CubicRegularization(TrustRegionBase):
    """Cubic regularization.

    Args:
        hess_module (Module | None, optional):
            A module that maintains a hessian approximation (not hessian inverse!).
            This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
            When using quasi-newton methods, set `inverse=False` when constructing them.
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted.
            When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        maxiter (float, optional): maximum iterations when solving cubic subproblem, defaults to 1e-7.
        eps (float, optional): epsilon for the solver, defaults to 1e-8.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        max_attempts (max_attempts, optional):
            maximum number of trust region size size reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        fallback (bool, optional):
            if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
            be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
        inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.


    Examples:
        Cubic regularized newton

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.CubicRegularization(tz.m.Newton()),
            )

    """
    def __init__(
        self,
        hess_module: Chainable,
        eta: float= 0.0,
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        init: float = 1,
        max_attempts: int = 10,
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
        maxiter: int = 100,
        eps: float = 1e-8,
        check_decrease:bool=False,
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict(maxiter=maxiter, eps=eps, check_decrease=check_decrease)
        super().__init__(
            defaults=defaults,
            hess_module=hess_module,
            eta=eta,
            nplus=nplus,
            nminus=nminus,
            rho_good=rho_good,
            rho_bad=rho_bad,
            init=init,
            max_attempts=max_attempts,
            radius_strategy=radius_strategy,
            update_freq=update_freq,
            inner=inner,

            boundary_tol=None,
            radius_fn=None,
        )

    def trust_solve(self, f, g, H, radius, params, closure, settings):
        params = TensorList(params)

        loss_at_params_plus_x_fn = None
        if settings['check_decrease']:
            def closure_plus_x(x):
                x_unflat = vec_to_tensors(x, params)
                params.add_(x_unflat)
                loss_x = closure(False)
                params.sub_(x_unflat)
                return loss_x
            loss_at_params_plus_x_fn = closure_plus_x


        d, _ = ls_cubic_solver(f=f, g=g, H=H, M=1/radius, loss_at_params_plus_x_fn=loss_at_params_plus_x_fn,
                               it_max=settings['maxiter'], epsilon=settings['eps'])
        return d.neg_()

CustomUnaryOperation

Bases: torchzero.core.transform.Transform

Applies :code:getattr(tensor, name) to each tensor

Source code in torchzero/modules/ops/unary.py
class CustomUnaryOperation(Transform):
    """Applies :code:`getattr(tensor, name)` to each tensor
    """
    def __init__(self, name: str, target: "Target" = 'update'):
        defaults = dict(name=name)
        super().__init__(defaults=defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        return getattr(tensors, settings[0]['name'])()

DFP

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Davidon–Fletcher–Powell Quasi-Newton method.

Note

a trust region or an accurate line search is recommended.

Warning

this uses at least O(N^2) memory.

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class DFP(_InverseHessianUpdateStrategyDefaults):
    """Davidon–Fletcher–Powell Quasi-Newton method.

    Note:
        a trust region or an accurate line search is recommended.

    Warning:
        this uses at least O(N^2) memory.
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return dfp_H_(H=H, s=s, y=y, tol=setting['tol'])
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return dfp_B(B=B, s=s, y=y, tol=setting['tol'])

DNRTR

Bases: torchzero.modules.quasi_newton.quasi_newton.HessianUpdateStrategy

Diagonal quasi-newton method.

Reference

Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.

Source code in torchzero/modules/quasi_newton/diagonal_quasi_newton.py
class DNRTR(HessianUpdateStrategy):
    """Diagonal quasi-newton method.

    Reference:
        Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
    """
    def __init__(
        self,
        lb: float = 1e-2,
        ub: float = 1e5,
        init_scale: float | Literal["auto"] = "auto",
        tol: float = 1e-32,
        ptol: float | None = 1e-32,
        ptol_restart: bool = False,
        gtol: float | None = 1e-32,
        restart_interval: int | None | Literal['auto'] = None,
        beta: float | None = None,
        update_freq: int = 1,
        scale_first: bool = False,
        concat_params: bool = True,
        inner: Chainable | None = None,
    ):
        defaults = dict(lb=lb, ub=ub)
        super().__init__(
            defaults=defaults,
            init_scale=init_scale,
            tol=tol,
            ptol=ptol,
            ptol_restart=ptol_restart,
            gtol=gtol,
            restart_interval=restart_interval,
            beta=beta,
            update_freq=update_freq,
            scale_first=scale_first,
            concat_params=concat_params,
            inverse=False,
            inner=inner,
        )

    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return diagonal_wqc_B_(B=B, s=s, y=y)

    def modify_B(self, B, state, setting):
        return _truncate(B, setting['lb'], setting['ub'])

    def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)

DYHS

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Dai-Yuan - Hestenes–Stiefel hybrid conjugate gradient method.

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class DYHS(ConguateGradientBase):
    """Dai-Yuan - Hestenes–Stiefel hybrid conjugate gradient method.

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
    """
    def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def get_beta(self, p, g, prev_g, prev_d):
        return dyhs_beta(g, prev_d, prev_g)

DaiYuan

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Dai–Yuan nonlinear conjugate gradient method.

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1) after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class DaiYuan(ConguateGradientBase):
    """Dai–Yuan nonlinear conjugate gradient method.

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1)`` after this.
    """
    def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def get_beta(self, p, g, prev_g, prev_d):
        return dai_yuan_beta(g, prev_d, prev_g)

Debias

Bases: torchzero.core.transform.Transform

Multiplies the update by an Adam debiasing term based first and/or second momentum.

Parameters:

  • beta1 (float | None, default: None ) –

    first momentum, should be the same as first momentum used in modules before. Defaults to None.

  • beta2 (float | None, default: None ) –

    second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.

  • alpha (float, default: 1 ) –

    learning rate. Defaults to 1.

  • pow (float, default: 2 ) –

    power, assumes absolute value is used. Defaults to 2.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/higher_level.py
class Debias(Transform):
    """Multiplies the update by an Adam debiasing term based first and/or second momentum.

    Args:
        beta1 (float | None, optional):
            first momentum, should be the same as first momentum used in modules before. Defaults to None.
        beta2 (float | None, optional):
            second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
        alpha (float, optional): learning rate. Defaults to 1.
        pow (float, optional): power, assumes absolute value is used. Defaults to 2.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, beta1: float | None = None, beta2: float | None = None, alpha: float = 1, pow:float=2, target: Target = 'update',):
        defaults = dict(beta1=beta1, beta2=beta2, alpha=alpha, pow=pow)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        pow = settings[0]['pow']
        alpha, beta1, beta2 = unpack_dicts(settings, 'alpha', 'beta1', 'beta2', cls=NumberList)

        return debias(TensorList(tensors), step=step, beta1=beta1, beta2=beta2, alpha=alpha, pow=pow, inplace=True)

Debias2

Bases: torchzero.core.transform.Transform

Multiplies the update by an Adam debiasing term based on the second momentum.

Parameters:

  • beta (float | None, default: 0.999 ) –

    second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.

  • pow (float, default: 2 ) –

    power, assumes absolute value is used. Defaults to 2.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/higher_level.py
class Debias2(Transform):
    """Multiplies the update by an Adam debiasing term based on the second momentum.

    Args:
        beta (float | None, optional):
            second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
        pow (float, optional): power, assumes absolute value is used. Defaults to 2.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, beta: float = 0.999, pow: float = 2, target: Target = 'update',):
        defaults = dict(beta=beta, pow=pow)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        pow = settings[0]['pow']
        beta = NumberList(s['beta'] for s in settings)
        return debias_second_momentum(TensorList(tensors), step=step, beta=beta, pow=pow, inplace=True)

DiagonalBFGS

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful.

Source code in torchzero/modules/quasi_newton/diagonal_quasi_newton.py
class DiagonalBFGS(_InverseHessianUpdateStrategyDefaults):
    """Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return diagonal_bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])

    def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)

DiagonalQuasiCauchi

Bases: torchzero.modules.quasi_newton.quasi_newton._HessianUpdateStrategyDefaults

Diagonal quasi-cauchi method.

Reference

Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.

Source code in torchzero/modules/quasi_newton/diagonal_quasi_newton.py
class DiagonalQuasiCauchi(_HessianUpdateStrategyDefaults):
    """Diagonal quasi-cauchi method.

    Reference:
        Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
    """
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return diagonal_qc_B_(B=B, s=s, y=y)

    def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)

DiagonalSR1

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful.

Source code in torchzero/modules/quasi_newton/diagonal_quasi_newton.py
class DiagonalSR1(_InverseHessianUpdateStrategyDefaults):
    """Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return diagonal_sr1_(H=H, s=s, y=y, tol=setting['tol'])
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return diagonal_sr1_(H=B, s=y, y=s, tol=setting['tol'])

    def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)

DiagonalWeightedQuasiCauchi

Bases: torchzero.modules.quasi_newton.quasi_newton._HessianUpdateStrategyDefaults

Diagonal quasi-cauchi method.

Reference

Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.

Source code in torchzero/modules/quasi_newton/diagonal_quasi_newton.py
class DiagonalWeightedQuasiCauchi(_HessianUpdateStrategyDefaults):
    """Diagonal quasi-cauchi method.

    Reference:
        Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
    """
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return diagonal_wqc_B_(B=B, s=s, y=y)

    def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)

DirectWeightDecay

Bases: torchzero.core.module.Module

Directly applies weight decay to parameters.

Parameters:

  • weight_decay (float) –

    weight decay scale.

  • ord (int, default: 2 ) –

    order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.

Source code in torchzero/modules/weight_decay/weight_decay.py
class DirectWeightDecay(Module):
    """Directly applies weight decay to parameters.

    Args:
        weight_decay (float): weight decay scale.
        ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
    """
    def __init__(self, weight_decay: float, ord: int = 2,):
        defaults = dict(weight_decay=weight_decay, ord=ord)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        weight_decay = self.get_settings(var.params, 'weight_decay', cls=NumberList)
        ord = self.defaults['ord']

        decay_weights_(var.params, weight_decay, ord)
        return var

Div

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Divide tensors by :code:other. :code:other can be a number or a module.

If :code:other is a module, this calculates :code:tensors / other(tensors)

Source code in torchzero/modules/ops/binary.py
class Div(BinaryOperationBase):
    """Divide tensors by :code:`other`. :code:`other` can be a number or a module.

    If :code:`other` is a module, this calculates :code:`tensors / other(tensors)`
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        torch._foreach_div_(update, other)
        return update

DivByLoss

Bases: torchzero.core.module.Module

Divides update by loss times :code:alpha

Source code in torchzero/modules/misc/misc.py
class DivByLoss(Module):
    """Divides update by loss times :code:`alpha`"""
    def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
        defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
        loss = var.get_loss(backward=self.defaults['backward'])
        mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
        torch._foreach_div_(var.update, mul)
        return var

DivModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates :code:input / other. :code:input and :code:other can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class DivModules(MultiOperationBase):
    """Calculates :code:`input / other`. :code:`input` and :code:`other` can be numbers or modules."""
    def __init__(self, input: Chainable | float, other: Chainable | float, other_first:bool=False):
        defaults = {}
        if other_first: super().__init__(defaults, other=other, input=input)
        else: super().__init__(defaults, input=input, other=other)

    @torch.no_grad
    def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
        if isinstance(input, (int,float)):
            assert isinstance(other, list)
            return input / TensorList(other)

        torch._foreach_div_(input, other)
        return input

Dogleg

Bases: torchzero.modules.trust_region.trust_region.TrustRegionBase

Dogleg trust region algorithm.

Parameters:

  • hess_module (Module | None) –

    A module that maintains a hessian approximation (not hessian inverse!). This includes all full-matrix quasi-newton methods, tz.m.Newton and tz.m.GaussNewton. When using quasi-newton methods, set inverse=False when constructing them.

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. When :code:hess_module is GaussNewton, this can be set to 0. Defaults to 0.15.

  • nplus (float, default: 2 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.75 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.25 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • max_attempts (max_attempts, default: 10 ) –

    maximum number of trust region size size reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to output of thise module. Defaults to None.

Source code in torchzero/modules/trust_region/dogleg.py
class Dogleg(TrustRegionBase):
    """Dogleg trust region algorithm.


    Args:
        hess_module (Module | None, optional):
            A module that maintains a hessian approximation (not hessian inverse!).
            This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
            When using quasi-newton methods, set `inverse=False` when constructing them.
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted.
            When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        max_attempts (max_attempts, optional):
            maximum number of trust region size size reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.

    """
    def __init__(
        self,
        hess_module: Chainable,
        eta: float= 0.0,
        nplus: float = 2,
        nminus: float = 0.25,
        rho_good: float = 0.75,
        rho_bad: float = 0.25,
        boundary_tol: float | None = None,
        init: float = 1,
        max_attempts: int = 10,
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict()
        super().__init__(
            defaults=defaults,
            hess_module=hess_module,
            eta=eta,
            nplus=nplus,
            nminus=nminus,
            rho_good=rho_good,
            rho_bad=rho_bad,
            boundary_tol=boundary_tol,
            init=init,
            max_attempts=max_attempts,
            radius_strategy=radius_strategy,
            update_freq=update_freq,
            inner=inner,

            radius_fn=torch.linalg.vector_norm,
        )

    def trust_solve(self, f, g, H, radius, params, closure, settings):
        if radius > 2: radius = self.global_state['radius'] = 2
        eps = torch.finfo(g.dtype).tiny * 2

        gHg = g.dot(H.matvec(g))
        if gHg <= eps:
            return (radius / torch.linalg.vector_norm(g)) * g # pylint:disable=not-callable

        p_cauchy = (g.dot(g) / gHg) * g
        p_newton = H.solve(g)

        a = p_newton - p_cauchy
        b = p_cauchy

        aa = a.dot(a)
        if aa < eps:
            return (radius / torch.linalg.vector_norm(g)) * g # pylint:disable=not-callable

        ab = a.dot(b)
        bb = b.dot(b)
        c = bb - radius**2
        discriminant = (2*ab)**2 - 4*aa*c
        beta = (-2*ab + torch.sqrt(discriminant.clip(min=0))) / (2 * aa)
        return p_cauchy + beta * (p_newton - p_cauchy)

Dropout

Bases: torchzero.core.transform.Transform

Applies dropout to the update.

For each weight the update to that weight has :code:p probability to be set to 0. This can be used to implement gradient dropout or update dropout depending on placement.

Parameters:

  • p (float, default: 0.5 ) –

    probability that update for a weight is replaced with 0. Defaults to 0.5.

  • graft (bool, default: False ) –

    if True, update after dropout is rescaled to have the same norm as before dropout. Defaults to False.

  • target (Literal, default: 'update' ) –

    what to set on var, refer to documentation. Defaults to 'update'.

Examples:

Gradient dropout.

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.Dropout(0.5),
    tz.m.Adam(),
    tz.m.LR(1e-3)
)

Update dropout.

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.Adam(),
    tz.m.Dropout(0.5),
    tz.m.LR(1e-3)
)
Source code in torchzero/modules/misc/regularization.py
class Dropout(Transform):
    """Applies dropout to the update.

    For each weight the update to that weight has :code:`p` probability to be set to 0.
    This can be used to implement gradient dropout or update dropout depending on placement.

    Args:
        p (float, optional): probability that update for a weight is replaced with 0. Defaults to 0.5.
        graft (bool, optional):
            if True, update after dropout is rescaled to have the same norm as before dropout. Defaults to False.
        target (Target, optional): what to set on var, refer to documentation. Defaults to 'update'.


    Examples:
        Gradient dropout.

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.Dropout(0.5),
                tz.m.Adam(),
                tz.m.LR(1e-3)
            )

        Update dropout.

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.Adam(),
                tz.m.Dropout(0.5),
                tz.m.LR(1e-3)
            )

    """
    def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
        defaults = dict(p=p, graft=graft)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        p = NumberList(s['p'] for s in settings)
        graft = settings[0]['graft']

        if graft:
            target_norm = tensors.global_vector_norm()
            tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
            return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft

        return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))

DualNormCorrection

Bases: torchzero.core.transform.TensorwiseTransform

Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon). Orthogonalize already has this built in with the dual_norm_correction setting.

Source code in torchzero/modules/adaptive/muon.py
class DualNormCorrection(TensorwiseTransform):
    """Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon).
    Orthogonalize already has this built in with the `dual_norm_correction` setting."""
    def __init__(self, target: Target='update'):
        super().__init__({}, uses_grad=True, target=target)

    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        assert grad is not None
        if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
            return _dual_norm_correction(tensor, grad, batch_first=False)
        return tensor

EMA

Bases: torchzero.core.transform.Transform

Maintains an exponential moving average of update.

Parameters:

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • dampening (float, default: 0 ) –

    momentum dampening. Defaults to 0.

  • debiased (bool, default: False ) –

    whether to debias the EMA like in Adam. Defaults to False.

  • lerp (bool, default: True ) –

    whether to use linear interpolation. Defaults to True.

  • ema_init (str, default: 'zeros' ) –

    initial values for the EMA, "zeros" or "update".

  • target (Literal, default: 'update' ) –

    target to apply EMA to. Defaults to 'update'.

Source code in torchzero/modules/momentum/momentum.py
class EMA(Transform):
    """Maintains an exponential moving average of update.

    Args:
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        dampening (float, optional): momentum dampening. Defaults to 0.
        debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
        lerp (bool, optional): whether to use linear interpolation. Defaults to True.
        ema_init (str, optional): initial values for the EMA, "zeros" or "update".
        target (Target, optional): target to apply EMA to. Defaults to 'update'.
    """
    def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
        defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])

        exp_avg = unpack_states(states, tensors, 'exp_avg',
                                init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
        momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)

        exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)

        if debiased: return debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
        else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned

EMASquared

Bases: torchzero.core.transform.Transform

Maintains an exponential moving average of squared updates.

Parameters:

  • beta (float, default: 0.999 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Methods:

  • EMA_SQ_FN

    Updates exp_avg_sq_ with EMA of squared tensors, if max_exp_avg_sq_ is not None, updates it with maximum of EMA.

Source code in torchzero/modules/ops/higher_level.py
class EMASquared(Transform):
    """Maintains an exponential moving average of squared updates.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    EMA_SQ_FN: staticmethod = staticmethod(ema_sq_)

    def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2):
        defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return self.EMA_SQ_FN(TensorList(tensors), exp_avg_sq_=exp_avg_sq, beta=beta, max_exp_avg_sq_=max_exp_avg_sq, pow=pow).clone()

EMA_SQ_FN

EMA_SQ_FN(tensors: TensorList, exp_avg_sq_: TensorList, beta: float | NumberList, max_exp_avg_sq_: TensorList | None, pow: float = 2)

Updates exp_avg_sq_ with EMA of squared tensors, if max_exp_avg_sq_ is not None, updates it with maximum of EMA.

Returns exp_avg_sq_ or max_exp_avg_sq_.

Source code in torchzero/modules/functional.py
def ema_sq_(
    tensors: TensorList,
    exp_avg_sq_: TensorList,
    beta: float | NumberList,
    max_exp_avg_sq_: TensorList | None,
    pow: float = 2,
):
    """
    Updates `exp_avg_sq_` with EMA of squared `tensors`, if `max_exp_avg_sq_` is not None, updates it with maximum of EMA.

    Returns `exp_avg_sq_` or `max_exp_avg_sq_`.
    """
    lerp_power_(tensors=tensors, exp_avg_pow_=exp_avg_sq_,beta=beta,pow=pow)

    # AMSGrad
    if max_exp_avg_sq_ is not None:
        max_exp_avg_sq_.maximum_(exp_avg_sq_)
        exp_avg_sq_ = max_exp_avg_sq_

    return exp_avg_sq_

ESGD

Bases: torchzero.core.module.Module

Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)

This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.

.. note:: In most cases Adagrad should be the first module in the chain because it relies on autograd. Use the :code:inner argument if you wish to apply Adagrad preconditioning to another module's output.

.. note:: If you are using gradient estimators or reformulations, set :code:hvp_method to "forward" or "central".

.. note:: This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

Parameters:

  • damping (float, default: 0.0001 ) –

    added to denominator for stability. Defaults to 1e-4.

  • update_freq (int, default: 20 ) –

    frequency of updating hessian diagonal estimate via a hessian-vector product. This value can be increased to reduce computational cost. Defaults to 20.

  • hvp_method (str, default: 'autograd' ) –

    Determines how Hessian-vector products are evaluated.

    • "autograd": Use PyTorch's autograd to calculate exact HVPs. This requires creating a graph for the gradient.
    • "forward": Use a forward finite difference formula to approximate the HVP. This requires one extra gradient evaluation.
    • "central": Use a central finite difference formula for a more accurate HVP approximation. This requires two extra gradient evaluations. Defaults to "autograd".
  • fd_h (float, default: 0.001 ) –

    finite difference step size if :code:hvp_method is "forward" or "central". Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of hessian-vector products with random vectors to evaluate each time when updating the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.

  • seed (int | None, default: None ) –

    seed for random vectors. Defaults to None.

  • inner (Chainable | None, default: None ) –

    Inner module. If this is specified, operations are performed in the following order. 1. compute hessian diagonal estimate. 2. pass inputs to :code:inner. 3. momentum and preconditioning are applied to the ouputs of :code:inner.

Examples:

Using ESGD:

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.ESGD(),
    tz.m.LR(0.1)
)

ESGD preconditioner can be applied to any other module by passing it to the :code:inner argument. Here is an example of applying ESGD preconditioning to nesterov momentum (:code:tz.m.NAG):

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
    tz.m.LR(0.1)
)
Source code in torchzero/modules/adaptive/esgd.py
class ESGD(Module):
    """Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)

    This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.

    .. note::
        In most cases Adagrad should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Adagrad preconditioning to another module's output.

    .. note::
        If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".

    .. note::
        This module requires a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients for calculating HVPs.
        The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        damping (float, optional): added to denominator for stability. Defaults to 1e-4.
        update_freq (int, optional):
            frequency of updating hessian diagonal estimate via a hessian-vector product.
            This value can be increased to reduce computational cost. Defaults to 20.
        hvp_method (str, optional):
            Determines how Hessian-vector products are evaluated.

            - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
              This requires creating a graph for the gradient.
            - ``"forward"``: Use a forward finite difference formula to
              approximate the HVP. This requires one extra gradient evaluation.
            - ``"central"``: Use a central finite difference formula for a
              more accurate HVP approximation. This requires two extra
              gradient evaluations.
            Defaults to "autograd".
        fd_h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
        n_samples (int, optional):
            number of hessian-vector products with random vectors to evaluate each time when updating
            the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
        seed (int | None, optional): seed for random vectors. Defaults to None.
        inner (Chainable | None, optional):
            Inner module. If this is specified, operations are performed in the following order.
            1. compute hessian diagonal estimate.
            2. pass inputs to :code:`inner`.
            3. momentum and preconditioning are applied to the ouputs of :code:`inner`.

    Examples:
        Using ESGD:

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.ESGD(),
                tz.m.LR(0.1)
            )

        ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
        ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
                tz.m.LR(0.1)
            )

    """
    def __init__(
        self,
        damping: float = 1e-4,
        update_freq: int = 20,
        hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
        fd_h: float = 1e-3,
        n_samples = 1,
        seed: int | None = None,
        inner: Chainable | None = None
    ):
        defaults = dict(damping=damping, update_freq=update_freq, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
        super().__init__(defaults)

        if inner is not None:
            self.set_child('inner', inner)

    @torch.no_grad
    def step(self, var):
        params = var.params
        settings = self.settings[params[0]]
        hvp_method = settings['hvp_method']
        fd_h = settings['fd_h']
        update_freq = settings['update_freq']
        n_samples = settings['n_samples']

        seed = settings['seed']
        generator = None
        if seed is not None:
            if 'generator' not in self.global_state:
                self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
            generator = self.global_state['generator']

        damping = self.get_settings(params, 'damping', cls=NumberList)
        D_sq_acc = self.get_state(params, 'D_sq_acc', cls=TensorList)
        i = self.global_state.get('i', 0)

        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        closure = var.closure
        assert closure is not None

        D = None
        if step % update_freq == 0:

            rgrad=None
            for j in range(n_samples):
                u = [torch.randn(p.size(), generator=generator, device=p.device, dtype=p.dtype) for p in params]

                Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
                                     h=fd_h, normalize=True, retain_grad=j < n_samples-1)

                if D is None: D = Hvp
                else: torch._foreach_add_(D, Hvp)

            assert D is not None
            if n_samples > 1: torch._foreach_div_(D, n_samples)

            D = TensorList(D)

        update = var.get_update()
        if 'inner' in self.children:
            update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)

        var.update, self.global_state['i'] = esgd_(
            tensors_=TensorList(update),
            D=TensorList(D) if D is not None else None,
            D_sq_acc_=D_sq_acc,
            damping=damping,
            update_freq=update_freq,
            step=step,
            i=i,
        )
        return var

EscapeAnnealing

Bases: torchzero.core.module.Module

If parameters stop changing, this runs a backward annealing random search

Source code in torchzero/modules/misc/escape.py
class EscapeAnnealing(Module):
    """If parameters stop changing, this runs a backward annealing random search"""
    def __init__(self, max_region:float = 1, max_iter:int = 1000, tol=1e-6, n_tol: int = 10):
        defaults = dict(max_region=max_region, max_iter=max_iter, tol=tol, n_tol=n_tol)
        super().__init__(defaults)


    @torch.no_grad
    def step(self, var):
        closure = var.closure
        if closure is None: raise RuntimeError("Escape requries closure")

        params = TensorList(var.params)
        settings = self.settings[params[0]]
        max_region = self.get_settings(params, 'max_region', cls=NumberList)
        max_iter = settings['max_iter']
        tol = settings['tol']
        n_tol = settings['n_tol']

        n_bad = self.global_state.get('n_bad', 0)

        prev_params = self.get_state(params, 'prev_params', cls=TensorList)
        diff = params-prev_params
        prev_params.copy_(params)

        if diff.abs().global_max() <= tol:
            n_bad += 1

        else:
            n_bad = 0

        self.global_state['n_bad'] = n_bad

        # no progress
        f_0 = var.get_loss(False)
        if n_bad >= n_tol:
            for i in range(1, max_iter+1):
                alpha = max_region * (i / max_iter)
                pert = params.sphere_like(radius=alpha)

                params.add_(pert)
                f_star = closure(False)

                if math.isfinite(f_star) and f_star < f_0-1e-12:
                    var.update = None
                    var.stop = True
                    var.skip_update = True
                    return var

                params.sub_(pert)

            self.global_state['n_bad'] = 0
        return var

Exp

Bases: torchzero.core.transform.Transform

Returns :code:exp(input)

Source code in torchzero/modules/ops/unary.py
class Exp(Transform):
    """Returns :code:`exp(input)`"""
    def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        torch._foreach_exp_(tensors)
        return tensors

ExpHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class ExpHomotopy(HomotopyBase):
    def __init__(self): super().__init__()
    def loss_transform(self, loss): return loss.exp()

FDM

Bases: torchzero.modules.grad_approximation.grad_approximator.GradApproximator

Approximate gradients via finite difference method.

Note

This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

Parameters:

  • h (float, default: 0.001 ) –

    magnitude of parameter perturbation. Defaults to 1e-3.

  • formula (Literal, default: 'central' ) –

    finite difference formula. Defaults to 'central2'.

  • target (Literal, default: 'closure' ) –

    what to set on var. Defaults to 'closure'.

Examples: plain FDM:

fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))

Any gradient-based method can use FDM-estimated gradients.

fdm_ncg = tz.Modular(
    model.parameters(),
    tz.m.FDM(),
    # set hvp_method to "forward" so that it
    # uses gradient difference instead of autograd
    tz.m.NewtonCG(hvp_method="forward"),
    tz.m.Backtracking()
)

Source code in torchzero/modules/grad_approximation/fdm.py
class FDM(GradApproximator):
    """Approximate gradients via finite difference method.

    Note:
        This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
        and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

    Args:
        h (float, optional): magnitude of parameter perturbation. Defaults to 1e-3.
        formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
        target (GradTarget, optional): what to set on var. Defaults to 'closure'.

    Examples:
    plain FDM:

    ```python
    fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
    ```

    Any gradient-based method can use FDM-estimated gradients.
    ```python
    fdm_ncg = tz.Modular(
        model.parameters(),
        tz.m.FDM(),
        # set hvp_method to "forward" so that it
        # uses gradient difference instead of autograd
        tz.m.NewtonCG(hvp_method="forward"),
        tz.m.Backtracking()
    )
    ```
    """
    def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central', target: GradTarget = 'closure'):
        defaults = dict(h=h, formula=formula)
        super().__init__(defaults, target=target)

    @torch.no_grad
    def approximate(self, closure, params, loss):
        grads = []
        loss_approx = None

        for p in params:
            g = torch.zeros_like(p)
            grads.append(g)

            settings = self.settings[p]
            h = settings['h']
            fd_fn = _FD_FUNCS[settings['formula']]

            p_flat = p.ravel(); g_flat = g.ravel()
            for i in range(len(p_flat)):
                loss, loss_approx, d = fd_fn(closure=closure, param=p_flat, idx=i, h=h, v_0=loss)
                g_flat[i] = d

        return grads, loss, loss_approx

Fill

Bases: torchzero.core.module.Module

Outputs tensors filled with :code:value

Source code in torchzero/modules/ops/utility.py
class Fill(Module):
    """Outputs tensors filled with :code:`value`"""
    def __init__(self, value: float):
        defaults = dict(value=value)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        var.update = [torch.full_like(p, self.settings[p]['value']) for p in var.params]
        return var

FillLoss

Bases: torchzero.core.module.Module

Outputs tensors filled with loss value times :code:alpha

Source code in torchzero/modules/misc/misc.py
class FillLoss(Module):
    """Outputs tensors filled with loss value times :code:`alpha`"""
    def __init__(self, alpha: float = 1, backward: bool = True):
        defaults = dict(alpha=alpha, backward=backward)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        alpha = self.get_settings(var.params, 'alpha')
        loss = var.get_loss(backward=self.defaults['backward'])
        var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
        return var

FletcherReeves

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Fletcher–Reeves nonlinear conjugate gradient method.

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class FletcherReeves(ConguateGradientBase):
    """Fletcher–Reeves nonlinear conjugate gradient method.

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
    """
    def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def initialize(self, p, g):
        self.global_state['prev_gg'] = g.dot(g)

    def get_beta(self, p, g, prev_g, prev_d):
        gg = g.dot(g)
        beta = fletcher_reeves_beta(gg, self.global_state['prev_gg'])
        self.global_state['prev_gg'] = gg
        return beta

FletcherVMM

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Fletcher's variable metric Quasi-Newton method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class FletcherVMM(_InverseHessianUpdateStrategyDefaults):
    """
    Fletcher's variable metric Quasi-Newton method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return fletcher_vmm_H_(H=H, s=s, y=y, tol=setting['tol'])

ForwardGradient

Bases: torchzero.modules.grad_approximation.rfdm.RandomizedFDM

Forward gradient method.

This method samples one or more directional derivatives evaluated via autograd jacobian-vector products. This is very similar to randomized finite difference.

Note

This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

Parameters:

  • n_samples (int, default: 1 ) –

    number of random gradient samples. Defaults to 1.

  • distribution (Literal, default: 'gaussian' ) –

    distribution for random gradient samples. Defaults to "gaussian".

  • beta (float, default: 0 ) –

    If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.

  • pre_generate (bool, default: True ) –

    whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.

  • jvp_method (str, default: 'autograd' ) –

    how to calculate jacobian vector product, note that with forward and 'central' this is equivalent to randomized finite difference. Defaults to 'autograd'.

  • h (float, default: 0.001 ) –

    finite difference step size of jvp_method is set to forward or central. Defaults to 1e-3.

  • target (Literal, default: 'closure' ) –

    what to set on var. Defaults to "closure".

References

Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022). Gradients without backpropagation. arXiv preprint arXiv:2202.08587.

Source code in torchzero/modules/grad_approximation/forward_gradient.py
class ForwardGradient(RandomizedFDM):
    """Forward gradient method.

    This method samples one or more directional derivatives evaluated via autograd jacobian-vector products. This is very similar to randomized finite difference.

    Note:
        This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
        and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.


    Args:
        n_samples (int, optional): number of random gradient samples. Defaults to 1.
        distribution (Distributions, optional): distribution for random gradient samples. Defaults to "gaussian".
        beta (float, optional):
            If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
        pre_generate (bool, optional):
            whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
        jvp_method (str, optional):
            how to calculate jacobian vector product, note that with `forward` and 'central' this is equivalent to randomized finite difference. Defaults to 'autograd'.
        h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
        target (GradTarget, optional): what to set on var. Defaults to "closure".

    References:
        Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022). Gradients without backpropagation. arXiv preprint arXiv:2202.08587.
    """
    PRE_MULTIPLY_BY_H = False
    def __init__(
        self,
        n_samples: int = 1,
        distribution: Distributions = "gaussian",
        beta: float = 0,
        pre_generate = True,
        jvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
        h: float = 1e-3,
        target: GradTarget = "closure",
        seed: int | None | torch.Generator = None,
    ):
        super().__init__(h=h, n_samples=n_samples, distribution=distribution, beta=beta, target=target, pre_generate=pre_generate, seed=seed)
        self.defaults['jvp_method'] = jvp_method

    @torch.no_grad
    def approximate(self, closure, params, loss):
        params = TensorList(params)
        loss_approx = None

        settings = self.settings[params[0]]
        n_samples = settings['n_samples']
        jvp_method = settings['jvp_method']
        h = settings['h']
        distribution = settings['distribution']
        default = [None]*n_samples
        perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
        generator = self._get_generator(settings['seed'], params)

        grad = None
        for i in range(n_samples):
            prt = perturbations[i]
            if prt[0] is None:
                prt = params.sample_like(distribution=distribution, variance=1, generator=generator)

            else: prt = TensorList(prt)

            if jvp_method == 'autograd':
                with torch.enable_grad():
                    loss, d = jvp(partial(closure, False), params=params, tangent=prt)

            elif jvp_method == 'forward':
                loss, d = jvp_fd_forward(partial(closure, False), params=params, tangent=prt, v_0=loss, normalize=True, h=h)

            elif jvp_method == 'central':
                loss_approx, d = jvp_fd_central(partial(closure, False), params=params, tangent=prt, normalize=True, h=h)

            else: raise ValueError(jvp_method)

            if grad is None: grad = prt * d
            else: grad += prt * d

        assert grad is not None
        if n_samples > 1: grad.div_(n_samples)
        return grad, loss, loss_approx

PRE_MULTIPLY_BY_H class-attribute

PRE_MULTIPLY_BY_H = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

FullMatrixAdagrad

Bases: torchzero.core.transform.TensorwiseTransform

Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).

Note

A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in tz.m.LMAdagrad.

Parameters:

  • beta (float | None, default: None ) –

    momentum for gradient outer product accumulators. if None, uses sum. Defaults to None.

  • decay (float | None, default: None ) –

    decay for gradient outer product accumulators. Defaults to None.

  • sqrt (bool, default: True ) –

    whether to take the square root of the accumulator. Defaults to True.

  • concat_params (bool, default: True ) –

    if False, each parameter will have it's own accumulator. Defaults to True.

  • precond_freq (int, default: 1 ) –

    frequency of updating the inverse square root of the accumulator. Defaults to 1.

  • init (Literal[str], default: 'identity' ) –

    how to initialize the accumulator. - "identity" - with identity matrix (default). - "zeros" - with zero matrix. - "ones" - with matrix of ones. -"GGT" - with the first outer product

  • divide (bool, default: False ) –

    whether to divide the accumulator by number of gradients in it. Defaults to False.

  • inner (Chainable | None, default: None ) –

    inner modules to apply preconditioning to. Defaults to None.

Examples:

Plain full-matrix adagrad

opt = tz.Modular(
    model.parameters(),
    tz.m.FullMatrixAdagrd(),
    tz.m.LR(1e-2),
)

Full-matrix RMSprop

opt = tz.Modular(
    model.parameters(),
    tz.m.FullMatrixAdagrad(beta=0.99),
    tz.m.LR(1e-2),
)

Full-matrix Adam

opt = tz.Modular(
    model.parameters(),
    tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9)),
    tz.m.Debias(0.9, 0.999),
    tz.m.LR(1e-2),
)

Source code in torchzero/modules/adaptive/adagrad.py
class FullMatrixAdagrad(TensorwiseTransform):
    """Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).

    Note:
        A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in ``tz.m.LMAdagrad``.

    Args:
        beta (float | None, optional): momentum for gradient outer product accumulators. if None, uses sum. Defaults to None.
        decay (float | None, optional): decay for gradient outer product accumulators. Defaults to None.
        sqrt (bool, optional): whether to take the square root of the accumulator. Defaults to True.
        concat_params (bool, optional): if False, each parameter will have it's own accumulator. Defaults to True.
        precond_freq (int, optional): frequency of updating the inverse square root of the accumulator. Defaults to 1.
        init (Literal[str], optional):
            how to initialize the accumulator.
            - "identity" - with identity matrix (default).
            - "zeros" - with zero matrix.
            - "ones" - with matrix of ones.
             -"GGT" - with the first outer product
        divide (bool, optional): whether to divide the accumulator by number of gradients in it. Defaults to False.
        inner (Chainable | None, optional): inner modules to apply preconditioning to. Defaults to None.

    ## Examples:

    Plain full-matrix adagrad
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.FullMatrixAdagrd(),
        tz.m.LR(1e-2),
    )
    ```

    Full-matrix RMSprop
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.FullMatrixAdagrad(beta=0.99),
        tz.m.LR(1e-2),
    )
    ```

    Full-matrix Adam
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9)),
        tz.m.Debias(0.9, 0.999),
        tz.m.LR(1e-2),
    )
    ```
    """
    def __init__(
        self,
        beta: float | None = None,
        decay: float | None = None,
        sqrt: bool = True,
        concat_params=True,
        precond_freq: int = 1,
        init: Literal["identity", "zeros", "ones", "GGT"] = "identity",
        reg: float = 1e-12,
        divide: bool = False,
        inner: Chainable | None = None,
    ):
        defaults = dict(beta=beta, decay=decay, sqrt=sqrt, precond_freq=precond_freq, init=init, divide=divide, reg=reg)
        super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner,)

    @torch.no_grad
    def update_tensor(self, tensor, param, grad, loss, state, setting):
        G = tensor.ravel()
        GG = torch.outer(G, G)
        decay = setting['decay']
        beta = setting['beta']
        init = setting['init']

        if 'GG' not in state:
            if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
            elif init == 'zeros': state['GG'] =  torch.zeros_like(GG)
            elif init == 'ones': state['GG'] = torch.ones_like(GG)
            elif init == 'GGT': state['GG'] = GG.clone()
            else: raise ValueError(init)
        if decay is not None: state['GG'].mul_(decay)

        if beta is not None: state['GG'].lerp_(GG, 1-beta)
        else: state['GG'].add_(GG)
        state['i'] = state.get('i', 0) + 1 # number of GGTs in sum

    @torch.no_grad
    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        step = state.get('step', 0)
        state['step'] = step + 1

        GG: torch.Tensor = state['GG']
        sqrt = setting['sqrt']
        divide = setting['divide']
        precond_freq = setting['precond_freq']
        reg = setting['reg']

        if divide: GG = GG/state.get('i', 1)

        if reg != 0:
            GG = GG + torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype).mul_(reg)

        if tensor.numel() == 1:
            GG = GG.squeeze()
            if sqrt: return tensor / GG.sqrt()
            return tensor / GG

        try:
            if sqrt:
                if "B" not in state or step % precond_freq == 0:
                    B = state["B"] = matrix_power_eigh(GG, -1/2)
                else:
                    B = state["B"]

            else: return torch.linalg.solve(GG, tensor.ravel()).view_as(tensor) # pylint:disable = not-callable

        except torch.linalg.LinAlgError:
            # fallback to diagonal AdaGrad
            denom = GG.diagonal()
            if sqrt: denom = denom.sqrt()
            return tensor.div_(denom + max(reg, 1e-12))

        return (B @ tensor.ravel()).view_as(tensor)

GaussNewton

Bases: torchzero.core.module.Module

Gauss-newton method.

To use this, the closure should return a vector of values to minimize sum of squares of. Please add the backward argument, it will always be False but it is required. Gradients will be calculated via batched autograd within this module, you don't need to implement the backward pass. Please see below for an example.

Note

This method requires ndim^2 memory, however, if it is used within tz.m.TrustCG trust region, the memory requirement is ndim*m, where m is number of values in the output.

Parameters:

  • reg (float, default: 1e-08 ) –

    regularization parameter. Defaults to 1e-8.

  • batched (bool, default: True ) –

    whether to use vmapping. Defaults to True.

Examples:

minimizing the rosenbrock function:

def rosenbrock(X):
    x1, x2 = X
    return torch.stack([(1 - x1), 100 * (x2 - x1**2)])

X = torch.tensor([-1.1, 2.5], requires_grad=True)
opt = tz.Modular([X], tz.m.GaussNewton(), tz.m.Backtracking())

# define the closure for line search
def closure(backward=True):
    return rosenbrock(X)

# minimize
for iter in range(10):
    loss = opt.step(closure)
    print(f'{loss = }')

training a neural network with a matrix-free GN trust region:

X = torch.randn(64, 20)
y = torch.randn(64, 10)

model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
opt = tz.Modular(
    model.parameters(),
    tz.m.TrustCG(tz.m.GaussNewton()),
)

def closure(backward=True):
    y_hat = model(X) # (64, 10)
    return (y_hat - y).pow(2).mean(0) # (10, )

for i in range(100):
    losses = opt.step(closure)
    if i % 10 == 0:
        print(f'{losses.mean() = }')

Source code in torchzero/modules/least_squares/gn.py
class GaussNewton(Module):
    """Gauss-newton method.

    To use this, the closure should return a vector of values to minimize sum of squares of.
    Please add the ``backward`` argument, it will always be False but it is required.
    Gradients will be calculated via batched autograd within this module, you don't need to
    implement the backward pass. Please see below for an example.

    Note:
        This method requires ``ndim^2`` memory, however, if it is used within ``tz.m.TrustCG`` trust region,
        the memory requirement is ``ndim*m``, where ``m`` is number of values in the output.

    Args:
        reg (float, optional): regularization parameter. Defaults to 1e-8.
        batched (bool, optional): whether to use vmapping. Defaults to True.

    Examples:

    minimizing the rosenbrock function:
    ```python
    def rosenbrock(X):
        x1, x2 = X
        return torch.stack([(1 - x1), 100 * (x2 - x1**2)])

    X = torch.tensor([-1.1, 2.5], requires_grad=True)
    opt = tz.Modular([X], tz.m.GaussNewton(), tz.m.Backtracking())

    # define the closure for line search
    def closure(backward=True):
        return rosenbrock(X)

    # minimize
    for iter in range(10):
        loss = opt.step(closure)
        print(f'{loss = }')
    ```

    training a neural network with a matrix-free GN trust region:
    ```python
    X = torch.randn(64, 20)
    y = torch.randn(64, 10)

    model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
    opt = tz.Modular(
        model.parameters(),
        tz.m.TrustCG(tz.m.GaussNewton()),
    )

    def closure(backward=True):
        y_hat = model(X) # (64, 10)
        return (y_hat - y).pow(2).mean(0) # (10, )

    for i in range(100):
        losses = opt.step(closure)
        if i % 10 == 0:
            print(f'{losses.mean() = }')
    ```
    """
    def __init__(self, reg:float = 1e-8, batched:bool=True, ):
        super().__init__(defaults=dict(batched=batched, reg=reg))

    @torch.no_grad
    def update(self, var):
        params = var.params
        batched = self.defaults['batched']

        closure = var.closure
        assert closure is not None

        # gauss newton direction
        with torch.enable_grad():
            f = var.get_loss(backward=False) # n_out
            assert isinstance(f, torch.Tensor)
            G_list = jacobian_wrt([f.ravel()], params, batched=batched)

        var.loss = f.pow(2).sum()

        G = self.global_state["G"] = flatten_jacobian(G_list) # (n_out, ndim)
        Gtf = G.T @ f.detach() # (ndim)
        self.global_state["Gtf"] = Gtf
        var.grad = vec_to_tensors(Gtf, var.params)

        # set closure to calculate sum of squares for line searches etc
        if var.closure is not None:
            def sos_closure(backward=True):
                if backward:
                    var.zero_grad()
                    with torch.enable_grad():
                        loss = closure(False).pow(2).sum()
                        loss.backward()
                    return loss

                loss = closure(False).pow(2).sum()
                return loss

            var.closure = sos_closure

    @torch.no_grad
    def apply(self, var):
        reg = self.defaults['reg']

        G = self.global_state['G']
        Gtf = self.global_state['Gtf']

        GtG = G.T @ G # (ndim, ndim)
        if reg != 0:
            GtG.add_(torch.eye(GtG.size(0), device=GtG.device, dtype=GtG.dtype).mul_(reg))

        v = torch.linalg.lstsq(GtG, Gtf).solution # pylint:disable=not-callable

        var.update = vec_to_tensors(v, var.params)
        return var

    def get_H(self, var):
        G = self.global_state['G']
        return linear_operator.AtA(G)

GaussianSmoothing

Bases: torchzero.modules.grad_approximation.rfdm.RandomizedFDM

Gradient approximation via Gaussian smoothing method.

Note

This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

Parameters:

  • h (float, default: 0.01 ) –

    finite difference step size of jvp_method is set to forward or central. Defaults to 1e-2.

  • n_samples (int, default: 100 ) –

    number of random gradient samples. Defaults to 100.

  • formula (Literal, default: 'forward2' ) –

    finite difference formula. Defaults to 'forward2'.

  • distribution (Literal, default: 'gaussian' ) –

    distribution. Defaults to "gaussian".

  • beta (float, default: 0 ) –

    If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.

  • pre_generate (bool, default: True ) –

    whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.

  • seed (int | None | Generator, default: None ) –

    Seed for random generator. Defaults to None.

  • target (Literal, default: 'closure' ) –

    what to set on var. Defaults to "closure".

References

Yurii Nesterov, Vladimir Spokoiny. (2015). Random Gradient-Free Minimization of Convex Functions. https://gwern.net/doc/math/2015-nesterov.pdf

Source code in torchzero/modules/grad_approximation/rfdm.py
class GaussianSmoothing(RandomizedFDM):
    """
    Gradient approximation via Gaussian smoothing method.

    Note:
        This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
        and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

    Args:
        h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-2.
        n_samples (int, optional): number of random gradient samples. Defaults to 100.
        formula (_FD_Formula, optional): finite difference formula. Defaults to 'forward2'.
        distribution (Distributions, optional): distribution. Defaults to "gaussian".
        beta (float, optional):
            If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
        pre_generate (bool, optional):
            whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
        seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
        target (GradTarget, optional): what to set on var. Defaults to "closure".


    References:
        Yurii Nesterov, Vladimir Spokoiny. (2015). Random Gradient-Free Minimization of Convex Functions. https://gwern.net/doc/math/2015-nesterov.pdf
    """
    def __init__(
        self,
        h: float = 1e-2,
        n_samples: int = 100,
        formula: _FD_Formula = "forward2",
        distribution: Distributions = "gaussian",
        beta: float = 0,
        pre_generate = True,
        target: GradTarget = "closure",
        seed: int | None | torch.Generator = None,
    ):
        super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)

Grad

Bases: torchzero.core.module.Module

Outputs the gradient

Source code in torchzero/modules/ops/utility.py
class Grad(Module):
    """Outputs the gradient"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def step(self, var):
        var.update = [g.clone() for g in var.get_grad()]
        return var

GradApproximator

Bases: torchzero.core.module.Module, abc.ABC

Base class for gradient approximations. This is an abstract class, to use it, subclass it and override approximate.

GradientApproximator modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure.

Parameters:

  • defaults (dict[str, Any] | None, default: None ) –

    dict with defaults. Defaults to None.

  • target (str, default: 'closure' ) –

    whether to set var.grad, var.update or 'var.closure`. Defaults to 'closure'.

Example:

Basic SPSA method implementation.

class SPSA(GradApproximator):
    def __init__(self, h=1e-3):
        defaults = dict(h=h)
        super().__init__(defaults)

    @torch.no_grad
    def approximate(self, closure, params, loss):
        perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]

        # evaluate params + perturbation
        torch._foreach_add_(params, perturbation)
        loss_plus = closure(False)

        # evaluate params - perturbation
        torch._foreach_sub_(params, perturbation)
        torch._foreach_sub_(params, perturbation)
        loss_minus = closure(False)

        # restore original params
        torch._foreach_add_(params, perturbation)

        # calculate SPSA gradients
        spsa_grads = []
        for p, pert in zip(params, perturbation):
            settings = self.settings[p]
            h = settings['h']
            d = (loss_plus - loss_minus) / (2*(h**2))
            spsa_grads.append(pert * d)

        # returns tuple: (grads, loss, loss_approx)
        # loss must be with initial parameters
        # since we only evaluated loss with perturbed parameters
        # we only have loss_approx
        return spsa_grads, None, loss_plus

Methods:

  • approximate

    Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!

  • pre_step

    This runs once before each step, whereas approximate may run multiple times per step if further modules

Source code in torchzero/modules/grad_approximation/grad_approximator.py
class GradApproximator(Module, ABC):
    """Base class for gradient approximations.
    This is an abstract class, to use it, subclass it and override `approximate`.

    GradientApproximator modifies the closure to evaluate the estimated gradients,
    and further closure-based modules will use the modified closure.

    Args:
        defaults (dict[str, Any] | None, optional): dict with defaults. Defaults to None.
        target (str, optional):
            whether to set `var.grad`, `var.update` or 'var.closure`. Defaults to 'closure'.

    Example:

    Basic SPSA method implementation.
    ```python
    class SPSA(GradApproximator):
        def __init__(self, h=1e-3):
            defaults = dict(h=h)
            super().__init__(defaults)

        @torch.no_grad
        def approximate(self, closure, params, loss):
            perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]

            # evaluate params + perturbation
            torch._foreach_add_(params, perturbation)
            loss_plus = closure(False)

            # evaluate params - perturbation
            torch._foreach_sub_(params, perturbation)
            torch._foreach_sub_(params, perturbation)
            loss_minus = closure(False)

            # restore original params
            torch._foreach_add_(params, perturbation)

            # calculate SPSA gradients
            spsa_grads = []
            for p, pert in zip(params, perturbation):
                settings = self.settings[p]
                h = settings['h']
                d = (loss_plus - loss_minus) / (2*(h**2))
                spsa_grads.append(pert * d)

            # returns tuple: (grads, loss, loss_approx)
            # loss must be with initial parameters
            # since we only evaluated loss with perturbed parameters
            # we only have loss_approx
            return spsa_grads, None, loss_plus
    ```
    """
    def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
        super().__init__(defaults)
        self._target: GradTarget = target

    @abstractmethod
    def approximate(self, closure: Callable, params: list[torch.Tensor], loss: torch.Tensor | None) -> tuple[Iterable[torch.Tensor], torch.Tensor | None, torch.Tensor | None]:
        """Returns a tuple: ``(grad, loss, loss_approx)``, make sure this resets parameters to their original values!"""

    def pre_step(self, var: Var) -> None:
        """This runs once before each step, whereas `approximate` may run multiple times per step if further modules
        evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""

    @torch.no_grad
    def step(self, var):
        self.pre_step(var)

        if var.closure is None: raise RuntimeError("Gradient approximation requires closure")
        params, closure, loss = var.params, var.closure, var.loss

        if self._target == 'closure':

            def approx_closure(backward=True):
                if backward:
                    # set loss to None because closure might be evaluated at different points
                    grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None)
                    for p, g in zip(params, grad): p.grad = g
                    return l if l is not None else closure(False)
                return closure(False)

            var.closure = approx_closure
            return var

        # if var.grad is not None:
        #     warnings.warn('Using grad approximator when `var.grad` is already set.')
        grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss)
        if loss_approx is not None: var.loss_approx = loss_approx
        if loss is not None: var.loss = var.loss_approx = loss
        if self._target == 'grad': var.grad = list(grad)
        elif self._target == 'update': var.update = list(grad)
        else: raise ValueError(self._target)
        return var

approximate

approximate(closure: Callable, params: list[Tensor], loss: Tensor | None) -> tuple[Iterable[Tensor], Tensor | None, Tensor | None]

Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!

Source code in torchzero/modules/grad_approximation/grad_approximator.py
@abstractmethod
def approximate(self, closure: Callable, params: list[torch.Tensor], loss: torch.Tensor | None) -> tuple[Iterable[torch.Tensor], torch.Tensor | None, torch.Tensor | None]:
    """Returns a tuple: ``(grad, loss, loss_approx)``, make sure this resets parameters to their original values!"""

pre_step

pre_step(var: Var) -> None

This runs once before each step, whereas approximate may run multiple times per step if further modules evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations.

Source code in torchzero/modules/grad_approximation/grad_approximator.py
def pre_step(self, var: Var) -> None:
    """This runs once before each step, whereas `approximate` may run multiple times per step if further modules
    evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""

GradSign

Bases: torchzero.core.transform.Transform

Copies gradient sign to update.

Source code in torchzero/modules/misc/misc.py
class GradSign(Transform):
    """Copies gradient sign to update."""
    def __init__(self, target: Target = 'update'):
        super().__init__({}, uses_grad=True, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        return [t.copysign_(g) for t,g in zip(tensors, grads)]

GradToNone

Bases: torchzero.core.module.Module

Sets :code:grad attribute to None on :code:var.

Source code in torchzero/modules/ops/utility.py
class GradToNone(Module):
    """Sets :code:`grad` attribute to None on :code:`var`."""
    def __init__(self): super().__init__()
    def step(self, var):
        var.grad = None
        return var

GradientAccumulation

Bases: torchzero.core.module.Module

Uses n steps to accumulate gradients, after n gradients have been accumulated, they are passed to :code:modules and parameters are updates.

Accumulating gradients for n steps is equivalent to increasing batch size by n. Increasing the batch size is more computationally efficient, but sometimes it is not feasible due to memory constraints.

Note

Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.

Parameters:

  • n (int) –

    number of gradients to accumulate.

  • mean (bool, default: True ) –

    if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.

  • stop (bool, default: True ) –

    this module prevents next modules from stepping unless n gradients have been accumulate. Setting this argument to False disables that. Defaults to True.

Examples:

Adam with gradients accumulated for 16 batches.

opt = tz.Modular(
    model.parameters(),
    tz.m.GradientAccumulation(),
    tz.m.Adam(),
    tz.m.LR(1e-2),
)
Source code in torchzero/modules/misc/gradient_accumulation.py
class GradientAccumulation(Module):
    """Uses ``n`` steps to accumulate gradients, after ``n`` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.

    Accumulating gradients for ``n`` steps is equivalent to increasing batch size by ``n``. Increasing the batch size
    is more computationally efficient, but sometimes it is not feasible due to memory constraints.

    Note:
        Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.

    Args:
        n (int): number of gradients to accumulate.
        mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
        stop (bool, optional):
            this module prevents next modules from stepping unless ``n`` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.

    ## Examples:

    Adam with gradients accumulated for 16 batches.

    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.GradientAccumulation(),
        tz.m.Adam(),
        tz.m.LR(1e-2),
    )
    ```
    """
    def __init__(self, n: int, mean=True, stop=True):
        defaults = dict(n=n, mean=mean, stop=stop)
        super().__init__(defaults)


    @torch.no_grad
    def step(self, var):
        accumulator = self.get_state(var.params, 'accumulator')
        settings = self.defaults
        n = settings['n']; mean = settings['mean']; stop = settings['stop']
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        # add update to accumulator
        torch._foreach_add_(accumulator, var.get_update())

        # step with accumulated updates
        if step % n == 0:
            if mean:
                torch._foreach_div_(accumulator, n)

            var.update = accumulator

            # zero accumulator
            self.clear_state_keys('accumulator')

        else:
            # prevent update
            if stop:
                var.update = None
                var.stop=True
                var.skip_update=True

        return var

GradientCorrection

Bases: torchzero.core.transform.Transform

Estimates gradient at minima along search direction assuming function is quadratic.

This can useful as inner module for second order methods with inexact line search.

Example:

L-BFGS with gradient correction

opt = tz.Modular(
    model.parameters(),
    tz.m.LBFGS(inner=tz.m.GradientCorrection()),
    tz.m.Backtracking()
)
Reference

HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class GradientCorrection(Transform):
    """
    Estimates gradient at minima along search direction assuming function is quadratic.

    This can useful as inner module for second order methods with inexact line search.

    ## Example:
    L-BFGS with gradient correction

    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.LBFGS(inner=tz.m.GradientCorrection()),
        tz.m.Backtracking()
    )
    ```

    Reference:
        HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394

    """
    def __init__(self):
        super().__init__(None, uses_grad=False)

    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        if 'p_prev' not in states[0]:
            p_prev = unpack_states(states, tensors, 'p_prev', init=params)
            g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
            return tensors

        p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
        g_hat = gradient_correction(TensorList(tensors), params-p_prev, tensors-g_prev)

        p_prev.copy_(params)
        g_prev.copy_(tensors)
        return g_hat

GradientSampling

Bases: torchzero.core.reformulation.Reformulation

Samples and aggregates gradients and values at perturbed points.

This module can be used for gaussian homotopy and gradient sampling methods.

Parameters:

  • modules (Chainable | None, default: None ) –

    modules that will be optimizing the modified objective. if None, returns gradient of the modified objective as the update. Defaults to None.

  • sigma (float, default: 1.0 ) –

    initial magnitude of the perturbations. Defaults to 1.

  • n (int, default: 100 ) –

    number of perturbations per step. Defaults to 100.

  • aggregate (str, default: 'mean' ) –

    how to aggregate values and gradients - "mean" - uses mean of the gradients, as in gaussian homotopy. - "max" - uses element-wise maximum of the gradients. - "min" - uses element-wise minimum of the gradients. - "min-norm" - picks gradient with the lowest norm.

    Defaults to 'mean'.

  • distribution (Literal, default: 'gaussian' ) –

    distribution for random perturbations. Defaults to 'gaussian'.

  • include_x0 (bool, default: True ) –

    whether to include gradient at un-perturbed point. Defaults to True.

  • fixed (bool, default: True ) –

    if True, perturbations do not get replaced by new random perturbations until termination criteria is satisfied. Defaults to True.

  • pre_generate (bool, default: True ) –

    if True, perturbations are pre-generated before each step. This requires more memory to store all of them, but ensures they do not change when closure is evaluated multiple times. Defaults to True.

  • termination (TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None, default: None ) –

    a termination criteria module, sigma will be multiplied by decay when termination criteria is satisfied, and new perturbations will be generated if fixed. Defaults to None.

  • decay (float, default: 0.6666666666666666 ) –

    sigma multiplier on termination criteria. Defaults to 2/3.

  • reset_on_termination (bool, default: True ) –

    whether to reset states of all other modules on termination. Defaults to True.

  • sigma_strategy (str | None, default: None ) –

    strategy for adapting sigma. If condition is satisfied, sigma is multiplied by sigma_nplus, otherwise it is multiplied by sigma_nminus. - "grad-norm" - at least sigma_target gradients should have lower norm than at un-perturbed point. - "value" - at least sigma_target values (losses) should be lower than at un-perturbed point. - None - doesn't use adaptive sigma.

    This introduces a side-effect to the closure, so it should be left at None of you use trust region or line search to optimize the modified objective. Defaults to None.

  • sigma_target (int, default: 0.2 ) –

    number of elements to satisfy the condition in sigma_strategy. Defaults to 1.

  • sigma_nplus (float, default: 1.3333333333333333 ) –

    sigma multiplier when sigma_strategy condition is satisfied. Defaults to 4/3.

  • sigma_nminus (float, default: 0.6666666666666666 ) –

    sigma multiplier when sigma_strategy condition is not satisfied. Defaults to 2/3.

  • seed (int | None, default: None ) –

    seed. Defaults to None.

Source code in torchzero/modules/smoothing/sampling.py
class GradientSampling(Reformulation):
    """Samples and aggregates gradients and values at perturbed points.

    This module can be used for gaussian homotopy and gradient sampling methods.

    Args:
        modules (Chainable | None, optional):
            modules that will be optimizing the modified objective.
            if None, returns gradient of the modified objective as the update. Defaults to None.
        sigma (float, optional): initial magnitude of the perturbations. Defaults to 1.
        n (int, optional): number of perturbations per step. Defaults to 100.
        aggregate (str, optional):
            how to aggregate values and gradients
            - "mean" - uses mean of the gradients, as in gaussian homotopy.
            - "max" - uses element-wise maximum of the gradients.
            - "min" - uses element-wise minimum of the gradients.
            - "min-norm" - picks gradient with the lowest norm.

            Defaults to 'mean'.
        distribution (Distributions, optional): distribution for random perturbations. Defaults to 'gaussian'.
        include_x0 (bool, optional): whether to include gradient at un-perturbed point. Defaults to True.
        fixed (bool, optional):
            if True, perturbations do not get replaced by new random perturbations until termination criteria is satisfied. Defaults to True.
        pre_generate (bool, optional):
            if True, perturbations are pre-generated before each step.
            This requires more memory to store all of them,
            but ensures they do not change when closure is evaluated multiple times.
            Defaults to True.
        termination (TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None, optional):
            a termination criteria module, sigma will be multiplied by ``decay`` when termination criteria is satisfied,
            and new perturbations will be generated if ``fixed``. Defaults to None.
        decay (float, optional): sigma multiplier on termination criteria. Defaults to 2/3.
        reset_on_termination (bool, optional): whether to reset states of all other modules on termination. Defaults to True.
        sigma_strategy (str | None, optional):
            strategy for adapting sigma. If condition is satisfied, sigma is multiplied by ``sigma_nplus``,
            otherwise it is multiplied by ``sigma_nminus``.
            - "grad-norm" - at least ``sigma_target`` gradients should have lower norm than at un-perturbed point.
            - "value" - at least ``sigma_target`` values (losses) should be lower than at un-perturbed point.
            - None - doesn't use adaptive sigma.

            This introduces a side-effect to the closure, so it should be left at None of you use
            trust region or line search to optimize the modified objective.
            Defaults to None.
        sigma_target (int, optional):
            number of elements to satisfy the condition in ``sigma_strategy``. Defaults to 1.
        sigma_nplus (float, optional): sigma multiplier when ``sigma_strategy`` condition is satisfied. Defaults to 4/3.
        sigma_nminus (float, optional): sigma multiplier when ``sigma_strategy`` condition is not satisfied. Defaults to 2/3.
        seed (int | None, optional): seed. Defaults to None.
    """
    def __init__(
        self,
        modules: Chainable | None = None,
        sigma: float = 1.,
        n:int = 100,
        aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = 'mean',
        distribution: Distributions = 'gaussian',
        include_x0: bool = True,

        fixed: bool=True,
        pre_generate: bool = True,
        termination: TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None = None,
        decay: float = 2/3,
        reset_on_termination: bool = True,

        sigma_strategy: Literal['grad-norm', 'value'] | None = None,
        sigma_target: int | float = 0.2,
        sigma_nplus: float = 4/3,
        sigma_nminus: float = 2/3,

        seed: int | None = None,
    ):

        defaults = dict(sigma=sigma, n=n, aggregate=aggregate, distribution=distribution, seed=seed, include_x0=include_x0, fixed=fixed, decay=decay, reset_on_termination=reset_on_termination, sigma_strategy=sigma_strategy, sigma_target=sigma_target, sigma_nplus=sigma_nplus, sigma_nminus=sigma_nminus, pre_generate=pre_generate)
        super().__init__(defaults, modules)

        if termination is not None:
            self.set_child('termination', make_termination_criteria(extra=termination))

    @torch.no_grad
    def pre_step(self, var):
        params = TensorList(var.params)

        fixed = self.defaults['fixed']

        # check termination criteria
        if 'termination' in self.children:
            termination = cast(TerminationCriteriaBase, self.children['termination'])
            if termination.should_terminate(var):

                # decay sigmas
                states = [self.state[p] for p in params]
                settings = [self.settings[p] for p in params]

                for state, setting in zip(states, settings):
                    if 'sigma' not in state: state['sigma'] = setting['sigma']
                    state['sigma'] *= setting['decay']

                # reset on sigmas decay
                if self.defaults['reset_on_termination']:
                    var.post_step_hooks.append(partial(_reset_except_self, self=self))

                # clear perturbations
                self.global_state.pop('perts', None)

        # pre-generate perturbations if not already pre-generated or not fixed
        if self.defaults['pre_generate'] and (('perts' not in self.global_state) or (not fixed)):
            states = [self.state[p] for p in params]
            settings = [self.settings[p] for p in params]

            n = self.defaults['n'] - self.defaults['include_x0']
            generator = self.get_generator(params[0].device, self.defaults['seed'])

            perts = [params.sample_like(self.defaults['distribution'], generator=generator) for _ in range(n)]

            self.global_state['perts'] = perts

    @torch.no_grad
    def closure(self, backward, closure, params, var):
        params = TensorList(params)
        loss_agg = None
        grad_agg = None

        states = [self.state[p] for p in params]
        settings = [self.settings[p] for p in params]
        sigma_inits = [s['sigma'] for s in settings]
        sigmas = [s.setdefault('sigma', si) for s, si in zip(states, sigma_inits)]

        include_x0 = self.defaults['include_x0']
        pre_generate = self.defaults['pre_generate']
        aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = self.defaults['aggregate']
        sigma_strategy: Literal['grad-norm', 'value'] | None = self.defaults['sigma_strategy']
        distribution = self.defaults['distribution']
        generator = self.get_generator(params[0].device, self.defaults['seed'])


        n_finite = 0
        n_good = 0
        f_0 = None; g_0 = None

        # evaluate at x_0
        if include_x0:
            f_0 = cast(torch.Tensor, var.get_loss(backward=backward))

            isfinite = math.isfinite(f_0)
            if isfinite:
                n_finite += 1
                loss_agg = f_0

            if backward:
                g_0 = var.get_grad()
                if isfinite: grad_agg = g_0

        # evaluate at x_0 + p for each perturbation
        if pre_generate:
            perts = self.global_state['perts']
        else:
            perts = [None] * (self.defaults['n'] - include_x0)

        x_0 = [p.clone() for p in params]

        for pert in perts:
            loss = None; grad = None

            # generate if not pre-generated
            if pert is None:
                pert = params.sample_like(distribution, generator=generator)

            # add perturbation and evaluate
            pert = pert * sigmas
            torch._foreach_add_(params, pert)

            with torch.enable_grad() if backward else nullcontext():
                loss = closure(backward)

            if math.isfinite(loss):
                n_finite += 1

                # add loss
                if loss_agg is None:
                    loss_agg = loss
                else:
                    if aggregate == 'mean':
                        loss_agg += loss

                    elif (aggregate=='min') or (aggregate=='min-value') or (aggregate=='min-norm' and not backward):
                        loss_agg = loss_agg.clamp(max=loss)

                    elif aggregate == 'max':
                        loss_agg = loss_agg.clamp(min=loss)

                # add grad
                if backward:
                    grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
                    if grad_agg is None:
                        grad_agg = grad
                    else:
                        if aggregate == 'mean':
                            torch._foreach_add_(grad_agg, grad)

                        elif aggregate == 'min':
                            grad_agg_abs = torch._foreach_abs(grad_agg)
                            torch._foreach_minimum_(grad_agg_abs, torch._foreach_abs(grad))
                            grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]

                        elif aggregate == 'max':
                            grad_agg_abs = torch._foreach_abs(grad_agg)
                            torch._foreach_maximum_(grad_agg_abs, torch._foreach_abs(grad))
                            grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]

                        elif aggregate == 'min-norm':
                            if TensorList(grad).global_vector_norm() < TensorList(grad_agg).global_vector_norm():
                                grad_agg = grad
                                loss_agg = loss

                        elif aggregate == 'min-value':
                            if loss < loss_agg:
                                grad_agg = grad
                                loss_agg = loss

            # undo perturbation
            torch._foreach_copy_(params, x_0)

            # adaptive sigma
            # by value
            if sigma_strategy == 'value':
                if f_0 is None:
                    with torch.enable_grad() if backward else nullcontext():
                        f_0 = closure(False)

                if loss < f_0:
                    n_good += 1

            # by gradient norm
            elif sigma_strategy == 'grad-norm' and backward and math.isfinite(loss):
                assert grad is not None
                if g_0 is None:
                    with torch.enable_grad() if backward else nullcontext():
                        closure()
                        g_0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]

                if TensorList(grad).global_vector_norm() < TensorList(g_0).global_vector_norm():
                    n_good += 1

        # update sigma if strategy is enabled
        if sigma_strategy is not None:

            sigma_target = self.defaults['sigma_target']
            if isinstance(sigma_target, float):
                sigma_target = int(max(1, n_finite * sigma_target))

            if n_good >= sigma_target:
                key = 'sigma_nplus'
            else:
                key = 'sigma_nminus'

            for p in params:
                self.state[p]['sigma'] *= self.settings[p][key]

        # if no finite losses, just return inf
        if n_finite == 0:
            assert loss_agg is None and grad_agg is None
            loss = torch.tensor(torch.inf, dtype=params[0].dtype, device=params[0].device)
            grad = [torch.full_like(p, torch.inf) for p in params]
            return loss, grad

        assert loss_agg is not None

        # no post processing needed when aggregate is 'max', 'min', 'min-norm', 'min-value'
        if aggregate != 'mean':
            return loss_agg, grad_agg

        # on mean divide by number of evals
        loss_agg /= n_finite

        if backward:
            assert grad_agg is not None
            torch._foreach_div_(grad_agg, n_finite)

        return loss_agg, grad_agg

Graft

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs tensors rescaled to have the same norm as :code:magnitude(tensors).

Source code in torchzero/modules/ops/binary.py
class Graft(BinaryOperationBase):
    """Outputs tensors rescaled to have the same norm as :code:`magnitude(tensors)`."""
    def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, magnitude=magnitude)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
        return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)

GraftGradToUpdate

Bases: torchzero.core.transform.Transform

Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update.

Source code in torchzero/modules/misc/misc.py
class GraftGradToUpdate(Transform):
    """Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update."""
    def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6, target: Target = 'update'):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, uses_grad=True, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
        return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)

GraftModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Outputs :code:direction output rescaled to have the same norm as :code:magnitude output.

Parameters:

  • direction (Chainable) –

    module to use the direction from

  • magnitude (Chainable) –

    module to use the magnitude from

  • tensorwise (bool, default: True ) –

    whether to calculate norm per-tensor or globally. Defaults to True.

  • ord (float, default: 2 ) –

    norm order. Defaults to 2.

  • eps (float, default: 1e-06 ) –

    clips denominator to be no less than this value. Defaults to 1e-6.

  • strength (float, default: 1 ) –

    strength of grafting. Defaults to 1.

Example

Shampoo grafted to Adam

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.GraftModules(
        direction = tz.m.Shampoo(),
        magnitude = tz.m.Adam(),
    ),
    tz.m.LR(1e-3)
)
Reference

Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803. https://arxiv.org/pdf/2002.11803

Source code in torchzero/modules/ops/multi.py
class GraftModules(MultiOperationBase):
    """Outputs :code:`direction` output rescaled to have the same norm as :code:`magnitude` output.

    Args:
        direction (Chainable): module to use the direction from
        magnitude (Chainable): module to use the magnitude from
        tensorwise (bool, optional): whether to calculate norm per-tensor or globally. Defaults to True.
        ord (float, optional): norm order. Defaults to 2.
        eps (float, optional): clips denominator to be no less than this value. Defaults to 1e-6.
        strength (float, optional): strength of grafting. Defaults to 1.

    Example:
        Shampoo grafted to Adam

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.GraftModules(
                    direction = tz.m.Shampoo(),
                    magnitude = tz.m.Adam(),
                ),
                tz.m.LR(1e-3)
            )

    Reference:
        Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803. https://arxiv.org/pdf/2002.11803
    """
    def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:Metrics=2, eps:float = 1e-6, strength:float=1):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
        super().__init__(defaults, direction=direction, magnitude=magnitude)

    @torch.no_grad
    def transform(self, var, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
        tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.defaults)
        return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)

GraftToGrad

Bases: torchzero.core.transform.Transform

Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient.

Source code in torchzero/modules/misc/misc.py
class GraftToGrad(Transform):
    """Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient."""
    def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6, target: Target = 'update'):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, uses_grad=True, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
        return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)

GraftToParams

Bases: torchzero.core.transform.Transform

Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:eps.

Source code in torchzero/modules/misc/misc.py
class GraftToParams(Transform):
    """Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:`eps`."""
    def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-4, target: Target = 'update'):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
        return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)

GraftToUpdate

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs :code:magnitude(tensors) rescaled to have the same norm as tensors

Source code in torchzero/modules/ops/binary.py
class RGraft(BinaryOperationBase):
    """Outputs :code:`magnitude(tensors)` rescaled to have the same norm as tensors"""

    def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, direction=direction)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tensor]):
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
        return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)

GramSchimdt

Bases: torchzero.modules.ops.binary.BinaryOperationBase

outputs tensors made orthogonal to other(tensors) via Gram-Schmidt.

Source code in torchzero/modules/ops/binary.py
class GramSchimdt(BinaryOperationBase):
    """outputs tensors made orthogonal to `other(tensors)` via Gram-Schmidt."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
        update = TensorList(update); other = TensorList(other)
        min = torch.finfo(update[0].dtype).tiny * 2
        return update - (other*update) / (other*other).clip(min=min)

Greenstadt1

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Greenstadt's first Quasi-Newton method.

Note

a trust region or an accurate line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class Greenstadt1(_InverseHessianUpdateStrategyDefaults):
    """Greenstadt's first Quasi-Newton method.

    Note:
        a trust region or an accurate line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev)

Greenstadt2

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Greenstadt's second Quasi-Newton method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class Greenstadt2(_InverseHessianUpdateStrategyDefaults):
    """Greenstadt's second Quasi-Newton method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return greenstadt2_H_(H=H, s=s, y=y)

HagerZhang

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Hager-Zhang nonlinear conjugate gradient method,

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class HagerZhang(ConguateGradientBase):
    """Hager-Zhang nonlinear conjugate gradient method,

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
    """
    def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def get_beta(self, p, g, prev_g, prev_d):
        return hager_zhang_beta(g, prev_d, prev_g)

HeavyBall

Bases: torchzero.modules.momentum.momentum.EMA

Polyak's momentum (heavy-ball method).

Parameters:

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • dampening (float, default: 0 ) –

    momentum dampening. Defaults to 0.

  • debiased (bool, default: False ) –

    whether to debias the EMA like in Adam. Defaults to False.

  • lerp (bool, default: False ) –

    whether to use linear interpolation, if True, this becomes exponential moving average. Defaults to False.

  • ema_init (str, default: 'update' ) –

    initial values for the EMA, "zeros" or "update".

  • target (Literal, default: 'update' ) –

    target to apply EMA to. Defaults to 'update'.

Source code in torchzero/modules/momentum/momentum.py
class HeavyBall(EMA):
    """Polyak's momentum (heavy-ball method).

    Args:
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        dampening (float, optional): momentum dampening. Defaults to 0.
        debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
        lerp (bool, optional):
            whether to use linear interpolation, if True, this becomes exponential moving average. Defaults to False.
        ema_init (str, optional): initial values for the EMA, "zeros" or "update".
        target (Target, optional): target to apply EMA to. Defaults to 'update'.
    """
    def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update', target: Target = 'update'):
        super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init, target=target)

HestenesStiefel

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Hestenes–Stiefel nonlinear conjugate gradient method.

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class HestenesStiefel(ConguateGradientBase):
    """Hestenes–Stiefel nonlinear conjugate gradient method.

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
    """
    def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def get_beta(self, p, g, prev_g, prev_d):
        return hestenes_stiefel_beta(g, prev_d, prev_g)

HigherOrderNewton

Bases: torchzero.core.module.Module

A basic arbitrary order newton's method with optional trust region and proximal penalty.

This constructs an nth order taylor approximation via autograd and minimizes it with scipy.optimize.minimize trust region newton solvers with optional proximal penalty.

The hessian of taylor approximation is easier to evaluate, plus it can be evaluated in a batched mode, so it can be more efficient in very specific instances.

Notes
  • In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the inner argument if you wish to apply Newton preconditioning to another module's output.
  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating higher order derivatives. The closure must accept a backward argument (refer to documentation).
  • this uses roughly O(N^order) memory and solving the subproblem is very expensive.
  • "none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.

Args:

order (int, optional):
    Order of the method, number of taylor series terms (orders of derivatives) used to approximate the function. Defaults to 4.
trust_method (str | None, optional):
    Method used for trust region.
    - "bounds" - the model is minimized within bounds defined by trust region.
    - "proximal" - the model is minimized with penalty for going too far from current point.
    - "none" - disables trust region.

    Defaults to 'bounds'.
increase (float, optional): trust region multiplier on good steps. Defaults to 1.5.
decrease (float, optional): trust region multiplier on bad steps. Defaults to 0.75.
trust_init (float | None, optional):
    initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on ``"proximal"``. Defaults to None.
trust_tol (float, optional):
    Maximum ratio of expected loss reduction to actual reduction for trust region increase.
    Should 1 or higer. Defaults to 2.
de_iters (int | None, optional):
    If this is specified, the model is minimized via differential evolution first to possibly escape local minima,
    then it is passed to scipy.optimize.minimize. Defaults to None.
vectorize (bool, optional): whether to enable vectorized jacobians (usually faster). Defaults to True.
Source code in torchzero/modules/higher_order/higher_order_newton.py
class HigherOrderNewton(Module):
    """A basic arbitrary order newton's method with optional trust region and proximal penalty.

    This constructs an nth order taylor approximation via autograd and minimizes it with
    ``scipy.optimize.minimize`` trust region newton solvers with optional proximal penalty.

    The hessian of taylor approximation is easier to evaluate, plus it can be evaluated in a batched mode,
    so it can be more efficient in very specific instances.

    Notes:
        - In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
        - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating higher order derivatives. The closure must accept a ``backward`` argument (refer to documentation).
        - this uses roughly O(N^order) memory and solving the subproblem is very expensive.
        - "none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.

    Args:

        order (int, optional):
            Order of the method, number of taylor series terms (orders of derivatives) used to approximate the function. Defaults to 4.
        trust_method (str | None, optional):
            Method used for trust region.
            - "bounds" - the model is minimized within bounds defined by trust region.
            - "proximal" - the model is minimized with penalty for going too far from current point.
            - "none" - disables trust region.

            Defaults to 'bounds'.
        increase (float, optional): trust region multiplier on good steps. Defaults to 1.5.
        decrease (float, optional): trust region multiplier on bad steps. Defaults to 0.75.
        trust_init (float | None, optional):
            initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on ``"proximal"``. Defaults to None.
        trust_tol (float, optional):
            Maximum ratio of expected loss reduction to actual reduction for trust region increase.
            Should 1 or higer. Defaults to 2.
        de_iters (int | None, optional):
            If this is specified, the model is minimized via differential evolution first to possibly escape local minima,
            then it is passed to scipy.optimize.minimize. Defaults to None.
        vectorize (bool, optional): whether to enable vectorized jacobians (usually faster). Defaults to True.
    """
    def __init__(
        self,
        order: int = 4,
        trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        init: float | None = None,
        eta: float = 1e-6,
        max_attempts = 10,
        boundary_tol: float = 1e-2,
        de_iters: int | None = None,
        vectorize: bool = True,
    ):
        if init is None:
            if trust_method == 'bounds': init = 1
            else: init = 0.1

        defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts, boundary_tol=boundary_tol, rho_good=rho_good, rho_bad=rho_bad)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        params = TensorList(var.params)
        closure = var.closure
        if closure is None: raise RuntimeError('HigherOrderNewton requires closure')

        settings = self.settings[params[0]]
        order = settings['order']
        nplus = settings['nplus']
        nminus = settings['nminus']
        eta = settings['eta']
        init = settings['init']
        trust_method = settings['trust_method']
        de_iters = settings['de_iters']
        max_attempts = settings['max_attempts']
        vectorize = settings['vectorize']
        boundary_tol = settings['boundary_tol']
        rho_good = settings['rho_good']
        rho_bad = settings['rho_bad']

        # ------------------------ calculate grad and hessian ------------------------ #
        with torch.enable_grad():
            loss = var.loss = var.loss_approx = closure(False)

            g_list = torch.autograd.grad(loss, params, create_graph=True)
            var.grad = list(g_list)

            g = torch.cat([t.ravel() for t in g_list])
            n = g.numel()
            derivatives = [g]
            T = g # current derivatives tensor

            # get all derivative up to order
            for o in range(2, order + 1):
                is_last = o == order
                T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
                with torch.no_grad() if is_last else nullcontext():
                    # the shape is (ndim, ) * order
                    T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
                    derivatives.append(T)

        x0 = torch.cat([p.ravel() for p in params])

        success = False
        x_star = None
        while not success:
            max_attempts -= 1
            if max_attempts < 0: break

            # load trust region value
            trust_value = self.global_state.get('trust_region', init)

            # make sure its not too small or too large
            finfo = torch.finfo(x0.dtype)
            if trust_value < finfo.tiny*2 or trust_value > finfo.max / (2*nplus):
                trust_value = self.global_state['trust_region'] = settings['init']

            # determine tr and prox values
            if trust_method is None: trust_method = 'none'
            else: trust_method = trust_method.lower()

            if trust_method == 'none':
                trust_region = None
                prox = 0

            elif trust_method == 'bounds':
                trust_region = trust_value
                prox = 0

            elif trust_method == 'proximal':
                trust_region = None
                prox = 1 / trust_value

            else:
                raise ValueError(trust_method)

            # minimize the model
            x_star, expected_loss = _poly_minimize(
                trust_region=trust_region,
                prox=prox,
                de_iters=de_iters,
                c=loss.item(),
                x=x0,
                derivatives=derivatives,
            )

            # update trust region
            if trust_method == 'none':
                success = True
            else:
                pred_reduction = loss - expected_loss

                vec_to_tensors_(x_star, params)
                loss_star = closure(False)
                vec_to_tensors_(x0, params)
                reduction = loss - loss_star

                rho = reduction / (max(pred_reduction, 1e-8))
                # failed step
                if rho < rho_bad:
                    self.global_state['trust_region'] = trust_value * nminus

                # very good step
                elif rho > rho_good:
                    step = (x_star - x0)
                    magn = torch.linalg.vector_norm(step) # pylint:disable=not-callable
                    if trust_method == 'proximal' or (trust_value - magn) / trust_value <= boundary_tol:
                        # close to boundary
                        self.global_state['trust_region'] = trust_value * nplus

                # if the ratio is high enough then accept the proposed step
                success = rho > eta

        assert x_star is not None
        if success:
            difference = vec_to_tensors(x0 - x_star, params)
            var.update = list(difference)
        else:
            var.update = params.zeros_like()
        return var

Horisho

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Horisho's variable metric Quasi-Newton method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class Horisho(_InverseHessianUpdateStrategyDefaults):
    """
    Horisho's variable metric Quasi-Newton method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
    """

    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return hoshino_H_(H=H, s=s, y=y, tol=setting['tol'])

HpuEstimate

Bases: torchzero.core.transform.Transform

returns y/||s||, where y is difference between current and previous update (gradient), s is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update.

Source code in torchzero/modules/misc/misc.py
class HpuEstimate(Transform):
    """returns ``y/||s||``, where ``y`` is difference between current and previous update (gradient), ``s`` is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update."""
    def __init__(self):
        defaults = dict()
        super().__init__(defaults, uses_grad=False)

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('prev_params', 'prev_update')

    @torch.no_grad
    def update_tensors(self, tensors, params, grads, loss, states, settings):
        prev_params, prev_update = self.get_state(params, 'prev_params', 'prev_update') # initialized to 0
        s = torch._foreach_sub(params, prev_params)
        y = torch._foreach_sub(tensors, prev_update)
        for p, c in zip(prev_params, params): p.copy_(c)
        for p, c in zip(prev_update, tensors): p.copy_(c)
        torch._foreach_div_(y, torch.linalg.norm(torch.cat([t.ravel() for t in s])).clip(min=1e-8)) # pylint:disable=not-callable
        self.store(params, 'y', y)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        return [self.state[p]['y'] for p in params]

ICUM

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Inverse Column-updating Quasi-Newton method. This is computationally cheaper than other Quasi-Newton methods due to only updating one column of the inverse hessian approximation per step.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class ICUM(_InverseHessianUpdateStrategyDefaults):
    """
    Inverse Column-updating Quasi-Newton method. This is computationally cheaper than other Quasi-Newton methods
    due to only updating one column of the inverse hessian approximation per step.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return icum_H_(H=H, s=s, y=y)

Identity

Bases: torchzero.core.module.Module

Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

Source code in torchzero/modules/ops/utility.py
class Identity(Module):
    """Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
    def __init__(self, *args, **kwargs): super().__init__()
    def step(self, var): return var
    def get_H(self, var):
        n = sum(p.numel() for p in var.params)
        p = var.params[0]
        return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)

IntermoduleCautious

Bases: torchzero.core.module.Module

Negaties update on :code:main module where it's sign doesn't match with output of :code:compare module.

Parameters:

  • main (Chainable) –

    main module or sequence of modules whose update will be cautioned.

  • compare (Chainable) –

    modules or sequence of modules to compare the sign to.

  • normalize (bool, default: False ) –

    renormalize update after masking. Defaults to False.

  • eps (float, default: 1e-06 ) –

    epsilon for normalization. Defaults to 1e-6.

  • mode (str, default: 'zero' ) –

    what to do with updates with inconsistent signs. - "zero" - set them to zero (as in paper) - "grad" - set them to the gradient (same as using update magnitude and gradient sign) - "backtrack" - negate them

Source code in torchzero/modules/momentum/cautious.py
class IntermoduleCautious(Module):
    """Negaties update on :code:`main` module where it's sign doesn't match with output of :code:`compare` module.

    Args:
        main (Chainable): main module or sequence of modules whose update will be cautioned.
        compare (Chainable): modules or sequence of modules to compare the sign to.
        normalize (bool, optional):
            renormalize update after masking. Defaults to False.
        eps (float, optional): epsilon for normalization. Defaults to 1e-6.
        mode (str, optional):
            what to do with updates with inconsistent signs.
            - "zero" - set them to zero (as in paper)
            - "grad" - set them to the gradient (same as using update magnitude and gradient sign)
            - "backtrack" - negate them
    """
    def __init__(
        self,
        main: Chainable,
        compare: Chainable,
        normalize=False,
        eps=1e-6,
        mode: Literal["zero", "grad", "backtrack"] = "zero",
    ):

        defaults = dict(normalize=normalize, eps=eps, mode=mode)
        super().__init__(defaults)

        self.set_child('main', main)
        self.set_child('compare', compare)

    @torch.no_grad
    def step(self, var):
        main = self.children['main']
        compare = self.children['compare']

        main_var = main.step(var.clone(clone_update=True))
        var.update_attrs_from_clone_(main_var)

        compare_var = compare.step(var.clone(clone_update=True))
        var.update_attrs_from_clone_(compare_var)

        mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.defaults)
        var.update = cautious_(
            TensorList(main_var.get_update()),
            TensorList(compare_var.get_update()),
            normalize=normalize,
            mode=mode,
            eps=eps,
        )

        return var

InverseFreeNewton

Bases: torchzero.core.module.Module

Inverse-free newton's method

.. note:: In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:inner argument if you wish to apply Newton preconditioning to another module's output.

.. note:: This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating the hessian. The closure must accept a backward argument (refer to documentation).

.. warning:: this uses roughly O(N^2) memory.

Reference Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.

Source code in torchzero/modules/second_order/newton.py
class InverseFreeNewton(Module):
    """Inverse-free newton's method

    .. note::
        In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.

    .. note::
        This module requires the a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients for calculating the hessian.
        The closure must accept a ``backward`` argument (refer to documentation).

    .. warning::
        this uses roughly O(N^2) memory.

    Reference
        Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.
    """
    def __init__(
        self,
        update_freq: int = 1,
        hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
        vectorize: bool = True,
        inner: Chainable | None = None,
    ):
        defaults = dict(hessian_method=hessian_method, vectorize=vectorize, update_freq=update_freq)
        super().__init__(defaults)

        if inner is not None:
            self.set_child('inner', inner)

    @torch.no_grad
    def update(self, var):
        params = TensorList(var.params)
        closure = var.closure
        if closure is None: raise RuntimeError('NewtonCG requires closure')

        settings = self.settings[params[0]]
        hessian_method = settings['hessian_method']
        vectorize = settings['vectorize']
        update_freq = settings['update_freq']

        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        g_list = var.grad
        Y = None
        if step % update_freq == 0:
            # ------------------------ calculate grad and hessian ------------------------ #
            if hessian_method == 'autograd':
                with torch.enable_grad():
                    loss = var.loss = var.loss_approx = closure(False)
                    g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
                    g_list = [t[0] for t in g_list] # remove leading dim from loss
                    var.grad = g_list
                    H = flatten_jacobian(H_list)

            elif hessian_method in ('func', 'autograd.functional'):
                strat = 'forward-mode' if vectorize else 'reverse-mode'
                with torch.enable_grad():
                    g_list = var.get_grad(retain_graph=True)
                    H = hessian_mat(partial(closure, backward=False), params,
                                    method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]

            else:
                raise ValueError(hessian_method)

            self.global_state["H"] = H

            # inverse free part
            if 'Y' not in self.global_state:
                num = H.T
                denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable
                finfo = torch.finfo(H.dtype)
                Y = self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))

            else:
                Y = self.global_state['Y']
                I = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
                I -= H @ Y
                Y = self.global_state['Y'] = Y @ I


    def apply(self, var):
        Y = self.global_state["Y"]
        params = var.params

        # -------------------------------- inner step -------------------------------- #
        update = var.get_update()
        if 'inner' in self.children:
            update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)

        g = torch.cat([t.ravel() for t in update])

        # ----------------------------------- solve ---------------------------------- #
        var.update = vec_to_tensors(Y@g, params)

        return var

    def get_H(self,var):
        return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])

LBFGS

Bases: torchzero.core.transform.Transform

Limited-memory BFGS algorithm. A line search or trust region is recommended.

Parameters:

  • history_size (int, default: 10 ) –

    number of past parameter differences and gradient differences to store. Defaults to 10.

  • ptol (float | None, default: 1e-32 ) –

    skips updating the history if maximum absolute value of parameter difference is less than this value. Defaults to 1e-10.

  • ptol_restart (bool, default: False ) –

    If true, whenever parameter difference is less then ptol, L-BFGS state will be reset. Defaults to None.

  • gtol (float | None, default: 1e-32 ) –

    skips updating the history if if maximum absolute value of gradient difference is less than this value. Defaults to 1e-10.

  • ptol_restart (bool, default: False ) –

    If true, whenever gradient difference is less then gtol, L-BFGS state will be reset. Defaults to None.

  • sy_tol (float | None, default: 1e-32 ) –

    history will not be updated whenever s⋅y is less than this value (negative s⋅y means negative curvature)

  • scale_first (bool, default: True ) –

    makes first step, when hessian approximation is not available, small to reduce number of line search iterations. Defaults to True.

  • update_freq (int, default: 1 ) –

    how often to update L-BFGS history. Larger values may be better for stochastic optimization. Defaults to 1.

  • damping (Union, default: None ) –

    damping to use, can be "powell" or "double". Defaults to None.

  • inner (Chainable | None, default: None ) –

    optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.

Examples:

L-BFGS with line search

opt = tz.Modular(
    model.parameters(),
    tz.m.LBFGS(100),
    tz.m.Backtracking()
)

L-BFGS with trust region

opt = tz.Modular(
    model.parameters(),
    tz.m.TrustCG(tz.m.LBFGS())
)

Source code in torchzero/modules/quasi_newton/lbfgs.py
class LBFGS(Transform):
    """Limited-memory BFGS algorithm. A line search or trust region is recommended.

    Args:
        history_size (int, optional):
            number of past parameter differences and gradient differences to store. Defaults to 10.
        ptol (float | None, optional):
            skips updating the history if maximum absolute value of
            parameter difference is less than this value. Defaults to 1e-10.
        ptol_restart (bool, optional):
            If true, whenever parameter difference is less then ``ptol``,
            L-BFGS state will be reset. Defaults to None.
        gtol (float | None, optional):
            skips updating the history if if maximum absolute value of
            gradient difference is less than this value. Defaults to 1e-10.
        ptol_restart (bool, optional):
            If true, whenever gradient difference is less then ``gtol``,
            L-BFGS state will be reset. Defaults to None.
        sy_tol (float | None, optional):
            history will not be updated whenever s⋅y is less than this value (negative s⋅y means negative curvature)
        scale_first (bool, optional):
            makes first step, when hessian approximation is not available,
            small to reduce number of line search iterations. Defaults to True.
        update_freq (int, optional):
            how often to update L-BFGS history. Larger values may be better for stochastic optimization. Defaults to 1.
        damping (DampingStrategyType, optional):
            damping to use, can be "powell" or "double". Defaults to None.
        inner (Chainable | None, optional):
            optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.

    ## Examples:

    L-BFGS with line search
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.LBFGS(100),
        tz.m.Backtracking()
    )
    ```

    L-BFGS with trust region
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.TrustCG(tz.m.LBFGS())
    )
    ```
    """
    def __init__(
        self,
        history_size=10,
        ptol: float | None = 1e-32,
        ptol_restart: bool = False,
        gtol: float | None = 1e-32,
        gtol_restart: bool = False,
        sy_tol: float = 1e-32,
        scale_first:bool=True,
        update_freq = 1,
        damping: DampingStrategyType = None,
        inner: Chainable | None = None,
    ):
        defaults = dict(
            history_size=history_size,
            scale_first=scale_first,
            ptol=ptol,
            gtol=gtol,
            ptol_restart=ptol_restart,
            gtol_restart=gtol_restart,
            sy_tol=sy_tol,
            damping = damping,
        )
        super().__init__(defaults, uses_grad=False, inner=inner, update_freq=update_freq)

        self.global_state['s_history'] = deque(maxlen=history_size)
        self.global_state['y_history'] = deque(maxlen=history_size)
        self.global_state['sy_history'] = deque(maxlen=history_size)

    def _reset_self(self):
        self.state.clear()
        self.global_state['step'] = 0
        self.global_state['s_history'].clear()
        self.global_state['y_history'].clear()
        self.global_state['sy_history'].clear()

    def reset(self):
        self._reset_self()
        for c in self.children.values(): c.reset()

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('p_prev', 'g_prev')
        self.global_state.pop('step', None)

    @torch.no_grad
    def update_tensors(self, tensors, params, grads, loss, states, settings):
        p = as_tensorlist(params)
        g = as_tensorlist(tensors)
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        # history of s and k
        s_history: deque[TensorList] = self.global_state['s_history']
        y_history: deque[TensorList] = self.global_state['y_history']
        sy_history: deque[torch.Tensor] = self.global_state['sy_history']

        ptol = self.defaults['ptol']
        gtol = self.defaults['gtol']
        ptol_restart = self.defaults['ptol_restart']
        gtol_restart = self.defaults['gtol_restart']
        sy_tol = self.defaults['sy_tol']
        damping = self.defaults['damping']

        p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)

        # 1st step - there are no previous params and grads, lbfgs will do normalized SGD step
        if step == 0:
            s = None; y = None; sy = None
        else:
            s = p - p_prev
            y = g - g_prev

            if damping is not None:
                s, y = apply_damping(damping, s=s, y=y, g=g, H=self.get_H())

            sy = s.dot(y)
            # damping to be added here

        below_tol = False
        # tolerance on parameter difference to avoid exploding after converging
        if ptol is not None:
            if s is not None and s.abs().global_max() <= ptol:
                if ptol_restart:
                    self._reset_self()
                sy = None
                below_tol = True

        # tolerance on gradient difference to avoid exploding when there is no curvature
        if gtol is not None:
            if y is not None and y.abs().global_max() <= gtol:
                if gtol_restart: self._reset_self()
                sy = None
                below_tol = True

        # store previous params and grads
        if not below_tol:
            p_prev.copy_(p)
            g_prev.copy_(g)

        # update effective preconditioning state
        if sy is not None and sy > sy_tol:
            assert s is not None and y is not None and sy is not None

            s_history.append(s)
            y_history.append(y)
            sy_history.append(sy)

    def get_H(self, var=...):
        s_history = [tl.to_vec() for tl in self.global_state['s_history']]
        y_history = [tl.to_vec() for tl in self.global_state['y_history']]
        sy_history = self.global_state['sy_history']
        return LBFGSLinearOperator(s_history, y_history, sy_history)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        scale_first = self.defaults['scale_first']

        tensors = as_tensorlist(tensors)

        s_history = self.global_state['s_history']
        y_history = self.global_state['y_history']
        sy_history = self.global_state['sy_history']

        # precondition
        dir = lbfgs_Hx(
            x=tensors,
            s_history=s_history,
            y_history=y_history,
            sy_history=sy_history,
        )

        # scale 1st step
        if scale_first and self.global_state.get('step', 1) == 1:
            dir *= initial_step_size(dir, eps=1e-7)

        return dir

LMAdagrad

Bases: torchzero.core.transform.TensorwiseTransform

Limited-memory full matrix Adagrad.

The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg. But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.

This is equivalent to full-matrix Adagrad on recent gradients.

Parameters:

  • history_size (int, default: 100 ) –

    number of past gradients to store. Defaults to 10.

  • update_freq (int, default: 1 ) –

    frequency of updating the preconditioner (U and S). Defaults to 1.

  • damping (float, default: 0.0001 ) –

    damping value. Defaults to 1e-4.

  • rdamping (float, default: 0 ) –

    value of damping relative to singular values norm. Defaults to 0.

  • order (int, default: 1 ) –

    order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.

  • true_damping (bool, default: True ) –

    If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.

  • U_beta (float | None, default: None ) –

    momentum for U (too unstable, don't use). Defaults to None.

  • L_beta (float | None, default: None ) –

    momentum for L (too unstable, don't use). Defaults to None.

  • interval (int, default: 1 ) –

    Interval between gradients that are added to history (2 means every second gradient is used). Defaults to 1.

  • concat_params (bool, default: True ) –

    if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to True.

  • inner (Chainable | None, default: None ) –

    preconditioner will be applied to output of this module. Defaults to None.

Examples:

Limited-memory Adagrad

optimizer = tz.Modular(
    model.parameters(),
    tz.m.LMAdagrad(),
    tz.m.LR(0.1)
)
Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)

optimizer = tz.Modular(
    model.parameters(),
    tz.m.LMAdagrad(inner=tz.m.EMA()),
    tz.m.Debias(0.9, 0.999),
    tz.m.LR(0.01)
)

Stable Adam with L-Adagrad preconditioner (this is what I would recommend)

optimizer = tz.Modular(
    model.parameters(),
    tz.m.LMAdagrad(inner=tz.m.EMA()),
    tz.m.Debias(0.9, 0.999),
    tz.m.ClipNormByEMA(max_ema_growth=1.2),
    tz.m.LR(0.01)
)
Reference: Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.

Source code in torchzero/modules/adaptive/lmadagrad.py
class LMAdagrad(TensorwiseTransform):
    """
    Limited-memory full matrix Adagrad.

    The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
    But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.

    This is equivalent to full-matrix Adagrad on recent gradients.

    Args:
        history_size (int, optional): number of past gradients to store. Defaults to 10.
        update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
        damping (float, optional): damping value. Defaults to 1e-4.
        rdamping (float, optional): value of damping relative to singular values norm. Defaults to 0.
        order (int, optional):
            order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.
        true_damping (bool, optional):
            If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.
        U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
        L_beta (float | None, optional): momentum for L (too unstable, don't use). Defaults to None.
        interval (int, optional): Interval between gradients that are added to history (2 means every second gradient is used). Defaults to 1.
        concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to True.
        inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.

    ## Examples:

    Limited-memory Adagrad

    ```python
    optimizer = tz.Modular(
        model.parameters(),
        tz.m.LMAdagrad(),
        tz.m.LR(0.1)
    )
    ```
    Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)

    ```python
    optimizer = tz.Modular(
        model.parameters(),
        tz.m.LMAdagrad(inner=tz.m.EMA()),
        tz.m.Debias(0.9, 0.999),
        tz.m.LR(0.01)
    )
    ```

    Stable Adam with L-Adagrad preconditioner (this is what I would recommend)

    ```python
    optimizer = tz.Modular(
        model.parameters(),
        tz.m.LMAdagrad(inner=tz.m.EMA()),
        tz.m.Debias(0.9, 0.999),
        tz.m.ClipNormByEMA(max_ema_growth=1.2),
        tz.m.LR(0.01)
    )
    ```
    Reference:
        Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
    """

    def __init__(
        self,
        history_size: int = 100,
        update_freq: int = 1,
        damping: float = 1e-4,
        rdamping: float = 0,
        order: int = 1,
        true_damping: bool = True,
        U_beta: float | None = None,
        L_beta: float | None = None,
        interval: int = 1,
        concat_params: bool = True,
        inner: Chainable | None = None,
    ):
        # history is still updated each step so Precondition's update_freq has different meaning
        defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta, L_beta=L_beta)
        super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner, update_freq=interval)

    @torch.no_grad
    def update_tensor(self, tensor, param, grad, loss, state, setting):
        order = setting['order']
        history_size = setting['history_size']
        update_freq = setting['update_freq']
        damping = setting['damping']
        rdamping = setting['rdamping']
        U_beta = setting['U_beta']
        L_beta = setting['L_beta']

        if 'history' not in state: state['history'] = deque(maxlen=history_size)
        history = state['history']

        if order == 1:
            t = tensor.clone().view(-1)
            history.append(t)
        else:

            # if order=2, history is of gradient differences, order 3 is differences between differences, etc
            # scaled by parameter differences
            cur_p = param.clone()
            cur_g = tensor.clone()
            eps = torch.finfo(cur_p.dtype).tiny * 2
            for i in range(1, order):
                if f'prev_g_{i}' not in state:
                    state[f'prev_p_{i}'] = cur_p
                    state[f'prev_g_{i}'] = cur_g
                    break

                s = cur_p - state[f'prev_p_{i}']
                y = cur_g - state[f'prev_g_{i}']
                state[f'prev_p_{i}'] = cur_p
                state[f'prev_g_{i}'] = cur_g
                cur_p = s
                cur_g = y

                if i == order - 1:
                    cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=eps) # pylint:disable=not-callable
                    history.append(cur_g.view(-1))

        step = state.get('step', 0)
        if step % update_freq == 0 and len(history) != 0:
            U, L = lm_adagrad_update(history, damping=damping, rdamping=rdamping)
            maybe_lerp_(state, U_beta, 'U', U)
            maybe_lerp_(state, L_beta, 'L', L)

        if len(history) != 0:
            state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)

    @torch.no_grad
    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        U = state.get('U', None)
        if U is None:
            # make a conservative step to avoid issues due to different GD scaling
            return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]

        L = state['L']
        update = lm_adagrad_apply(tensor.view(-1), U, L).view_as(tensor)

        return update

LR

Bases: torchzero.core.transform.Transform

Learning rate. Adding this module also adds support for LR schedulers.

Source code in torchzero/modules/step_size/lr.py
class LR(Transform):
    """Learning rate. Adding this module also adds support for LR schedulers."""
    def __init__(self, lr: float):
        defaults=dict(lr=lr)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)

LSR1

Bases: torchzero.core.transform.Transform

Limited-memory SR1 algorithm. A line search or trust region is recommended.

Parameters:

  • history_size (int, default: 10 ) –

    number of past parameter differences and gradient differences to store. Defaults to 10.

  • ptol (float | None, default: None ) –

    skips updating the history if maximum absolute value of parameter difference is less than this value. Defaults to None.

  • ptol_restart (bool, default: False ) –

    If true, whenever parameter difference is less then ptol, L-SR1 state will be reset. Defaults to None.

  • gtol (float | None, default: None ) –

    skips updating the history if if maximum absolute value of gradient difference is less than this value. Defaults to None.

  • ptol_restart (bool, default: False ) –

    If true, whenever gradient difference is less then gtol, L-SR1 state will be reset. Defaults to None.

  • scale_first (bool, default: False ) –

    makes first step, when hessian approximation is not available, small to reduce number of line search iterations. Defaults to False.

  • update_freq (int, default: 1 ) –

    how often to update L-SR1 history. Larger values may be better for stochastic optimization. Defaults to 1.

  • damping (Union, default: None ) –

    damping to use, can be "powell" or "double". Defaults to None.

  • compact (bool) –

    if True, uses a compact representation verstion of L-SR1. It is much faster computationally, but less stable.

  • inner (Chainable | None, default: None ) –

    optional inner modules applied after updating L-SR1 history and before preconditioning. Defaults to None.

Examples:

L-SR1 with line search

opt = tz.Modular(
    model.parameters(),
    tz.m.SR1(),
    tz.m.StrongWolfe(c2=0.1, fallback=True)
)

L-SR1 with trust region

opt = tz.Modular(
    model.parameters(),
    tz.m.TrustCG(tz.m.LSR1())
)

Source code in torchzero/modules/quasi_newton/lsr1.py
class LSR1(Transform):
    """Limited-memory SR1 algorithm. A line search or trust region is recommended.

    Args:
        history_size (int, optional):
            number of past parameter differences and gradient differences to store. Defaults to 10.
        ptol (float | None, optional):
            skips updating the history if maximum absolute value of
            parameter difference is less than this value. Defaults to None.
        ptol_restart (bool, optional):
            If true, whenever parameter difference is less then ``ptol``,
            L-SR1 state will be reset. Defaults to None.
        gtol (float | None, optional):
            skips updating the history if if maximum absolute value of
            gradient difference is less than this value. Defaults to None.
        ptol_restart (bool, optional):
            If true, whenever gradient difference is less then ``gtol``,
            L-SR1 state will be reset. Defaults to None.
        scale_first (bool, optional):
            makes first step, when hessian approximation is not available,
            small to reduce number of line search iterations. Defaults to False.
        update_freq (int, optional):
            how often to update L-SR1 history. Larger values may be better for stochastic optimization. Defaults to 1.
        damping (DampingStrategyType, optional):
            damping to use, can be "powell" or "double". Defaults to None.
        compact (bool, optional):
            if True, uses a compact representation verstion of L-SR1. It is much faster computationally, but less stable.
        inner (Chainable | None, optional):
            optional inner modules applied after updating L-SR1 history and before preconditioning. Defaults to None.

    ## Examples:

    L-SR1 with line search
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.SR1(),
        tz.m.StrongWolfe(c2=0.1, fallback=True)
    )
    ```

    L-SR1 with trust region
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.TrustCG(tz.m.LSR1())
    )
    ```
    """
    def __init__(
        self,
        history_size=10,
        ptol: float | None = None,
        ptol_restart: bool = False,
        gtol: float | None = None,
        gtol_restart: bool = False,
        scale_first:bool=False,
        update_freq = 1,
        damping: DampingStrategyType = None,
        inner: Chainable | None = None,
    ):
        defaults = dict(
            history_size=history_size,
            scale_first=scale_first,
            ptol=ptol,
            gtol=gtol,
            ptol_restart=ptol_restart,
            gtol_restart=gtol_restart,
            damping = damping,
        )
        super().__init__(defaults, uses_grad=False, inner=inner, update_freq=update_freq)

        self.global_state['s_history'] = deque(maxlen=history_size)
        self.global_state['y_history'] = deque(maxlen=history_size)

    def _reset_self(self):
        self.state.clear()
        self.global_state['step'] = 0
        self.global_state['s_history'].clear()
        self.global_state['y_history'].clear()

    def reset(self):
        self._reset_self()
        for c in self.children.values(): c.reset()

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('p_prev', 'g_prev')
        self.global_state.pop('step', None)

    @torch.no_grad
    def update_tensors(self, tensors, params, grads, loss, states, settings):
        p = as_tensorlist(params)
        g = as_tensorlist(tensors)
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        # history of s and k
        s_history: deque = self.global_state['s_history']
        y_history: deque = self.global_state['y_history']

        ptol = self.defaults['ptol']
        gtol = self.defaults['gtol']
        ptol_restart = self.defaults['ptol_restart']
        gtol_restart = self.defaults['gtol_restart']
        damping = self.defaults['damping']

        p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)

        # 1st step - there are no previous params and grads, lsr1 will do normalized SGD step
        if step == 0:
            s = None; y = None; sy = None
        else:
            s = p - p_prev
            y = g - g_prev

            if damping is not None:
                s, y = apply_damping(damping, s=s, y=y, g=g, H=self.get_H())

            sy = s.dot(y)
            # damping to be added here

        below_tol = False
        # tolerance on parameter difference to avoid exploding after converging
        if ptol is not None:
            if s is not None and s.abs().global_max() <= ptol:
                if ptol_restart: self._reset_self()
                sy = None
                below_tol = True

        # tolerance on gradient difference to avoid exploding when there is no curvature
        if gtol is not None:
            if y is not None and y.abs().global_max() <= gtol:
                if gtol_restart: self._reset_self()
                sy = None
                below_tol = True

        # store previous params and grads
        if not below_tol:
            p_prev.copy_(p)
            g_prev.copy_(g)

        # update effective preconditioning state
        if sy is not None:
            assert s is not None and y is not None and sy is not None

            s_history.append(s)
            y_history.append(y)

    def get_H(self, var=...):
        s_history = [tl.to_vec() for tl in self.global_state['s_history']]
        y_history = [tl.to_vec() for tl in self.global_state['y_history']]
        return LSR1LinearOperator(s_history, y_history)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        scale_first = self.defaults['scale_first']

        tensors = as_tensorlist(tensors)

        s_history = self.global_state['s_history']
        y_history = self.global_state['y_history']

        # precondition
        dir = lsr1_Hx(
            x=tensors,
            s_history=s_history,
            y_history=y_history,
        )

        # scale 1st step
        if scale_first and self.global_state.get('step', 1) == 1:
            dir *= initial_step_size(dir, eps=1e-7)

        return dir

LambdaHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class LambdaHomotopy(HomotopyBase):
    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        defaults = dict(fn=fn)
        super().__init__(defaults)

    def loss_transform(self, loss): return self.defaults['fn'](loss)

LaplacianSmoothing

Bases: torchzero.core.transform.Transform

Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.

Parameters:

  • sigma (float, default: 1 ) –

    controls the amount of smoothing. Defaults to 1.

  • layerwise (bool, default: True ) –

    If True, applies smoothing to each parameter's gradient separately, Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.

  • min_numel (int, default: 4 ) –

    minimum number of elements in a parameter to apply laplacian smoothing to. Only has effect if layerwise is True. Defaults to 4.

  • target (str, default: 'update' ) –

    what to set on var.

Examples:

Laplacian Smoothing Gradient Descent optimizer as in the paper

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.LaplacianSmoothing(),
    tz.m.LR(1e-2),
)
Reference

Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.

Source code in torchzero/modules/smoothing/laplacian.py
class LaplacianSmoothing(Transform):
    """Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.

    Args:
        sigma (float, optional): controls the amount of smoothing. Defaults to 1.
        layerwise (bool, optional):
            If True, applies smoothing to each parameter's gradient separately,
            Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
        min_numel (int, optional):
            minimum number of elements in a parameter to apply laplacian smoothing to.
            Only has effect if `layerwise` is True. Defaults to 4.
        target (str, optional):
            what to set on var.

    Examples:
        Laplacian Smoothing Gradient Descent optimizer as in the paper

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.LaplacianSmoothing(),
                tz.m.LR(1e-2),
            )

    Reference:
        Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.

    """
    def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4, target: Target = 'update'):
        defaults = dict(sigma = sigma, layerwise=layerwise, min_numel=min_numel)
        super().__init__(defaults, uses_grad=False, target=target)
        # precomputed denominator for when layerwise=False
        self.global_state['full_denominator'] = None


    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        layerwise = settings[0]['layerwise']

        # layerwise laplacian smoothing
        if layerwise:

            # precompute the denominator for each layer and store it in each parameters state
            smoothed_target = TensorList()
            for p, t, state, setting in zip(params, tensors, states, settings):
                if p.numel() > setting['min_numel']:
                    if 'denominator' not in state: state['denominator'] = _precompute_denominator(p, setting['sigma'])
                    smoothed_target.append(torch.fft.ifft(torch.fft.fft(t.view(-1)) / state['denominator']).real.view_as(t)) #pylint:disable=not-callable
                else:
                    smoothed_target.append(t)

            return smoothed_target

        # else
        # full laplacian smoothing
        # precompute full denominator
        tensors = TensorList(tensors)
        if self.global_state.get('full_denominator', None) is None:
            self.global_state['full_denominator'] = _precompute_denominator(tensors.to_vec(), settings[0]['sigma'])

        # apply the smoothing
        vec = tensors.to_vec()
        return tensors.from_vec(torch.fft.ifft(torch.fft.fft(vec) / self.global_state['full_denominator']).real)#pylint:disable=not-callable

LastAbsoluteRatio

Bases: torchzero.core.transform.Transform

Outputs ratio between absolute values of past two updates the numerator is determined by :code:numerator argument.

Source code in torchzero/modules/misc/misc.py
class LastAbsoluteRatio(Transform):
    """Outputs ratio between absolute values of past two updates the numerator is determined by :code:`numerator` argument."""
    def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
        defaults = dict(numerator=numerator, eps=eps)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
        numerator = settings[0]['numerator']
        eps = NumberList(s['eps'] for s in settings)

        torch._foreach_abs_(tensors)
        torch._foreach_clamp_min_(prev, eps)

        if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
        else: ratio = torch._foreach_div(prev, tensors)
        for p, c in zip(prev, tensors): p.set_(c)
        return ratio

LastDifference

Bases: torchzero.core.transform.Transform

Outputs difference between past two updates.

Source code in torchzero/modules/misc/misc.py
class LastDifference(Transform):
    """Outputs difference between past two updates."""
    def __init__(self,target: Target = 'update'):
        super().__init__({}, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        prev_tensors = unpack_states(states, tensors, 'prev_tensors') # initialized to 0
        difference = torch._foreach_sub(tensors, prev_tensors)
        for p, c in zip(prev_tensors, tensors): p.set_(c)
        return difference

LastGradDifference

Bases: torchzero.core.module.Module

Outputs difference between past two gradients.

Source code in torchzero/modules/misc/misc.py
class LastGradDifference(Module):
    """Outputs difference between past two gradients."""
    def __init__(self):
        super().__init__({})

    @torch.no_grad
    def step(self, var):
        grad = var.get_grad()
        prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
        difference = torch._foreach_sub(grad, prev_grad)
        for p, c in zip(prev_grad, grad): p.copy_(c)
        var.update = list(difference)
        return var

LastProduct

Bases: torchzero.core.transform.Transform

Outputs difference between past two updates.

Source code in torchzero/modules/misc/misc.py
class LastProduct(Transform):
    """Outputs difference between past two updates."""
    def __init__(self,target: Target = 'update'):
        super().__init__({}, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
        prod = torch._foreach_mul(tensors, prev)
        for p, c in zip(prev, tensors): p.set_(c)
        return prod

LastRatio

Bases: torchzero.core.transform.Transform

Outputs ratio between past two updates, the numerator is determined by :code:numerator argument.

Source code in torchzero/modules/misc/misc.py
class LastRatio(Transform):
    """Outputs ratio between past two updates, the numerator is determined by :code:`numerator` argument."""
    def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
        defaults = dict(numerator=numerator)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
        numerator = settings[0]['numerator']
        if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
        else: ratio = torch._foreach_div(prev, tensors)
        for p, c in zip(prev, tensors): p.set_(c)
        return ratio

LerpModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Does a linear interpolation of :code:input(tensors) and :code:end(tensors) based on a scalar :code:weight.

The output is given by :code:output = input(tensors) + weight * (end(tensors) - input(tensors))

Source code in torchzero/modules/ops/multi.py
class LerpModules(MultiOperationBase):
    """Does a linear interpolation of :code:`input(tensors)` and :code:`end(tensors)` based on a scalar :code:`weight`.

    The output is given by :code:`output = input(tensors) + weight * (end(tensors) - input(tensors))`
    """
    def __init__(self, input: Chainable, end: Chainable, weight: float):
        defaults = dict(weight=weight)
        super().__init__(defaults, input=input, end=end)

    @torch.no_grad
    def transform(self, var: Var, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
        torch._foreach_lerp_(input, end, weight=self.defaults['weight'])
        return input

LevenbergMarquardt

Bases: torchzero.modules.trust_region.trust_region.TrustRegionBase

Levenberg-Marquardt trust region algorithm.

Parameters:

  • hess_module (Module | None) –

    A module that maintains a hessian approximation (not hessian inverse!). This includes all full-matrix quasi-newton methods, tz.m.Newton and tz.m.GaussNewton. When using quasi-newton methods, set inverse=False when constructing them.

  • y (float, default: 0 ) –

    when y=0, identity matrix is added to hessian, when y=1, diagonal of the hessian approximation is added. Values between interpolate. This should only be used with Gauss-Newton. Defaults to 0.

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. When hess_module is Newton or GaussNewton, this can be set to 0. Defaults to 0.15.

  • nplus (float, default: 3.5 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.99 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.0001 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • max_attempts (max_attempts, default: 10 ) –

    maximum number of trust region size size reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • fallback (bool, default: False ) –

    if True, when hess_module maintains hessian inverse which can't be inverted efficiently, it will be inverted anyway. When False (default), a RuntimeError will be raised instead.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to output of thise module. Defaults to None.

Examples:

Gauss-Newton with Levenberg-Marquardt trust-region

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
)

LM-SR1

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
)

First order trust region (hessian is assumed to be identity)

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.Identity()),
)
Source code in torchzero/modules/trust_region/levenberg_marquardt.py
class LevenbergMarquardt(TrustRegionBase):
    """Levenberg-Marquardt trust region algorithm.


    Args:
        hess_module (Module | None, optional):
            A module that maintains a hessian approximation (not hessian inverse!).
            This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
            When using quasi-newton methods, set ``inverse=False`` when constructing them.
        y (float, optional):
            when ``y=0``, identity matrix is added to hessian, when ``y=1``, diagonal of the hessian approximation
            is added. Values between interpolate. This should only be used with Gauss-Newton. Defaults to 0.
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted.
            When ``hess_module`` is ``Newton`` or ``GaussNewton``, this can be set to 0. Defaults to 0.15.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        max_attempts (max_attempts, optional):
            maximum number of trust region size size reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        fallback (bool, optional):
            if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
            be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
        inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.

    Examples:
        Gauss-Newton with Levenberg-Marquardt trust-region

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
            )

        LM-SR1

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
            )

        First order trust region (hessian is assumed to be identity)

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.LevenbergMarquardt(tz.m.Identity()),
            )

    """
    def __init__(
        self,
        hess_module: Chainable,
        eta: float= 0.0,
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        init: float = 1,
        max_attempts: int = 10,
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
        y: float = 0,
        fallback: bool = False,
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict(y=y, fallback=fallback)
        super().__init__(
            defaults=defaults,
            hess_module=hess_module,
            eta=eta,
            nplus=nplus,
            nminus=nminus,
            rho_good=rho_good,
            rho_bad=rho_bad,
            init=init,
            max_attempts=max_attempts,
            radius_strategy=radius_strategy,
            update_freq=update_freq,
            inner=inner,

            boundary_tol=None,
            radius_fn=None,
        )

    def trust_solve(self, f, g, H, radius, params, closure, settings):
        y = settings['y']

        if isinstance(H, linear_operator.DenseInverse):
            if settings['fallback']:
                H = H.to_dense()
            else:
                raise RuntimeError(
                    f"{self.children['hess_module']} maintains a hessian inverse. "
                    "LevenbergMarquardt requires the hessian, not the inverse. "
                    "If that module is a quasi-newton module, pass `inverse=False` on initialization. "
                    "Or pass `fallback=True` to LevenbergMarquardt to allow inverting the hessian inverse, "
                    "however that can be inefficient and unstable."
                )

        reg = 1/radius
        if y == 0:
            return H.add_diagonal(reg).solve(g)

        diag = H.diagonal()
        diag = torch.where(diag < torch.finfo(diag.dtype).tiny * 2, 1, diag)
        if y != 1: diag = (diag*y) + (1-y)
        return H.add_diagonal(diag*reg).solve(g)

LineSearchBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for line searches.

This is an abstract class, to use it, subclass it and override search.

Parameters:

  • defaults (dict[str, Any] | None) –

    dictionary with defaults.

  • maxiter (int | None, default: None ) –

    if this is specified, the search method will terminate upon evaluating the objective this many times, and step size with the lowest loss value will be used. This is useful when passing make_objective to an external library which doesn't have a maxiter option. Defaults to None.

Other useful methods
  • evaluate_f - returns loss with a given scalar step size
  • evaluate_f_d - returns loss and directional derivative with a given scalar step size
  • make_objective - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
  • make_objective_with_derivative - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.

Examples:

This evaluates all step sizes in a range by using the :code:self.evaluate_step_size method.

class GridLineSearch(LineSearch):
    def __init__(self, start, end, num):
        defaults = dict(start=start,end=end,num=num)
        super().__init__(defaults)

    @torch.no_grad
    def search(self, update, var):

        start = self.defaults["start"]
        end = self.defaults["end"]
        num = self.defaults["num"]

        lowest_loss = float("inf")
        best_step_size = best_step_size

        for step_size in torch.linspace(start,end,num):
            loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
            if loss < lowest_loss:
                lowest_loss = loss
                best_step_size = step_size

        return best_step_size

Using external solver via self.make_objective

Here we let :code:scipy.optimize.minimize_scalar solver find the best step size via :code:self.make_objective

class ScipyMinimizeScalar(LineSearch):
    def __init__(self, method: str | None = None):
        defaults = dict(method=method)
        super().__init__(defaults)

    @torch.no_grad
    def search(self, update, var):
        objective = self.make_objective(var=var)
        method = self.defaults["method"]

        res = self.scopt.minimize_scalar(objective, method=method)
        return res.x

Methods:

  • evaluate_f

    evaluate function value at alpha step_size.

  • evaluate_f_d

    evaluate function value and directional derivative in the direction of the update at step size step_size.

  • evaluate_f_d_g

    evaluate function value, directional derivative, and gradient list at step size step_size.

  • search

    Finds the step size to use

Source code in torchzero/modules/line_search/line_search.py
class LineSearchBase(Module, ABC):
    """Base class for line searches.

    This is an abstract class, to use it, subclass it and override `search`.

    Args:
        defaults (dict[str, Any] | None): dictionary with defaults.
        maxiter (int | None, optional):
            if this is specified, the search method will terminate upon evaluating
            the objective this many times, and step size with the lowest loss value will be used.
            This is useful when passing `make_objective` to an external library which
            doesn't have a maxiter option. Defaults to None.

    Other useful methods:
        * ``evaluate_f`` - returns loss with a given scalar step size
        * ``evaluate_f_d`` - returns loss and directional derivative with a given scalar step size
        * ``make_objective`` - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
        * ``make_objective_with_derivative`` - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.

    Examples:

    #### Basic line search

    This evaluates all step sizes in a range by using the :code:`self.evaluate_step_size` method.
    ```python
    class GridLineSearch(LineSearch):
        def __init__(self, start, end, num):
            defaults = dict(start=start,end=end,num=num)
            super().__init__(defaults)

        @torch.no_grad
        def search(self, update, var):

            start = self.defaults["start"]
            end = self.defaults["end"]
            num = self.defaults["num"]

            lowest_loss = float("inf")
            best_step_size = best_step_size

            for step_size in torch.linspace(start,end,num):
                loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
                if loss < lowest_loss:
                    lowest_loss = loss
                    best_step_size = step_size

            return best_step_size
    ```

    #### Using external solver via self.make_objective

    Here we let :code:`scipy.optimize.minimize_scalar` solver find the best step size via :code:`self.make_objective`

    ```python
    class ScipyMinimizeScalar(LineSearch):
        def __init__(self, method: str | None = None):
            defaults = dict(method=method)
            super().__init__(defaults)

        @torch.no_grad
        def search(self, update, var):
            objective = self.make_objective(var=var)
            method = self.defaults["method"]

            res = self.scopt.minimize_scalar(objective, method=method)
            return res.x
    ```
    """
    def __init__(self, defaults: dict[str, Any] | None, maxiter: int | None = None):
        super().__init__(defaults)
        self._maxiter = maxiter
        self._reset()

    def _reset(self):
        self._current_step_size: float = 0
        self._lowest_loss = float('inf')
        self._best_step_size: float = 0
        self._current_iter = 0
        self._initial_params = None

    def set_step_size_(
        self,
        step_size: float,
        params: list[torch.Tensor],
        update: list[torch.Tensor],
    ):
        if not math.isfinite(step_size): return

         # fixes overflow when backtracking keeps increasing alpha after converging
        step_size = max(min(tofloat(step_size), 1e36), -1e36)

        # skip is parameters are already at suggested step size
        if self._current_step_size == step_size: return

        # this was basically causing floating point imprecision to build up
        #if False:
        # if abs(alpha) < abs(step_size) and step_size != 0:
        #     torch._foreach_add_(params, update, alpha=alpha)

        # else:
        assert self._initial_params is not None
        if step_size == 0:
            new_params = [p.clone() for p in self._initial_params]
        else:
            new_params = torch._foreach_sub(self._initial_params, update, alpha=step_size)
        for c, n in zip(params, new_params):
            set_storage_(c, n)

        self._current_step_size = step_size

    def _set_per_parameter_step_size_(
        self,
        step_size: Sequence[float],
        params: list[torch.Tensor],
        update: list[torch.Tensor],
    ):
        # if not np.isfinite(step_size): step_size = [0 for _ in step_size]
        # alpha = [self._current_step_size - s for s in step_size]
        # if any(a!=0 for a in alpha):
        #     torch._foreach_add_(params, torch._foreach_mul(update, alpha))
        assert self._initial_params is not None
        if not np.isfinite(step_size).all(): step_size = [0 for _ in step_size]

        if any(s!=0 for s in step_size):
            new_params = torch._foreach_sub(self._initial_params, torch._foreach_mul(update, step_size))
        else:
            new_params = [p.clone() for p in self._initial_params]

        for c, n in zip(params, new_params):
            set_storage_(c, n)

    def _loss(self, step_size: float, var: Var, closure, params: list[torch.Tensor],
              update: list[torch.Tensor], backward:bool=False) -> float:

        # if step_size is 0, we might already know the loss
        if (var.loss is not None) and (step_size == 0):
            return tofloat(var.loss)

        # check max iter
        if self._maxiter is not None and self._current_iter >= self._maxiter: raise MaxLineSearchItersReached
        self._current_iter += 1

        # set new lr and evaluate loss with it
        self.set_step_size_(step_size, params=params, update=update)
        if backward:
            with torch.enable_grad(): loss = closure()
        else:
            loss = closure(False)

        # if it is the best so far, record it
        if loss < self._lowest_loss:
            self._lowest_loss = tofloat(loss)
            self._best_step_size = step_size

        # if evaluated loss at step size 0, set it to var.loss
        if step_size == 0:
            var.loss = loss
            if backward: var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]

        return tofloat(loss)

    def _loss_derivative_gradient(self, step_size: float, var: Var, closure,
                         params: list[torch.Tensor], update: list[torch.Tensor]):
        # if step_size is 0, we might already know the derivative
        if (var.grad is not None) and (step_size == 0):
            loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=False)
            derivative = - sum(t.sum() for t in torch._foreach_mul(var.grad, update))

        else:
            # loss with a backward pass sets params.grad
            loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=True)

            # directional derivative
            derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
                                                                    else torch.zeros_like(p) for p in params], update))

        assert var.grad is not None
        return loss, tofloat(derivative), var.grad

    def _loss_derivative(self, step_size: float, var: Var, closure,
                         params: list[torch.Tensor], update: list[torch.Tensor]):
        return self._loss_derivative_gradient(step_size=step_size, var=var,closure=closure,params=params,update=update)[:2]

    def evaluate_f(self, step_size: float, var: Var, backward:bool=False):
        """evaluate function value at alpha `step_size`."""
        closure = var.closure
        if closure is None: raise RuntimeError('line search requires closure')
        return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update(),backward=backward)

    def evaluate_f_d(self, step_size: float, var: Var):
        """evaluate function value and directional derivative in the direction of the update at step size `step_size`."""
        closure = var.closure
        if closure is None: raise RuntimeError('line search requires closure')
        return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())

    def evaluate_f_d_g(self, step_size: float, var: Var):
        """evaluate function value, directional derivative, and gradient list at step size `step_size`."""
        closure = var.closure
        if closure is None: raise RuntimeError('line search requires closure')
        return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())

    def make_objective(self, var: Var, backward:bool=False):
        closure = var.closure
        if closure is None: raise RuntimeError('line search requires closure')
        return partial(self._loss, var=var, closure=closure, params=var.params, update=var.get_update(), backward=backward)

    def make_objective_with_derivative(self, var: Var):
        closure = var.closure
        if closure is None: raise RuntimeError('line search requires closure')
        return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_update())

    def make_objective_with_derivative_and_gradient(self, var: Var):
        closure = var.closure
        if closure is None: raise RuntimeError('line search requires closure')
        return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.get_update())

    @abstractmethod
    def search(self, update: list[torch.Tensor], var: Var) -> float:
        """Finds the step size to use"""

    @torch.no_grad
    def step(self, var: Var) -> Var:
        self._reset()

        params = var.params
        self._initial_params = [p.clone() for p in params]
        update = var.get_update()

        try:
            step_size = self.search(update=update, var=var)
        except MaxLineSearchItersReached:
            step_size = self._best_step_size

        # set loss_approx
        if var.loss_approx is None: var.loss_approx = self._lowest_loss

        # this is last module - set step size to found step_size times lr
        if var.is_last:
            if var.last_module_lrs is None:
                self.set_step_size_(step_size, params=params, update=update)

            else:
                self._set_per_parameter_step_size_([step_size*lr for lr in var.last_module_lrs], params=params, update=update)

            var.stop = True; var.skip_update = True
            return var

        # revert parameters and multiply update by step size
        self.set_step_size_(0, params=params, update=update)
        torch._foreach_mul_(var.update, step_size)
        return var

evaluate_f

evaluate_f(step_size: float, var: Var, backward: bool = False)

evaluate function value at alpha step_size.

Source code in torchzero/modules/line_search/line_search.py
def evaluate_f(self, step_size: float, var: Var, backward:bool=False):
    """evaluate function value at alpha `step_size`."""
    closure = var.closure
    if closure is None: raise RuntimeError('line search requires closure')
    return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update(),backward=backward)

evaluate_f_d

evaluate_f_d(step_size: float, var: Var)

evaluate function value and directional derivative in the direction of the update at step size step_size.

Source code in torchzero/modules/line_search/line_search.py
def evaluate_f_d(self, step_size: float, var: Var):
    """evaluate function value and directional derivative in the direction of the update at step size `step_size`."""
    closure = var.closure
    if closure is None: raise RuntimeError('line search requires closure')
    return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())

evaluate_f_d_g

evaluate_f_d_g(step_size: float, var: Var)

evaluate function value, directional derivative, and gradient list at step size step_size.

Source code in torchzero/modules/line_search/line_search.py
def evaluate_f_d_g(self, step_size: float, var: Var):
    """evaluate function value, directional derivative, and gradient list at step size `step_size`."""
    closure = var.closure
    if closure is None: raise RuntimeError('line search requires closure')
    return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())

search

search(update: list[Tensor], var: Var) -> float

Finds the step size to use

Source code in torchzero/modules/line_search/line_search.py
@abstractmethod
def search(self, update: list[torch.Tensor], var: Var) -> float:
    """Finds the step size to use"""

Lion

Bases: torchzero.core.transform.Transform

Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.

Parameters:

  • beta1 (float, default: 0.9 ) –

    dampening for momentum. Defaults to 0.9.

  • beta2 (float, default: 0.99 ) –

    momentum factor. Defaults to 0.99.

Source code in torchzero/modules/adaptive/lion.py
class Lion(Transform):
    """Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.

    Args:
        beta1 (float, optional): dampening for momentum. Defaults to 0.9.
        beta2 (float, optional): momentum factor. Defaults to 0.99.
    """

    def __init__(self, beta1: float = 0.9, beta2: float = 0.99):
        defaults = dict(beta1=beta1, beta2=beta2)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
        exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
        return lion_(TensorList(tensors),exp_avg,beta1,beta2)

LiuStorey

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Liu-Storey nonlinear conjugate gradient method.

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class LiuStorey(ConguateGradientBase):
    """Liu-Storey nonlinear conjugate gradient method.

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
    """
    def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def get_beta(self, p, g, prev_g, prev_d):
        return liu_storey_beta(g, prev_d, prev_g)

LogHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class LogHomotopy(HomotopyBase):
    def __init__(self): super().__init__()
    def loss_transform(self, loss): return (loss+1e-12).log()

MARSCorrection

Bases: torchzero.core.transform.Transform

MARS variance reduction correction.

Place any other momentum-based optimizer after this, make sure beta parameter matches with momentum in the optimizer.

Parameters:

  • beta (float, default: 0.9 ) –

    use the same beta as you use in the momentum module. Defaults to 0.9.

  • scaling (float, default: 0.025 ) –

    controls the scale of gradient correction in variance reduction. Defaults to 0.025.

  • max_norm (float, default: 1 ) –

    clips norm of corrected gradients, None to disable. Defaults to 1.

Examples:

Mars-AdamW

optimizer = tz.Modular(
    model.parameters(),
    tz.m.MARSCorrection(beta=0.95),
    tz.m.Adam(beta1=0.95, beta2=0.99),
    tz.m.WeightDecay(1e-3),
    tz.m.LR(0.1)
)

Mars-Lion

optimizer = tz.Modular(
    model.parameters(),
    tz.m.MARSCorrection(beta=0.9),
    tz.m.Lion(beta1=0.9),
    tz.m.LR(0.1)
)

Source code in torchzero/modules/adaptive/mars.py
class MARSCorrection(Transform):
    """MARS variance reduction correction.

    Place any other momentum-based optimizer after this,
    make sure ``beta`` parameter matches with momentum in the optimizer.

    Args:
        beta (float, optional): use the same beta as you use in the momentum module. Defaults to 0.9.
        scaling (float, optional): controls the scale of gradient correction in variance reduction. Defaults to 0.025.
        max_norm (float, optional): clips norm of corrected gradients, None to disable. Defaults to 1.

    ## Examples:

    Mars-AdamW
    ```python
    optimizer = tz.Modular(
        model.parameters(),
        tz.m.MARSCorrection(beta=0.95),
        tz.m.Adam(beta1=0.95, beta2=0.99),
        tz.m.WeightDecay(1e-3),
        tz.m.LR(0.1)
    )
    ```

    Mars-Lion
    ```python
    optimizer = tz.Modular(
        model.parameters(),
        tz.m.MARSCorrection(beta=0.9),
        tz.m.Lion(beta1=0.9),
        tz.m.LR(0.1)
    )
    ```

    """
    def __init__(
        self,
        beta: float = 0.9,
        scaling: float = 0.025,
        max_norm: float | None = 1,
    ):
        defaults=dict(beta=beta, scaling=scaling, max_norm=max_norm)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
        beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
        max_norm = settings[0]['max_norm']

        return mars_correction_(
            tensors_=TensorList(tensors),
            prev_=prev,
            beta=beta,
            scaling=scaling,
            max_norm=max_norm,
        )

MSAM

Bases: torchzero.core.transform.Transform

Momentum-SAM from https://arxiv.org/pdf/2401.12033.

This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in replacement for momentum strategies in other optimizers.

To combine MSAM with other optimizers in the way done in the official implementation, e.g. to make Adam_MSAM, use tz.m.MSAMObjective module.

Note MSAM has a learning rate hyperparameter that can't really be removed from the update rule. To avoid compounding learning rate mofications, remove the tz.m.LR module if you had it.

Parameters:

  • lr (float) –

    learning rate. Adding this module adds support for learning rate schedulers.

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • rho (float, default: 0.3 ) –

    perturbation strength. Defaults to 0.3.

  • weight_decay (float, default: 0 ) –

    weight decay. It is applied to perturbed parameters, so it is differnet from applying :code:tz.m.WeightDecay after MSAM. Defaults to 0.

  • nesterov (bool, default: False ) –

    whether to use nesterov momentum formula. Defaults to False.

  • lerp (bool, default: False ) –

    whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.

Examples:

MSAM

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.MSAM(1e-3)
)

Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM. To make Adam_MSAM and such, use the :code:tz.m.MSAMObjective module.

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
    tz.m.Debias(0.9, 0.999),
)
Source code in torchzero/modules/adaptive/msam.py
class MSAM(Transform):
    """Momentum-SAM from https://arxiv.org/pdf/2401.12033.

    This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in
    replacement for momentum strategies in other optimizers.

    To combine MSAM with other optimizers in the way done in the official implementation,
    e.g. to make Adam_MSAM, use ``tz.m.MSAMObjective`` module.

    Note
        MSAM has a learning rate hyperparameter that can't really be removed from the update rule.
        To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.

    Args:
        lr (float): learning rate. Adding this module adds support for learning rate schedulers.
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        rho (float, optional): perturbation strength. Defaults to 0.3.
        weight_decay (float, optional):
            weight decay. It is applied to perturbed parameters, so it is differnet
            from applying :code:`tz.m.WeightDecay` after MSAM. Defaults to 0.
        nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
        lerp (bool, optional):
            whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.

    Examples:
        MSAM

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.MSAM(1e-3)
            )

        Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
        To make Adam_MSAM and such, use the :code:`tz.m.MSAMObjective` module.

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
                tz.m.Debias(0.9, 0.999),
            )
    """
    _USES_LR = True
    def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3,  weight_decay:float=0, nesterov=False, lerp=False,):
        defaults = dict(momentum=momentum,rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
        if self._USES_LR: defaults['lr'] = lr
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
        s = self.settings[params[0]]
        lerp = s['lerp']
        nesterov = s['nesterov']

        if self._USES_LR:
            lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)

        else:
            lr=None
            momentum,rho,weight_decay = unpack_dicts(settings, 'momentum','rho','weight_decay', cls=NumberList)

        return msam_(
            TensorList(tensors),
            params=TensorList(params),
            velocity_=velocity,
            momentum=momentum,
            lr=lr,
            rho=rho,
            weight_decay=weight_decay,
            nesterov=nesterov,
            lerp=lerp,

            # inner args
            inner=self.children.get("modules", None),
            grads=grads,
        )

MSAMObjective

Bases: torchzero.modules.adaptive.msam.MSAM

Momentum-SAM from https://arxiv.org/pdf/2401.12033.

Note

Please make sure to place tz.m.LR inside the modules argument. For example, tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)]). Putting LR after MSAM will lead to an incorrect update rule.

Parameters:

  • modules (Chainable) –

    modules that will optimizer the MSAM objective. Make sure :code:tz.m.LR is one of them.

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • rho (float, default: 0.3 ) –

    perturbation strength. Defaults to 0.3.

  • nesterov (bool, default: False ) –

    whether to use nesterov momentum formula. Defaults to False.

  • lerp (bool, default: False ) –

    whether to use linear interpolation, if True, MSAM momentum becomes similar to exponential moving average. Defaults to False.

Examples:

AdamW-MSAM

.. code-block:: python

opt = tz.Modular(
    bench.parameters(),
    tz.m.MSAMObjective(
        [tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
        rho=1.
    )
)
Source code in torchzero/modules/adaptive/msam.py
class MSAMObjective(MSAM):
    """Momentum-SAM from https://arxiv.org/pdf/2401.12033.

    Note:
        Please make sure to place ``tz.m.LR`` inside the ``modules`` argument. For example,
        ``tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)])``. Putting LR after MSAM will lead
        to an incorrect update rule.

    Args:
        modules (Chainable): modules that will optimizer the MSAM objective. Make sure :code:`tz.m.LR` is one of them.
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        rho (float, optional): perturbation strength. Defaults to 0.3.
        nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
        lerp (bool, optional):
            whether to use linear interpolation, if True, MSAM momentum becomes similar to exponential moving average.
            Defaults to False.

    Examples:
        AdamW-MSAM

        .. code-block:: python

            opt = tz.Modular(
                bench.parameters(),
                tz.m.MSAMObjective(
                    [tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
                    rho=1.
                )
            )
    """
    _USES_LR = False
    def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
        super().__init__(lr=0, momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
        self.set_child('modules', modules)

MatrixMomentum

Bases: torchzero.core.module.Module

Second order momentum method.

Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.

Notes
  • mu needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable. I have devised an adaptive version of this - tz.m.AdaptiveMatrixMomentum, and it works well without having to tune mu, however the adaptive version doesn't work on stochastic objectives.

  • In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.

  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument.

Parameters:

  • mu (float, default: 0.1 ) –

    this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.

  • hvp_method (str, default: 'autograd' ) –

    Determines how Hessian-vector products are evaluated.

    • "autograd": Use PyTorch's autograd to calculate exact HVPs. This requires creating a graph for the gradient.
    • "forward": Use a forward finite difference formula to approximate the HVP. This requires one extra gradient evaluation.
    • "central": Use a central finite difference formula for a more accurate HVP approximation. This requires two extra gradient evaluations. Defaults to "autograd".
  • h (float, default: 0.001 ) –

    finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.

  • hvp_tfm (Chainable | None, default: None ) –

    optional module applied to hessian-vector products. Defaults to None.

Reference

Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).

Source code in torchzero/modules/adaptive/matrix_momentum.py
class MatrixMomentum(Module):
    """Second order momentum method.

    Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.

    Notes:
        - ``mu`` needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable. I have devised an adaptive version of this - ``tz.m.AdaptiveMatrixMomentum``, and it works well without having to tune ``mu``, however the adaptive version doesn't work on stochastic objectives.

        - In most cases ``MatrixMomentum`` should be the first module in the chain because it relies on autograd.

        - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument.

    Args:
        mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
        hvp_method (str, optional):
            Determines how Hessian-vector products are evaluated.

            - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
              This requires creating a graph for the gradient.
            - ``"forward"``: Use a forward finite difference formula to
              approximate the HVP. This requires one extra gradient evaluation.
            - ``"central"``: Use a central finite difference formula for a
              more accurate HVP approximation. This requires two extra
              gradient evaluations.
            Defaults to "autograd".
        h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
        hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.

    Reference:
        Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
    """

    def __init__(
        self,
        lr:float,
        mu=0.1,
        hvp_method: Literal["autograd", "forward", "central"] = "autograd",
        h: float = 1e-3,
        adaptive:bool = False,
        adapt_freq: int | None = None,
        hvp_tfm: Chainable | None = None,
    ):
        defaults = dict(lr=lr, mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
        super().__init__(defaults)

        if hvp_tfm is not None:
            self.set_child('hvp_tfm', hvp_tfm)

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('p_prev')

    @torch.no_grad
    def update(self, var):
        assert var.closure is not None
        p = TensorList(var.params)
        p_prev = self.get_state(p, 'p_prev', init=var.params)

        hvp_method = self.defaults['hvp_method']
        h = self.defaults['h']
        step = self.global_state.get("step", 0)
        self.global_state["step"] = step + 1

        if step > 0:
            s = p - p_prev

            Hs, _ = self.Hvp(s, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
            Hs = [t.detach() for t in Hs]

            if 'hvp_tfm' in self.children:
                Hs = TensorList(apply_transform(self.children['hvp_tfm'], Hs, params=p, grads=var.grad, var=var))

            self.store(p, ("Hs", "s"), (Hs, s))

            # -------------------------------- adaptive mu ------------------------------- #
            if self.defaults["adaptive"]:
                g = TensorList(var.get_grad())

                if self.defaults["adapt_freq"] is None:
                    # ---------------------------- deterministic case ---------------------------- #
                    g_prev = self.get_state(var.params, "g_prev", cls=TensorList)
                    y = g - g_prev
                    g_prev.copy_(g)
                    denom = y.global_vector_norm()
                    denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
                    self.global_state["mu_mul"] = s.global_vector_norm() / denom

                else:
                    # -------------------------------- stochastic -------------------------------- #
                    adapt_freq = self.defaults["adapt_freq"]

                    # we start on 1nd step, and want to adapt when we start, so use (step - 1)
                    if (step - 1) % adapt_freq == 0:
                        assert var.closure is not None
                        params = TensorList(var.params)
                        p_cur = params.clone()

                        # move to previous params and evaluate p_prev with current mini-batch
                        params.copy_(self.get_state(var.params, 'p_prev'))
                        with torch.enable_grad():
                            var.closure()
                        g_prev = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
                        y = g - g_prev

                        # move back to current params
                        params.copy_(p_cur)

                        denom = y.global_vector_norm()
                        denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
                        self.global_state["mu_mul"] = s.global_vector_norm() / denom

        torch._foreach_copy_(p_prev, var.params)

    @torch.no_grad
    def apply(self, var):
        update = TensorList(var.get_update())
        lr,mu = self.get_settings(var.params, "lr", 'mu', cls=NumberList)

        if "mu_mul" in self.global_state:
            mu = mu * self.global_state["mu_mul"]

        # --------------------------------- 1st step --------------------------------- #
        # p_prev is not available so make a small step
        step = self.global_state["step"]
        if step == 1:
            if self.defaults["adaptive"]: self.get_state(var.params, "g_prev", init=var.get_grad())
            update.mul_(lr) # separate so that initial_step_size can clip correctly
            update.mul_(initial_step_size(update, 1e-7))
            return var

        # -------------------------- matrix momentum update -------------------------- #
        s, Hs = self.get_state(var.params, 's', 'Hs', cls=TensorList)

        update.mul_(lr).sub_(s).add_(Hs*mu)
        var.update = update
        return var

Maximum

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs :code:maximum(tensors, other(tensors))

Source code in torchzero/modules/ops/binary.py
class Maximum(BinaryOperationBase):
    """Outputs :code:`maximum(tensors, other(tensors))`"""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
        torch._foreach_maximum_(update, other)
        return update

MaximumModules

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs elementwise maximum of :code:inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class MaximumModules(ReduceOperationBase):
    """Outputs elementwise maximum of :code:`inputs` that can be modules or numbers."""
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        maximum = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_maximum_(maximum, v)

        return maximum

McCormick

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

McCormicks's Quasi-Newton method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.

This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method in other sources.

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class McCormick(_InverseHessianUpdateStrategyDefaults):
    """McCormicks's Quasi-Newton method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.

        This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method in other sources.
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return mccormick_H_(H=H, s=s, y=y)

MeZO

Bases: torchzero.modules.grad_approximation.grad_approximator.GradApproximator

Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.

Note

This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

Parameters:

  • h (float, default: 0.001 ) –

    finite difference step size of jvp_method is set to forward or central. Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of random gradient samples. Defaults to 1.

  • formula (Literal, default: 'central2' ) –

    finite difference formula. Defaults to 'central2'.

  • distribution (Literal, default: 'rademacher' ) –

    distribution. Defaults to "rademacher". If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.

  • target (Literal, default: 'closure' ) –

    what to set on var. Defaults to "closure".

References

Malladi, S., Gao, T., Nichani, E., Damian, A., Lee, J. D., Chen, D., & Arora, S. (2023). Fine-tuning language models with just forward passes. Advances in Neural Information Processing Systems, 36, 53038-53075. https://arxiv.org/abs/2305.17333

Source code in torchzero/modules/grad_approximation/rfdm.py
class MeZO(GradApproximator):
    """Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.

    Note:
        This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
        and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

    Args:
        h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
        n_samples (int, optional): number of random gradient samples. Defaults to 1.
        formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
        distribution (Distributions, optional): distribution. Defaults to "rademacher".
            If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
        target (GradTarget, optional): what to set on var. Defaults to "closure".

    References:
        Malladi, S., Gao, T., Nichani, E., Damian, A., Lee, J. D., Chen, D., & Arora, S. (2023). Fine-tuning language models with just forward passes. Advances in Neural Information Processing Systems, 36, 53038-53075. https://arxiv.org/abs/2305.17333
    """

    def __init__(self, h: float=1e-3, n_samples: int = 1, formula: _FD_Formula = 'central2',
                 distribution: Distributions = 'rademacher', target: GradTarget = 'closure'):

        defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution)
        super().__init__(defaults, target=target)

    def _seeded_perturbation(self, params: list[torch.Tensor], distribution, seed, h):
        prt = TensorList(params).sample_like(
            distribution=distribution,
            variance=h,
            generator=torch.Generator(params[0].device).manual_seed(seed)
        )
        return prt

    def pre_step(self, var):
        h = NumberList(self.settings[p]['h'] for p in var.params)

        n_samples = self.defaults['n_samples']
        distribution = self.defaults['distribution']

        step = var.current_step

        # create functions that generate a deterministic perturbation from seed based on current step
        prt_fns = []
        for i in range(n_samples):

            prt_fn = partial(self._seeded_perturbation, params=var.params, distribution=distribution, seed=1_000_000*step + i, h=h)
            prt_fns.append(prt_fn)

        self.global_state['prt_fns'] = prt_fns

    @torch.no_grad
    def approximate(self, closure, params, loss):
        params = TensorList(params)
        loss_approx = None

        h = NumberList(self.settings[p]['h'] for p in params)
        settings = self.settings[params[0]]
        n_samples = settings['n_samples']
        fd_fn = _RFD_FUNCS[settings['formula']]
        prt_fns = self.global_state['prt_fns']

        grad = None
        for i in range(n_samples):
            loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=prt_fns[i], h=h, f_0=loss)
            if grad is None: grad = prt_fns[i]().mul_(d)
            else: grad += prt_fns[i]().mul_(d)

        assert grad is not None
        if n_samples > 1: grad.div_(n_samples)
        return grad, loss, loss_approx

Mean

Bases: torchzero.modules.ops.reduce.Sum

Outputs a mean of :code:inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class Mean(Sum):
    """Outputs a mean of :code:`inputs` that can be modules or numbers."""
    USE_MEAN = True

USE_MEAN class-attribute

USE_MEAN = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

MedianAveraging

Bases: torchzero.core.transform.TensorwiseTransform

Median of past history_size updates.

Parameters:

  • history_size (int) –

    Number of past updates to average

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/momentum/averaging.py
class MedianAveraging(TensorwiseTransform):
    """Median of past ``history_size`` updates.

    Args:
        history_size (int): Number of past updates to average
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, history_size: int, target: Target = 'update'):
        defaults = dict(history_size = history_size)
        super().__init__(uses_grad=False, defaults=defaults, target=target)

    @torch.no_grad
    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        history_size = setting['history_size']

        if 'history' not in state:
            state['history'] = deque(maxlen=history_size)

        history = state['history']
        history.append(tensor)

        stacked = torch.stack(tuple(history), 0)
        return torch.quantile(stacked, 0.5, dim = 0)

Minimum

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs :code:minimum(tensors, other(tensors))

Source code in torchzero/modules/ops/binary.py
class Minimum(BinaryOperationBase):
    """Outputs :code:`minimum(tensors, other(tensors))`"""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
        torch._foreach_minimum_(update, other)
        return update

MinimumModules

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs elementwise minimum of :code:inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class MinimumModules(ReduceOperationBase):
    """Outputs elementwise minimum of :code:`inputs` that can be modules or numbers."""
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        minimum = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_minimum_(minimum, v)

        return minimum

Mul

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Multiply tensors by :code:other. :code:other can be a number or a module.

If :code:other is a module, this calculates :code:tensors * other(tensors)

Source code in torchzero/modules/ops/binary.py
class Mul(BinaryOperationBase):
    """Multiply tensors by :code:`other`. :code:`other` can be a number or a module.

    If :code:`other` is a module, this calculates :code:`tensors * other(tensors)`
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        torch._foreach_mul_(update, other)
        return update

MulByLoss

Bases: torchzero.core.module.Module

Multiplies update by loss times :code:alpha

Source code in torchzero/modules/misc/misc.py
class MulByLoss(Module):
    """Multiplies update by loss times :code:`alpha`"""
    def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
        defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
        loss = var.get_loss(backward=self.defaults['backward'])
        mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
        torch._foreach_mul_(var.update, mul)
        return var

MultiOperationBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for operations that use operands. This is an abstract class, subclass it and override transform method to use it.

Methods:

  • transform

    applies the operation to operands

Source code in torchzero/modules/ops/multi.py
class MultiOperationBase(Module, ABC):
    """Base class for operations that use operands. This is an abstract class, subclass it and override `transform` method to use it."""
    def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
        super().__init__(defaults=defaults)

        self.operands = {}
        for k,v in operands.items():

            if isinstance(v, (Module, Sequence)):
                self.set_child(k, v)
                self.operands[k] = self.children[k]
            else:
                self.operands[k] = v

        if not self.children:
            raise ValueError('At least one operand must be a module')

    @abstractmethod
    def transform(self, var: Var, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
        """applies the operation to operands"""
        raise NotImplementedError

    @torch.no_grad
    def step(self, var: Var) -> Var:
        # pass cloned update to all module operands
        processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()

        for k,v in self.operands.items():
            if k in self.children:
                v: Module
                updated_var = v.step(var.clone(clone_update=True))
                processed_operands[k] = updated_var.get_update()
                var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them

        transformed = self.transform(var, **processed_operands)
        var.update = transformed
        return var

transform

transform(var: Var, **operands: Any | list[Tensor]) -> list[Tensor]

applies the operation to operands

Source code in torchzero/modules/ops/multi.py
@abstractmethod
def transform(self, var: Var, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
    """applies the operation to operands"""
    raise NotImplementedError

Multistep

Bases: torchzero.core.module.Module

Performs :code:steps inner steps with :code:module per each step.

The update is taken to be the parameter difference between parameters before and after the inner loop.

Source code in torchzero/modules/misc/multistep.py
class Multistep(Module):
    """Performs :code:`steps` inner steps with :code:`module` per each step.

    The update is taken to be the parameter difference between parameters before and after the inner loop."""
    def __init__(self, module: Chainable, steps: int):
        defaults = dict(steps=steps)
        super().__init__(defaults)
        self.set_child('module', module)

    @torch.no_grad
    def step(self, var):
        return _sequential_step(self, var, sequential=False)

MuonAdjustLR

Bases: torchzero.core.transform.Transform

LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master). Orthogonalize already has this built in with the adjust_lr setting, however you might want to move this to be later in the chain.

Source code in torchzero/modules/adaptive/muon.py
class MuonAdjustLR(Transform):
    """LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master).
    Orthogonalize already has this built in with the `adjust_lr` setting, however you might want to move this to be later in the chain."""
    def __init__(self, alpha: float = 1, target: Target='update'):
        defaults = dict(alpha=alpha)
        super().__init__(defaults=defaults, uses_grad=False, target=target)

    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        alphas = [s['alpha'] for s in settings]
        tensors_alphas = [(t, adjust_lr_for_muon(a, t.shape)) for t, a in zip(tensors, alphas) if _is_at_least_2d(t)]
        tensors = [i[0] for i in tensors_alphas]
        a = [i[1] for i in alphas]
        torch._foreach_mul_(tensors, a)
        return tensors

NAG

Bases: torchzero.core.transform.Transform

Nesterov accelerated gradient method (nesterov momentum).

Parameters:

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • dampening (float, default: 0 ) –

    momentum dampening. Defaults to 0.

  • lerp (bool, default: False ) –

    whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.

  • target (Literal, default: 'update' ) –

    target to apply EMA to. Defaults to 'update'.

Source code in torchzero/modules/momentum/momentum.py
class NAG(Transform):
    """Nesterov accelerated gradient method (nesterov momentum).

    Args:
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        dampening (float, optional): momentum dampening. Defaults to 0.
        lerp (bool, optional):
            whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
        target (Target, optional): target to apply EMA to. Defaults to 'update'.
    """
    def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False, target: Target = 'update'):
        defaults = dict(momentum=momentum,dampening=dampening, lerp=lerp)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
        lerp = self.settings[params[0]]['lerp']

        momentum,dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
        return nag_(TensorList(tensors), velocity_=velocity,momentum=momentum,dampening=dampening,lerp=lerp)

NanToNum

Bases: torchzero.core.transform.Transform

Convert nan, inf and -inf to numbers.

Parameters:

  • nan (optional, default: None ) –

    the value to replace NaNs with. Default is zero.

  • posinf (optional, default: None ) –

    if a Number, the value to replace positive infinity values with. If None, positive infinity values are replaced with the greatest finite value representable by input's dtype. Default is None.

  • neginf (optional, default: None ) –

    if a Number, the value to replace negative infinity values with. If None, negative infinity values are replaced with the lowest finite value representable by input's dtype. Default is None.

Source code in torchzero/modules/ops/unary.py
class NanToNum(Transform):
    """Convert `nan`, `inf` and `-inf` to numbers.

    Args:
        nan (optional): the value to replace NaNs with. Default is zero.
        posinf (optional): if a Number, the value to replace positive infinity values with.
            If None, positive infinity values are replaced with the greatest finite value
            representable by input's dtype. Default is None.
        neginf (optional): if a Number, the value to replace negative infinity values with.
            If None, negative infinity values are replaced with the lowest finite value
            representable by input's dtype. Default is None.
    """
    def __init__(self, nan=None, posinf=None, neginf=None, target: "Target" = 'update'):
        defaults = dict(nan=nan, posinf=posinf, neginf=neginf)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        nan, posinf, neginf = unpack_dicts(settings, 'nan', 'posinf', 'neginf')
        return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]

NaturalGradient

Bases: torchzero.core.module.Module

Natural gradient approximated via empirical fisher information matrix.

To use this, either pass vector of per-sample losses to the step method, or make sure the closure returns it. Gradients will be calculated via batched autograd within this module, you don't need to implement the backward pass. When using closure, please add the backward argument, it will always be False but it is required. See below for an example.

Note

Empirical fisher information matrix may give a really bad approximation in some cases. If that is the case, set sqrt to True to perform whitening instead, which is way more robust.

Parameters:

  • reg (float, default: 1e-08 ) –

    regularization parameter. Defaults to 1e-8.

  • sqrt (bool, default: False ) –

    if True, uses square root of empirical fisher information matrix. Both EFIM and it's square root can be calculated and stored efficiently without ndim^2 memory. Square root whitens the gradient and often performs much better, especially when you try to use NGD with a vector that isn't strictly per-sample gradients, but rather for example different losses.

  • gn_grad (bool, default: False ) –

    if True, uses Gauss-Newton G^T @ f as the gradient, which is effectively sum weighted by value and is equivalent to squaring the values. This way you can solve least-squares objectives with a NGD-like algorithm. If False, uses sum of per-sample gradients. This has an effect when sqrt=True, and affects the grad attribute. Defaults to False.

  • batched (bool, default: True ) –

    whether to use vmapping. Defaults to True.

Examples:

training a neural network:

X = torch.randn(64, 20)
y = torch.randn(64, 10)

model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
opt = tz.Modular(
    model.parameters(),
    tz.m.NaturalGradient(),
    tz.m.LR(3e-2)
)

for i in range(100):
    y_hat = model(X) # (64, 10)
    losses = (y_hat - y).pow(2).mean(0) # (10, )
    opt.step(loss=losses)
    if i % 10 == 0:
        print(f'{losses.mean() = }')

training a neural network - closure version

X = torch.randn(64, 20)
y = torch.randn(64, 10)

model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
opt = tz.Modular(
    model.parameters(),
    tz.m.NaturalGradient(),
    tz.m.LR(3e-2)
)

def closure(backward=True):
    y_hat = model(X) # (64, 10)
    return (y_hat - y).pow(2).mean(0) # (10, )

for i in range(100):
    losses = opt.step(closure)
    if i % 10 == 0:
    print(f'{losses.mean() = }')

minimizing the rosenbrock function with a mix of natural gradient, whitening and gauss-newton:

def rosenbrock(X):
    x1, x2 = X
    return torch.stack([(1 - x1).abs(), (10 * (x2 - x1**2).abs())])

X = torch.tensor([-1.1, 2.5], requires_grad=True)
opt = tz.Modular([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))

for iter in range(200):
    losses = rosenbrock(X)
    opt.step(loss=losses)
    if iter % 20 == 0:
        print(f'{losses.mean() = }')

Source code in torchzero/modules/adaptive/natural_gradient.py
class NaturalGradient(Module):
    """Natural gradient approximated via empirical fisher information matrix.

    To use this, either pass vector of per-sample losses to the step method, or make sure
    the closure returns it. Gradients will be calculated via batched autograd within this module,
    you don't need to implement the backward pass. When using closure, please add the ``backward`` argument,
    it will always be False but it is required. See below for an example.

    Note:
        Empirical fisher information matrix may give a really bad approximation in some cases.
        If that is the case, set ``sqrt`` to True to perform whitening instead, which is way more robust.

    Args:
        reg (float, optional): regularization parameter. Defaults to 1e-8.
        sqrt (bool, optional):
            if True, uses square root of empirical fisher information matrix. Both EFIM and it's square
            root can be calculated and stored efficiently without ndim^2 memory. Square root
            whitens the gradient and often performs much better, especially when you try to use NGD
            with a vector that isn't strictly per-sample gradients, but rather for example different losses.
        gn_grad (bool, optional):
            if True, uses Gauss-Newton G^T @ f as the gradient, which is effectively sum weighted by value
            and is equivalent to squaring the values. This way you can solve least-squares
            objectives with a NGD-like algorithm. If False, uses sum of per-sample gradients.
            This has an effect when ``sqrt=True``, and affects the ``grad`` attribute.
            Defaults to False.
        batched (bool, optional): whether to use vmapping. Defaults to True.

    Examples:

    training a neural network:
    ```python
    X = torch.randn(64, 20)
    y = torch.randn(64, 10)

    model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
    opt = tz.Modular(
        model.parameters(),
        tz.m.NaturalGradient(),
        tz.m.LR(3e-2)
    )

    for i in range(100):
        y_hat = model(X) # (64, 10)
        losses = (y_hat - y).pow(2).mean(0) # (10, )
        opt.step(loss=losses)
        if i % 10 == 0:
            print(f'{losses.mean() = }')
    ```

    training a neural network - closure version
    ```python
    X = torch.randn(64, 20)
    y = torch.randn(64, 10)

    model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
    opt = tz.Modular(
        model.parameters(),
        tz.m.NaturalGradient(),
        tz.m.LR(3e-2)
    )

    def closure(backward=True):
        y_hat = model(X) # (64, 10)
        return (y_hat - y).pow(2).mean(0) # (10, )

    for i in range(100):
        losses = opt.step(closure)
        if i % 10 == 0:
        print(f'{losses.mean() = }')
    ```

    minimizing the rosenbrock function with a mix of natural gradient, whitening and gauss-newton:
    ```python
    def rosenbrock(X):
        x1, x2 = X
        return torch.stack([(1 - x1).abs(), (10 * (x2 - x1**2).abs())])

    X = torch.tensor([-1.1, 2.5], requires_grad=True)
    opt = tz.Modular([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))

    for iter in range(200):
        losses = rosenbrock(X)
        opt.step(loss=losses)
        if iter % 20 == 0:
            print(f'{losses.mean() = }')
    ```
    """
    def __init__(self, reg:float = 1e-8, sqrt:bool=False, gn_grad:bool=False, batched:bool=True, ):
        super().__init__(defaults=dict(batched=batched, reg=reg, sqrt=sqrt, gn_grad=gn_grad))

    @torch.no_grad
    def update(self, var):
        params = var.params
        batched = self.defaults['batched']
        gn_grad = self.defaults['gn_grad']

        closure = var.closure
        assert closure is not None

        with torch.enable_grad():
            f = var.get_loss(backward=False) # n_out
            assert isinstance(f, torch.Tensor)
            G_list = jacobian_wrt([f.ravel()], params, batched=batched)

        var.loss = f.sum()
        G = self.global_state["G"] = flatten_jacobian(G_list) # (n_samples, ndim)

        if gn_grad:
            g = self.global_state["g"] = G.H @ f.detach()

        else:
            g = self.global_state["g"] = G.sum(0)

        var.grad = vec_to_tensors(g, params)

        # set closure to calculate scalar value for line searches etc
        if var.closure is not None:
            def ngd_closure(backward=True):
                if backward:
                    var.zero_grad()
                    with torch.enable_grad():
                        loss = closure(False)
                        if gn_grad: loss = loss.pow(2)
                        loss = loss.sum()
                        loss.backward()
                    return loss

                loss = closure(False)
                if gn_grad: loss = loss.pow(2)
                return loss.sum()

            var.closure = ngd_closure

    @torch.no_grad
    def apply(self, var):
        params = var.params
        reg = self.defaults['reg']
        sqrt = self.defaults['sqrt']

        G: torch.Tensor = self.global_state['G'] # (n_samples, n_dim)

        if sqrt:
            # this computes U, S <- SVD(M), then calculate update as U S^-1 Uᵀg,
            # but it computes it through eigendecompotision
            U, L = lm_adagrad_update(G.H, reg, 0)
            if U is None or L is None: return var

            v = lm_adagrad_apply(self.global_state["g"], U, L)
            var.update = vec_to_tensors(v, params)
            return var

        GGT = G @ G.H # (n_samples, n_samples)

        if reg != 0:
            GGT.add_(torch.eye(GGT.size(0), device=GGT.device, dtype=GGT.dtype).mul_(reg))

        z, _ = torch.linalg.solve_ex(GGT, torch.ones_like(GGT[0])) # pylint:disable=not-callable
        v = G.H @ z

        var.update = vec_to_tensors(v, params)
        return var


    def get_H(self, var):
        if "G" not in self.global_state: return linear_operator.ScaledIdentity()
        G = self.global_state['G']
        return linear_operator.AtA(G)

Negate

Bases: torchzero.core.transform.Transform

Returns :code:- input

Source code in torchzero/modules/ops/unary.py
class Negate(Transform):
    """Returns :code:`- input`"""
    def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        torch._foreach_neg_(tensors)
        return tensors

NegateOnLossIncrease

Bases: torchzero.core.module.Module

Uses an extra forward pass to evaluate loss at :code:parameters+update, if loss is larger than at :code:parameters, the update is set to 0 if :code:backtrack=False and to :code:-update otherwise

Source code in torchzero/modules/misc/multistep.py
class NegateOnLossIncrease(Module):
    """Uses an extra forward pass to evaluate loss at :code:`parameters+update`,
    if loss is larger than at :code:`parameters`,
    the update is set to 0 if :code:`backtrack=False` and to :code:`-update` otherwise"""
    def __init__(self, backtrack=False):
        defaults = dict(backtrack=backtrack)
        super().__init__(defaults=defaults)

    @torch.no_grad
    def step(self, var):
        closure = var.closure
        if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
        backtrack = self.defaults['backtrack']

        update = var.get_update()
        f_0 = var.get_loss(backward=False)

        torch._foreach_sub_(var.params, update)
        f_1 = closure(False)

        if f_1 <= f_0:
            if var.is_last and var.last_module_lrs is None:
                var.stop = True
                var.skip_update = True
                return var

            torch._foreach_add_(var.params, update)
            return var

        torch._foreach_add_(var.params, update)
        if backtrack:
            torch._foreach_neg_(var.update)
        else:
            torch._foreach_zero_(var.update)
        return var

NewDQN

Bases: torchzero.modules.quasi_newton.diagonal_quasi_newton.DNRTR

Diagonal quasi-newton method.

Reference

Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.

Source code in torchzero/modules/quasi_newton/diagonal_quasi_newton.py
class NewDQN(DNRTR):
    """Diagonal quasi-newton method.

    Reference:
        Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
    """
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return new_dqn_B_(B=B, s=s, y=y)

NewSSM

Bases: torchzero.modules.quasi_newton.quasi_newton.HessianUpdateStrategy

Self-scaling Quasi-Newton method.

Note

a line search such as tz.m.StrongWolfe() is required.

Warning

this uses roughly O(N^2) memory.

Reference

Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class NewSSM(HessianUpdateStrategy):
    """Self-scaling Quasi-Newton method.

    Note:
        a line search such as ``tz.m.StrongWolfe()`` is required.

    Warning:
        this uses roughly O(N^2) memory.

    Reference:
        Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
    """
    def __init__(
        self,
        type: Literal[1, 2] = 1,
        init_scale: float | Literal["auto"] = "auto",
        tol: float = 1e-32,
        ptol: float | None = 1e-32,
        ptol_restart: bool = False,
        gtol: float | None = 1e-32,
        restart_interval: int | None = None,
        beta: float | None = None,
        update_freq: int = 1,
        scale_first: bool = False,
        concat_params: bool = True,
        inner: Chainable | None = None,
    ):
        super().__init__(
            defaults=dict(type=type),
            init_scale=init_scale,
            tol=tol,
            ptol=ptol,
            ptol_restart=ptol_restart,
            gtol=gtol,
            restart_interval=restart_interval,
            beta=beta,
            update_freq=update_freq,
            scale_first=scale_first,
            concat_params=concat_params,
            inverse=True,
            inner=inner,
        )
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        f = state['f']
        f_prev = state['f_prev']
        return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=setting['type'], tol=setting['tol'])

Newton

Bases: torchzero.core.module.Module

Exact newton's method via autograd.

Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function. The update rule is given by (H + yI)⁻¹g, where H is the hessian and g is the gradient, y is the damping parameter. g can be output of another module, if it is specifed in inner argument.

Note

In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:inner argument if you wish to apply Newton preconditioning to another module's output.

Note

This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating the hessian. The closure must accept a backward argument (refer to documentation).

Parameters:

  • damping (float, default: 0 ) –

    tikhonov regularizer value. Set this to 0 when using trust region. Defaults to 0.

  • search_negative ((bool, Optional), default: False ) –

    if True, whenever a negative eigenvalue is detected, search direction is proposed along weighted sum of eigenvectors corresponding to negative eigenvalues.

  • use_lstsq ((bool, Optional), default: False ) –

    if True, least squares will be used to solve the linear system, this may generate reasonable directions when hessian is not invertible. If False, tries cholesky, if it fails tries LU, and then least squares. If eigval_fn is specified, eigendecomposition will always be used to solve the linear system and this argument will be ignored.

  • hessian_method (str, default: 'autograd' ) –

    how to calculate hessian. Defaults to "autograd".

  • vectorize (bool, default: True ) –

    whether to enable vectorized hessian. Defaults to True.

  • inner (Chainable | None, default: None ) –

    modules to apply hessian preconditioner to. Defaults to None.

  • H_tfm (Callable | None, default: None ) –

    optional hessian transforms, takes in two arguments - (hessian, gradient).

    must return either a tuple: (hessian, is_inverted) with transformed hessian and a boolean value which must be True if transform inverted the hessian and False otherwise.

    Or it returns a single tensor which is used as the update.

    Defaults to None.

  • eigval_fn (Callable | None, default: None ) –

    optional eigenvalues transform, for example torch.abs or lambda L: torch.clip(L, min=1e-8). If this is specified, eigendecomposition will be used to invert the hessian.

See also

  • tz.m.NewtonCG: uses a matrix-free conjugate gradient solver and hessian-vector products, useful for large scale problems as it doesn't form the full hessian.
  • tz.m.NewtonCGSteihaug: trust region version of tz.m.NewtonCG.
  • tz.m.InverseFreeNewton: an inverse-free variant of Newton's method.
  • tz.m.quasi_newton: large collection of quasi-newton methods that estimate the hessian.

Notes

Implementation details

(H + yI)⁻¹g is calculated by solving the linear system (H + yI)x = g. The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares. Least squares can be forced by setting use_lstsq=True, which may generate better search directions when linear system is overdetermined.

Additionally, if eigval_fn is specified or search_negative is True, eigendecomposition of the hessian is computed, eigval_fn is applied to the eigenvalues, and (H + yI)⁻¹ is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive.

Handling non-convexity

Standard Newton's method does not handle non-convexity well without some modifications. This is because it jumps to the stationary point, which may be the maxima of the quadratic approximation.

The first modification to handle non-convexity is to modify the eignevalues to be positive, for example by setting eigval_fn = lambda L: L.abs().clip(min=1e-4).

Second modification is search_negative=True, which will search along a negative curvature direction if one is detected. This also requires an eigendecomposition.

The Newton direction can also be forced to be a descent direction by using tz.m.GradSign() or tz.m.Cautious, but that may be significantly less efficient.

Examples:

Newton's method with backtracking line search

opt = tz.Modular(
    model.parameters(),
    tz.m.Newton(),
    tz.m.Backtracking()
)

Newton preconditioning applied to momentum

opt = tz.Modular(
    model.parameters(),
    tz.m.Newton(inner=tz.m.EMA(0.9)),
    tz.m.LR(0.1)
)

Diagonal newton example. This will still evaluate the entire hessian so it isn't efficient, but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.

opt = tz.Modular(
    model.parameters(),
    tz.m.Newton(H_tfm = lambda H, g: g/H.diag()),
    tz.m.Backtracking()
)
Source code in torchzero/modules/second_order/newton.py
class Newton(Module):
    """Exact newton's method via autograd.

    Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
    The update rule is given by ``(H + yI)⁻¹g``, where ``H`` is the hessian and ``g`` is the gradient, ``y`` is the ``damping`` parameter.
    ``g`` can be output of another module, if it is specifed in ``inner`` argument.

    Note:
        In most cases Newton should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.

    Note:
        This module requires the a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients for calculating the hessian.
        The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        damping (float, optional): tikhonov regularizer value. Set this to 0 when using trust region. Defaults to 0.
        search_negative (bool, Optional):
            if True, whenever a negative eigenvalue is detected,
            search direction is proposed along weighted sum of eigenvectors corresponding to negative eigenvalues.
        use_lstsq (bool, Optional):
            if True, least squares will be used to solve the linear system, this may generate reasonable directions
            when hessian is not invertible. If False, tries cholesky, if it fails tries LU, and then least squares.
            If ``eigval_fn`` is specified, eigendecomposition will always be used to solve the linear system and this
            argument will be ignored.
        hessian_method (str):
            how to calculate hessian. Defaults to "autograd".
        vectorize (bool, optional):
            whether to enable vectorized hessian. Defaults to True.
        inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
        H_tfm (Callable | None, optional):
            optional hessian transforms, takes in two arguments - `(hessian, gradient)`.

            must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
            which must be True if transform inverted the hessian and False otherwise.

            Or it returns a single tensor which is used as the update.

            Defaults to None.
        eigval_fn (Callable | None, optional):
            optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
            If this is specified, eigendecomposition will be used to invert the hessian.

    # See also

    * ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products,
    useful for large scale problems as it doesn't form the full hessian.
    * ``tz.m.NewtonCGSteihaug``: trust region version of ``tz.m.NewtonCG``.
    * ``tz.m.InverseFreeNewton``: an inverse-free variant of Newton's method.
    * ``tz.m.quasi_newton``: large collection of quasi-newton methods that estimate the hessian.

    # Notes

    ## Implementation details

    ``(H + yI)⁻¹g`` is calculated by solving the linear system ``(H + yI)x = g``.
    The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares.
    Least squares can be forced by setting ``use_lstsq=True``, which may generate better search directions when linear system is overdetermined.

    Additionally, if ``eigval_fn`` is specified or ``search_negative`` is ``True``,
    eigendecomposition of the hessian is computed, ``eigval_fn`` is applied to the eigenvalues,
    and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues.
    This is more generally more computationally expensive.

    ## Handling non-convexity

    Standard Newton's method does not handle non-convexity well without some modifications.
    This is because it jumps to the stationary point, which may be the maxima of the quadratic approximation.

    The first modification to handle non-convexity is to modify the eignevalues to be positive,
    for example by setting ``eigval_fn = lambda L: L.abs().clip(min=1e-4)``.

    Second modification is ``search_negative=True``, which will search along a negative curvature direction if one is detected.
    This also requires an eigendecomposition.

    The Newton direction can also be forced to be a descent direction by using ``tz.m.GradSign()`` or ``tz.m.Cautious``,
    but that may be significantly less efficient.

    # Examples:

    Newton's method with backtracking line search

    ```py
    opt = tz.Modular(
        model.parameters(),
        tz.m.Newton(),
        tz.m.Backtracking()
    )
    ```

    Newton preconditioning applied to momentum

    ```py
    opt = tz.Modular(
        model.parameters(),
        tz.m.Newton(inner=tz.m.EMA(0.9)),
        tz.m.LR(0.1)
    )
    ```

    Diagonal newton example. This will still evaluate the entire hessian so it isn't efficient,
    but if you wanted to see how diagonal newton behaves or compares to full newton, you can use this.

    ```py
    opt = tz.Modular(
        model.parameters(),
        tz.m.Newton(H_tfm = lambda H, g: g/H.diag()),
        tz.m.Backtracking()
    )
    ```

    """
    def __init__(
        self,
        damping: float = 0,
        search_negative: bool = False,
        use_lstsq: bool = False,
        update_freq: int = 1,
        hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
        vectorize: bool = True,
        inner: Chainable | None = None,
        H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
        eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
    ):
        defaults = dict(damping=damping, hessian_method=hessian_method, use_lstsq=use_lstsq, vectorize=vectorize, H_tfm=H_tfm, eigval_fn=eigval_fn, search_negative=search_negative, update_freq=update_freq)
        super().__init__(defaults)

        if inner is not None:
            self.set_child('inner', inner)

    @torch.no_grad
    def update(self, var):
        params = TensorList(var.params)
        closure = var.closure
        if closure is None: raise RuntimeError('NewtonCG requires closure')

        settings = self.settings[params[0]]
        damping = settings['damping']
        hessian_method = settings['hessian_method']
        vectorize = settings['vectorize']
        update_freq = settings['update_freq']

        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        g_list = var.grad
        H = None
        if step % update_freq == 0:
            # ------------------------ calculate grad and hessian ------------------------ #
            if hessian_method == 'autograd':
                with torch.enable_grad():
                    loss = var.loss = var.loss_approx = closure(False)
                    g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
                    g_list = [t[0] for t in g_list] # remove leading dim from loss
                    var.grad = g_list
                    H = flatten_jacobian(H_list)

            elif hessian_method in ('func', 'autograd.functional'):
                strat = 'forward-mode' if vectorize else 'reverse-mode'
                with torch.enable_grad():
                    g_list = var.get_grad(retain_graph=True)
                    H = hessian_mat(partial(closure, backward=False), params,
                                    method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]

            else:
                raise ValueError(hessian_method)

            if damping != 0: H.add_(torch.eye(H.size(-1), dtype=H.dtype, device=H.device).mul_(damping))
            self.global_state['H'] = H

    @torch.no_grad
    def apply(self, var):
        H = self.global_state["H"]

        params = var.params
        settings = self.settings[params[0]]
        search_negative = settings['search_negative']
        H_tfm = settings['H_tfm']
        eigval_fn = settings['eigval_fn']
        use_lstsq = settings['use_lstsq']

        # -------------------------------- inner step -------------------------------- #
        update = var.get_update()
        if 'inner' in self.children:
            update = apply_transform(self.children['inner'], update, params=params, grads=var.grad, var=var)

        g = torch.cat([t.ravel() for t in update])

        # ----------------------------------- solve ---------------------------------- #
        update = None
        if H_tfm is not None:
            ret = H_tfm(H, g)

            if isinstance(ret, torch.Tensor):
                update = ret

            else: # returns (H, is_inv)
                H, is_inv = ret
                if is_inv: update = H @ g

        if search_negative or (eigval_fn is not None):
            update = _eigh_solve(H, g, eigval_fn, search_negative=search_negative)

        if update is None and use_lstsq: update = _least_squares_solve(H, g)
        if update is None: update = _cholesky_solve(H, g)
        if update is None: update = _lu_solve(H, g)
        if update is None: update = _least_squares_solve(H, g)

        var.update = vec_to_tensors(update, params)

        return var

    def get_H(self,var):
        H = self.global_state["H"]
        settings = self.defaults
        if settings['eigval_fn'] is not None:
            try:
                L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
                L = settings['eigval_fn'](L)
                H = Q @ L.diag_embed() @ Q.mH
                H_inv = Q @ L.reciprocal().diag_embed() @ Q.mH
                return DenseWithInverse(H, H_inv)

            except torch.linalg.LinAlgError:
                pass

        return Dense(H)

NewtonCG

Bases: torchzero.core.module.Module

Newton's method with a matrix-free conjugate gradient or minimial-residual solver.

Notes
  • In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply Newton preconditioning to another module's output.

  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

Warning

CG may fail if hessian is not positive-definite.

Parameters:

  • maxiter (int | None, default: None ) –

    Maximum number of iterations for the conjugate gradient solver. By default, this is set to the number of dimensions in the objective function, which is the theoretical upper bound for CG convergence. Setting this to a smaller value (truncated Newton) can still generate good search directions. Defaults to None.

  • tol (float, default: 1e-08 ) –

    Relative tolerance for the conjugate gradient solver to determine convergence. Defaults to 1e-4.

  • reg (float, default: 1e-08 ) –

    Regularization parameter (damping) added to the Hessian diagonal. This helps ensure the system is positive-definite. Defaults to 1e-8.

  • hvp_method (str, default: 'autograd' ) –

    Determines how Hessian-vector products are evaluated.

    • "autograd": Use PyTorch's autograd to calculate exact HVPs. This requires creating a graph for the gradient.
    • "forward": Use a forward finite difference formula to approximate the HVP. This requires one extra gradient evaluation.
    • "central": Use a central finite difference formula for a more accurate HVP approximation. This requires two extra gradient evaluations. Defaults to "autograd".
  • h (float, default: 0.001 ) –

    The step size for finite differences if :code:hvp_method is "forward" or "central". Defaults to 1e-3.

  • warm_start (bool, default: False ) –

    If True, the conjugate gradient solver is initialized with the solution from the previous optimization step. This can accelerate convergence, especially in truncated Newton methods. Defaults to False.

  • inner (Chainable | None, default: None ) –

    NewtonCG will attempt to apply preconditioning to the output of this module.

Examples: Newton-CG with a backtracking line search:

opt = tz.Modular(
    model.parameters(),
    tz.m.NewtonCG(),
    tz.m.Backtracking()
)

Truncated Newton method (useful for large-scale problems):

opt = tz.Modular(
    model.parameters(),
    tz.m.NewtonCG(maxiter=10),
    tz.m.Backtracking()
)

Source code in torchzero/modules/second_order/newton_cg.py
class NewtonCG(Module):
    """Newton's method with a matrix-free conjugate gradient or minimial-residual solver.

    Notes:
        * In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.

        * This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

    Warning:
        CG may fail if hessian is not positive-definite.

    Args:
        maxiter (int | None, optional):
            Maximum number of iterations for the conjugate gradient solver.
            By default, this is set to the number of dimensions in the
            objective function, which is the theoretical upper bound for CG
            convergence. Setting this to a smaller value (truncated Newton)
            can still generate good search directions. Defaults to None.
        tol (float, optional):
            Relative tolerance for the conjugate gradient solver to determine
            convergence. Defaults to 1e-4.
        reg (float, optional):
            Regularization parameter (damping) added to the Hessian diagonal.
            This helps ensure the system is positive-definite. Defaults to 1e-8.
        hvp_method (str, optional):
            Determines how Hessian-vector products are evaluated.

            - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
              This requires creating a graph for the gradient.
            - ``"forward"``: Use a forward finite difference formula to
              approximate the HVP. This requires one extra gradient evaluation.
            - ``"central"``: Use a central finite difference formula for a
              more accurate HVP approximation. This requires two extra
              gradient evaluations.
            Defaults to "autograd".
        h (float, optional):
            The step size for finite differences if :code:`hvp_method` is
            ``"forward"`` or ``"central"``. Defaults to 1e-3.
        warm_start (bool, optional):
            If ``True``, the conjugate gradient solver is initialized with the
            solution from the previous optimization step. This can accelerate
            convergence, especially in truncated Newton methods.
            Defaults to False.
        inner (Chainable | None, optional):
            NewtonCG will attempt to apply preconditioning to the output of this module.

    Examples:
    Newton-CG with a backtracking line search:

    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.NewtonCG(),
        tz.m.Backtracking()
    )
    ```

    Truncated Newton method (useful for large-scale problems):
    ```
    opt = tz.Modular(
        model.parameters(),
        tz.m.NewtonCG(maxiter=10),
        tz.m.Backtracking()
    )
    ```

    """
    def __init__(
        self,
        maxiter: int | None = None,
        tol: float = 1e-8,
        reg: float = 1e-8,
        hvp_method: Literal["forward", "central", "autograd"] = "autograd",
        solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
        h: float = 1e-3, # tuned 1e-4 or 1e-3
        miniter:int = 1,
        warm_start=False,
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner']
        super().__init__(defaults,)

        if inner is not None:
            self.set_child('inner', inner)

        self._num_hvps = 0
        self._num_hvps_last_step = 0

    @torch.no_grad
    def step(self, var):
        params = TensorList(var.params)
        closure = var.closure
        if closure is None: raise RuntimeError('NewtonCG requires closure')

        settings = self.settings[params[0]]
        tol = settings['tol']
        reg = settings['reg']
        maxiter = settings['maxiter']
        hvp_method = settings['hvp_method']
        solver = settings['solver'].lower().strip()
        h = settings['h']
        warm_start = settings['warm_start']

        self._num_hvps_last_step = 0
        # ---------------------- Hessian vector product function --------------------- #
        if hvp_method == 'autograd':
            grad = var.get_grad(create_graph=True)

            def H_mm(x):
                self._num_hvps_last_step += 1
                with torch.enable_grad():
                    return TensorList(hvp(params, grad, x, retain_graph=True))

        else:

            with torch.enable_grad():
                grad = var.get_grad()

            if hvp_method == 'forward':
                def H_mm(x):
                    self._num_hvps_last_step += 1
                    return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])

            elif hvp_method == 'central':
                def H_mm(x):
                    self._num_hvps_last_step += 1
                    return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])

            else:
                raise ValueError(hvp_method)


        # -------------------------------- inner step -------------------------------- #
        b = var.get_update()
        if 'inner' in self.children:
            b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
        b = as_tensorlist(b)

        # ---------------------------------- run cg ---------------------------------- #
        x0 = None
        if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway

        if solver == 'cg':
            d, _ = cg(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, miniter=self.defaults["miniter"],reg=reg)

        elif solver == 'minres':
            d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)

        elif solver == 'minres_npc':
            d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)

        else:
            raise ValueError(f"Unknown solver {solver}")

        if warm_start:
            assert x0 is not None
            x0.copy_(d)

        var.update = d

        self._num_hvps += self._num_hvps_last_step
        return var

NewtonCGSteihaug

Bases: torchzero.core.module.Module

Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.

Notes
  • In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply Newton preconditioning to another module's output.

  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

Parameters:

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. Defaults to 0.0.

  • nplus (float, default: 3.5 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.99 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.0001 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • max_attempts (max_attempts, default: 100 ) –

    maximum number of trust radius reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • max_history (int, default: 100 ) –

    CG will store this many intermediate solutions, reusing them when trust radius is reduced instead of re-running CG. Each solution storage requires 2N memory. Defaults to 100.

  • boundary_tol (float | None, default: 1e-06 ) –

    The trust region only increases when suggested step's norm is at least (1-boundary_tol)*trust_region. This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.

  • maxiter (int | None, default: None ) –

    maximum number of CG iterations per step. Each iteration requies one backward pass if hvp_method="forward", two otherwise. Defaults to None.

  • miniter (int, default: 1 ) –

    minimal number of CG iterations. This prevents making no progress

  • tol (float, default: 1e-08 ) –

    terminates CG when norm of the residual is less than this value. Defaults to 1e-8. when initial guess is below tolerance. Defaults to 1.

  • reg (float, default: 1e-08 ) –

    hessian regularization. Defaults to 1e-8.

  • solver (str, default: 'cg' ) –

    solver, "cg" or "minres". "cg" is recommended. Defaults to 'cg'.

  • adapt_tol (bool, default: True ) –

    if True, whenever trust radius collapses to smallest representable number, the tolerance is multiplied by 0.1. Defaults to True.

  • npc_terminate (bool, default: False ) –

    whether to terminate CG/MINRES whenever negative curvature is detected. Defaults to False.

  • hvp_method (str, default: 'central' ) –

    either "forward" to use forward formula which requires one backward pass per Hvp, or "central" to use a more accurate central formula which requires two backward passes. "forward" is usually accurate enough. Defaults to "forward".

  • h (float, default: 0.001 ) –

    finite difference step size. Defaults to 1e-3.

  • inner (Chainable | None, default: None ) –

    applies preconditioning to output of this module. Defaults to None.

Examples:

Trust-region Newton-CG:

opt = tz.Modular(
    model.parameters(),
    tz.m.NewtonCGSteihaug(),
)
Reference:
Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
Source code in torchzero/modules/second_order/newton_cg.py
class NewtonCGSteihaug(Module):
    """Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.

    Notes:
        * In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.

        * This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted. Defaults to 0.0.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        max_attempts (max_attempts, optional):
            maximum number of trust radius reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        max_history (int, optional):
            CG will store this many intermediate solutions, reusing them when trust radius is reduced
            instead of re-running CG. Each solution storage requires 2N memory. Defaults to 100.
        boundary_tol (float | None, optional):
            The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
            This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.

        maxiter (int | None, optional):
            maximum number of CG iterations per step. Each iteration requies one backward pass if `hvp_method="forward"`, two otherwise. Defaults to None.
        miniter (int, optional):
            minimal number of CG iterations. This prevents making no progress
        tol (float, optional):
            terminates CG when norm of the residual is less than this value. Defaults to 1e-8.
            when initial guess is below tolerance. Defaults to 1.
        reg (float, optional): hessian regularization. Defaults to 1e-8.
        solver (str, optional): solver, "cg" or "minres". "cg" is recommended. Defaults to 'cg'.
        adapt_tol (bool, optional):
            if True, whenever trust radius collapses to smallest representable number,
            the tolerance is multiplied by 0.1. Defaults to True.
        npc_terminate (bool, optional):
            whether to terminate CG/MINRES whenever negative curvature is detected. Defaults to False.

        hvp_method (str, optional):
            either "forward" to use forward formula which requires one backward pass per Hvp, or "central" to use a more accurate central formula which requires two backward passes. "forward" is usually accurate enough. Defaults to "forward".
        h (float, optional): finite difference step size. Defaults to 1e-3.

        inner (Chainable | None, optional):
            applies preconditioning to output of this module. Defaults to None.

    ### Examples:
    Trust-region Newton-CG:

    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.NewtonCGSteihaug(),
    )
    ```

    ### Reference:
        Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
    """
    def __init__(
        self,
        # trust region settings
        eta: float= 0.0,
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        init: float = 1,
        max_attempts: int = 100,
        max_history: int = 100,
        boundary_tol: float = 1e-6, # tuned

        # cg settings
        maxiter: int | None = None,
        miniter: int = 1,
        tol: float = 1e-8,
        reg: float = 1e-8,
        solver: Literal['cg', "minres"] = 'cg',
        adapt_tol: bool = True,
        npc_terminate: bool = False,

        # hvp settings
        hvp_method: Literal["forward", "central"] = "central",
        h: float = 1e-3, # tuned 1e-4 or 1e-3

        # inner
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner']
        super().__init__(defaults,)

        if inner is not None:
            self.set_child('inner', inner)

        self._num_hvps = 0
        self._num_hvps_last_step = 0

    @torch.no_grad
    def step(self, var):
        params = TensorList(var.params)
        closure = var.closure
        if closure is None: raise RuntimeError('NewtonCG requires closure')

        tol = self.defaults['tol'] * self.global_state.get('tol_mul', 1)
        solver = self.defaults['solver'].lower().strip()

        (reg, maxiter, hvp_method, h, max_attempts, boundary_tol,
         eta, nplus, nminus, rho_good, rho_bad, init, npc_terminate,
         miniter, max_history, adapt_tol) = itemgetter(
             "reg", "maxiter", "hvp_method", "h", "max_attempts", "boundary_tol",
             "eta", "nplus", "nminus", "rho_good", "rho_bad", "init", "npc_terminate",
             "miniter", "max_history", "adapt_tol",
        )(self.defaults)

        self._num_hvps_last_step = 0

        # ---------------------- Hessian vector product function --------------------- #
        if hvp_method == 'autograd':
            grad = var.get_grad(create_graph=True)

            def H_mm(x):
                self._num_hvps_last_step += 1
                with torch.enable_grad():
                    return TensorList(hvp(params, grad, x, retain_graph=True))

        else:

            with torch.enable_grad():
                grad = var.get_grad()

            if hvp_method == 'forward':
                def H_mm(x):
                    self._num_hvps_last_step += 1
                    return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])

            elif hvp_method == 'central':
                def H_mm(x):
                    self._num_hvps_last_step += 1
                    return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])

            else:
                raise ValueError(hvp_method)


        # -------------------------------- inner step -------------------------------- #
        b = var.get_update()
        if 'inner' in self.children:
            b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
        b = as_tensorlist(b)

        # ------------------------------- trust region ------------------------------- #
        success = False
        d = None
        x0 = [p.clone() for p in params]
        solution = None

        while not success:
            max_attempts -= 1
            if max_attempts < 0: break

            trust_radius = self.global_state.get('trust_radius', init)

            # -------------- make sure trust radius isn't too small or large ------------- #
            finfo = torch.finfo(x0[0].dtype)
            if trust_radius < finfo.tiny * 2:
                trust_radius = self.global_state['trust_radius'] = init
                if adapt_tol:
                    self.global_state["tol_mul"] = self.global_state.get("tol_mul", 1) * 0.1

            elif trust_radius > finfo.max / 2:
                trust_radius = self.global_state['trust_radius'] = init

            # ----------------------------------- solve ---------------------------------- #
            d = None
            if solution is not None and solution.history is not None:
                d = find_within_trust_radius(solution.history, trust_radius)

            if d is None:
                if solver == 'cg':
                    d, solution = cg(
                        A_mm=H_mm,
                        b=b,
                        tol=tol,
                        maxiter=maxiter,
                        reg=reg,
                        trust_radius=trust_radius,
                        miniter=miniter,
                        npc_terminate=npc_terminate,
                        history_size=max_history,
                    )

                elif solver == 'minres':
                    d = minres(A_mm=H_mm, b=b, trust_radius=trust_radius, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)

                else:
                    raise ValueError(f"unknown solver {solver}")

            # ---------------------------- update trust radius --------------------------- #
            self.global_state["trust_radius"], success = default_radius(
                params=params,
                closure=closure,
                f=tofloat(var.get_loss(False)),
                g=b,
                H=H_mm,
                d=d,
                trust_radius=trust_radius,
                eta=eta,
                nplus=nplus,
                nminus=nminus,
                rho_good=rho_good,
                rho_bad=rho_bad,
                boundary_tol=boundary_tol,

                init=init, # init isn't used because check_overflow=False
                state=self.global_state, # not used
                settings=self.defaults, # not used
                check_overflow=False, # this is checked manually to adapt tolerance
            )

        # --------------------------- assign new direction --------------------------- #
        assert d is not None
        if success:
            var.update = d

        else:
            var.update = params.zeros_like()

        self._num_hvps += self._num_hvps_last_step
        return var

NoiseSign

Bases: torchzero.core.transform.Transform

Outputs random tensors with sign copied from the update.

Source code in torchzero/modules/misc/misc.py
class NoiseSign(Transform):
    """Outputs random tensors with sign copied from the update."""
    def __init__(self, distribution:Distributions = 'normal', variance:float | None = None):
        defaults = dict(distribution=distribution, variance=variance)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        variance = unpack_dicts(settings, 'variance')
        return TensorList(tensors).sample_like(settings[0]['distribution'], variance=variance).copysign_(tensors)

Noop

Bases: torchzero.core.module.Module

Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

Source code in torchzero/modules/ops/utility.py
class Identity(Module):
    """Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
    def __init__(self, *args, **kwargs): super().__init__()
    def step(self, var): return var
    def get_H(self, var):
        n = sum(p.numel() for p in var.params)
        p = var.params[0]
        return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)

Normalize

Bases: torchzero.core.transform.Transform

Normalizes the update.

Parameters:

  • norm_value (float, default: 1 ) –

    desired norm value.

  • ord (float, default: 2 ) –

    norm order. Defaults to 2.

  • dim (int | Sequence[int] | str | None, default: None ) –

    calculates norm along those dimensions. If list/tuple, tensors are normalized along all dimensios in dim that they have. Can be set to "global" to normalize by global norm of all gradients concatenated to a vector. Defaults to None.

  • inverse_dims (bool, default: False ) –

    if True, the dims argument is inverted, and all other dimensions are normalized.

  • min_size (int, default: 1 ) –

    minimal size of a dimension to normalize along it. Defaults to 1.

  • target (str, default: 'update' ) –

    what this affects.

Examples: Gradient normalization:

opt = tz.Modular(
    model.parameters(),
    tz.m.Normalize(1),
    tz.m.Adam(),
    tz.m.LR(1e-2),
)

Update normalization:

opt = tz.Modular(
    model.parameters(),
    tz.m.Adam(),
    tz.m.Normalize(1),
    tz.m.LR(1e-2),
)
Source code in torchzero/modules/clipping/clipping.py
class Normalize(Transform):
    """Normalizes the update.

    Args:
        norm_value (float): desired norm value.
        ord (float, optional): norm order. Defaults to 2.
        dim (int | Sequence[int] | str | None, optional):
            calculates norm along those dimensions.
            If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
            Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
            Defaults to None.
        inverse_dims (bool, optional):
            if True, the `dims` argument is inverted, and all other dimensions are normalized.
        min_size (int, optional):
            minimal size of a dimension to normalize along it. Defaults to 1.
        target (str, optional):
            what this affects.

    Examples:
    Gradient normalization:
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.Normalize(1),
        tz.m.Adam(),
        tz.m.LR(1e-2),
    )
    ```

    Update normalization:

    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.Adam(),
        tz.m.Normalize(1),
        tz.m.LR(1e-2),
    )
    ```
    """
    def __init__(
        self,
        norm_value: float = 1,
        ord: Metrics = 2,
        dim: int | Sequence[int] | Literal["global"] | None = None,
        inverse_dims: bool = False,
        min_size: int = 1,
        target: Target = "update",
    ):
        defaults = dict(norm_value=norm_value,ord=ord,dim=dim,min_size=min_size, inverse_dims=inverse_dims)
        super().__init__(defaults, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        norm_value = NumberList(s['norm_value'] for s in settings)
        ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])

        _clip_norm_(
            tensors_ = TensorList(tensors),
            min = None,
            max = None,
            norm_value = norm_value,
            ord = ord,
            dim = dim,
            inverse_dims=inverse_dims,
            min_size = min_size,
        )

        return tensors

NormalizeByEMA

Bases: torchzero.modules.clipping.ema_clipping.ClipNormByEMA

Sets norm of the update to be the same as the norm of an exponential moving average of past updates.

Parameters:

  • beta (float, default: 0.99 ) –

    beta for the exponential moving average. Defaults to 0.99.

  • ord (float, default: 2 ) –

    order of the norm. Defaults to 2.

  • eps (float, default: 1e-06 ) –

    epsilon for division. Defaults to 1e-6.

  • tensorwise (bool, default: True ) –

    if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.

  • max_ema_growth (float | None, default: 1.5 ) –

    if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.

  • ema_init (str, default: 'zeros' ) –

    How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.

Source code in torchzero/modules/clipping/ema_clipping.py
class NormalizeByEMA(ClipNormByEMA):
    """Sets norm of the update to be the same as the norm of an exponential moving average of past updates.

    Args:
        beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
        ord (float, optional): order of the norm. Defaults to 2.
        eps (float, optional): epsilon for division. Defaults to 1e-6.
        tensorwise (bool, optional):
            if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
        max_ema_growth (float | None, optional):
            if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
        ema_init (str, optional):
            How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
    """
    NORMALIZE = True

NORMALIZE class-attribute

NORMALIZE = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

NystromPCG

Bases: torchzero.core.module.Module

Newton's method with a Nyström-preconditioned conjugate gradient solver. This tends to outperform NewtonCG but requires tuning sketch size. An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.

.. note:: This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

.. note:: In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the :code:inner argument if you wish to apply Newton preconditioning to another module's output.

Parameters:

  • sketch_size (int) –

    size of the sketch for preconditioning, this many hessian-vector products will be evaluated before running the conjugate gradient solver. Larger value improves the preconditioning and speeds up conjugate gradient.

  • maxiter (int | None, default: None ) –

    maximum number of iterations. By default this is set to the number of dimensions in the objective function, which is supposed to be enough for conjugate gradient to have guaranteed convergence. Setting this to a small value can still generate good enough directions. Defaults to None.

  • tol (float, default: 0.001 ) –

    relative tolerance for conjugate gradient solver. Defaults to 1e-4.

  • reg (float, default: 1e-06 ) –

    regularization parameter. Defaults to 1e-8.

  • hvp_method (str, default: 'autograd' ) –

    Determines how Hessian-vector products are evaluated.

    • "autograd": Use PyTorch's autograd to calculate exact HVPs. This requires creating a graph for the gradient.
    • "forward": Use a forward finite difference formula to approximate the HVP. This requires one extra gradient evaluation.
    • "central": Use a central finite difference formula for a more accurate HVP approximation. This requires two extra gradient evaluations. Defaults to "autograd".
  • h (float, default: 0.001 ) –

    finite difference step size if :code:hvp_method is "forward" or "central". Defaults to 1e-3.

  • inner (Chainable | None, default: None ) –

    modules to apply hessian preconditioner to. Defaults to None.

  • seed (int | None, default: None ) –

    seed for random generator. Defaults to None.

Examples:

NystromPCG with backtracking line search

.. code-block:: python

    opt = tz.Modular(
        model.parameters(),
        tz.m.NystromPCG(10),
        tz.m.Backtracking()
    )
Reference

Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820

Source code in torchzero/modules/second_order/nystrom.py
class NystromPCG(Module):
    """Newton's method with a Nyström-preconditioned conjugate gradient solver.
    This tends to outperform NewtonCG but requires tuning sketch size.
    An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.

    .. note::
        This module requires the a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients for calculating HVPs.
        The closure must accept a ``backward`` argument (refer to documentation).

    .. note::
        In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.

    Args:
        sketch_size (int):
            size of the sketch for preconditioning, this many hessian-vector products will be evaluated before
            running the conjugate gradient solver. Larger value improves the preconditioning and speeds up
            conjugate gradient.
        maxiter (int | None, optional):
            maximum number of iterations. By default this is set to the number of dimensions
            in the objective function, which is supposed to be enough for conjugate gradient
            to have guaranteed convergence. Setting this to a small value can still generate good enough directions.
            Defaults to None.
        tol (float, optional): relative tolerance for conjugate gradient solver. Defaults to 1e-4.
        reg (float, optional): regularization parameter. Defaults to 1e-8.
        hvp_method (str, optional):
            Determines how Hessian-vector products are evaluated.

            - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
              This requires creating a graph for the gradient.
            - ``"forward"``: Use a forward finite difference formula to
              approximate the HVP. This requires one extra gradient evaluation.
            - ``"central"``: Use a central finite difference formula for a
              more accurate HVP approximation. This requires two extra
              gradient evaluations.
            Defaults to "autograd".
        h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
        inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
        seed (int | None, optional): seed for random generator. Defaults to None.

    Examples:

        NystromPCG with backtracking line search

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.NystromPCG(10),
                tz.m.Backtracking()
            )

    Reference:
        Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820

    """
    def __init__(
        self,
        sketch_size: int,
        maxiter=None,
        tol=1e-3,
        reg: float = 1e-6,
        hvp_method: Literal["forward", "central", "autograd"] = "autograd",
        h=1e-3,
        inner: Chainable | None = None,
        seed: int | None = None,
    ):
        defaults = dict(sketch_size=sketch_size, reg=reg, maxiter=maxiter, tol=tol, hvp_method=hvp_method, h=h, seed=seed)
        super().__init__(defaults,)

        if inner is not None:
            self.set_child('inner', inner)

    @torch.no_grad
    def step(self, var):
        params = TensorList(var.params)

        closure = var.closure
        if closure is None: raise RuntimeError('NewtonCG requires closure')

        settings = self.settings[params[0]]
        sketch_size = settings['sketch_size']
        maxiter = settings['maxiter']
        tol = settings['tol']
        reg = settings['reg']
        hvp_method = settings['hvp_method']
        h = settings['h']


        seed = settings['seed']
        generator = None
        if seed is not None:
            if 'generator' not in self.global_state:
                self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
            generator = self.global_state['generator']


        # ---------------------- Hessian vector product function --------------------- #
        if hvp_method == 'autograd':
            grad = var.get_grad(create_graph=True)

            def H_mm(x):
                with torch.enable_grad():
                    Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
                    return torch.cat([t.ravel() for t in Hvp])

        else:

            with torch.enable_grad():
                grad = var.get_grad()

            if hvp_method == 'forward':
                def H_mm(x):
                    Hvp = hvp_fd_forward(closure, params, params.from_vec(x), h=h, g_0=grad, normalize=True)[1]
                    return torch.cat([t.ravel() for t in Hvp])

            elif hvp_method == 'central':
                def H_mm(x):
                    Hvp = hvp_fd_central(closure, params, params.from_vec(x), h=h, normalize=True)[1]
                    return torch.cat([t.ravel() for t in Hvp])

            else:
                raise ValueError(hvp_method)


        # -------------------------------- inner step -------------------------------- #
        b = var.get_update()
        if 'inner' in self.children:
            b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)

        # ------------------------------ sketch&n&solve ------------------------------ #
        x = nystrom_pcg(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), sketch_size=sketch_size, reg=reg, tol=tol, maxiter=maxiter, x0_=None, generator=generator)
        var.update = vec_to_tensors(x, reference=params)
        return var

NystromSketchAndSolve

Bases: torchzero.core.module.Module

Newton's method with a Nyström sketch-and-solve solver.

.. note:: This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

.. note:: In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the :code:inner argument if you wish to apply Newton preconditioning to another module's output.

.. note:: If this is unstable, increase the :code:reg parameter and tune the rank.

.. note: :code:tz.m.NystromPCG usually outperforms this.

Parameters:

  • rank (int) –

    size of the sketch, this many hessian-vector products will be evaluated per step.

  • reg (float, default: 0.001 ) –

    regularization parameter. Defaults to 1e-3.

  • hvp_method (str, default: 'autograd' ) –

    Determines how Hessian-vector products are evaluated.

    • "autograd": Use PyTorch's autograd to calculate exact HVPs. This requires creating a graph for the gradient.
    • "forward": Use a forward finite difference formula to approximate the HVP. This requires one extra gradient evaluation.
    • "central": Use a central finite difference formula for a more accurate HVP approximation. This requires two extra gradient evaluations. Defaults to "autograd".
  • h (float, default: 0.001 ) –

    finite difference step size if :code:hvp_method is "forward" or "central". Defaults to 1e-3.

  • inner (Chainable | None, default: None ) –

    modules to apply hessian preconditioner to. Defaults to None.

  • seed (int | None, default: None ) –

    seed for random generator. Defaults to None.

Examples:

NystromSketchAndSolve with backtracking line search

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.NystromSketchAndSolve(10),
    tz.m.Backtracking()
)
Reference

Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820

Source code in torchzero/modules/second_order/nystrom.py
class NystromSketchAndSolve(Module):
    """Newton's method with a Nyström sketch-and-solve solver.

    .. note::
        This module requires the a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients for calculating HVPs.
        The closure must accept a ``backward`` argument (refer to documentation).

    .. note::
        In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.

    .. note::
        If this is unstable, increase the :code:`reg` parameter and tune the rank.

    .. note:
        :code:`tz.m.NystromPCG` usually outperforms this.

    Args:
        rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
        reg (float, optional): regularization parameter. Defaults to 1e-3.
        hvp_method (str, optional):
            Determines how Hessian-vector products are evaluated.

            - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
              This requires creating a graph for the gradient.
            - ``"forward"``: Use a forward finite difference formula to
              approximate the HVP. This requires one extra gradient evaluation.
            - ``"central"``: Use a central finite difference formula for a
              more accurate HVP approximation. This requires two extra
              gradient evaluations.
            Defaults to "autograd".
        h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
        inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
        seed (int | None, optional): seed for random generator. Defaults to None.

    Examples:
        NystromSketchAndSolve with backtracking line search

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.NystromSketchAndSolve(10),
                tz.m.Backtracking()
            )

    Reference:
        Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
    """
    def __init__(
        self,
        rank: int,
        reg: float = 1e-3,
        hvp_method: Literal["forward", "central", "autograd"] = "autograd",
        h: float = 1e-3,
        inner: Chainable | None = None,
        seed: int | None = None,
    ):
        defaults = dict(rank=rank, reg=reg, hvp_method=hvp_method, h=h, seed=seed)
        super().__init__(defaults,)

        if inner is not None:
            self.set_child('inner', inner)

    @torch.no_grad
    def step(self, var):
        params = TensorList(var.params)

        closure = var.closure
        if closure is None: raise RuntimeError('NewtonCG requires closure')

        settings = self.settings[params[0]]
        rank = settings['rank']
        reg = settings['reg']
        hvp_method = settings['hvp_method']
        h = settings['h']

        seed = settings['seed']
        generator = None
        if seed is not None:
            if 'generator' not in self.global_state:
                self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
            generator = self.global_state['generator']

        # ---------------------- Hessian vector product function --------------------- #
        if hvp_method == 'autograd':
            grad = var.get_grad(create_graph=True)

            def H_mm(x):
                with torch.enable_grad():
                    Hvp = hvp(params, grad, params.from_vec(x), retain_graph=True)
                    return torch.cat([t.ravel() for t in Hvp])

        else:

            with torch.enable_grad():
                grad = var.get_grad()

            if hvp_method == 'forward':
                def H_mm(x):
                    Hvp = hvp_fd_forward(closure, params, params.from_vec(x), h=h, g_0=grad, normalize=True)[1]
                    return torch.cat([t.ravel() for t in Hvp])

            elif hvp_method == 'central':
                def H_mm(x):
                    Hvp = hvp_fd_central(closure, params, params.from_vec(x), h=h, normalize=True)[1]
                    return torch.cat([t.ravel() for t in Hvp])

            else:
                raise ValueError(hvp_method)


        # -------------------------------- inner step -------------------------------- #
        b = var.get_update()
        if 'inner' in self.children:
            b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)

        # ------------------------------ sketch&n&solve ------------------------------ #
        x = nystrom_sketch_and_solve(A_mm=H_mm, b=torch.cat([t.ravel() for t in b]), rank=rank, reg=reg, generator=generator)
        var.update = vec_to_tensors(x, reference=params)
        return var

Ones

Bases: torchzero.core.module.Module

Outputs ones

Source code in torchzero/modules/ops/utility.py
class Ones(Module):
    """Outputs ones"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def step(self, var):
        var.update = [torch.ones_like(p) for p in var.params]
        return var

Online

Bases: torchzero.core.module.Module

Allows certain modules to be used for mini-batch optimization.

Examples:

Online L-BFGS with Backtracking line search

opt = tz.Modular(
    model.parameters(),
    tz.m.Online(tz.m.LBFGS()),
    tz.m.Backtracking()
)

Online L-BFGS trust region

opt = tz.Modular(
    model.parameters(),
    tz.m.TrustCG(tz.m.Online(tz.m.LBFGS()))
)

Source code in torchzero/modules/misc/multistep.py
class Online(Module):
    """Allows certain modules to be used for mini-batch optimization.

    Examples:

    Online L-BFGS with Backtracking line search
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.Online(tz.m.LBFGS()),
        tz.m.Backtracking()
    )
    ```

    Online L-BFGS trust region
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.TrustCG(tz.m.Online(tz.m.LBFGS()))
    )
    ```

    """
    def __init__(self, *modules: Module,):
        super().__init__()

        self.set_child('module', modules)

    @torch.no_grad
    def update(self, var):
        closure = var.closure
        if closure is None: raise ValueError("Closure must be passed for Online")

        step = self.global_state.get('step', 0) + 1
        self.global_state['step'] = step

        params = TensorList(var.params)
        p_cur = params.clone()
        p_prev = self.get_state(params, 'p_prev', cls=TensorList)

        module = self.children['module']
        var_c = var.clone(clone_update=False)

        # on 1st step just step and store previous params
        if step == 1:
            p_prev.copy_(params)

            module.update(var_c)
            var.update_attrs_from_clone_(var_c)
            return

        # restore previous params and update
        var_prev = Var(params=params, closure=closure, model=var.model, current_step=var.current_step)
        params.set_(p_prev)
        module.reset_for_online()
        module.update(var_prev)

        # restore current params and update
        params.set_(p_cur)
        p_prev.copy_(params)
        module.update(var_c)
        var.update_attrs_from_clone_(var_c)

    @torch.no_grad
    def apply(self, var):
        module = self.children['module']
        return module.apply(var.clone(clone_update=False))

    def get_H(self, var):
        return self.children['module'].get_H(var)

OrthoGrad

Bases: torchzero.core.transform.Transform

Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

Parameters:

  • eps (float, default: 1e-08 ) –

    epsilon added to the denominator for numerical stability (default: 1e-30)

  • renormalize (bool, default: True ) –

    whether to graft projected gradient to original gradient norm. Defaults to True.

  • target (Literal, default: 'update' ) –

    what to set on var. Defaults to 'update'.

Source code in torchzero/modules/adaptive/orthograd.py
class OrthoGrad(Transform):
    """Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

    Args:
        eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
        renormalize (bool, optional): whether to graft projected gradient to original gradient norm. Defaults to True.
        target (Target, optional): what to set on var. Defaults to 'update'.
    """
    def __init__(self, eps: float = 1e-8, renormalize=True, target: Target = 'update'):
        defaults = dict(eps=eps, renormalize=renormalize)
        super().__init__(defaults, uses_grad=False, target=target)

    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        eps = settings[0]['eps']
        renormalize = settings[0]['renormalize']

        params = as_tensorlist(params)
        target = as_tensorlist(tensors)

        scale = params.dot(target)/(params.dot(params) + eps)
        if renormalize:
            norm = target.global_vector_norm()
            target -= params * scale
            target *= (norm / target.global_vector_norm())
            return target

        target -= params * scale
        return target

Orthogonalize

Bases: torchzero.core.transform.TensorwiseTransform

Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.

To disable orthogonalization for a parameter, put it into a parameter group with "orthogonalize" = False. The Muon page says that embeddings and classifier heads should not be orthogonalized. Usually only matrix parameters that are directly used in matmuls should be orthogonalized.

To make Muon, use Split with Adam on 1d params

Parameters:

  • ns_steps (int, default: 5 ) –

    The number of Newton-Schulz iterations to run. Defaults to 5.

  • adjust_lr (bool, default: False ) –

    Enables LR adjustment based on parameter size from "Muon is Scalable for LLM Training". Defaults to False.

  • dual_norm_correction (bool, default: False ) –

    enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.

  • method (str, default: 'newton-schulz' ) –

    Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.

  • target (str, default: 'update' ) –

    what to set on var.

Examples:

standard Muon with Adam fallback

opt = tz.Modular(
    model.head.parameters(),
    tz.m.Split(
        # apply muon only to 2D+ parameters
        filter = lambda t: t.ndim >= 2,
        true = [
            tz.m.HeavyBall(),
            tz.m.Orthogonalize(),
            tz.m.LR(1e-2),
        ],
        false = tz.m.Adam()
    ),
    tz.m.LR(1e-2)
)

Reference

Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon

Source code in torchzero/modules/adaptive/muon.py
class Orthogonalize(TensorwiseTransform):
    """Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.

    To disable orthogonalization for a parameter, put it into a parameter group with "orthogonalize" = False.
    The Muon page says that embeddings and classifier heads should not be orthogonalized.
    Usually only matrix parameters that are directly used in matmuls should be orthogonalized.

    To make Muon, use Split with Adam on 1d params

    Args:
        ns_steps (int, optional):
            The number of Newton-Schulz iterations to run. Defaults to 5.
        adjust_lr (bool, optional):
            Enables LR adjustment based on parameter size from "Muon is Scalable for LLM Training". Defaults to False.
        dual_norm_correction (bool, optional):
            enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
        method (str, optional):
            Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
        target (str, optional):
            what to set on var.

    ## Examples:

    standard Muon with Adam fallback
    ```py
    opt = tz.Modular(
        model.head.parameters(),
        tz.m.Split(
            # apply muon only to 2D+ parameters
            filter = lambda t: t.ndim >= 2,
            true = [
                tz.m.HeavyBall(),
                tz.m.Orthogonalize(),
                tz.m.LR(1e-2),
            ],
            false = tz.m.Adam()
        ),
        tz.m.LR(1e-2)
    )
    ```

    Reference:
        Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
    """
    def __init__(self, ns_steps=5, adjust_lr=False, dual_norm_correction=False,
                 method: Literal['newton-schulz', 'svd'] = 'newton-schulz', target:Target='update'):
        defaults = dict(orthogonalize=True, ns_steps=ns_steps, dual_norm_correction=dual_norm_correction, adjust_lr=adjust_lr, method=method.lower())
        super().__init__(uses_grad=False, defaults=defaults, target=target)

    @torch.no_grad
    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        orthogonalize, ns_steps, dual_norm_correction, adjust_lr, method = itemgetter(
            'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(setting)

        if not orthogonalize: return tensor

        if _is_at_least_2d(tensor):

            X = _orthogonalize_tensor(tensor, ns_steps, method)

            if dual_norm_correction:
                X = _dual_norm_correction(X, tensor, batch_first=False)

            if adjust_lr:
                X.mul_(adjust_lr_for_muon(1, param.shape))

            return X.view_as(param)

        return tensor

PSB

Bases: torchzero.modules.quasi_newton.quasi_newton._HessianUpdateStrategyDefaults

Powell's Symmetric Broyden Quasi-Newton method.

Note

a line search or a trust region is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class PSB(_HessianUpdateStrategyDefaults):
    """Powell's Symmetric Broyden Quasi-Newton method.

    Note:
        a line search or a trust region is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
    """
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return psb_B_(B=B, s=s, y=y)

Params

Bases: torchzero.core.module.Module

Outputs parameters

Source code in torchzero/modules/ops/utility.py
class Params(Module):
    """Outputs parameters"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def step(self, var):
        var.update = [p.clone() for p in var.params]
        return var

Pearson

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Pearson's Quasi-Newton method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class Pearson(_InverseHessianUpdateStrategyDefaults):
    """
    Pearson's Quasi-Newton method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return pearson_H_(H=H, s=s, y=y)

PerturbWeights

Bases: torchzero.core.module.Module

Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.

Can be disabled for a parameter by setting :code:perturb=False in corresponding parameter group.

Parameters:

  • alpha (float, default: 0.1 ) –

    multiplier for perturbation magnitude. Defaults to 0.1.

  • relative (bool, default: True ) –

    whether to multiply perturbation by mean absolute value of the parameter. Defaults to True.

  • distribution (bool, default: 'normal' ) –

    distribution of the random perturbation. Defaults to False.

Source code in torchzero/modules/misc/regularization.py
class PerturbWeights(Module):
    """
    Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.

    Can be disabled for a parameter by setting :code:`perturb=False` in corresponding parameter group.

    Args:
        alpha (float, optional): multiplier for perturbation magnitude. Defaults to 0.1.
        relative (bool, optional): whether to multiply perturbation by mean absolute value of the parameter. Defaults to True.
        distribution (bool, optional):
            distribution of the random perturbation. Defaults to False.
    """
    def __init__(self, alpha: float = 0.1, relative:bool=True, distribution:Distributions = 'normal'):
        defaults = dict(alpha=alpha, relative=relative, distribution=distribution, perturb=True)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        closure = var.closure
        if closure is None: raise RuntimeError('WeightDropout requires closure')
        params = TensorList(var.params)

        # create perturbations
        perts = []
        for p in params:
            settings = self.settings[p]
            if not settings['perturb']:
                perts.append(torch.zeros_like(p))
                continue

            alpha = settings['alpha']
            if settings['relative']:
                alpha *= p.abs().mean()

            distribution = self.settings[p]['distribution'].lower()
            if distribution in ('normal', 'gaussian'):
                perts.append(torch.randn_like(p).mul_(alpha))
            elif distribution == 'uniform':
                perts.append(torch.empty_like(p).uniform_(-alpha,alpha))
            elif distribution == 'sphere':
                r = torch.randn_like(p)
                perts.append((r * alpha) / torch.linalg.vector_norm(r)) # pylint:disable=not-callable
            else:
                raise ValueError(distribution)

        @torch.no_grad
        def perturbed_closure(backward=True):
            params.add_(perts)
            if backward:
                with torch.enable_grad(): loss = closure()
            else:
                loss = closure(False)
            params.sub_(perts)
            return loss

        var.closure = perturbed_closure
        return var

PolakRibiere

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Polak-Ribière-Polyak nonlinear conjugate gradient method.

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class PolakRibiere(ConguateGradientBase):
    """Polak-Ribière-Polyak nonlinear conjugate gradient method.

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
    """
    def __init__(self, clip_beta=True, restart_interval: int | None | Literal['auto'] = 'auto', inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def get_beta(self, p, g, prev_g, prev_d):
        return polak_ribiere_beta(g, prev_g)

PolyakStepSize

Bases: torchzero.core.transform.Transform

Polyak's subgradient method with known or unknown f*.

Parameters:

  • f_star (float | Mone, default: 0 ) –

    minimal possible value of the objective function. If not known, set to None. Defaults to 0.

  • y (float, default: 1 ) –

    when f_star is set to None, it is calculated as f_best - y.

  • y_decay (float, default: 0.001 ) –

    y is multiplied by (1 - y_decay) after each step. Defaults to 1e-3.

  • max (float | None, default: None ) –

    maximum possible step size. Defaults to None.

  • use_grad (bool, default: True ) –

    if True, uses dot product of update and gradient to compute the step size. Otherwise, dot product of update with itself is used.

  • alpha (float, default: 1 ) –

    multiplier to Polyak step-size. Defaults to 1.

Source code in torchzero/modules/step_size/adaptive.py
class PolyakStepSize(Transform):
    """Polyak's subgradient method with known or unknown f*.

    Args:
        f_star (float | Mone, optional):
            minimal possible value of the objective function. If not known, set to ``None``. Defaults to 0.
        y (float, optional):
            when ``f_star`` is set to None, it is calculated as ``f_best - y``.
        y_decay (float, optional):
            ``y`` is multiplied by ``(1 - y_decay)`` after each step. Defaults to 1e-3.
        max (float | None, optional): maximum possible step size. Defaults to None.
        use_grad (bool, optional):
            if True, uses dot product of update and gradient to compute the step size.
            Otherwise, dot product of update with itself is used.
        alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
    """
    def __init__(self, f_star: float | None = 0, y: float = 1, y_decay: float = 1e-3, max: float | None = None, use_grad=True, alpha: float = 1, inner: Chainable | None = None):

        defaults = dict(alpha=alpha, max=max, f_star=f_star, y=y, y_decay=y_decay)
        super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)

    @torch.no_grad
    def update_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None and loss is not None
        tensors = TensorList(tensors)
        grads = TensorList(grads)

        # load variables
        max, f_star, y, y_decay = itemgetter('max', 'f_star', 'y', 'y_decay')(settings[0])
        y_val = self.global_state.get('y_val', y)
        f_best = self.global_state.get('f_best', None)

        # gg
        if self._uses_grad: gg = tensors.dot(grads)
        else: gg = tensors.dot(tensors)

        # store loss
        if f_best is None or loss < f_best: f_best = tofloat(loss)
        if f_star is None: f_star = f_best - y_val

        # calculate the step size
        if gg <= torch.finfo(gg.dtype).tiny * 2: alpha = 0 # converged
        else: alpha = (loss - f_star) / gg

        # clip
        if max is not None:
            if alpha > max: alpha = max

        # store state
        self.global_state['f_best'] = f_best
        self.global_state['y_val'] = y_val * (1 - y_decay)
        self.global_state['alpha'] = alpha

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        alpha = self.global_state.get('alpha', 1)
        if not _acceptable_alpha(alpha, tensors[0]): alpha = epsilon_step_size(TensorList(tensors))

        torch._foreach_mul_(tensors, alpha * unpack_dicts(settings, 'alpha', cls=NumberList))
        return tensors

    def get_H(self, var):
        return _get_H(self, var)

Pow

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Take tensors to the power of :code:exponent. :code:exponent can be a number or a module.

If :code:exponent is a module, this calculates :code:tensors ^ exponent(tensors)

Source code in torchzero/modules/ops/binary.py
class Pow(BinaryOperationBase):
    """Take tensors to the power of :code:`exponent`. :code:`exponent` can be a number or a module.

    If :code:`exponent` is a module, this calculates :code:`tensors ^ exponent(tensors)`
    """
    def __init__(self, exponent: Chainable | float):
        super().__init__({}, exponent=exponent)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
        torch._foreach_pow_(update, exponent)
        return update

PowModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates :code:input ** exponent. :code:input and :code:other can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class PowModules(MultiOperationBase):
    """Calculates :code:`input ** exponent`. :code:`input` and :code:`other` can be numbers or modules."""
    def __init__(self, input: Chainable | float, exponent: Chainable | float):
        defaults = {}
        super().__init__(defaults, input=input, exponent=exponent)

    @torch.no_grad
    def transform(self, var: Var, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
        if isinstance(input, (int,float)):
            assert isinstance(exponent, list)
            return input ** TensorList(exponent)

        torch._foreach_div_(input, exponent)
        return input

PowellRestart

Bases: torchzero.modules.restarts.restars.RestartStrategyBase

Powell's two restarting criterions for conjugate gradient methods.

The restart clears all states of modules.

Parameters:

  • modules (Chainable | None) –

    modules to reset. If None, resets all modules.

  • cond1 (float | None, default: 0.2 ) –

    criterion that checks for nonconjugacy of the search directions. Restart is performed whenevr g^Tg_{k+1} >= cond1*||g_{k+1}||^2. The default condition value of 0.2 is suggested by Powell. Can be None to disable that criterion.

  • cond2 (float | None, default: 0.2 ) –

    criterion that checks if direction is not effectively downhill. Restart is performed if -1.2||g||^2 < d^Tg < -0.8||g||^2. Defaults to 0.2. Can be None to disable that criterion.

Reference

Powell, Michael James David. "Restart procedures for the conjugate gradient method." Mathematical programming 12.1 (1977): 241-254.

Source code in torchzero/modules/restarts/restars.py
class PowellRestart(RestartStrategyBase):
    """Powell's two restarting criterions for conjugate gradient methods.

    The restart clears all states of ``modules``.

    Args:
        modules (Chainable | None):
            modules to reset. If None, resets all modules.
        cond1 (float | None, optional):
            criterion that checks for nonconjugacy of the search directions.
            Restart is performed whenevr g^Tg_{k+1} >= cond1*||g_{k+1}||^2.
            The default condition value of 0.2 is suggested by Powell. Can be None to disable that criterion.
        cond2 (float | None, optional):
            criterion that checks if direction is not effectively downhill.
            Restart is performed if -1.2||g||^2 < d^Tg < -0.8||g||^2.
            Defaults to 0.2. Can be None to disable that criterion.

    Reference:
        Powell, Michael James David. "Restart procedures for the conjugate gradient method." Mathematical programming 12.1 (1977): 241-254.
    """
    def __init__(self, modules: Chainable | None, cond1:float | None = 0.2, cond2:float | None = 0.2):
        defaults=dict(cond1=cond1, cond2=cond2)
        super().__init__(defaults, modules)

    def should_reset(self, var):
        g = TensorList(var.get_grad())
        cond1 = self.defaults['cond1']; cond2 = self.defaults['cond2']

        # -------------------------------- initialize -------------------------------- #
        if 'initialized' not in self.global_state:
            self.global_state['initialized'] = 0
            g_prev = self.get_state(var.params, 'g_prev', init=g)
            return False

        g_g = g.dot(g)

        reset = False
        # ------------------------------- 1st condition ------------------------------ #
        if cond1 is not None:
            g_prev = self.get_state(var.params, 'g_prev', must_exist=True, cls=TensorList)
            g_g_prev = g_prev.dot(g)

            if g_g_prev.abs() >= cond1 * g_g:
                reset = True

        # ------------------------------- 2nd condition ------------------------------ #
        if (cond2 is not None) and (not reset):
            d_g = TensorList(var.get_update()).dot(g)
            if (-1-cond2) * g_g < d_g < (-1 + cond2) * g_g:
                reset = True

        # ------------------------------ clear on reset ------------------------------ #
        if reset:
            self.global_state.clear()
            self.clear_state_keys('g_prev')
            return True

        return False

Previous

Bases: torchzero.core.transform.TensorwiseTransform

Maintains an update from n steps back, for example if n=1, returns previous update

Source code in torchzero/modules/misc/misc.py
class Previous(TensorwiseTransform):
    """Maintains an update from n steps back, for example if n=1, returns previous update"""
    def __init__(self, n=1, target: Target = 'update'):
        defaults = dict(n=n)
        super().__init__(uses_grad=False, defaults=defaults, target=target)


    @torch.no_grad
    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        n = setting['n']

        if 'history' not in state:
            state['history'] = deque(maxlen=n+1)

        state['history'].append(tensor)

        return state['history'][0]

PrintLoss

Bases: torchzero.core.module.Module

Prints var.get_loss().

Source code in torchzero/modules/misc/debug.py
class PrintLoss(Module):
    """Prints var.get_loss()."""
    def __init__(self, text = 'loss = ', print_fn = print):
        defaults = dict(text=text, print_fn=print_fn)
        super().__init__(defaults)

    def step(self, var):
        self.defaults["print_fn"](f'{self.defaults["text"]}{var.get_loss(False)}')
        return var

PrintParams

Bases: torchzero.core.module.Module

Prints current update.

Source code in torchzero/modules/misc/debug.py
class PrintParams(Module):
    """Prints current update."""
    def __init__(self, text = 'params = ', print_fn = print):
        defaults = dict(text=text, print_fn=print_fn)
        super().__init__(defaults)

    def step(self, var):
        self.defaults["print_fn"](f'{self.defaults["text"]}{var.params}')
        return var

PrintShape

Bases: torchzero.core.module.Module

Prints shapes of the update.

Source code in torchzero/modules/misc/debug.py
class PrintShape(Module):
    """Prints shapes of the update."""
    def __init__(self, text = 'shapes = ', print_fn = print):
        defaults = dict(text=text, print_fn=print_fn)
        super().__init__(defaults)

    def step(self, var):
        shapes = [u.shape for u in var.update] if var.update is not None else None
        self.defaults["print_fn"](f'{self.defaults["text"]}{shapes}')
        return var

PrintUpdate

Bases: torchzero.core.module.Module

Prints current update.

Source code in torchzero/modules/misc/debug.py
class PrintUpdate(Module):
    """Prints current update."""
    def __init__(self, text = 'update = ', print_fn = print):
        defaults = dict(text=text, print_fn=print_fn)
        super().__init__(defaults)

    def step(self, var):
        self.defaults["print_fn"](f'{self.defaults["text"]}{var.update}')
        return var

Prod

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs product of :code:inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class Prod(ReduceOperationBase):
    """Outputs product of :code:`inputs` that can be modules or numbers."""
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        prod = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_mul_(prod, v)

        return prod

ProjectedGradientMethod

Bases: torchzero.modules.quasi_newton.quasi_newton.HessianUpdateStrategy

Projected gradient method. Directly projects the gradient onto subspace conjugate to past directions.

Notes
  • This method uses N^2 memory.
  • This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.
  • This is not the same as projected gradient descent.
Reference

Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171. (algorithm 5 in section 6)

Source code in torchzero/modules/conjugate_gradient/cg.py
class ProjectedGradientMethod(HessianUpdateStrategy): # this doesn't maintain hessian
    """Projected gradient method. Directly projects the gradient onto subspace conjugate to past directions.

    Notes:
        - This method uses N^2 memory.
        - This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
        - This is not the same as projected gradient descent.

    Reference:
        Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.  (algorithm 5 in section 6)

    """

    def __init__(
        self,
        init_scale: float | Literal["auto"] = 1,
        tol: float = 1e-32,
        ptol: float | None = 1e-32,
        ptol_restart: bool = False,
        gtol: float | None = 1e-32,
        restart_interval: int | None | Literal['auto'] = 'auto',
        beta: float | None = None,
        update_freq: int = 1,
        scale_first: bool = False,
        concat_params: bool = True,
        # inverse: bool = True,
        inner: Chainable | None = None,
    ):
        super().__init__(
            defaults=None,
            init_scale=init_scale,
            tol=tol,
            ptol=ptol,
            ptol_restart=ptol_restart,
            gtol=gtol,
            restart_interval=restart_interval,
            beta=beta,
            update_freq=update_freq,
            scale_first=scale_first,
            concat_params=concat_params,
            inverse=True,
            inner=inner,
        )



    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return projected_gradient_(H=H, y=y)

ProjectedNewtonRaphson

Bases: torchzero.modules.quasi_newton.quasi_newton.HessianUpdateStrategy

Projected Newton Raphson method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.

This one is Algorithm 7.

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class ProjectedNewtonRaphson(HessianUpdateStrategy):
    """
    Projected Newton Raphson method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.

        This one is Algorithm 7.
    """
    def __init__(
        self,
        init_scale: float | Literal["auto"] = 'auto',
        tol: float = 1e-32,
        ptol: float | None = 1e-32,
        ptol_restart: bool = False,
        gtol: float | None = 1e-32,
        restart_interval: int | None | Literal['auto'] = 'auto',
        beta: float | None = None,
        update_freq: int = 1,
        scale_first: bool = False,
        concat_params: bool = True,
        inner: Chainable | None = None,
    ):
        super().__init__(
            init_scale=init_scale,
            tol=tol,
            ptol = ptol,
            ptol_restart=ptol_restart,
            gtol=gtol,
            restart_interval=restart_interval,
            beta=beta,
            update_freq=update_freq,
            scale_first=scale_first,
            concat_params=concat_params,
            inverse=True,
            inner=inner,
        )

    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
        H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y)
        state["R"] = R
        return H

    def reset_P(self, P, s, y, inverse, init_scale, state):
        assert inverse
        if 'R' not in state: state['R'] = torch.eye(P.size(-1), device=P.device, dtype=P.dtype)
        P.copy_(state["R"])

ProjectionBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for projections. This is an abstract class, to use it, subclass it and override project and unproject.

Parameters:

  • modules (Chainable) –

    modules that will be applied in the projected domain.

  • project_update (bool, default: True ) –

    whether to project the update. Defaults to True.

  • project_params (bool, default: False ) –

    whether to project the params. This is necessary for modules that use closure. Defaults to False.

  • project_grad (bool, default: False ) –

    whether to project the gradients (separately from update). Defaults to False.

  • defaults (dict[str, Any] | None, default: None ) –

    dictionary with defaults. Defaults to None.

Methods:

  • project

    projects tensors. Note that this can be called multiple times per step with params, grads, and update.

  • unproject

    unprojects tensors. Note that this can be called multiple times per step with params, grads, and update.

Source code in torchzero/modules/projections/projection.py
class ProjectionBase(Module, ABC):
    """
    Base class for projections.
    This is an abstract class, to use it, subclass it and override `project` and `unproject`.

    Args:
        modules (Chainable): modules that will be applied in the projected domain.
        project_update (bool, optional): whether to project the update. Defaults to True.
        project_params (bool, optional):
            whether to project the params. This is necessary for modules that use closure. Defaults to False.
        project_grad (bool, optional): whether to project the gradients (separately from update). Defaults to False.
        defaults (dict[str, Any] | None, optional): dictionary with defaults. Defaults to None.
    """

    def __init__(
        self,
        modules: Chainable,
        project_update=True,
        project_params=False,
        project_grad=False,
        defaults: dict[str, Any] | None = None,
    ):
        super().__init__(defaults)
        self.set_child('modules', modules)
        self.global_state['current_step'] = 0
        self._project_update = project_update
        self._project_params = project_params
        self._project_grad = project_grad
        self._projected_params = None

        self._states: dict[str, list[dict[str, Any]]] = {}
        """per-parameter states for each projection target"""

    @abstractmethod
    def project(
        self,
        tensors: list[torch.Tensor],
        params: list[torch.Tensor],
        grads: list[torch.Tensor] | None,
        loss: torch.Tensor | None,
        states: list[dict[str, Any]],
        settings: list[ChainMap[str, Any]],
        current: str,
    ) -> Iterable[torch.Tensor]:
        """projects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""

    @abstractmethod
    def unproject(
        self,
        projected_tensors: list[torch.Tensor],
        params: list[torch.Tensor],
        grads: list[torch.Tensor] | None,
        loss: torch.Tensor | None,
        states: list[dict[str, Any]],
        settings: list[ChainMap[str, Any]],
        current: str,
    ) -> Iterable[torch.Tensor]:
        """unprojects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`.

        Args:
            projected_tensors (list[torch.Tensor]): projected tensors to unproject.
            params (list[torch.Tensor]): original, unprojected parameters.
            grads (list[torch.Tensor] | None): original, unprojected gradients
            loss (torch.Tensor | None): loss at initial point.
            states (list[dict[str, Any]]): list of state dictionaries per each UNPROJECTED tensor.
            settings (list[ChainMap[str, Any]]): list of setting dictionaries per each UNPROJECTED tensor.
            current (str): string representing what is being unprojected, e.g. "params", "grads" or "update".

        Returns:
            Iterable[torch.Tensor]: unprojected tensors of the same shape as params
        """

    @torch.no_grad
    def step(self, var: Var):
        params = var.params
        settings = [self.settings[p] for p in params]

        def _project(tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
            states = self._states.setdefault(current, [{} for _ in params])
            return list(self.project(
                tensors=tensors,
                params=params,
                grads=var.grad,
                loss=var.loss,
                states=states,
                settings=settings,
                current=current,
            ))

        projected_var = var.clone(clone_update=False, parent=var)

        closure = var.closure

        # if this is True, update and grad were projected simultaneously under current="grads"
        # so update will have to be unprojected with current="grads"
        update_is_grad = False

        # if closure is provided and project_params=True, make new closure that evaluates projected params
        # that also means projected modules can evaluate grad/update at will, it shouldn't be computed here
        # but if it has already been computed, it should be projected
        if self._project_params and closure is not None:

            if self._project_update and var.update is not None:
                # project update only if it already exists
                projected_var.update = _project(var.update, current='update')

            else:
                # update will be set to gradients on var.get_grad()
                # therefore projection will happen with current="grads"
                update_is_grad = True

            # project grad only if it already exists
            if self._project_grad and var.grad is not None:
                projected_var.grad = _project(var.grad, current='grads')

        # otherwise update/grad needs to be calculated and projected here
        else:
            if self._project_update:
                if var.update is None:
                    # update is None, meaning it will be set to `grad`.
                    # we can project grad and use it for update
                    grad = var.get_grad()
                    projected_var.grad = _project(grad, current='grads')
                    projected_var.update = [g.clone() for g in projected_var.grad]
                    del var.update
                    update_is_grad = True

                else:
                    # update exists so it needs to be projected
                    update = var.get_update()
                    projected_var.update = _project(update, current='update')
                    del update, var.update

            if self._project_grad and projected_var.grad is None:
                # projected_vars.grad may have been projected simultaneously with update
                # but if that didn't happen, it is projected here
                grad = var.get_grad()
                projected_var.grad = _project(grad, current='grads')


        original_params = None
        if self._project_params:
            original_params = [p.clone() for p in var.params]
            projected_params = _project(var.params, current='params')

        else:
            # make fake params for correct shapes and state storage
            # they reuse update or grad storage for memory efficiency
            projected_params = projected_var.update if projected_var.update is not None else projected_var.grad
            assert projected_params is not None

        if self._projected_params is None:
            # 1st step - create objects for projected_params. They have to remain the same python objects
            # to support per-parameter states which are stored by ids.
            self._projected_params = [p.view_as(p).requires_grad_() for p in projected_params]
        else:
            # set storage to new fake params while ID remains the same
            for empty_p, new_p in zip(self._projected_params, projected_params):
                empty_p.set_(new_p.view_as(new_p).requires_grad_()) # pyright: ignore[reportArgumentType]

        projected_params = self._projected_params
        # projected_settings = [self.settings[p] for p in projected_params]

        def _unproject(projected_tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
            states = self._states.setdefault(current, [{} for _ in params])
            return list(self.unproject(
                projected_tensors=projected_tensors,
                params=params,
                grads=var.grad,
                loss=var.loss,
                states=states,
                settings=settings,
                current=current,
            ))

        # project closure
        if self._project_params:
            projected_var.closure = _make_projected_closure(closure, project_fn=_project, unproject_fn=_unproject,
                                                            params=params, projected_params=projected_params)

        elif closure is not None:
            projected_var.closure = _FakeProjectedClosure(closure, project_fn=_project,
                                                          params=params, fake_params=projected_params)

        else:
            projected_var.closure = None

        # ----------------------------------- step ----------------------------------- #
        projected_var.params = projected_params
        projected_var = self.children['modules'].step(projected_var)

        # empty fake params storage
        # this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
        if not self._project_params:
            for p in self._projected_params:
                set_storage_(p, torch.empty(0, device=p.device, dtype=p.dtype))

        # --------------------------------- unproject -------------------------------- #
        unprojected_var = projected_var.clone(clone_update=False)
        unprojected_var.closure = var.closure
        unprojected_var.params = var.params
        unprojected_var.grad = var.grad # this may also be set by projected_var since it has var as parent

        if self._project_update:
            assert projected_var.update is not None
            unprojected_var.update = _unproject(projected_var.update, current='grads' if update_is_grad else 'update')
            del projected_var.update

        del projected_var

        # original params are stored if params are projected
        if original_params is not None:
            for p, o in zip(unprojected_var.params, original_params):
                p.set_(o) # pyright: ignore[reportArgumentType]

        return unprojected_var

project

project(tensors: list[Tensor], params: list[Tensor], grads: list[Tensor] | None, loss: Tensor | None, states: list[dict[str, Any]], settings: list[ChainMap[str, Any]], current: str) -> Iterable[Tensor]

projects tensors. Note that this can be called multiple times per step with params, grads, and update.

Source code in torchzero/modules/projections/projection.py
@abstractmethod
def project(
    self,
    tensors: list[torch.Tensor],
    params: list[torch.Tensor],
    grads: list[torch.Tensor] | None,
    loss: torch.Tensor | None,
    states: list[dict[str, Any]],
    settings: list[ChainMap[str, Any]],
    current: str,
) -> Iterable[torch.Tensor]:
    """projects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""

unproject

unproject(projected_tensors: list[Tensor], params: list[Tensor], grads: list[Tensor] | None, loss: Tensor | None, states: list[dict[str, Any]], settings: list[ChainMap[str, Any]], current: str) -> Iterable[Tensor]

unprojects tensors. Note that this can be called multiple times per step with params, grads, and update.

Parameters:

  • projected_tensors (list[Tensor]) –

    projected tensors to unproject.

  • params (list[Tensor]) –

    original, unprojected parameters.

  • grads (list[Tensor] | None) –

    original, unprojected gradients

  • loss (Tensor | None) –

    loss at initial point.

  • states (list[dict[str, Any]]) –

    list of state dictionaries per each UNPROJECTED tensor.

  • settings (list[ChainMap[str, Any]]) –

    list of setting dictionaries per each UNPROJECTED tensor.

  • current (str) –

    string representing what is being unprojected, e.g. "params", "grads" or "update".

Returns:

  • Iterable[Tensor]

    Iterable[torch.Tensor]: unprojected tensors of the same shape as params

Source code in torchzero/modules/projections/projection.py
@abstractmethod
def unproject(
    self,
    projected_tensors: list[torch.Tensor],
    params: list[torch.Tensor],
    grads: list[torch.Tensor] | None,
    loss: torch.Tensor | None,
    states: list[dict[str, Any]],
    settings: list[ChainMap[str, Any]],
    current: str,
) -> Iterable[torch.Tensor]:
    """unprojects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`.

    Args:
        projected_tensors (list[torch.Tensor]): projected tensors to unproject.
        params (list[torch.Tensor]): original, unprojected parameters.
        grads (list[torch.Tensor] | None): original, unprojected gradients
        loss (torch.Tensor | None): loss at initial point.
        states (list[dict[str, Any]]): list of state dictionaries per each UNPROJECTED tensor.
        settings (list[ChainMap[str, Any]]): list of setting dictionaries per each UNPROJECTED tensor.
        current (str): string representing what is being unprojected, e.g. "params", "grads" or "update".

    Returns:
        Iterable[torch.Tensor]: unprojected tensors of the same shape as params
    """

RCopySign

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Returns :code:other(tensors) with sign copied from tensors.

Source code in torchzero/modules/ops/binary.py
class RCopySign(BinaryOperationBase):
    """Returns :code:`other(tensors)` with sign copied from tensors."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
        return [o.copysign_(u) for u, o in zip(update, other)]

RDSA

Bases: torchzero.modules.grad_approximation.rfdm.RandomizedFDM

Gradient approximation via Random-direction stochastic approximation (RDSA) method.

Note

This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

Parameters:

  • h (float, default: 0.001 ) –

    finite difference step size of jvp_method is set to forward or central. Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of random gradient samples. Defaults to 1.

  • formula (Literal, default: 'central2' ) –

    finite difference formula. Defaults to 'central2'.

  • distribution (Literal, default: 'gaussian' ) –

    distribution. Defaults to "gaussian".

  • beta (float, default: 0 ) –

    If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.

  • pre_generate (bool, default: True ) –

    whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.

  • seed (int | None | Generator, default: None ) –

    Seed for random generator. Defaults to None.

  • target (Literal, default: 'closure' ) –

    what to set on var. Defaults to "closure".

References

Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771

Source code in torchzero/modules/grad_approximation/rfdm.py
class RDSA(RandomizedFDM):
    """
    Gradient approximation via Random-direction stochastic approximation (RDSA) method.

    Note:
        This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
        and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

    Args:
        h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
        n_samples (int, optional): number of random gradient samples. Defaults to 1.
        formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
        distribution (Distributions, optional): distribution. Defaults to "gaussian".
        beta (float, optional):
            If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
        pre_generate (bool, optional):
            whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
        seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
        target (GradTarget, optional): what to set on var. Defaults to "closure".

    References:
        Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771

    """
    def __init__(
        self,
        h: float = 1e-3,
        n_samples: int = 1,
        formula: _FD_Formula = "central2",
        distribution: Distributions = "gaussian",
        beta: float = 0,
        pre_generate = True,
        target: GradTarget = "closure",
        seed: int | None | torch.Generator = None,
    ):
        super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)

RDiv

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Divide :code:other by tensors. :code:other can be a number or a module.

If :code:other is a module, this calculates :code:other(tensors) / tensors

Source code in torchzero/modules/ops/binary.py
class RDiv(BinaryOperationBase):
    """Divide :code:`other` by tensors. :code:`other` can be a number or a module.

    If :code:`other` is a module, this calculates :code:`other(tensors) / tensors`
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        return other / TensorList(update)

RGraft

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs :code:magnitude(tensors) rescaled to have the same norm as tensors

Source code in torchzero/modules/ops/binary.py
class RGraft(BinaryOperationBase):
    """Outputs :code:`magnitude(tensors)` rescaled to have the same norm as tensors"""

    def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, direction=direction)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tensor]):
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
        return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)

RMSprop

Bases: torchzero.core.transform.Transform

Divides graient by EMA of gradient squares.

This implementation is identical to :code:torch.optim.RMSprop.

Parameters:

  • smoothing (float, default: 0.99 ) –

    beta for exponential moving average of gradient squares. Defaults to 0.99.

  • eps (float, default: 1e-08 ) –

    epsilon for division. Defaults to 1e-8.

  • centered (bool, default: False ) –

    whether to center EMA of gradient squares using an additional EMA. Defaults to False.

  • debiased (bool, default: False ) –

    applies Adam debiasing. Defaults to False.

  • amsgrad (bool, default: False ) –

    Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.

  • pow (float, default: 2 ) –

    power used in second momentum power and root. Defaults to 2.

  • init (str, default: 'zeros' ) –

    how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".

  • inner (Chainable | None, default: None ) –

    Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.

Source code in torchzero/modules/adaptive/rmsprop.py
class RMSprop(Transform):
    """Divides graient by EMA of gradient squares.

    This implementation is identical to :code:`torch.optim.RMSprop`.

    Args:
        smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
        eps (float, optional): epsilon for division. Defaults to 1e-8.
        centered (bool, optional): whether to center EMA of gradient squares using an additional EMA. Defaults to False.
        debiased (bool, optional): applies Adam debiasing. Defaults to False.
        amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
        pow (float, optional): power used in second momentum power and root. Defaults to 2.
        init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
        inner (Chainable | None, optional):
            Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.
    """
    def __init__(
        self,
        smoothing: float = 0.99,
        eps: float = 1e-8,
        centered: bool = False,
        debiased: bool = False,
        amsgrad: bool = False,
        pow: float = 2,
        init: Literal["zeros", "update"] = "zeros",
        inner: Chainable | None = None,
    ):
        defaults = dict(smoothing=smoothing,eps=eps,centered=centered,debiased=debiased,amsgrad=amsgrad,pow=pow,init=init)
        super().__init__(defaults=defaults, uses_grad=False)

        if inner is not None:
            self.set_child('inner', inner)

    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1
        smoothing, eps = unpack_dicts(settings, 'smoothing', 'eps', cls=NumberList)
        centered, debiased, amsgrad, pow, init = itemgetter('centered','debiased','amsgrad','pow','init')(settings[0])

        exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
        exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList) if centered else None
        max_exp_avg_sq = unpack_states(states, tensors, 'max_exp_avg_sq', cls=TensorList) if amsgrad else None

        if init == 'update' and step == 1:
            exp_avg_sq.set_([t**2 for t in tensors])
            if exp_avg is not None: exp_avg.set_([t.clone() for t in tensors])

        return rmsprop_(
            TensorList(tensors),
            exp_avg_sq_=exp_avg_sq,
            smoothing=smoothing,
            eps=eps,
            debiased=debiased,
            step=step,
            exp_avg_=exp_avg,
            max_exp_avg_sq_=max_exp_avg_sq,
            pow=pow,

            # inner args
            inner=self.children.get("inner", None),
            params=params,
            grads=grads,
        )

RPow

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Take :code:other to the power of tensors. :code:other can be a number or a module.

If :code:other is a module, this calculates :code:other(tensors) ^ tensors

Source code in torchzero/modules/ops/binary.py
class RPow(BinaryOperationBase):
    """Take :code:`other` to the power of tensors. :code:`other` can be a number or a module.

    If :code:`other` is a module, this calculates :code:`other(tensors) ^ tensors`
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        if isinstance(other, (int, float)): return torch._foreach_pow(other, update) # no in-place
        torch._foreach_pow_(other, update)
        return other

RSub

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Subtract tensors from :code:other. :code:other can be a number or a module.

If :code:other is a module, this calculates :code:other(tensors) - tensors

Source code in torchzero/modules/ops/binary.py
class RSub(BinaryOperationBase):
    """Subtract tensors from :code:`other`. :code:`other` can be a number or a module.

    If :code:`other` is a module, this calculates :code:`other(tensors) - tensors`
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        return other - TensorList(update)

Randn

Bases: torchzero.core.module.Module

Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1.

Source code in torchzero/modules/ops/utility.py
class Randn(Module):
    """Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1."""
    def __init__(self):
        super().__init__({})

    @torch.no_grad
    def step(self, var):
        var.update = [torch.randn_like(p) for p in var.params]
        return var

RandomHvp

Bases: torchzero.core.module.Module

Returns a hessian-vector product with a random vector

Source code in torchzero/modules/misc/misc.py
class RandomHvp(Module):
    """Returns a hessian-vector product with a random vector"""

    def __init__(
        self,
        n_samples: int = 1,
        distribution: Distributions = "normal",
        update_freq: int = 1,
        hvp_method: Literal["autograd", "forward", "central"] = "autograd",
        h=1e-3,
    ):
        defaults = dict(n_samples=n_samples, distribution=distribution, hvp_method=hvp_method, h=h, update_freq=update_freq)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        params = TensorList(var.params)
        settings = self.settings[params[0]]
        n_samples = settings['n_samples']
        distribution = settings['distribution']
        hvp_method = settings['hvp_method']
        h = settings['h']
        update_freq = settings['update_freq']

        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        D = None
        if step % update_freq == 0:

            rgrad = None
            for i in range(n_samples):
                u = params.sample_like(distribution=distribution, variance=1)

                Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
                                    h=h, normalize=True, retain_grad=i < n_samples-1)

                if D is None: D = Hvp
                else: torch._foreach_add_(D, Hvp)

            if n_samples > 1: torch._foreach_div_(D, n_samples)
            if update_freq != 1:
                assert D is not None
                D_buf = self.get_state(params, "D", cls=TensorList)
                D_buf.set_(D)

        if D is None:
            D = self.get_state(params, "D", cls=TensorList)

        var.update = list(D)
        return var

RandomSample

Bases: torchzero.core.module.Module

Outputs tensors filled with random numbers from distribution depending on value of :code:distribution.

Source code in torchzero/modules/ops/utility.py
class RandomSample(Module):
    """Outputs tensors filled with random numbers from distribution depending on value of :code:`distribution`."""
    def __init__(self, distribution: Distributions = 'normal', variance:float | None = None):
        defaults = dict(distribution=distribution, variance=variance)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        distribution = self.defaults['distribution']
        variance = self.get_settings(var.params, 'variance')
        var.update = TensorList(var.params).sample_like(distribution=distribution, variance=variance)
        return var

RandomStepSize

Bases: torchzero.core.transform.Transform

Uses random global or layer-wise step size from low to high.

Parameters:

  • low (float, default: 0 ) –

    minimum learning rate. Defaults to 0.

  • high (float, default: 1 ) –

    maximum learning rate. Defaults to 1.

  • parameterwise (bool, default: False ) –

    if True, generate random step size for each parameter separately, if False generate one global random step size. Defaults to False.

Source code in torchzero/modules/step_size/lr.py
class RandomStepSize(Transform):
    """Uses random global or layer-wise step size from `low` to `high`.

    Args:
        low (float, optional): minimum learning rate. Defaults to 0.
        high (float, optional): maximum learning rate. Defaults to 1.
        parameterwise (bool, optional):
            if True, generate random step size for each parameter separately,
            if False generate one global random step size. Defaults to False.
    """
    def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
        defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        s = settings[0]
        parameterwise = s['parameterwise']

        seed = s['seed']
        if 'generator' not in self.global_state:
            self.global_state['generator'] = random.Random(seed)
        generator: random.Random = self.global_state['generator']

        if parameterwise:
            low, high = unpack_dicts(settings, 'low', 'high')
            lr = [generator.uniform(l, h) for l, h in zip(low, high)]
        else:
            low = s['low']
            high = s['high']
            lr = generator.uniform(low, high)

        torch._foreach_mul_(tensors, lr)
        return tensors

RandomizedFDM

Bases: torchzero.modules.grad_approximation.grad_approximator.GradApproximator

Gradient approximation via a randomized finite-difference method.

Note

This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

Parameters:

  • h (float, default: 0.001 ) –

    finite difference step size of jvp_method is set to forward or central. Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of random gradient samples. Defaults to 1.

  • formula (Literal, default: 'central' ) –

    finite difference formula. Defaults to 'central2'.

  • distribution (Literal, default: 'rademacher' ) –

    distribution. Defaults to "rademacher". If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.

  • beta (float, default: 0 ) –

    optinal momentum for generated perturbations. Defaults to 1e-3.

  • pre_generate (bool, default: True ) –

    whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.

  • seed (int | None | Generator, default: None ) –

    Seed for random generator. Defaults to None.

  • target (Literal, default: 'closure' ) –

    what to set on var. Defaults to "closure".

Examples:

Simultaneous perturbation stochastic approximation (SPSA) method

SPSA is randomized finite differnce with rademacher distribution and central formula.

spsa = tz.Modular(
    model.parameters(),
    tz.m.RandomizedFDM(formula="central", distribution="rademacher"),
    tz.m.LR(1e-2)
)

Random-direction stochastic approximation (RDSA) method

RDSA is randomized finite differnce with usually gaussian distribution and central formula.

rdsa = tz.Modular(
    model.parameters(),
    tz.m.RandomizedFDM(formula="central", distribution="gaussian"),
    tz.m.LR(1e-2)
)
RandomizedFDM with momentum

Momentum might help by reducing the variance of the estimated gradients.

momentum_spsa = tz.Modular(
    model.parameters(),
    tz.m.RandomizedFDM(),
    tz.m.HeavyBall(0.9),
    tz.m.LR(1e-3)
)
Gaussian smoothing method

GS uses many gaussian samples with possibly a larger finite difference step size.

gs = tz.Modular(
    model.parameters(),
    tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
    tz.m.NewtonCG(hvp_method="forward"),
    tz.m.Backtracking()
)
SPSA-NewtonCG

NewtonCG with hessian-vector product estimated via gradient difference calls closure multiple times per step. If each closure call estimates gradients with different perturbations, NewtonCG is unable to produce useful directions.

By setting pre_generate to True, perturbations are generated once before each step, and each closure call estimates gradients using the same pre-generated perturbations. This way closure-based algorithms are able to use gradients estimated in a consistent way.

opt = tz.Modular(
    model.parameters(),
    tz.m.RandomizedFDM(n_samples=10),
    tz.m.NewtonCG(hvp_method="forward", pre_generate=True),
    tz.m.Backtracking()
)
SPSA-LBFGS

LBFGS uses a memory of past parameter and gradient differences. If past gradients were estimated with different perturbations, LBFGS directions will be useless.

To alleviate this momentum can be added to random perturbations to make sure they only change by a little bit, and the history stays relevant. The momentum is determined by the :code:beta parameter. The disadvantage is that the subspace the algorithm is able to explore changes slowly.

Additionally we will reset SPSA and LBFGS memory every 100 steps to remove influence from old gradient estimates.

opt = tz.Modular(
    bench.parameters(),
    tz.m.ResetEvery(
        [tz.m.RandomizedFDM(n_samples=10, pre_generate=True, beta=0.99), tz.m.LBFGS()],
        steps = 100,
    ),
    tz.m.Backtracking()
)
Source code in torchzero/modules/grad_approximation/rfdm.py
class RandomizedFDM(GradApproximator):
    """Gradient approximation via a randomized finite-difference method.

    Note:
        This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
        and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

    Args:
        h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
        n_samples (int, optional): number of random gradient samples. Defaults to 1.
        formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
        distribution (Distributions, optional): distribution. Defaults to "rademacher".
            If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
        beta (float, optional): optinal momentum for generated perturbations. Defaults to 1e-3.
        pre_generate (bool, optional):
            whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
        seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
        target (GradTarget, optional): what to set on var. Defaults to "closure".

    Examples:
    #### Simultaneous perturbation stochastic approximation (SPSA) method

    SPSA is randomized finite differnce with rademacher distribution and central formula.
    ```py
    spsa = tz.Modular(
        model.parameters(),
        tz.m.RandomizedFDM(formula="central", distribution="rademacher"),
        tz.m.LR(1e-2)
    )
    ```

    #### Random-direction stochastic approximation (RDSA) method

    RDSA is randomized finite differnce with usually gaussian distribution and central formula.

    ```
    rdsa = tz.Modular(
        model.parameters(),
        tz.m.RandomizedFDM(formula="central", distribution="gaussian"),
        tz.m.LR(1e-2)
    )
    ```

    #### RandomizedFDM with momentum

    Momentum might help by reducing the variance of the estimated gradients.

    ```
    momentum_spsa = tz.Modular(
        model.parameters(),
        tz.m.RandomizedFDM(),
        tz.m.HeavyBall(0.9),
        tz.m.LR(1e-3)
    )
    ```

    #### Gaussian smoothing method

    GS uses many gaussian samples with possibly a larger finite difference step size.

    ```
    gs = tz.Modular(
        model.parameters(),
        tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
        tz.m.NewtonCG(hvp_method="forward"),
        tz.m.Backtracking()
    )
    ```

    #### SPSA-NewtonCG

    NewtonCG with hessian-vector product estimated via gradient difference
    calls closure multiple times per step. If each closure call estimates gradients
    with different perturbations, NewtonCG is unable to produce useful directions.

    By setting pre_generate to True, perturbations are generated once before each step,
    and each closure call estimates gradients using the same pre-generated perturbations.
    This way closure-based algorithms are able to use gradients estimated in a consistent way.

    ```
    opt = tz.Modular(
        model.parameters(),
        tz.m.RandomizedFDM(n_samples=10),
        tz.m.NewtonCG(hvp_method="forward", pre_generate=True),
        tz.m.Backtracking()
    )
    ```

    #### SPSA-LBFGS

    LBFGS uses a memory of past parameter and gradient differences. If past gradients
    were estimated with different perturbations, LBFGS directions will be useless.

    To alleviate this momentum can be added to random perturbations to make sure they only
    change by a little bit, and the history stays relevant. The momentum is determined by the :code:`beta` parameter.
    The disadvantage is that the subspace the algorithm is able to explore changes slowly.

    Additionally we will reset SPSA and LBFGS memory every 100 steps to remove influence from old gradient estimates.

    ```
    opt = tz.Modular(
        bench.parameters(),
        tz.m.ResetEvery(
            [tz.m.RandomizedFDM(n_samples=10, pre_generate=True, beta=0.99), tz.m.LBFGS()],
            steps = 100,
        ),
        tz.m.Backtracking()
    )
    ```
    """
    PRE_MULTIPLY_BY_H = True
    def __init__(
        self,
        h: float = 1e-3,
        n_samples: int = 1,
        formula: _FD_Formula = "central",
        distribution: Distributions = "rademacher",
        beta: float = 0,
        pre_generate = True,
        seed: int | None | torch.Generator = None,
        target: GradTarget = "closure",
    ):
        defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, beta=beta, pre_generate=pre_generate, seed=seed)
        super().__init__(defaults, target=target)

    def reset(self):
        self.state.clear()
        generator = self.global_state.get('generator', None) # avoid resetting generator
        self.global_state.clear()
        if generator is not None: self.global_state['generator'] = generator
        for c in self.children.values(): c.reset()

    def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
        if 'generator' not in self.global_state:
            if isinstance(seed, torch.Generator): self.global_state['generator'] = seed
            elif seed is not None: self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
            else: self.global_state['generator'] = None
        return self.global_state['generator']

    def pre_step(self, var):
        h, beta = self.get_settings(var.params, 'h', 'beta')

        n_samples = self.defaults['n_samples']
        distribution = self.defaults['distribution']
        pre_generate = self.defaults['pre_generate']

        if pre_generate:
            params = TensorList(var.params)
            generator = self._get_generator(self.defaults['seed'], var.params)
            perturbations = [params.sample_like(distribution=distribution, variance=1, generator=generator) for _ in range(n_samples)]

            if self.PRE_MULTIPLY_BY_H:
                torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])

            if all(i==0 for i in beta):
                # just use pre-generated perturbations
                for param, prt in zip(params, zip(*perturbations)):
                    self.state[param]['perturbations'] = prt

            else:
                # lerp old and new perturbations. This makes the subspace change gradually
                # which in theory might improve algorithms with history
                for i,p in enumerate(params):
                    state = self.state[p]
                    if 'perturbations' not in state: state['perturbations'] = [p[i] for p in perturbations]

                cur = [self.state[p]['perturbations'][:n_samples] for p in params]
                cur_flat = [p for l in cur for p in l]
                new_flat = [p for l in zip(*perturbations) for p in l]
                betas = [1-v for b in beta for v in [b]*n_samples]
                torch._foreach_lerp_(cur_flat, new_flat, betas)

    @torch.no_grad
    def approximate(self, closure, params, loss):
        params = TensorList(params)
        orig_params = params.clone() # store to avoid small changes due to float imprecision
        loss_approx = None

        h = NumberList(self.settings[p]['h'] for p in params)
        settings = self.settings[params[0]]
        n_samples = settings['n_samples']
        fd_fn = _RFD_FUNCS[settings['formula']]
        default = [None]*n_samples
        perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
        distribution = settings['distribution']
        generator = self._get_generator(settings['seed'], params)

        grad = None
        for i in range(n_samples):
            prt = perturbations[i]

            if prt[0] is None:
                prt = params.sample_like(distribution=distribution, generator=generator, variance=1).mul_(h)

            else: prt = TensorList(prt)

            loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h, f_0=loss)
            # here `d` is a numberlist of directional derivatives, due to per parameter `h` values.

            # support for per-sample values which gives better estimate
            if d[0].numel() > 1: d = d.map(torch.mean)

            if grad is None: grad = prt * d
            else: grad += prt * d

        params.set_(orig_params)
        assert grad is not None
        if n_samples > 1: grad.div_(n_samples)

        # mean if got per-sample values
        if loss is not None:
            if loss.numel() > 1:
                loss = loss.mean()

        if loss_approx is not None:
            if loss_approx.numel() > 1:
                loss_approx = loss_approx.mean()

        return grad, loss, loss_approx

PRE_MULTIPLY_BY_H class-attribute

PRE_MULTIPLY_BY_H = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

Reciprocal

Bases: torchzero.core.transform.Transform

Returns :code:1 / input

Source code in torchzero/modules/ops/unary.py
class Reciprocal(Transform):
    """Returns :code:`1 / input`"""
    def __init__(self, eps = 0, target: "Target" = 'update'):
        defaults = dict(eps = eps)
        super().__init__(defaults, uses_grad=False, target=target)
    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        eps = [s['eps'] for s in settings]
        if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
        torch._foreach_reciprocal_(tensors)
        return tensors

ReduceOperationBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override transform method to use it.

Methods:

  • transform

    applies the operation to operands

Source code in torchzero/modules/ops/reduce.py
class ReduceOperationBase(Module, ABC):
    """Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override `transform` method to use it."""
    def __init__(self, defaults: dict[str, Any] | None, *operands: Chainable | Any):
        super().__init__(defaults=defaults)

        self.operands = []
        for i, v in enumerate(operands):

            if isinstance(v, (Module, Sequence)):
                self.set_child(f'operand_{i}', v)
                self.operands.append(self.children[f'operand_{i}'])
            else:
                self.operands.append(v)

        if not self.children:
            raise ValueError('At least one operand must be a module')

    @abstractmethod
    def transform(self, var: Var, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
        """applies the operation to operands"""
        raise NotImplementedError

    @torch.no_grad
    def step(self, var: Var) -> Var:
        # pass cloned update to all module operands
        processed_operands: list[Any | list[torch.Tensor]] = self.operands.copy()

        for i, v in enumerate(self.operands):
            if f'operand_{i}' in self.children:
                v: Module
                updated_var = v.step(var.clone(clone_update=True))
                processed_operands[i] = updated_var.get_update()
                var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them

        transformed = self.transform(var, *processed_operands)
        var.update = transformed
        return var

transform

transform(var: Var, *operands: Any | list[Tensor]) -> list[Tensor]

applies the operation to operands

Source code in torchzero/modules/ops/reduce.py
@abstractmethod
def transform(self, var: Var, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
    """applies the operation to operands"""
    raise NotImplementedError

Relative

Bases: torchzero.core.transform.Transform

Multiplies update by absolute parameter values to make it relative to their magnitude, :code:min_value is minimum allowed value to avoid getting stuck at 0.

Source code in torchzero/modules/misc/misc.py
class Relative(Transform):
    """Multiplies update by absolute parameter values to make it relative to their magnitude, :code:`min_value` is minimum allowed value to avoid getting stuck at 0."""
    def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
        defaults = dict(min_value=min_value)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
        torch._foreach_mul_(tensors, mul)
        return tensors

RelativeWeightDecay

Bases: torchzero.core.transform.Transform

Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of norm_input argument.

Parameters:

  • weight_decay (float, default: 0.1 ) –

    relative weight decay scale.

  • ord (int, default: 2 ) –

    order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.

  • norm_input (str, default: 'update' ) –

    determines what should weight decay be relative to. "update", "grad" or "params". Defaults to "update".

  • metric (Ords, default: 'mad' ) –

    metric (norm, etc) that weight decay should be relative to. defaults to 'mad' (mean absolute deviation).

  • target (Literal, default: 'update' ) –

    what to set on var. Defaults to 'update'.

Examples:

Adam with non-decoupled relative weight decay

opt = tz.Modular(
    model.parameters(),
    tz.m.RelativeWeightDecay(1e-1),
    tz.m.Adam(),
    tz.m.LR(1e-3)
)

Adam with decoupled relative weight decay

opt = tz.Modular(
    model.parameters(),
    tz.m.Adam(),
    tz.m.RelativeWeightDecay(1e-1),
    tz.m.LR(1e-3)
)

Source code in torchzero/modules/weight_decay/weight_decay.py
class RelativeWeightDecay(Transform):
    """Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of ``norm_input`` argument.

    Args:
        weight_decay (float): relative weight decay scale.
        ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
        norm_input (str, optional):
            determines what should weight decay be relative to. "update", "grad" or "params".
            Defaults to "update".
        metric (Ords, optional):
            metric (norm, etc) that weight decay should be relative to.
            defaults to 'mad' (mean absolute deviation).
        target (Target, optional): what to set on var. Defaults to 'update'.

    ### Examples:

    Adam with non-decoupled relative weight decay
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.RelativeWeightDecay(1e-1),
        tz.m.Adam(),
        tz.m.LR(1e-3)
    )
    ```

    Adam with decoupled relative weight decay
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.Adam(),
        tz.m.RelativeWeightDecay(1e-1),
        tz.m.LR(1e-3)
    )
    ```
    """
    def __init__(
        self,
        weight_decay: float = 0.1,
        ord: int  = 2,
        norm_input: Literal["update", "grad", "params"] = "update",
        metric: Metrics = 'mad',
        target: Target = "update",
    ):
        defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input, metric=metric)
        super().__init__(defaults, uses_grad=norm_input == 'grad', target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        weight_decay = NumberList(s['weight_decay'] for s in settings)

        ord = settings[0]['ord']
        norm_input = settings[0]['norm_input']
        metric = settings[0]['metric']

        if norm_input == 'update': src = TensorList(tensors)
        elif norm_input == 'grad':
            assert grads is not None
            src = TensorList(grads)
        elif norm_input == 'params':
            src = TensorList(params)
        else:
            raise ValueError(norm_input)

        norm = src.global_metric(metric)
        return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * norm, ord)

RestartEvery

Bases: torchzero.modules.restarts.restars.RestartStrategyBase

Resets the state every n steps

Parameters:

  • modules (Chainable | None) –

    modules to reset. If None, resets all modules.

  • steps (int | Literal['ndim']) –

    number of steps between resets. "ndim" to use number of parameters.

Source code in torchzero/modules/restarts/restars.py
class RestartEvery(RestartStrategyBase):
    """Resets the state every n steps

    Args:
        modules (Chainable | None):
            modules to reset. If None, resets all modules.
        steps (int | Literal["ndim"]):
            number of steps between resets. "ndim" to use number of parameters.
    """
    def __init__(self, modules: Chainable | None, steps: int | Literal['ndim']):
        defaults = dict(steps=steps)
        super().__init__(defaults, modules)

    def should_reset(self, var):
        step = self.global_state.get('step', 0) + 1
        self.global_state['step'] = step

        n = self.defaults['steps']
        if isinstance(n, str): n = sum(p.numel() for p in var.params if p.requires_grad)

        # reset every n steps
        if step % n == 0:
            self.global_state.clear()
            return True

        return False

RestartOnStuck

Bases: torchzero.modules.restarts.restars.RestartStrategyBase

Resets the state when update (difference in parameters) is zero for multiple steps in a row.

Parameters:

  • modules (Chainable | None) –

    modules to reset. If None, resets all modules.

  • tol (float, default: None ) –

    step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to None (uses twice the smallest respresentable number)

  • n_tol (int, default: 10 ) –

    number of failed consequtive steps required to trigger a reset. Defaults to 10.

Source code in torchzero/modules/restarts/restars.py
class RestartOnStuck(RestartStrategyBase):
    """Resets the state when update (difference in parameters) is zero for multiple steps in a row.

    Args:
        modules (Chainable | None):
            modules to reset. If None, resets all modules.
        tol (float, optional):
            step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to None (uses twice the smallest respresentable number)
        n_tol (int, optional):
            number of failed consequtive steps required to trigger a reset. Defaults to 10.

    """
    def __init__(self, modules: Chainable | None, tol: float | None = None, n_tol: int = 10):
        defaults = dict(tol=tol, n_tol=n_tol)
        super().__init__(defaults, modules)

    @torch.no_grad
    def should_reset(self, var):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        params = TensorList(var.params)
        tol = self.defaults['tol']
        if tol is None: tol = torch.finfo(params[0].dtype).tiny * 2
        n_tol = self.defaults['n_tol']
        n_bad = self.global_state.get('n_bad', 0)

        # calculate difference in parameters
        prev_params = self.get_state(params, 'prev_params', cls=TensorList)
        update = params - prev_params
        prev_params.copy_(params)

        # if update is too small, it is considered bad, otherwise n_bad is reset to 0
        if step > 0:
            if update.abs().global_max() <= tol:
                n_bad += 1

            else:
                n_bad = 0

        self.global_state['n_bad'] = n_bad

        # no progress, reset
        if n_bad >= n_tol:
            self.global_state.clear()
            return True

        return False

RestartStrategyBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for restart strategies.

On each update/step this checks reset condition and if it is satisfied, resets the modules before updating or stepping.

Methods:

Source code in torchzero/modules/restarts/restars.py
class RestartStrategyBase(Module, ABC):
    """Base class for restart strategies.

    On each ``update``/``step`` this checks reset condition and if it is satisfied,
    resets the modules before updating or stepping.
    """
    def __init__(self, defaults: dict | None = None, modules: Chainable | None = None):
        if defaults is None: defaults = {}
        super().__init__(defaults)
        if modules is not None:
            self.set_child('modules', modules)

    @abstractmethod
    def should_reset(self, var: Var) -> bool:
        """returns whether reset should occur"""

    def _reset_on_condition(self, var):
        modules = self.children.get('modules', None)

        if self.should_reset(var):
            if modules is None:
                var.post_step_hooks.append(partial(_reset_except_self, self=self))
            else:
                modules.reset()

        return modules

    @final
    def update(self, var):
        modules = self._reset_on_condition(var)
        if modules is not None:
            modules.update(var)

    @final
    def apply(self, var):
        # don't check here because it was check in `update`
        modules = self.children.get('modules', None)
        if modules is None: return var
        return modules.apply(var.clone(clone_update=False))

    @final
    def step(self, var):
        modules = self._reset_on_condition(var)
        if modules is None: return var
        return modules.step(var.clone(clone_update=False))

should_reset

should_reset(var: Var) -> bool

returns whether reset should occur

Source code in torchzero/modules/restarts/restars.py
@abstractmethod
def should_reset(self, var: Var) -> bool:
    """returns whether reset should occur"""

Rprop

Bases: torchzero.core.transform.Transform

Resilient propagation. The update magnitude gets multiplied by nplus if gradient didn't change the sign, or nminus if it did. Then the update is applied with the sign of the current gradient.

Additionally, if gradient changes sign, the update for that weight is reverted. Next step, magnitude for that weight won't change.

Compared to pytorch this also implements backtracking update when sign changes.

This implementation is identical to :code:torch.optim.Rprop if :code:backtrack is set to False.

Parameters:

  • nplus (float, default: 1.2 ) –

    multiplicative increase factor for when ascent didn't change sign (default: 1.2).

  • nminus (float, default: 0.5 ) –

    multiplicative decrease factor for when ascent changed sign (default: 0.5).

  • lb (float, default: 1e-06 ) –

    minimum step size, can be None (default: 1e-6)

  • ub (float, default: 50 ) –

    maximum step size, can be None (default: 50)

  • backtrack (float, default: True ) –

    if True, when ascent sign changes, undoes last weight update, otherwise sets update to 0. When this is False, this exactly matches pytorch Rprop. (default: True)

  • alpha (float, default: 1 ) –

    initial per-parameter learning rate (default: 1).

reference Riedmiller, M., & Braun, H. (1993, March). A direct adaptive method for faster backpropagation learning: The RPROP algorithm. In IEEE international conference on neural networks (pp. 586-591). IEEE.

Source code in torchzero/modules/adaptive/rprop.py
class Rprop(Transform):
    """
    Resilient propagation. The update magnitude gets multiplied by `nplus` if gradient didn't change the sign,
    or `nminus` if it did. Then the update is applied with the sign of the current gradient.

    Additionally, if gradient changes sign, the update for that weight is reverted.
    Next step, magnitude for that weight won't change.

    Compared to pytorch this also implements backtracking update when sign changes.

    This implementation is identical to :code:`torch.optim.Rprop` if :code:`backtrack` is set to False.

    Args:
        nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
        nminus (float): multiplicative decrease factor for when ascent changed sign (default: 0.5).
        lb (float): minimum step size, can be None (default: 1e-6)
        ub (float): maximum step size, can be None (default: 50)
        backtrack (float):
            if True, when ascent sign changes, undoes last weight update, otherwise sets update to 0.
            When this is False, this exactly matches pytorch Rprop. (default: True)
        alpha (float): initial per-parameter learning rate (default: 1).

    reference
        *Riedmiller, M., & Braun, H. (1993, March). A direct adaptive method for faster backpropagation learning:
        The RPROP algorithm. In IEEE international conference on neural networks (pp. 586-591). IEEE.*
    """
    def __init__(
        self,
        nplus: float = 1.2,
        nminus: float = 0.5,
        lb: float = 1e-6,
        ub: float = 50,
        backtrack=True,
        alpha: float = 1,
    ):
        defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub, backtrack=backtrack)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        nplus, nminus, lb, ub, alpha = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', 'alpha', cls=NumberList)
        prev, allowed, magnitudes = unpack_states(
            states, tensors,
            'prev','allowed','magnitudes',
            init=[torch.zeros_like, _bool_ones_like, torch.zeros_like],
            cls = TensorList,
        )

        tensors = rprop_(
            tensors_ = as_tensorlist(tensors),
            prev_ = prev,
            allowed_ = allowed,
            magnitudes_ = magnitudes,
            nplus = nplus,
            nminus = nminus,
            lb = lb,
            ub = ub,
            alpha = alpha,
            backtrack=settings[0]['backtrack'],
            step=step,
        )

        return tensors

SAM

Bases: torchzero.core.module.Module

Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412

SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value. It performs two forward and backward passes per step.

This implementation modifies the closure to return loss and calculate gradients of the SAM objective. All modules after this will use the modified objective.

.. note:: This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients at two points on each step.

Parameters:

  • rho (float, default: 0.05 ) –

    Neighborhood size. Defaults to 0.05.

  • p (float, default: 2 ) –

    norm of the SAM objective. Defaults to 2.

  • asam (bool, default: False ) –

    enables ASAM variant which makes perturbation relative to weight magnitudes. ASAM requires a much larger :code:rho, like 0.5 or 1. The :code:tz.m.ASAM class is idential to setting this argument to True, but it has larger :code:rho by default.

Examples:

SAM-SGD:

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.SAM(),
    tz.m.LR(1e-2)
)

SAM-Adam:

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.SAM(),
    tz.m.Adam(),
    tz.m.LR(1e-2)
)
References

Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412. https://arxiv.org/abs/2010.01412#page=3.16

Source code in torchzero/modules/adaptive/sam.py
class SAM(Module):
    """Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412

    SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
    It performs two forward and backward passes per step.

    This implementation modifies the closure to return loss and calculate gradients
    of the SAM objective. All modules after this will use the modified objective.

    .. note::
        This module requires a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients at two points on each step.

    Args:
        rho (float, optional): Neighborhood size. Defaults to 0.05.
        p (float, optional): norm of the SAM objective. Defaults to 2.
        asam (bool, optional):
            enables ASAM variant which makes perturbation relative to weight magnitudes.
            ASAM requires a much larger :code:`rho`, like 0.5 or 1.
            The :code:`tz.m.ASAM` class is idential to setting this argument to True, but
            it has larger :code:`rho` by default.

    Examples:
        SAM-SGD:

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.SAM(),
                tz.m.LR(1e-2)
            )

        SAM-Adam:

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.SAM(),
                tz.m.Adam(),
                tz.m.LR(1e-2)
            )

    References:
        Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412. https://arxiv.org/abs/2010.01412#page=3.16
    """
    def __init__(self, rho: float = 0.05, p: float = 2, eps=1e-10, asam=False):
        defaults = dict(rho=rho, p=p, eps=eps, asam=asam)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):

        params = var.params
        closure = var.closure
        zero_grad = var.zero_grad
        if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
        p, rho = self.get_settings(var.params, 'p', 'rho', cls=NumberList)
        s = self.defaults
        eps = s['eps']
        asam = s['asam']

        # 1/p + 1/q = 1
        # okay, authors of SAM paper, I will manually solve your equation
        # so q = -p/(1-p)
        q = -p / (1-p)
        # as a validation for 2 it is -2 / -1 = 2

        @torch.no_grad
        def sam_closure(backward=True):
            orig_grads = None
            if not backward:
                # if backward is False, make sure this doesn't modify gradients
                # to avoid issues
                orig_grads = [p.grad for p in params]

            # gradient at initial parameters
            zero_grad()
            with torch.enable_grad():
                closure()

            grad = TensorList(p.grad if p.grad is not None else torch.zeros_like(p) for p in params)
            grad_abs = grad.abs()

            # compute e
            term1 = grad.sign().mul_(rho)
            term2 = grad_abs.pow(q-1)

            if asam:
                grad_abs.mul_(torch._foreach_abs(params))

            denom = grad_abs.pow_(q).sum().pow(1/p)

            e = term1.mul_(term2).div_(denom.clip(min=eps))

            if asam:
                e.mul_(torch._foreach_pow(params, 2))

            # calculate loss and gradient approximation of inner problem
            torch._foreach_add_(params, e)
            if backward:
                zero_grad()
                with torch.enable_grad():
                    # this sets .grad attributes
                    sam_loss = closure()

            else:
                sam_loss = closure(False)

            # and restore initial parameters
            torch._foreach_sub_(params, e)

            if orig_grads is not None:
                for param,orig_grad in zip(params, orig_grads):
                    param.grad = orig_grad

            return sam_loss

        var.closure = sam_closure
        return var

SOAP

Bases: torchzero.core.transform.Transform

SOAP (ShampoO with Adam in the Preconditioner's eigenbasis from https://arxiv.org/abs/2409.11321).

Parameters:

  • beta1 (float, default: 0.95 ) –

    beta for first momentum. Defaults to 0.95.

  • beta2 (float, default: 0.95 ) –

    beta for second momentum. Defaults to 0.95.

  • shampoo_beta (float | None, default: 0.95 ) –

    beta for covariance matrices accumulators. Can be None, then it just sums them like Adagrad (which works worse). Defaults to 0.95.

  • precond_freq (int, default: 10 ) –

    How often to update the preconditioner. Defaults to 10.

  • merge_small (bool, default: True ) –

    Whether to merge small dims. Defaults to True.

  • max_dim (int, default: 2000 ) –

    Won't precondition dims larger than this. Defaults to 2_000.

  • precondition_1d (bool, default: True ) –

    Whether to precondition 1d params (SOAP paper sets this to False). Defaults to True.

  • eps (float, default: 1e-08 ) –

    epsilon for dividing first momentum by second. Defaults to 1e-8.

  • decay (float | None, default: None ) –

    Decays covariance matrix accumulators, this may be useful if shampoo_beta is None. Defaults to None.

  • alpha (float, default: 1 ) –

    learning rate. Defaults to 1.

  • bias_correction (bool, default: True ) –

    enables adam bias correction. Defaults to True.

Examples:

SOAP:

.. code-block:: python

opt = tz.Modular(model.parameters(), tz.m.SOAP(), tz.m.LR(1e-3))

Stabilized SOAP:

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.SOAP(),
    tz.m.NormalizeByEMA(max_ema_growth=1.2),
    tz.m.LR(1e-2)
)
Source code in torchzero/modules/adaptive/soap.py
class SOAP(Transform):
    """SOAP (ShampoO with Adam in the Preconditioner's eigenbasis from https://arxiv.org/abs/2409.11321).

    Args:
        beta1 (float, optional): beta for first momentum. Defaults to 0.95.
        beta2 (float, optional): beta for second momentum. Defaults to 0.95.
        shampoo_beta (float | None, optional):
            beta for covariance matrices accumulators. Can be None, then it just sums them like Adagrad (which works worse). Defaults to 0.95.
        precond_freq (int, optional): How often to update the preconditioner. Defaults to 10.
        merge_small (bool, optional): Whether to merge small dims. Defaults to True.
        max_dim (int, optional): Won't precondition dims larger than this. Defaults to 2_000.
        precondition_1d (bool, optional):
            Whether to precondition 1d params (SOAP paper sets this to False). Defaults to True.
        eps (float, optional):
            epsilon for dividing first momentum by second. Defaults to 1e-8.
        decay (float | None, optional):
            Decays covariance matrix accumulators, this may be useful if `shampoo_beta` is None. Defaults to None.
        alpha (float, optional):
            learning rate. Defaults to 1.
        bias_correction (bool, optional):
            enables adam bias correction. Defaults to True.

    Examples:
        SOAP:

        .. code-block:: python

            opt = tz.Modular(model.parameters(), tz.m.SOAP(), tz.m.LR(1e-3))

        Stabilized SOAP:

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.SOAP(),
                tz.m.NormalizeByEMA(max_ema_growth=1.2),
                tz.m.LR(1e-2)
            )
    """
    def __init__(
        self,
        beta1: float = 0.95,
        beta2: float = 0.95,
        shampoo_beta: float | None = 0.95,
        precond_freq: int = 10,
        merge_small: bool = True,
        max_dim: int = 2_000,
        precondition_1d: bool = True,
        eps: float = 1e-8,
        decay: float | None = None,
        alpha: float = 1,
        bias_correction: bool = True,
    ):
        defaults = dict(
            beta1=beta1,
            beta2=beta2,
            shampoo_beta=shampoo_beta,
            precond_freq=precond_freq,
            merge_small=merge_small,
            max_dim=max_dim,
            precondition_1d=precondition_1d,
            eps=eps,
            decay=decay,
            bias_correction=bias_correction,
            alpha=alpha,
        )
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        updates = []
        # update preconditioners
        for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
            beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps,alpha = itemgetter(
                'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps','alpha')(setting)

            if merge_small:
                t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)

            # initialize state on 1st step
            if 'GG' not in state:
                state["exp_avg"] = torch.zeros_like(t)
                state["exp_avg_sq_projected"] = torch.zeros_like(t)

                if not precondition_1d and t.ndim <= 1:
                    state['GG'] = []

                else:
                    state['GG'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]

                # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
                if len([i is not None for i in state['GG']]) == 0:
                    state['GG'] = None

                if state['GG'] is not None:
                    update_soap_covariances_(t, GGs_=state['GG'], beta=shampoo_beta)
                    try: state['Q'] = get_orthogonal_matrix(state['GG'])
                    except torch.linalg.LinAlgError as e:
                        warnings.warn(f"torch.linalg.eigh raised an error when initializing SOAP Q matrices on 1st step, diagonal preconditioning will be used for this parameter. The error was:\n{e}")
                        state["GG"] = None

                state['step'] = 0
                updates.append(tensors[i].clip(-0.1, 0.1))
                continue  # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
                # I use scaled update instead as to not mess up with next modules.

            # Projecting gradients to the eigenbases of Shampoo's preconditioner
            # i.e. projecting to the eigenbases of matrices in state['GG']
            t_projected = None
            if state['GG'] is not None:
                t_projected = project(t, state['Q'])

            # exponential moving averages
            # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
            exp_avg: torch.Tensor = state["exp_avg"]
            exp_avg_sq_projected: torch.Tensor = state["exp_avg_sq_projected"]

            exp_avg.lerp_(t, 1-beta1)

            if t_projected is None:
                exp_avg_sq_projected.mul_(beta2).addcmul_(t, t, value=1-beta2)
            else:
                exp_avg_sq_projected.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)

            # project exponential moving averages if they are accumulated unprojected
            exp_avg_projected = exp_avg
            if t_projected is not None:
                exp_avg_projected = project(exp_avg, state['Q'])

            denom = exp_avg_sq_projected.sqrt().add_(eps)
            # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')

            # Projecting back the preconditioned (by Adam) exponential moving average of gradients
            # to the original space
            update = exp_avg_projected / denom

            if t_projected is not None:
                update = project_back(update, state["Q"])

            if setting['bias_correction']:
                bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
                bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
                update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
            elif alpha is not None:
                update *= alpha

            if merge_small:
                update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])

            updates.append(update)
            state["step"] += 1

            # Update is done after the gradient step to avoid using current gradients in the projection.
            if state['GG'] is not None:
                update_soap_covariances_(t, state['GG'], shampoo_beta)
                if state['step'] % setting['precond_freq'] == 0:
                    try:
                        state['Q'], state['exp_avg_sq_projected'] = get_orthogonal_matrix_QR(exp_avg_sq_projected, state['GG'], state['Q'])
                    except torch.linalg.LinAlgError:
                        pass
        return updates

SPSA

Bases: torchzero.modules.grad_approximation.rfdm.RandomizedFDM

Gradient approximation via Simultaneous perturbation stochastic approximation (SPSA) method.

Note

This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

Parameters:

  • h (float, default: 0.001 ) –

    finite difference step size of jvp_method is set to forward or central. Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of random gradient samples. Defaults to 1.

  • formula (Literal, default: 'central' ) –

    finite difference formula. Defaults to 'central2'.

  • distribution (Literal, default: 'rademacher' ) –

    distribution. Defaults to "rademacher".

  • beta (float, default: 0 ) –

    If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.

  • pre_generate (bool, default: True ) –

    whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.

  • seed (int | None | Generator, default: None ) –

    Seed for random generator. Defaults to None.

  • target (Literal, default: 'closure' ) –

    what to set on var. Defaults to "closure".

References

Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771

Source code in torchzero/modules/grad_approximation/rfdm.py
class SPSA(RandomizedFDM):
    """
    Gradient approximation via Simultaneous perturbation stochastic approximation (SPSA) method.

    Note:
        This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
        and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

    Args:
        h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
        n_samples (int, optional): number of random gradient samples. Defaults to 1.
        formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
        distribution (Distributions, optional): distribution. Defaults to "rademacher".
        beta (float, optional):
            If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
        pre_generate (bool, optional):
            whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
        seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
        target (GradTarget, optional): what to set on var. Defaults to "closure".

    References:
        Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771
    """

SR1

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Symmetric Rank 1. This works best with a trust region:

tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False))

Parameters:

  • init_scale (float | Literal['auto'], default: 'auto' ) –

    initial hessian matrix is set to identity times this.

    "auto" corresponds to a heuristic from [1] p.142-143.

    Defaults to "auto".

  • tol (float, default: 1e-32 ) –

    tolerance for denominator in SR1 update rule as in [1] p.146. Defaults to 1e-32.

  • ptol (float | None, default: 1e-32 ) –

    skips update if maximum difference between current and previous gradients is less than this, to avoid instability. Defaults to 1e-32.

  • ptol_restart (bool, default: False ) –

    whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.

  • restart_interval (int | None | Literal['auto'], default: None ) –

    interval between resetting the hessian approximation.

    "auto" corresponds to number of decision variables + 1.

    None - no resets.

    Defaults to None.

  • beta (float | None, default: None ) –

    momentum on H or B. Defaults to None.

  • update_freq (int, default: 1 ) –

    frequency of updating H or B. Defaults to 1.

  • scale_first (bool, default: False ) –

    whether to downscale first step before hessian approximation becomes available. Defaults to True.

  • scale_second (bool) –

    whether to downscale second step. Defaults to False.

  • concat_params (bool, default: True ) –

    If true, all parameters are treated as a single vector. If False, the update rule is applied to each parameter separately. Defaults to True.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to the output of this module. Defaults to None.

Examples:

SR1 with trust region

opt = tz.Modular(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
)

References:
[1]. Nocedal. Stephen J. Wright. Numerical Optimization
Source code in torchzero/modules/quasi_newton/quasi_newton.py
class SR1(_InverseHessianUpdateStrategyDefaults):
    """Symmetric Rank 1. This works best with a trust region:
    ```python
    tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False))
    ```

    Args:
        init_scale (float | Literal["auto"], optional):
            initial hessian matrix is set to identity times this.

            "auto" corresponds to a heuristic from [1] p.142-143.

            Defaults to "auto".
        tol (float, optional):
            tolerance for denominator in SR1 update rule as in [1] p.146. Defaults to 1e-32.
        ptol (float | None, optional):
            skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
            Defaults to 1e-32.
        ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
        restart_interval (int | None | Literal["auto"], optional):
            interval between resetting the hessian approximation.

            "auto" corresponds to number of decision variables + 1.

            None - no resets.

            Defaults to None.
        beta (float | None, optional): momentum on H or B. Defaults to None.
        update_freq (int, optional): frequency of updating H or B. Defaults to 1.
        scale_first (bool, optional):
            whether to downscale first step before hessian approximation becomes available. Defaults to True.
        scale_second (bool, optional): whether to downscale second step. Defaults to False.
        concat_params (bool, optional):
            If true, all parameters are treated as a single vector.
            If False, the update rule is applied to each parameter separately. Defaults to True.
        inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.

    ### Examples:

    SR1 with trust region
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
    )
    ```

    ###  References:
        [1]. Nocedal. Stephen J. Wright. Numerical Optimization
    """

    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return sr1_(H=H, s=s, y=y, tol=setting['tol'])
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return sr1_(H=B, s=y, y=s, tol=setting['tol'])

SSVM

Bases: torchzero.modules.quasi_newton.quasi_newton.HessianUpdateStrategy

Self-scaling variable metric Quasi-Newton method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class SSVM(HessianUpdateStrategy):
    """
    Self-scaling variable metric Quasi-Newton method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
    """
    def __init__(
        self,
        switch: tuple[float,float] | Literal[1,2,3,4] = 3,
        init_scale: float | Literal["auto"] = 'auto',
        tol: float = 1e-32,
        ptol: float | None = 1e-32,
        ptol_restart: bool = False,
        gtol: float | None = 1e-32,
        restart_interval: int | None = None,
        beta: float | None = None,
        update_freq: int = 1,
        scale_first: bool = False,
        concat_params: bool = True,
        inner: Chainable | None = None,
    ):
        defaults = dict(switch=switch)
        super().__init__(
            defaults=defaults,
            init_scale=init_scale,
            tol=tol,
            ptol=ptol,
            ptol_restart=ptol_restart,
            gtol=gtol,
            restart_interval=restart_interval,
            beta=beta,
            update_freq=update_freq,
            scale_first=scale_first,
            concat_params=concat_params,
            inverse=True,
            inner=inner,
        )

    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return ssvm_H_(H=H, s=s, y=y, g=g, switch=setting['switch'], tol=setting['tol'])

SVRG

Bases: torchzero.core.module.Module

Stochastic variance reduced gradient method (SVRG).

To use, put SVRG as the first module, it can be used with any other modules. To reduce variance of a gradient estimator, put the gradient estimator before SVRG.

First it uses first accum_steps batches to compute full gradient at initial parameters using gradient accumulation, the model will not be updated during this.

Then it performs svrg_steps SVRG steps, each requires two forward and backward passes.

After svrg_steps, it goes back to full gradient computation step step.

As an alternative to gradient accumulation you can pass "full_closure" argument to the step method, which should compute full gradients, set them to .grad attributes of the parameters, and return full loss.

Parameters:

  • svrg_steps (int) –

    number of steps before calculating full gradient. This can be set to length of the dataloader.

  • accum_steps (int | None, default: None ) –

    number of steps to accumulate the gradient for. Not used if "full_closure" is passed to the step method. If None, uses value of svrg_steps. Defaults to None.

  • reset_before_accum (bool, default: True ) –

    whether to reset all other modules when re-calculating full gradient. Defaults to True.

  • svrg_loss (bool, default: True ) –

    whether to replace loss with SVRG loss (calculated by same formula as SVRG gradient). Defaults to True.

  • alpha (float, default: 1 ) –

    multiplier to g_full(x_0) - g_batch(x_0) term, can be annealed linearly from 1 to 0 as suggested in https://arxiv.org/pdf/2311.05589#page=6

Examples:

SVRG-LBFGS

opt = tz.Modular(
    model.parameters(),
    tz.m.SVRG(len(dataloader)),
    tz.m.LBFGS(),
    tz.m.Backtracking(),
)

For extra variance reduction one can use Online versions of algorithms, although it won't always help.

opt = tz.Modular(
    model.parameters(),
    tz.m.SVRG(len(dataloader)),
    tz.m.Online(tz.m.LBFGS()),
    tz.m.Backtracking(),
)

Variance reduction can also be applied to gradient estimators.
```python
opt = tz.Modular(
    model.parameters(),
    tz.m.SPSA(),
    tz.m.SVRG(100),
    tz.m.LR(1e-2),
)

Notes

The SVRG gradient is computed as g_b(x) - alpha * g_b(x_0) - g_f(x0.), where: - x is current parameters - x_0 is initial parameters, where full gradient was computed - g_b refers to mini-batch gradient at x or x_0 - g_f refers to full gradient at x_0.

The SVRG loss is computed using the same formula.

Source code in torchzero/modules/variance_reduction/svrg.py
class SVRG(Module):
    """Stochastic variance reduced gradient method (SVRG).

    To use, put SVRG as the first module, it can be used with any other modules.
    To reduce variance of a gradient estimator, put the gradient estimator before SVRG.

    First it uses first ``accum_steps`` batches to compute full gradient at initial
    parameters using gradient accumulation, the model will not be updated during this.

    Then it performs ``svrg_steps`` SVRG steps, each requires two forward and backward passes.

    After ``svrg_steps``, it goes back to full gradient computation step step.

    As an alternative to gradient accumulation you can pass "full_closure" argument to the ``step`` method,
    which should compute full gradients, set them to ``.grad`` attributes of the parameters,
    and return full loss.

    Args:
        svrg_steps (int): number of steps before calculating full gradient. This can be set to length of the dataloader.
        accum_steps (int | None, optional):
            number of steps to accumulate the gradient for. Not used if "full_closure" is passed to the ``step`` method. If None, uses value of ``svrg_steps``. Defaults to None.
        reset_before_accum (bool, optional):
            whether to reset all other modules when re-calculating full gradient. Defaults to True.
        svrg_loss (bool, optional):
            whether to replace loss with SVRG loss (calculated by same formula as SVRG gradient). Defaults to True.
        alpha (float, optional):
            multiplier to ``g_full(x_0) - g_batch(x_0)`` term, can be annealed linearly from 1 to 0 as suggested in https://arxiv.org/pdf/2311.05589#page=6

    ## Examples:
    SVRG-LBFGS
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.SVRG(len(dataloader)),
        tz.m.LBFGS(),
        tz.m.Backtracking(),
    )
    ```

    For extra variance reduction one can use Online versions of algorithms, although it won't always help.
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.SVRG(len(dataloader)),
        tz.m.Online(tz.m.LBFGS()),
        tz.m.Backtracking(),
    )

    Variance reduction can also be applied to gradient estimators.
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.SPSA(),
        tz.m.SVRG(100),
        tz.m.LR(1e-2),
    )
    ```
    ## Notes

    The SVRG gradient is computed as ``g_b(x) - alpha * g_b(x_0) - g_f(x0.)``, where:
    - ``x`` is current parameters
    - ``x_0`` is initial parameters, where full gradient was computed
    - ``g_b`` refers to mini-batch gradient at ``x`` or ``x_0``
    - ``g_f`` refers to full gradient at ``x_0``.

    The SVRG loss is computed using the same formula.
    """
    def __init__(self, svrg_steps: int, accum_steps: int | None = None, reset_before_accum:bool=True, svrg_loss:bool=True, alpha:float=1):
        defaults = dict(svrg_steps = svrg_steps, accum_steps=accum_steps, reset_before_accum=reset_before_accum, svrg_loss=svrg_loss, alpha=alpha)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        params = var.params
        closure = var.closure
        assert closure is not None

        if "full_grad" not in self.global_state:

            # -------------------------- calculate full gradient ------------------------- #
            if "full_closure" in var.storage:
                full_closure = var.storage['full_closure']
                with torch.enable_grad():
                    full_loss = full_closure()
                    if all(p.grad is None for p in params):
                        warnings.warn("all gradients are None after evaluating full_closure.")

                    full_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
                    self.global_state["full_loss"] = full_loss
                    self.global_state["full_grad"] = full_grad
                    self.global_state['x_0'] = [p.clone() for p in params]

                # current batch will be used for svrg update

            else:
                # accumulate gradients over n steps
                accum_steps = self.defaults['accum_steps']
                if accum_steps is None: accum_steps = self.defaults['svrg_steps']

                current_accum_step = self.global_state.get('current_accum_step', 0) + 1
                self.global_state['current_accum_step'] = current_accum_step

                # accumulate grads
                accumulator = self.get_state(params, 'accumulator')
                grad = var.get_grad()
                torch._foreach_add_(accumulator, grad)

                # accumulate loss
                loss_accumulator = self.global_state.get('loss_accumulator', 0)
                loss_accumulator += tofloat(var.loss)
                self.global_state['loss_accumulator'] = loss_accumulator

                # on nth step, use the accumulated gradient
                if current_accum_step >= accum_steps:
                    torch._foreach_div_(accumulator, accum_steps)
                    self.global_state["full_grad"] = accumulator
                    self.global_state["full_loss"] = loss_accumulator / accum_steps

                    self.global_state['x_0'] = [p.clone() for p in params]
                    self.clear_state_keys('accumulator')
                    del self.global_state['current_accum_step']

                # otherwise skip update until enough grads are accumulated
                else:
                    var.update = None
                    var.stop = True
                    var.skip_update = True
                    return var


        svrg_steps = self.defaults['svrg_steps']
        current_svrg_step = self.global_state.get('current_svrg_step', 0) + 1
        self.global_state['current_svrg_step'] = current_svrg_step

        # --------------------------- SVRG gradient closure -------------------------- #
        x0 = self.global_state['x_0']
        gf_x0 = self.global_state["full_grad"]
        ff_x0 = self.global_state['full_loss']
        use_svrg_loss = self.defaults['svrg_loss']
        alpha = self.get_settings(params, 'alpha')
        alpha_0 = alpha[0]
        if all(a == 1 for a in alpha): alpha = None

        def svrg_closure(backward=True):
            # g_b(x) - α * (g_f(x_0) - g_b(x_0)) and same for loss
            with torch.no_grad():
                x = [p.clone() for p in params]

                if backward:
                    # f and g at x
                    with torch.enable_grad(): fb_x = closure()
                    gb_x = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]

                    # f and g at x_0
                    torch._foreach_copy_(params, x0)
                    with torch.enable_grad(): fb_x0 = closure()
                    gb_x0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
                    torch._foreach_copy_(params, x)

                    # g_svrg = gb_x - alpha * (gf_x0 - gb_x0)
                    correction = torch._foreach_sub(gb_x0, gf_x0)
                    if alpha is not None: torch._foreach_mul_(correction, alpha)
                    g_svrg = torch._foreach_sub(gb_x, correction)

                    f_svrg = fb_x - alpha_0 * (fb_x0 - ff_x0)
                    for p, g in zip(params, g_svrg):
                        p.grad = g

                    if use_svrg_loss: return f_svrg
                    return fb_x

            # no backward
            if use_svrg_loss:
                fb_x = closure(False)
                torch._foreach_copy_(params, x0)
                fb_x0 = closure(False)
                torch._foreach_copy_(params, x)
                f_svrg = fb_x - alpha_0 * (fb_x0 - ff_x0)
                return f_svrg

            return closure(False)

        var.closure = svrg_closure

        # --- after svrg_steps steps reset so that new full gradient is calculated on next step --- #
        if current_svrg_step >= svrg_steps:
            del self.global_state['current_svrg_step']
            del self.global_state['full_grad']
            del self.global_state['full_loss']
            del self.global_state['x_0']
            if self.defaults['reset_before_accum']:
                var.post_step_hooks.append(partial(_reset_except_self, self=self))

        return var

SaveBest

Bases: torchzero.core.module.Module

Saves best parameters found so far, ones that have lowest loss. Put this as the last module.

Adds the following attrs:

  • best_params - a list of tensors with best parameters.
  • best_loss - loss value with best_params.
  • load_best_parameters - a function that sets parameters to the best parameters./

Examples

```python def rosenbrock(x, y): return (1 - x)2 + (100 * (y - x2))**2

xy = torch.tensor((-1.1, 2.5), requires_grad=True) opt = tz.Modular( [xy], tz.m.NAG(0.999), tz.m.LR(1e-6), tz.m.SaveBest() )

optimize for 1000 steps

for i in range(1000): loss = rosenbrock(*xy) opt.zero_grad() loss.backward() opt.step(loss=loss) # SaveBest needs closure or loss

NAG overshot, but we saved the best params

print(f'{rosenbrock(*xy) = }') # >> 3.6583 print(f"{opt.attrs['best_loss'] = }") # >> 0.000627

load best parameters

opt.attrs'load_best_params' print(f'{rosenbrock(*xy) = }') # >> 0.000627

Source code in torchzero/modules/misc/misc.py
class SaveBest(Module):
    """Saves best parameters found so far, ones that have lowest loss. Put this as the last module.

    Adds the following attrs:

    - ``best_params`` - a list of tensors with best parameters.
    - ``best_loss`` - loss value with ``best_params``.
    - ``load_best_parameters`` - a function that sets parameters to the best parameters./

    ## Examples
    ```python
    def rosenbrock(x, y):
        return (1 - x)**2 + (100 * (y - x**2))**2

    xy = torch.tensor((-1.1, 2.5), requires_grad=True)
    opt = tz.Modular(
        [xy],
        tz.m.NAG(0.999),
        tz.m.LR(1e-6),
        tz.m.SaveBest()
    )

    # optimize for 1000 steps
    for i in range(1000):
        loss = rosenbrock(*xy)
        opt.zero_grad()
        loss.backward()
        opt.step(loss=loss) # SaveBest needs closure or loss

    # NAG overshot, but we saved the best params
    print(f'{rosenbrock(*xy) = }') # >> 3.6583
    print(f"{opt.attrs['best_loss'] = }") # >> 0.000627

    # load best parameters
    opt.attrs['load_best_params']()
    print(f'{rosenbrock(*xy) = }') # >> 0.000627
    """
    def __init__(self):
        super().__init__()

    @torch.no_grad
    def step(self, var):
        loss = tofloat(var.get_loss(False))
        lowest_loss = self.global_state.get('lowest_loss', float("inf"))

        if loss < lowest_loss:
            self.global_state['lowest_loss'] = loss
            best_params = var.attrs['best_params'] = [p.clone() for p in var.params]
            var.attrs['best_loss'] = loss
            var.attrs['load_best_params'] = partial(_load_best_parameters, params=var.params, best_params=best_params)

        return var

ScalarProjection

Bases: torchzero.modules.projections.projection.ProjectionBase

projetion that splits all parameters into individual scalars

Source code in torchzero/modules/projections/projection.py
class ScalarProjection(ProjectionBase):
    """projetion that splits all parameters into individual scalars"""
    def __init__(
        self,
        modules: Chainable,
        project_update=True,
        project_params=True,
        project_grad=True,
    ):
        super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)

    @torch.no_grad
    def project(self, tensors, params, grads, loss, states, settings, current):
        return [s for t in tensors for s in t.ravel().unbind(0)]

    @torch.no_grad
    def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
        return vec_to_tensors(vec=torch.stack(projected_tensors), reference=params)

ScaleByGradCosineSimilarity

Bases: torchzero.core.transform.Transform

Multiplies the update by cosine similarity with gradient. If cosine similarity is negative, naturally the update will be negated as well.

Parameters:

  • eps (float, default: 1e-06 ) –

    epsilon for division. Defaults to 1e-6.

Examples:

Scaled Adam

opt = tz.Modular(
    bench.parameters(),
    tz.m.Adam(),
    tz.m.ScaleByGradCosineSimilarity(),
    tz.m.LR(1e-2)
)

Source code in torchzero/modules/momentum/cautious.py
class ScaleByGradCosineSimilarity(Transform):
    """Multiplies the update by cosine similarity with gradient.
    If cosine similarity is negative, naturally the update will be negated as well.

    Args:
        eps (float, optional): epsilon for division. Defaults to 1e-6.

    ## Examples:

    Scaled Adam
    ```python
    opt = tz.Modular(
        bench.parameters(),
        tz.m.Adam(),
        tz.m.ScaleByGradCosineSimilarity(),
        tz.m.LR(1e-2)
    )
    ```
    """
    def __init__(
        self,
        eps: float = 1e-6,
    ):
        defaults = dict(eps=eps)
        super().__init__(defaults, uses_grad=True)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        eps = settings[0]['eps']
        tensors = TensorList(tensors)
        grads = TensorList(grads)
        cos_sim = tensors.dot(grads) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)

        return tensors.mul_(cos_sim)

ScaleLRBySignChange

Bases: torchzero.core.transform.Transform

learning rate gets multiplied by nplus if ascent/gradient didn't change the sign, or nminus if it did.

This is part of RProp update rule.

Parameters:

  • nplus (float, default: 1.2 ) –

    learning rate gets multiplied by nplus if ascent/gradient didn't change the sign

  • nminus (float, default: 0.5 ) –

    learning rate gets multiplied by nminus if ascent/gradient changed the sign

  • lb (float, default: 1e-06 ) –

    lower bound for lr.

  • ub (float, default: 50.0 ) –

    upper bound for lr.

  • alpha (float, default: 1.0 ) –

    initial learning rate.

Source code in torchzero/modules/adaptive/rprop.py
class ScaleLRBySignChange(Transform):
    """
    learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign,
    or `nminus` if it did.

    This is part of RProp update rule.

    Args:
        nplus (float): learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign
        nminus (float): learning rate gets multiplied by `nminus` if ascent/gradient changed the sign
        lb (float): lower bound for lr.
        ub (float): upper bound for lr.
        alpha (float): initial learning rate.

    """

    def __init__(
        self,
        nplus: float = 1.2,
        nminus: float = 0.5,
        lb=1e-6,
        ub=50.0,
        alpha=1.0,
        use_grad=False,
        target: Target = "update",
    ):
        defaults = dict(nplus=nplus, nminus=nminus, alpha=alpha, lb=lb, ub=ub, use_grad=use_grad)
        super().__init__(defaults, uses_grad=use_grad, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        tensors = as_tensorlist(tensors)
        use_grad = settings[0]['use_grad']
        if use_grad: cur = as_tensorlist(grads)
        else: cur = tensors

        nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
        prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)

        if step == 0:
            lrs.set_(tensors.full_like([s['alpha'] for s in settings]))

        tensors = scale_by_sign_change_(
            tensors_ = tensors,
            cur = cur,
            prev_ = prev,
            lrs_ = lrs,
            nplus = nplus,
            nminus = nminus,
            lb = lb,
            ub = ub,
            step = step,
        )
        return tensors

ScaleModulesByCosineSimilarity

Bases: torchzero.core.module.Module

Scales the output of :code:main module by it's cosine similarity to the output of :code:compare module.

Parameters:

  • main (Chainable) –

    main module or sequence of modules whose update will be scaled.

  • compare (Chainable) –

    module or sequence of modules to compare to

  • eps (float, default: 1e-06 ) –

    epsilon for division. Defaults to 1e-6.

Examples:

Adam scaled by similarity to RMSprop

opt = tz.Modular(
    bench.parameters(),
    tz.m.ScaleModulesByCosineSimilarity(
        main = tz.m.Adam(),
        compare = tz.m.RMSprop(0.999, debiased=True),
    ),
    tz.m.LR(1e-2)
)

Source code in torchzero/modules/momentum/cautious.py
class ScaleModulesByCosineSimilarity(Module):
    """Scales the output of :code:`main` module by it's cosine similarity to the output
    of :code:`compare` module.

    Args:
        main (Chainable): main module or sequence of modules whose update will be scaled.
        compare (Chainable): module or sequence of modules to compare to
        eps (float, optional): epsilon for division. Defaults to 1e-6.

    ## Examples:

    Adam scaled by similarity to RMSprop
    ```python
    opt = tz.Modular(
        bench.parameters(),
        tz.m.ScaleModulesByCosineSimilarity(
            main = tz.m.Adam(),
            compare = tz.m.RMSprop(0.999, debiased=True),
        ),
        tz.m.LR(1e-2)
    )
    ```
    """
    def __init__(
        self,
        main: Chainable,
        compare: Chainable,
        eps=1e-6,
    ):
        defaults = dict(eps=eps)
        super().__init__(defaults)

        self.set_child('main', main)
        self.set_child('compare', compare)

    @torch.no_grad
    def step(self, var):
        main = self.children['main']
        compare = self.children['compare']

        main_var = main.step(var.clone(clone_update=True))
        var.update_attrs_from_clone_(main_var)

        compare_var = compare.step(var.clone(clone_update=True))
        var.update_attrs_from_clone_(compare_var)

        m = TensorList(main_var.get_update())
        c = TensorList(compare_var.get_update())
        eps = self.defaults['eps']

        cos_sim = m.dot(c) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)

        var.update = m.mul_(cos_sim)
        return var

ScipyMinimizeScalar

Bases: torchzero.modules.line_search.line_search.LineSearchBase

Line search via :code:scipy.optimize.minimize_scalar which implements brent, golden search and bounded brent methods.

Parameters:

  • method (str | None, default: None ) –

    "brent", "golden" or "bounded". Defaults to None.

  • maxiter (int | None, default: None ) –

    maximum number of function evaluations the line search is allowed to perform. Defaults to None.

  • bracket (Sequence | None, default: None ) –

    Either a triple (xa, xb, xc) satisfying xa < xb < xc and func(xb) < func(xa) and func(xb) < func(xc), or a pair (xa, xb) to be used as initial points for a downhill bracket search. Defaults to None.

  • bounds (Sequence | None, default: None ) –

    For method ‘bounded’, bounds is mandatory and must have two finite items corresponding to the optimization bounds. Defaults to None.

  • tol (float | None, default: None ) –

    Tolerance for termination. Defaults to None.

  • options (dict | None, default: None ) –

    A dictionary of solver options. Defaults to None.

For more details on methods and arguments refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html

Source code in torchzero/modules/line_search/scipy.py
class ScipyMinimizeScalar(LineSearchBase):
    """Line search via :code:`scipy.optimize.minimize_scalar` which implements brent, golden search and bounded brent methods.

    Args:
        method (str | None, optional): "brent", "golden" or "bounded". Defaults to None.
        maxiter (int | None, optional): maximum number of function evaluations the line search is allowed to perform. Defaults to None.
        bracket (Sequence | None, optional):
            Either a triple (xa, xb, xc) satisfying xa < xb < xc and func(xb) < func(xa) and  func(xb) < func(xc), or a pair (xa, xb) to be used as initial points for a downhill bracket search. Defaults to None.
        bounds (Sequence | None, optional):
            For method ‘bounded’, bounds is mandatory and must have two finite items corresponding to the optimization bounds. Defaults to None.
        tol (float | None, optional): Tolerance for termination. Defaults to None.
        options (dict | None, optional): A dictionary of solver options. Defaults to None.

    For more details on methods and arguments refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html

    """
    def __init__(
        self,
        method: str | None = None,
        maxiter: int | None = None,
        bracket=None,
        bounds=None,
        tol: float | None = None,
        options=None,
    ):
        defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter)
        super().__init__(defaults)

        import scipy.optimize
        self.scopt = scipy.optimize


    @torch.no_grad
    def search(self, update, var):
        objective = self.make_objective(var=var)
        method, bracket, bounds, tol, options, maxiter = itemgetter(
            'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.defaults)

        if maxiter is not None:
            options = dict(options) if isinstance(options, Mapping) else {}
            options['maxiter'] = maxiter

        res = self.scopt.minimize_scalar(objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
        return res.x

Sequential

Bases: torchzero.core.module.Module

On each step, this sequentially steps with :code:modules :code:steps times.

The update is taken to be the parameter difference between parameters before and after the inner loop.

Source code in torchzero/modules/misc/multistep.py
class Sequential(Module):
    """On each step, this sequentially steps with :code:`modules` :code:`steps` times.

    The update is taken to be the parameter difference between parameters before and after the inner loop."""
    def __init__(self, modules: Iterable[Chainable], steps: int=1):
        defaults = dict(steps=steps)
        super().__init__(defaults)
        self.set_children_sequence(modules)

    @torch.no_grad
    def step(self, var):
        return _sequential_step(self, var, sequential=True)

Shampoo

Bases: torchzero.core.transform.Transform

Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).

.. note:: Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.

.. note:: Shampoo is a very computationally expensive optimizer, increase :code:update_freq if it is too slow.

.. note:: SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as :code:tz.m.SOAP.

Parameters:

  • decay (float | None, default: None ) –

    slowly decays preconditioners. Defaults to None.

  • beta (float | None, default: None ) –

    if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.

  • update_freq (int, default: 10 ) –

    preconditioner update frequency. Defaults to 10.

  • exp_override (int | None, default: 2 ) –

    matrix exponent override, if not set, uses 2*ndim. Defaults to 2.

  • merge_small (bool, default: True ) –

    whether to merge small dims on tensors. Defaults to True.

  • max_dim (int, default: 2000 ) –

    maximum dimension size for preconditioning. Defaults to 2_000.

  • precondition_1d (bool, default: True ) –

    whether to precondition 1d tensors. Defaults to True.

  • adagrad_eps (float, default: 1e-08 ) –

    epsilon for adagrad division for tensors where shampoo can't be applied. Defaults to 1e-8.

  • inner (Chainable | None, default: None ) –

    module applied after updating preconditioners and before applying preconditioning. For example if beta≈0.999 and inner=tz.m.EMA(0.9), this becomes Adam with shampoo preconditioner (ignoring debiasing). Defaults to None.

Examples:

Shampoo grafted to Adam

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.GraftModules(
        direction = tz.m.Shampoo(),
        magnitude = tz.m.Adam(),
    ),
    tz.m.LR(1e-3)
)

Adam with Shampoo preconditioner

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
    tz.m.Debias(0.9, 0.999),
    tz.m.LR(1e-3)
)
Source code in torchzero/modules/adaptive/shampoo.py
class Shampoo(Transform):
    """Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).

    .. note::
        Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.

    .. note::
        Shampoo is a very computationally expensive optimizer, increase :code:`update_freq` if it is too slow.

    .. note::
        SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as :code:`tz.m.SOAP`.

    Args:
        decay (float | None, optional): slowly decays preconditioners. Defaults to None.
        beta (float | None, optional):
            if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
        update_freq (int, optional): preconditioner update frequency. Defaults to 10.
        exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to 2.
        merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
        max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 2_000.
        precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
        adagrad_eps (float, optional): epsilon for adagrad division for tensors where shampoo can't be applied. Defaults to 1e-8.
        inner (Chainable | None, optional):
            module applied after updating preconditioners and before applying preconditioning.
            For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
            Defaults to None.

    Examples:
        Shampoo grafted to Adam

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.GraftModules(
                    direction = tz.m.Shampoo(),
                    magnitude = tz.m.Adam(),
                ),
                tz.m.LR(1e-3)
            )

        Adam with Shampoo preconditioner

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
                tz.m.Debias(0.9, 0.999),
                tz.m.LR(1e-3)
            )
    """
    def __init__(
        self,
        decay: float | None = None,
        beta: float | None = None,
        reg: float = 1e-12,
        update_freq: int = 10,
        exp_override: int | None = 2,
        merge_small: bool = True,
        max_dim: int = 2_000,
        precondition_1d: bool = True,
        adagrad_eps: float = 1e-8,
        inner: Chainable | None = None,
    ):
        defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps, reg=reg)
        super().__init__(defaults, uses_grad=False)

        if inner is not None:
            self.set_child('inner', inner)

    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        merged_tensors = [] # target with merged dims

        # update preconditioners
        for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
            beta, update_freq, exp_override, merge_small, max_dim, precondition_1d, reg = itemgetter(
                'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d', "reg")(setting)

            if merge_small:
                t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)

            merged_tensors.append(t)

            # initialize accumulators and preconditioners for each dim on 1st step
            if 'accumulators' not in state:

                if not precondition_1d and t.ndim <= 1:
                    state['accumulators'] = []

                else:
                    state['accumulators'] = [torch.eye(s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
                    state['preconditioners'] = [torch.eye(s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]

                # either scalar parameter, 1d with precondition_1d=False, or too big, then basic diagonal preconditioner is used.
                if len([i is not None for i in state['accumulators']]) == 0:
                    state['diagonal_accumulator'] = torch.zeros_like(t)

                state['step'] = 0

            # update preconditioners
            if 'diagonal_accumulator' in state:
                update_diagonal_(t, state['diagonal_accumulator'], beta)
            else:
                update_shampoo_preconditioner_(
                    t,
                    accumulators_=state['accumulators'],
                    preconditioners_=state['preconditioners'],
                    step=state['step'],
                    update_freq=update_freq,
                    exp_override=exp_override,
                    beta=beta,
                    reg=reg,
                )

        # inner step
        if 'inner' in self.children:
            tensors = apply_transform(self.children['inner'], tensors, params=params, grads=grads)

            # have to merge small dims again
            merged_tensors = [] # target with merged dims
            for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
                if setting['merge_small']:
                    t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, setting['max_dim'])
                merged_tensors.append(t)

        # precondition
        for i,(t,state, setting) in enumerate(zip(merged_tensors, states, settings)):
            decay, merge_small, adagrad_eps= itemgetter('decay', 'merge_small', 'adagrad_eps')(setting)

            if 'diagonal_accumulator' in state:
                tensors[i] = apply_diagonal_(t, state['diagonal_accumulator'], decay=decay, eps=adagrad_eps)
            else:
                tensors[i] = apply_shampoo_preconditioner(t, preconditioners_=state['preconditioners'], decay=decay)

            if merge_small:
                tensors[i] = _unmerge_small_dims(tensors[i], state['flat_sizes'], state['sort_idxs'])

            state['step'] += 1

        return tensors

ShorR

Bases: torchzero.modules.quasi_newton.quasi_newton.HessianUpdateStrategy

Shor’s r-algorithm.

Note

A line search such as tz.m.StrongWolfe(a_init="quadratic", fallback=True) is required. Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling, so setting a_init in the line search is recommended.

References

S HOR , N. Z. (1985) Minimization Methods for Non-differentiable Functions. New York: Springer.

Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720. - good overview.

Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998. - this is where a more efficient formula is described.

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class ShorR(HessianUpdateStrategy):
    """Shor’s r-algorithm.

    Note:
        A line search such as ``tz.m.StrongWolfe(a_init="quadratic", fallback=True)`` is required.
        Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling,
        so setting ``a_init`` in the line search is recommended.

    References:
        S HOR , N. Z. (1985) Minimization Methods for Non-differentiable Functions. New York: Springer.

        Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720. - good overview.

        Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998. - this is where a more efficient formula is described.
    """

    def __init__(
        self,
        alpha=0.5,
        init_scale: float | Literal["auto"] = 1,
        tol: float = 1e-32,
        ptol: float | None = 1e-32,
        ptol_restart: bool = False,
        gtol: float | None = 1e-32,
        restart_interval: int | None | Literal['auto'] = None,
        beta: float | None = None,
        update_freq: int = 1,
        scale_first: bool = False,
        concat_params: bool = True,
        # inverse: bool = True,
        inner: Chainable | None = None,
    ):
        defaults = dict(alpha=alpha)
        super().__init__(
            defaults=defaults,
            init_scale=init_scale,
            tol=tol,
            ptol=ptol,
            ptol_restart=ptol_restart,
            gtol=gtol,
            restart_interval=restart_interval,
            beta=beta,
            update_freq=update_freq,
            scale_first=scale_first,
            concat_params=concat_params,
            inverse=True,
            inner=inner,
        )

    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return shor_r_(H=H, y=y, alpha=setting['alpha'])

Sign

Bases: torchzero.core.transform.Transform

Returns :code:sign(input)

Source code in torchzero/modules/ops/unary.py
class Sign(Transform):
    """Returns :code:`sign(input)`"""
    def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        torch._foreach_sign_(tensors)
        return tensors

SignConsistencyLRs

Bases: torchzero.core.transform.Transform

Outputs per-weight learning rates based on consecutive sign consistency.

The learning rate for a weight is multiplied by :code:nplus when two consecutive update signs are the same, otherwise it is multiplied by :code:nplus. The learning rates are bounded to be in :code:(lb, ub) range.

Examples:

GD scaled by consecutive gradient sign consistency

.. code-block:: python

    opt = tz.Modular(
        model.parameters(),
        tz.m.Mul(tz.m.SignConsistencyLRs()),
        tz.m.LR(1e-2)
    )
Source code in torchzero/modules/adaptive/rprop.py
class SignConsistencyLRs(Transform):
    """Outputs per-weight learning rates based on consecutive sign consistency.

    The learning rate for a weight is multiplied by :code:`nplus` when two consecutive update signs are the same, otherwise it is multiplied by :code:`nplus`. The learning rates are bounded to be in :code:`(lb, ub)` range.

    Examples:

        GD scaled by consecutive gradient sign consistency

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.Mul(tz.m.SignConsistencyLRs()),
                tz.m.LR(1e-2)
            )

    """
    def __init__(
        self,
        nplus: float = 1.2,
        nminus: float = 0.5,
        lb: float | None = 1e-6,
        ub: float | None = 50,
        alpha: float = 1,
        target: Target = 'update'
    ):
        defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
        super().__init__(defaults, uses_grad=False, target = target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        target = as_tensorlist(tensors)
        nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
        prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)

        if step == 0:
            lrs.set_(target.full_like([s['alpha'] for s in settings]))

        target = sign_consistency_lrs_(
            tensors = target,
            prev_ = prev,
            lrs_ = lrs,
            nplus = nplus,
            nminus = nminus,
            lb = lb,
            ub = ub,
            step = step,
        )
        return target.clone()

SignConsistencyMask

Bases: torchzero.core.transform.Transform

Outputs a mask of sign consistency of current and previous inputs.

The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.

Examples:

GD that skips update for weights where gradient sign changed compared to previous gradient.

.. code-block:: python

    opt = tz.Modular(
        model.parameters(),
        tz.m.Mul(tz.m.SignConsistencyMask()),
        tz.m.LR(1e-2)
    )
Source code in torchzero/modules/adaptive/rprop.py
class SignConsistencyMask(Transform):
    """
    Outputs a mask of sign consistency of current and previous inputs.

    The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.

    Examples:

        GD that skips update for weights where gradient sign changed compared to previous gradient.

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.Mul(tz.m.SignConsistencyMask()),
                tz.m.LR(1e-2)
            )

    """
    def __init__(self,target: Target = 'update'):
        super().__init__({}, uses_grad=False, target = target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        prev = unpack_states(states, tensors, 'prev', cls=TensorList)
        mask = prev.mul_(tensors).gt_(0)
        prev.copy_(tensors)
        return mask

SixthOrder3P

Bases: torchzero.modules.second_order.multipoint.HigherOrderMethodBase

Sixth-order iterative method.

Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.

Source code in torchzero/modules/second_order/multipoint.py
class SixthOrder3P(HigherOrderMethodBase):
    """Sixth-order iterative method.

    Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
    """
    def __init__(self, lstsq: bool=False, vectorize: bool = True):
        defaults=dict(lstsq=lstsq)
        super().__init__(defaults=defaults, vectorize=vectorize)

    def one_iteration(self, x, evaluate, var):
        settings = self.defaults
        lstsq = settings['lstsq']
        def f(x): return evaluate(x, 1)[1]
        def f_j(x): return evaluate(x, 2)[1:]
        x_star = sixth_order_3p(x, f, f_j, lstsq)
        return x - x_star

SixthOrder3PM2

Bases: torchzero.modules.second_order.multipoint.HigherOrderMethodBase

Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45.

Source code in torchzero/modules/second_order/multipoint.py
class SixthOrder3PM2(HigherOrderMethodBase):
    """Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45."""
    def __init__(self, lstsq: bool=False, vectorize: bool = True):
        defaults=dict(lstsq=lstsq)
        super().__init__(defaults=defaults, vectorize=vectorize)

    def one_iteration(self, x, evaluate, var):
        settings = self.defaults
        lstsq = settings['lstsq']
        def f_j(x): return evaluate(x, 2)[1:]
        def f(x): return evaluate(x, 1)[1]
        x_star = sixth_order_3pm2(x, f, f_j, lstsq)
        return x - x_star

SixthOrder5P

Bases: torchzero.modules.second_order.multipoint.HigherOrderMethodBase

Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139.

Source code in torchzero/modules/second_order/multipoint.py
class SixthOrder5P(HigherOrderMethodBase):
    """Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139."""
    def __init__(self, lstsq: bool=False, vectorize: bool = True):
        defaults=dict(lstsq=lstsq)
        super().__init__(defaults=defaults, vectorize=vectorize)

    def one_iteration(self, x, evaluate, var):
        settings = self.defaults
        lstsq = settings['lstsq']
        def f_j(x): return evaluate(x, 2)[1:]
        x_star = sixth_order_5p(x, f_j, lstsq)
        return x - x_star

SophiaH

Bases: torchzero.core.module.Module

SophiaH optimizer from https://arxiv.org/abs/2305.14342

This is similar to Adam, but the second momentum is replaced by an exponential moving average of randomized hessian diagonal estimates, and the update is agressively clipped.

.. note:: In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the :code:inner argument if you wish to apply SophiaH preconditioning to another module's output.

.. note:: If you are using gradient estimators or reformulations, set :code:hvp_method to "forward" or "central".

.. note:: This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

Parameters:

  • beta1 (float, default: 0.96 ) –

    first momentum. Defaults to 0.96.

  • beta2 (float, default: 0.99 ) –

    momentum for hessian diagonal estimate. Defaults to 0.99.

  • update_freq (int, default: 10 ) –

    frequency of updating hessian diagonal estimate via a hessian-vector product. Defaults to 10.

  • precond_scale (float, default: 1 ) –

    scale of the preconditioner. Defaults to 1.

  • clip (float, default: 1 ) –

    clips update to (-clip, clip). Defaults to 1.

  • eps (float, default: 1e-12 ) –

    clips hessian diagonal esimate to be no less than this value. Defaults to 1e-12.

  • hvp_method (str, default: 'autograd' ) –

    Determines how Hessian-vector products are evaluated.

    • "autograd": Use PyTorch's autograd to calculate exact HVPs. This requires creating a graph for the gradient.
    • "forward": Use a forward finite difference formula to approximate the HVP. This requires one extra gradient evaluation.
    • "central": Use a central finite difference formula for a more accurate HVP approximation. This requires two extra gradient evaluations. Defaults to "autograd".
  • fd_h (float, default: 0.001 ) –

    finite difference step size if :code:hvp_method is "forward" or "central". Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of hessian-vector products with random vectors to evaluate each time when updating the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.

  • seed (int | None, default: None ) –

    seed for random vectors. Defaults to None.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to the output of this module. Defaults to None.

Examples:

Using SophiaH:

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.SophiaH(),
    tz.m.LR(0.1)
)

SophiaH preconditioner can be applied to any other module by passing it to the :code:inner argument. Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying SophiaH preconditioning to nesterov momentum (:code:tz.m.NAG):

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
    tz.m.LR(0.1)
)
Source code in torchzero/modules/adaptive/sophia_h.py
class SophiaH(Module):
    """SophiaH optimizer from https://arxiv.org/abs/2305.14342

    This is similar to Adam, but the second momentum is replaced by an exponential moving average of randomized hessian diagonal estimates, and the update is agressively clipped.

    .. note::
        In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply SophiaH preconditioning to another module's output.

    .. note::
        If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".

    .. note::
        This module requires the a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients for calculating HVPs.
        The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        beta1 (float, optional): first momentum. Defaults to 0.96.
        beta2 (float, optional): momentum for hessian diagonal estimate. Defaults to 0.99.
        update_freq (int, optional):
            frequency of updating hessian diagonal estimate via a hessian-vector product. Defaults to 10.
        precond_scale (float, optional):
            scale of the preconditioner. Defaults to 1.
        clip (float, optional):
            clips update to (-clip, clip). Defaults to 1.
        eps (float, optional):
            clips hessian diagonal esimate to be no less than this value. Defaults to 1e-12.
        hvp_method (str, optional):
            Determines how Hessian-vector products are evaluated.

            - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
              This requires creating a graph for the gradient.
            - ``"forward"``: Use a forward finite difference formula to
              approximate the HVP. This requires one extra gradient evaluation.
            - ``"central"``: Use a central finite difference formula for a
              more accurate HVP approximation. This requires two extra
              gradient evaluations.
            Defaults to "autograd".
        fd_h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
        n_samples (int, optional):
            number of hessian-vector products with random vectors to evaluate each time when updating
            the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
        seed (int | None, optional): seed for random vectors. Defaults to None.
        inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.

    Examples:
        Using SophiaH:

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.SophiaH(),
                tz.m.LR(0.1)
            )

        SophiaH preconditioner can be applied to any other module by passing it to the :code:`inner` argument.
        Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying
        SophiaH preconditioning to nesterov momentum (:code:`tz.m.NAG`):

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
                tz.m.LR(0.1)
            )

    """
    def __init__(
        self,
        beta1: float = 0.96,
        beta2: float = 0.99,
        update_freq: int = 10,
        precond_scale: float = 1,
        clip: float = 1,
        eps: float = 1e-12,
        hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
        fd_h: float = 1e-3,
        n_samples = 1,
        seed: int | None = None,
        inner: Chainable | None = None
    ):
        defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, precond_scale=precond_scale, clip=clip, eps=eps, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
        super().__init__(defaults)

        if inner is not None:
            self.set_child('inner', inner)

    @torch.no_grad
    def step(self, var):
        params = var.params
        settings = self.settings[params[0]]
        hvp_method = settings['hvp_method']
        fd_h = settings['fd_h']
        update_freq = settings['update_freq']
        n_samples = settings['n_samples']

        seed = settings['seed']
        generator = None
        if seed is not None:
            if 'generator' not in self.global_state:
                self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
            generator = self.global_state['generator']

        beta1, beta2, precond_scale, clip, eps = self.get_settings(params,
            'beta1', 'beta2', 'precond_scale', 'clip', 'eps', cls=NumberList)

        exp_avg, h_exp_avg = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)

        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        closure = var.closure
        assert closure is not None

        h = None
        if step % update_freq == 0:

            rgrad=None
            for i in range(n_samples):
                u = [torch.randn(p.shape, device=p.device, dtype=p.dtype, generator=generator) for p in params]

                Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
                                     h=fd_h, normalize=True, retain_grad=i < n_samples-1)
                Hvp = tuple(Hvp)

                if h is None: h = Hvp
                else: torch._foreach_add_(h, Hvp)

            assert h is not None
            if n_samples > 1: torch._foreach_div_(h, n_samples)

        update = var.get_update()
        if 'inner' in self.children:
            update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)

        var.update = sophia_H(
            tensors=TensorList(update),
            h=TensorList(h) if h is not None else None,
            exp_avg_=exp_avg,
            h_exp_avg_=h_exp_avg,
            beta1=beta1,
            beta2=beta2,
            update_freq=update_freq,
            precond_scale=precond_scale,
            clip=clip,
            eps=eps,
            step=step,
        )
        return var

Split

Bases: torchzero.core.module.Module

Apply true modules to all parameters filtered by filter, apply false modules to all other parameters.

Parameters:

  • filter (Filter, bool]) –

    a filter that selects tensors to be optimized by true. - tensor or iterable of tensors (e.g. encoder.parameters()). - function that takes in tensor and outputs a bool (e.g. lambda x: x.ndim >= 2). - a sequence of above (acts as "or", so returns true if any of them is true).

  • true (Chainable | None) –

    modules that are applied to tensors where filter is True.

  • false (Chainable | None) –

    modules that are applied to tensors where filter is False.

Examples:

Muon with Adam fallback using same hyperparams as https://github.com/KellerJordan/Muon

opt = tz.Modular(
    model.parameters(),
    tz.m.NAG(0.95),
    tz.m.Split(
        lambda p: p.ndim >= 2,
        true = tz.m.Orthogonalize(),
        false = [tz.m.Adam(0.9, 0.95), tz.m.Mul(1/66)],
    ),
    tz.m.LR(1e-2),
)
Source code in torchzero/modules/misc/split.py
class Split(Module):
    """Apply ``true`` modules to all parameters filtered by ``filter``, apply ``false`` modules to all other parameters.

    Args:
        filter (Filter, bool]):
            a filter that selects tensors to be optimized by ``true``.
            - tensor or iterable of tensors (e.g. ``encoder.parameters()``).
            - function that takes in tensor and outputs a bool (e.g. ``lambda x: x.ndim >= 2``).
            - a sequence of above (acts as "or", so returns true if any of them is true).

        true (Chainable | None): modules that are applied to tensors where ``filter`` is ``True``.
        false (Chainable | None): modules that are applied to tensors where ``filter`` is ``False``.

    ### Examples:

    Muon with Adam fallback using same hyperparams as https://github.com/KellerJordan/Muon

    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.NAG(0.95),
        tz.m.Split(
            lambda p: p.ndim >= 2,
            true = tz.m.Orthogonalize(),
            false = [tz.m.Adam(0.9, 0.95), tz.m.Mul(1/66)],
        ),
        tz.m.LR(1e-2),
    )
    ```
    """
    def __init__(self, filter: Filter, true: Chainable | None, false: Chainable | None):
        defaults = dict(filter=filter)
        super().__init__(defaults)

        if true is not None: self.set_child('true', true)
        if false is not None: self.set_child('false', false)

    def step(self, var):

        params = var.params
        filter = _make_filter(self.settings[params[0]]['filter'])

        true_idxs = []
        false_idxs = []
        for i,p in enumerate(params):
            if filter(p): true_idxs.append(i)
            else: false_idxs.append(i)

        if 'true' in self.children and len(true_idxs) > 0:
            true = self.children['true']
            var = _split(true, idxs=true_idxs, params=params, var=var)

        if 'false' in self.children and len(false_idxs) > 0:
            false = self.children['false']
            var = _split(false, idxs=false_idxs, params=params, var=var)

        return var

Sqrt

Bases: torchzero.core.transform.Transform

Returns :code:sqrt(input)

Source code in torchzero/modules/ops/unary.py
class Sqrt(Transform):
    """Returns :code:`sqrt(input)`"""
    def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        torch._foreach_sqrt_(tensors)
        return tensors

SqrtEMASquared

Bases: torchzero.core.transform.Transform

Maintains an exponential moving average of squared updates, outputs optionally debiased square root.

Parameters:

  • beta (float, default: 0.999 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • debiased (bool, default: False ) –

    whether to multiply the output by a debiasing term from the Adam method. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Methods:

  • SQRT_EMA_SQ_FN

    Updates exp_avg_sq_ with EMA of squared tensors and calculates it's square root,

Source code in torchzero/modules/ops/higher_level.py
class SqrtEMASquared(Transform):
    """Maintains an exponential moving average of squared updates, outputs optionally debiased square root.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    SQRT_EMA_SQ_FN: staticmethod = staticmethod(sqrt_ema_sq_)
    def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2,):
        defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
        super().__init__(defaults, uses_grad=False)


    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return self.SQRT_EMA_SQ_FN(
            TensorList(tensors),
            exp_avg_sq_=exp_avg_sq,
            beta=beta,
            max_exp_avg_sq_=max_exp_avg_sq,
            debiased=debiased,
            step=step,
            pow=pow,
        )

SQRT_EMA_SQ_FN

SQRT_EMA_SQ_FN(tensors: TensorList, exp_avg_sq_: TensorList, beta: float | NumberList, max_exp_avg_sq_: TensorList | None, debiased: bool, step: int, pow: float = 2, ema_sq_fn: Callable = ema_sq_)

Updates exp_avg_sq_ with EMA of squared tensors and calculates it's square root, with optional AMSGrad and debiasing.

Returns new tensors.

Source code in torchzero/modules/functional.py
def sqrt_ema_sq_(
    tensors: TensorList,
    exp_avg_sq_: TensorList,
    beta: float | NumberList,
    max_exp_avg_sq_: TensorList | None,
    debiased: bool,
    step: int,
    pow: float = 2,
    ema_sq_fn: Callable = ema_sq_,
):
    """
    Updates `exp_avg_sq_` with EMA of squared `tensors` and calculates it's square root,
    with optional AMSGrad and debiasing.

    Returns new tensors.
    """
    exp_avg_sq_=ema_sq_fn(
        tensors=tensors,
        exp_avg_sq_=exp_avg_sq_,
        beta=beta,
        max_exp_avg_sq_=max_exp_avg_sq_,
        pow=pow,
    )

    sqrt_exp_avg_sq = root(exp_avg_sq_, pow, inplace=False)

    if debiased: sqrt_exp_avg_sq = debias_second_momentum(sqrt_exp_avg_sq, step=step, beta=beta, pow=pow, inplace=True)
    return sqrt_exp_avg_sq

SqrtHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class SqrtHomotopy(HomotopyBase):
    def __init__(self): super().__init__()
    def loss_transform(self, loss): return (loss+1e-12).sqrt()

SquareHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class SquareHomotopy(HomotopyBase):
    def __init__(self): super().__init__()
    def loss_transform(self, loss): return loss.square().copysign(loss)

StepSize

Bases: torchzero.core.transform.Transform

this is exactly the same as LR, except the lr parameter can be renamed to any other name to avoid clashes

Source code in torchzero/modules/step_size/lr.py
class StepSize(Transform):
    """this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
    def __init__(self, step_size: float, key = 'step_size'):
        defaults={"key": key, key: step_size}
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)

StrongWolfe

Bases: torchzero.modules.line_search.line_search.LineSearchBase

Interpolation line search satisfying Strong Wolfe condition.

Parameters:

  • c1 (float, default: 0.0001 ) –

    sufficient descent condition. Defaults to 1e-4.

  • c2 (float, default: 0.9 ) –

    strong curvature condition. For CG set to 0.1. Defaults to 0.9.

  • a_init (str, default: 'fixed' ) –

    strategy for initializing the initial step size guess. - "fixed" - uses a fixed value specified in init_value argument. - "first-order" - assumes first-order change in the function at iterate will be the same as that obtained at the previous step. - "quadratic" - interpolates quadratic to f(x_{-1}) and f_x. - "quadratic-clip" - same as quad, but uses min(1, 1.01*alpha) as described in Numerical Optimization. - "previous" - uses final step size found on previous iteration.

    For 2nd order methods it is usually best to leave at "fixed". For methods that do not produce well scaled search directions, e.g. conjugate gradient, "first-order" or "quadratic-clip" are recommended. Defaults to 'init'.

  • a_max (float, default: 1000000000000.0 ) –

    upper bound for the proposed step sizes. Defaults to 1e12.

  • init_value (float, default: 1 ) –

    initial step size. Used when a_init="fixed", and with other strategies as fallback value. Defaults to 1.

  • maxiter (int, default: 25 ) –

    maximum number of line search iterations. Defaults to 25.

  • maxzoom (int, default: 10 ) –

    maximum number of zoom iterations. Defaults to 10.

  • maxeval (int | None, default: None ) –

    maximum number of function evaluations. Defaults to None.

  • tol_change (float, default: 1e-09 ) –

    tolerance, terminates on small brackets. Defaults to 1e-9.

  • interpolation (str, default: 'cubic' ) –

    What type of interpolation to use. - "bisection" - uses the middle point. This is robust, especially if the objective function is non-smooth, however it may need more function evaluations. - "quadratic" - minimizes a quadratic model, generally outperformed by "cubic". - "cubic" - minimizes a cubic model - this is the most widely used interpolation strategy. - "polynomial" - fits a a polynomial to all points obtained during line search. - "polynomial2" - alternative polynomial fit, where if a point is outside of bounds, a lower degree polynomial is tried. This may have faster convergence than "cubic" and "polynomial".

    Defaults to 'cubic'.

  • adaptive (bool, default: True ) –

    if True, the initial step size will be halved when line search failed to find a good direction. When a good direction is found, initial step size is reset to the original value. Defaults to True.

  • fallback (bool, default: False ) –

    if True, when no point satisfied strong wolfe criteria, returns a point with value lower than initial value that doesn't satisfy the criteria. Defaults to False.

  • plus_minus (bool, default: False ) –

    if True, enables the plus-minus variant, where if curvature is negative, line search is performed in the opposite direction. Defaults to False.

Examples:

Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG. Since CG doesn't produce well scaled directions, initial alpha can be determined from function values by a_init="first-order".

opt = tz.Modular(
    model.parameters(),
    tz.m.PolakRibiere(),
    tz.m.StrongWolfe(c2=0.1, a_init="first-order")
)

LBFGS strong wolfe line search:

opt = tz.Modular(
    model.parameters(),
    tz.m.LBFGS(),
    tz.m.StrongWolfe()
)

Source code in torchzero/modules/line_search/strong_wolfe.py
class StrongWolfe(LineSearchBase):
    """Interpolation line search satisfying Strong Wolfe condition.

    Args:
        c1 (float, optional): sufficient descent condition. Defaults to 1e-4.
        c2 (float, optional): strong curvature condition. For CG set to 0.1. Defaults to 0.9.
        a_init (str, optional):
            strategy for initializing the initial step size guess.
            - "fixed" - uses a fixed value specified in `init_value` argument.
            - "first-order" - assumes first-order change in the function at iterate will be the same as that obtained at the previous step.
            - "quadratic" - interpolates quadratic to f(x_{-1}) and f_x.
            - "quadratic-clip" - same as quad, but uses min(1, 1.01*alpha) as described in Numerical Optimization.
            - "previous" - uses final step size found on previous iteration.

            For 2nd order methods it is usually best to leave at "fixed".
            For methods that do not produce well scaled search directions, e.g. conjugate gradient,
            "first-order" or "quadratic-clip" are recommended. Defaults to 'init'.
        a_max (float, optional): upper bound for the proposed step sizes. Defaults to 1e12.
        init_value (float, optional):
            initial step size. Used when ``a_init``="fixed", and with other strategies as fallback value. Defaults to 1.
        maxiter (int, optional): maximum number of line search iterations. Defaults to 25.
        maxzoom (int, optional): maximum number of zoom iterations. Defaults to 10.
        maxeval (int | None, optional): maximum number of function evaluations. Defaults to None.
        tol_change (float, optional): tolerance, terminates on small brackets. Defaults to 1e-9.
        interpolation (str, optional):
            What type of interpolation to use.
            - "bisection" - uses the middle point. This is robust, especially if the objective function is non-smooth, however it may need more function evaluations.
            - "quadratic" - minimizes a quadratic model, generally outperformed by "cubic".
            - "cubic" - minimizes a cubic model - this is the most widely used interpolation strategy.
            - "polynomial" - fits a a polynomial to all points obtained during line search.
            - "polynomial2" - alternative polynomial fit, where if a point is outside of bounds, a lower degree polynomial is tried.
            This may have faster convergence than "cubic" and "polynomial".

            Defaults to 'cubic'.
        adaptive (bool, optional):
            if True, the initial step size will be halved when line search failed to find a good direction.
            When a good direction is found, initial step size is reset to the original value. Defaults to True.
        fallback (bool, optional):
            if True, when no point satisfied strong wolfe criteria,
            returns a point with value lower than initial value that doesn't satisfy the criteria. Defaults to False.
        plus_minus (bool, optional):
            if True, enables the plus-minus variant, where if curvature is negative, line search is performed
            in the opposite direction. Defaults to False.


    ## Examples:

    Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG. Since CG doesn't produce well scaled directions, initial alpha can be determined from function values by ``a_init="first-order"``.

    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.PolakRibiere(),
        tz.m.StrongWolfe(c2=0.1, a_init="first-order")
    )
    ```

    LBFGS strong wolfe line search:
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.LBFGS(),
        tz.m.StrongWolfe()
    )
    ```

    """
    def __init__(
        self,
        c1: float = 1e-4,
        c2: float = 0.9,
        a_init: Literal['first-order', 'quadratic', 'quadratic-clip', 'previous', 'fixed'] = 'fixed',
        a_max: float = 1e12,
        init_value: float = 1,
        maxiter: int = 25,
        maxzoom: int = 10,
        maxeval: int | None = None,
        tol_change: float = 1e-9,
        interpolation: Literal["quadratic", "cubic", "bisection", "polynomial", 'polynomial2'] = 'cubic',
        adaptive = True,
        fallback:bool = False,
        plus_minus = False,
    ):
        defaults=dict(init_value=init_value,init=a_init,a_max=a_max,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom, fallback=fallback,
                      maxeval=maxeval, adaptive=adaptive, interpolation=interpolation, plus_minus=plus_minus, tol_change=tol_change)
        super().__init__(defaults=defaults)

        self.global_state['initial_scale'] = 1.0

    @torch.no_grad
    def search(self, update, var):
        self._g_prev = self._f_prev = None
        objective = self.make_objective_with_derivative(var=var)

        init_value, init, c1, c2, a_max, maxiter, maxzoom, maxeval, interpolation, adaptive, plus_minus, fallback, tol_change = itemgetter(
            'init_value', 'init', 'c1', 'c2', 'a_max', 'maxiter', 'maxzoom',
            'maxeval', 'interpolation', 'adaptive', 'plus_minus', 'fallback', 'tol_change')(self.defaults)

        dir = as_tensorlist(var.get_update())
        grad_list = var.get_grad()

        g_0 = -sum(t.sum() for t in torch._foreach_mul(grad_list, dir))
        f_0 = var.get_loss(False)
        dir_norm = dir.global_vector_norm()

        inverted = False
        if plus_minus and g_0 > 0:
            original_objective = objective
            def inverted_objective(a):
                l, g_a = original_objective(-a)
                return l, -g_a
            objective = inverted_objective
            inverted = True

        # --------------------- determine initial step size guess -------------------- #
        init = init.lower().strip()

        a_init = init_value
        if init == 'fixed':
            pass # use init_value

        elif init == 'previous':
            if 'a_prev' in self.global_state:
                a_init = self.global_state['a_prev']

        elif init == 'first-order':
            if 'g_prev' in self.global_state and g_0 < -torch.finfo(dir[0].dtype).tiny * 2:
                a_prev = self.global_state['a_prev']
                g_prev = self.global_state['g_prev']
                if g_prev < 0:
                    a_init = a_prev * g_prev / g_0

        elif init in ('quadratic', 'quadratic-clip'):
            if 'f_prev' in self.global_state and g_0 < -torch.finfo(dir[0].dtype).tiny * 2:
                f_prev = self.global_state['f_prev']
                if f_0 < f_prev:
                    a_init = 2 * (f_0 - f_prev) / g_0
                    if init == 'quadratic-clip': a_init = min(1, 1.01*a_init)
        else:
            raise ValueError(init)

        if adaptive:
            a_init *= self.global_state.get('initial_scale', 1)

        strong_wolfe = _StrongWolfe(
            f=objective,
            f_0=f_0,
            g_0=g_0,
            d_norm=dir_norm,
            a_init=a_init,
            a_max=a_max,
            c1=c1,
            c2=c2,
            maxiter=maxiter,
            maxzoom=maxzoom,
            maxeval=maxeval,
            tol_change=tol_change,
            interpolation=interpolation,
        )

        a, f_a, g_a = strong_wolfe.search()
        if inverted and a is not None: a = -a
        if f_a is not None and (f_a > f_0 or not math.isfinite(f_a)): a = None

        if fallback:
            if a is None or a==0 or not math.isfinite(a):
                lowest = min(strong_wolfe.history.items(), key=lambda x: x[1][0])
                if lowest[1][0] < f_0:
                    a = lowest[0]
                    f_a, g_a = lowest[1]
                    if inverted: a = -a

        if a is not None and a != 0 and math.isfinite(a):
            self.global_state['initial_scale'] = 1
            self.global_state['a_prev'] = a
            self.global_state['f_prev'] = f_0
            self.global_state['g_prev'] = g_0
            return a

        # fail
        if adaptive:
            self.global_state['initial_scale'] = self.global_state.get('initial_scale', 1) * 0.5
            finfo = torch.finfo(dir[0].dtype)
            if self.global_state['initial_scale'] < finfo.tiny * 2:
                self.global_state['initial_scale'] = finfo.max / 2

        return 0

Sub

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Subtract :code:other from tensors. :code:other can be a number or a module.

If :code:other is a module, this calculates :code:tensors - other(tensors)

Source code in torchzero/modules/ops/binary.py
class Sub(BinaryOperationBase):
    """Subtract :code:`other` from tensors. :code:`other` can be a number or a module.

    If :code:`other` is a module, this calculates :code:`tensors - other(tensors)`
    """
    def __init__(self, other: Chainable | float, alpha: float = 1):
        defaults = dict(alpha=alpha)
        super().__init__(defaults, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.defaults['alpha'])
        else: torch._foreach_sub_(update, other, alpha=self.defaults['alpha'])
        return update

SubModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates :code:input - other. :code:input and :code:other can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class SubModules(MultiOperationBase):
    """Calculates :code:`input - other`. :code:`input` and :code:`other` can be numbers or modules."""
    def __init__(self, input: Chainable | float, other: Chainable | float, alpha: float = 1):
        defaults = dict(alpha=alpha)
        super().__init__(defaults, input=input, other=other)

    @torch.no_grad
    def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
        alpha = self.defaults['alpha']

        if isinstance(input, (int,float)):
            assert isinstance(other, list)
            return input - TensorList(other).mul_(alpha)

        if isinstance(other, (int, float)): torch._foreach_sub_(input, other * alpha)
        else: torch._foreach_sub_(input, other, alpha=alpha)
        return input

Sum

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs sum of :code:inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class Sum(ReduceOperationBase):
    """Outputs sum of :code:`inputs` that can be modules or numbers."""
    USE_MEAN = False
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        sum = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_add_(sum, v)

        if self.USE_MEAN and len(sorted_inputs) > 1: torch._foreach_div_(sum, len(sorted_inputs))
        return sum

USE_MEAN class-attribute

USE_MEAN = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

SumOfSquares

Bases: torchzero.core.module.Module

Sets loss to be the sum of squares of values returned by the closure.

This is meant to be used to test least squares methods against ordinary minimization methods.

To use this, the closure should return a vector of values to minimize sum of squares of. Please add the backward argument, it will always be False but it is required.

Source code in torchzero/modules/least_squares/gn.py
class SumOfSquares(Module):
    """Sets loss to be the sum of squares of values returned by the closure.

    This is meant to be used to test least squares methods against ordinary minimization methods.

    To use this, the closure should return a vector of values to minimize sum of squares of.
    Please add the `backward` argument, it will always be False but it is required.
    """
    def __init__(self):
        super().__init__()

    @torch.no_grad
    def step(self, var):
        closure = var.closure

        if closure is not None:
            def sos_closure(backward=True):
                if backward:
                    var.zero_grad()
                    with torch.enable_grad():
                        loss = closure(False)
                        loss = loss.pow(2).sum()
                        loss.backward()
                    return loss

                loss = closure(False)
                return loss.pow(2).sum()

            var.closure = sos_closure

        if var.loss is not None:
            var.loss = var.loss.pow(2).sum()

        if var.loss_approx is not None:
            var.loss_approx = var.loss_approx.pow(2).sum()

        return var

Switch

Bases: torchzero.modules.misc.switch.Alternate

After :code:steps steps switches to the next module.

Parameters:

  • steps (int | Iterable[int]) –

    Number of steps to perform with each module.

Examples:

Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.Switch(
        [tz.m.Adam(), tz.m.LR(1e-3)],
        [tz.m.LBFGS(), tz.m.Backtracking()],
        [tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
        steps = (1000, 2000)
    )
)
Source code in torchzero/modules/misc/switch.py
class Switch(Alternate):
    """After :code:`steps` steps switches to the next module.

    Args:
        steps (int | Iterable[int]): Number of steps to perform with each module.

    Examples:
        Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.Switch(
                    [tz.m.Adam(), tz.m.LR(1e-3)],
                    [tz.m.LBFGS(), tz.m.Backtracking()],
                    [tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
                    steps = (1000, 2000)
                )
            )
    """

    LOOP = False
    def __init__(self, *modules: Chainable, steps: int | Iterable[int]):

        if isinstance(steps, Iterable):
            steps = list(steps)
            if len(steps) != len(modules) - 1:
                raise ValueError(f"steps must be the same length as modules minus 1, got {len(modules) = }, {len(steps) = }")

            steps.append(1)

        super().__init__(*modules, steps=steps)

LOOP class-attribute

LOOP = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

TerminateAfterNEvaluations

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateAfterNEvaluations(TerminationCriteriaBase):
    def __init__(self, maxevals:int):
        defaults = dict(maxevals=maxevals)
        super().__init__(defaults)

    def termination_criteria(self, var):
        maxevals = self.defaults['maxevals']
        return var.modular.num_evaluations >= maxevals

TerminateAfterNSeconds

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateAfterNSeconds(TerminationCriteriaBase):
    def __init__(self, seconds:float, sec_fn = time.time):
        defaults = dict(seconds=seconds, sec_fn=sec_fn)
        super().__init__(defaults)

    def termination_criteria(self, var):
        max_seconds = self.defaults['seconds']
        sec_fn = self.defaults['sec_fn']

        if 'start' not in self.global_state:
            self.global_state['start'] = sec_fn()
            return False

        seconds_passed = sec_fn() - self.global_state['start']
        return seconds_passed >= max_seconds

TerminateAfterNSteps

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateAfterNSteps(TerminationCriteriaBase):
    def __init__(self, steps:int):
        defaults = dict(steps=steps)
        super().__init__(defaults)

    def termination_criteria(self, var):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        max_steps = self.defaults['steps']
        return step >= max_steps

TerminateAll

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateAll(TerminationCriteriaBase):
    def __init__(self, *criteria: TerminationCriteriaBase):
        super().__init__()

        self.set_children_sequence(criteria)

    def termination_criteria(self, var: Var) -> bool:
        for c in self.get_children_sequence():
            if not cast(TerminationCriteriaBase, c).termination_criteria(var): return False

        return True

TerminateAny

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateAny(TerminationCriteriaBase):
    def __init__(self, *criteria: TerminationCriteriaBase):
        super().__init__()

        self.set_children_sequence(criteria)

    def termination_criteria(self, var: Var) -> bool:
        for c in self.get_children_sequence():
            if cast(TerminationCriteriaBase, c).termination_criteria(var): return True

        return False

TerminateByGradientNorm

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateByGradientNorm(TerminationCriteriaBase):
    def __init__(self, tol:float = 1e-8, n: int = 3, ord: Metrics = 2):
        defaults = dict(tol=tol, ord=ord)
        super().__init__(defaults, n=n)

    def termination_criteria(self, var):
        tol = self.defaults['tol']
        ord = self.defaults['ord']
        return TensorList(var.get_grad()).global_metric(ord) <= tol

TerminateByUpdateNorm

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

update is calculated as parameter difference

Source code in torchzero/modules/termination/termination.py
class TerminateByUpdateNorm(TerminationCriteriaBase):
    """update is calculated as parameter difference"""
    def __init__(self, tol:float = 1e-8, n: int = 3, ord: Metrics = 2):
        defaults = dict(tol=tol, ord=ord)
        super().__init__(defaults, n=n)

    def termination_criteria(self, var):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        tol = self.defaults['tol']
        ord = self.defaults['ord']

        p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
        if step == 0:
            p_prev.copy_(var.params)
            return False

        should_terminate = (p_prev - var.params).global_metric(ord) <= tol
        p_prev.copy_(var.params)
        return should_terminate

TerminateNever

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateNever(TerminationCriteriaBase):
    def __init__(self):
        super().__init__()

    def termination_criteria(self, var): return False

TerminateOnLossReached

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateOnLossReached(TerminationCriteriaBase):
    def __init__(self, value: float):
        defaults = dict(value=value)
        super().__init__(defaults)

    def termination_criteria(self, var):
        value = self.defaults['value']
        return var.get_loss(False) <= value

TerminateOnNoImprovement

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateOnNoImprovement(TerminationCriteriaBase):
    def __init__(self, tol:float = 1e-8, n: int = 10):
        defaults = dict(tol=tol)
        super().__init__(defaults, n=n)

    def termination_criteria(self, var):
        tol = self.defaults['tol']

        f = tofloat(var.get_loss(False))
        if 'f_min' not in self.global_state:
            self.global_state['f_min'] = f
            return False

        f_min = self.global_state['f_min']
        d = f_min - f
        should_terminate = d <= tol
        self.global_state['f_min'] = min(f, f_min)
        return should_terminate

TerminationCriteriaBase

Bases: torchzero.core.module.Module

Source code in torchzero/modules/termination/termination.py
class TerminationCriteriaBase(Module):
    def __init__(self, defaults:dict | None = None, n: int = 1):
        if defaults is None: defaults = {}
        safe_dict_update_(defaults, {"_n": n})
        super().__init__(defaults)

    @abstractmethod
    def termination_criteria(self, var: Var) -> bool:
        ...

    def should_terminate(self, var: Var) -> bool:
        n_bad = self.global_state.get('_n_bad', 0)
        n = self.defaults['_n']

        if self.termination_criteria(var):
            n_bad += 1
            if n_bad >= n:
                self.global_state['_n_bad'] = 0
                return True

        else:
            n_bad = 0

        self.global_state['_n_bad'] = n_bad
        return False


    def update(self, var):
        var.should_terminate = self.should_terminate(var)
        if var.should_terminate: self.global_state['_n_bad'] = 0

    def apply(self, var):
        return var

ThomasOptimalMethod

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Thomas's "optimal" Quasi-Newton method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975.

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class ThomasOptimalMethod(_InverseHessianUpdateStrategyDefaults):
    """
    Thomas's "optimal" Quasi-Newton method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975.
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
        H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y)
        return H

    def reset_P(self, P, s, y, inverse, init_scale, state):
        super().reset_P(P, s, y, inverse, init_scale, state)
        for st in self.state.values():
            st.pop("R", None)

Threshold

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs tensors thresholded such that values above :code:threshold are set to :code:value.

Source code in torchzero/modules/ops/binary.py
class Threshold(BinaryOperationBase):
    """Outputs tensors thresholded such that values above :code:`threshold` are set to :code:`value`."""
    def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
        defaults = dict(update_above=update_above)
        super().__init__(defaults, threshold=threshold, value=value)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
        update_above = self.defaults['update_above']
        update = TensorList(update)
        if update_above:
            if isinstance(value, list): return update.where_(update>threshold, value)
            return update.masked_fill_(update<=threshold, value)

        if isinstance(value, list): return update.where_(update<threshold, value)
        return update.masked_fill_(update>=threshold, value)

To

Bases: torchzero.modules.projections.projection.ProjectionBase

Cast modules to specified device and dtype

Source code in torchzero/modules/projections/cast.py
class To(ProjectionBase):
    """Cast modules to specified device and dtype"""
    def __init__(self, modules: Chainable, dtype: torch.dtype | None, device:torch.types.Device | None = None):
        defaults = dict(dtype=dtype, device=device)
        super().__init__(modules, project_update=True, project_params=True, project_grad=True, defaults=defaults)

    @torch.no_grad
    def project(self, tensors, params, grads, loss, states, settings, current):
        casted = []
        for tensor, state, setting in zip(tensors,states, settings):
            state['dtype'] = tensor.dtype
            state['device'] = tensor.device
            tensor = tensor.to(dtype=setting['dtype'], device=setting['device'])
            casted.append(tensor)
        return casted

    @torch.no_grad
    def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
        uncasted = []
        for tensor, state in zip(projected_tensors, states):
            tensor = tensor.to(dtype=state['dtype'], device=state['device'])
            uncasted.append(tensor)
        return uncasted

TrustCG

Bases: torchzero.modules.trust_region.trust_region.TrustRegionBase

Trust region via Steihaug-Toint Conjugate Gradient method.

.. note::

If you wish to use exact hessian, use the matrix-free :code:`tz.m.NewtonCGSteihaug`
which only uses hessian-vector products. While passing ``tz.m.Newton`` to this
is possible, it is usually less efficient.

Parameters:

  • hess_module (Module | None) –

    A module that maintains a hessian approximation (not hessian inverse!). This includes all full-matrix quasi-newton methods, tz.m.Newton and tz.m.GaussNewton. When using quasi-newton methods, set inverse=False when constructing them.

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. When :code:hess_module is GaussNewton, this can be set to 0. Defaults to 0.15.

  • nplus (float, default: 3.5 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.99 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.0001 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • reg (int, default: 0 ) –

    regularization parameter for conjugate gradient. Defaults to 0.

  • max_attempts (max_attempts, default: 10 ) –

    maximum number of trust region size size reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • boundary_tol (float | None, default: 1e-06 ) –

    The trust region only increases when suggested step's norm is at least (1-boundary_tol)*trust_region. This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.

  • prefer_exact (bool, default: True ) –

    when exact solution can be easily calculated without CG (e.g. hessian is stored as scaled identity), uses the exact solution. If False, always uses CG. Defaults to True.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to output of thise module. Defaults to None.

Examples:

Trust-SR1

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
)
Source code in torchzero/modules/trust_region/trust_cg.py
class TrustCG(TrustRegionBase):
    """Trust region via Steihaug-Toint Conjugate Gradient method.

    .. note::

        If you wish to use exact hessian, use the matrix-free :code:`tz.m.NewtonCGSteihaug`
        which only uses hessian-vector products. While passing ``tz.m.Newton`` to this
        is possible, it is usually less efficient.

    Args:
        hess_module (Module | None, optional):
            A module that maintains a hessian approximation (not hessian inverse!).
            This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
            When using quasi-newton methods, set `inverse=False` when constructing them.
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted.
            When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        reg (int, optional): regularization parameter for conjugate gradient. Defaults to 0.
        max_attempts (max_attempts, optional):
            maximum number of trust region size size reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        boundary_tol (float | None, optional):
            The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
            This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
        prefer_exact (bool, optional):
            when exact solution can be easily calculated without CG (e.g. hessian is stored as scaled identity),
            uses the exact solution. If False, always uses CG. Defaults to True.
        inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.

    Examples:
        Trust-SR1

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
            )
    """
    def __init__(
        self,
        hess_module: Chainable,
        eta: float= 0.0,
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        boundary_tol: float | None = 1e-6, # tuned
        init: float = 1,
        max_attempts: int = 10,
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
        reg: float = 0,
        maxiter: int | None = None,
        miniter: int = 1,
        cg_tol: float = 1e-8,
        prefer_exact: bool = True,
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol, maxiter=maxiter, miniter=miniter)
        super().__init__(
            defaults=defaults,
            hess_module=hess_module,
            eta=eta,
            nplus=nplus,
            nminus=nminus,
            rho_good=rho_good,
            rho_bad=rho_bad,
            boundary_tol=boundary_tol,
            init=init,
            max_attempts=max_attempts,
            radius_strategy=radius_strategy,
            update_freq=update_freq,
            inner=inner,

            radius_fn=torch.linalg.vector_norm,
        )

    def trust_solve(self, f, g, H, radius, params, closure, settings):
        if settings['prefer_exact'] and isinstance(H, linear_operator.ScaledIdentity):
            return H.solve_bounded(g, radius)

        x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], maxiter=settings["maxiter"], miniter=settings["miniter"], tol=settings["cg_tol"])
        return x

TrustRegionBase

Bases: torchzero.core.module.Module, abc.ABC

Methods:

  • trust_region_apply

    Solves the trust region subproblem and outputs Var with the solution direction.

  • trust_region_update

    updates the state of this module after H or B have been updated, if necessary

  • trust_solve

    Solve Hx=g with a trust region penalty/bound defined by radius

Source code in torchzero/modules/trust_region/trust_region.py
class TrustRegionBase(Module, ABC):
    def __init__(
        self,
        defaults: dict | None,
        hess_module: Chainable,
        # suggested default values:
        # Gould, Nicholas IM, et al. "Sensitivity of trust-region algorithms to their parameters." 4OR 3.3 (2005): 227-241.
        # which I found from https://github.com/patrick-kidger/optimistix/blob/c1dad7e75fc35bd5a4977ac3a872991e51e83d2c/optimistix/_solver/trust_region.py#L113-200
        eta: float, # 0.0
        nplus: float, # 3.5
        nminus: float, # 0.25
        rho_good: float, # 0.99
        rho_bad: float, # 1e-4
        boundary_tol: float | None, # None or 1e-1
        init: float, # 1
        max_attempts: int, # 10
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS, # "default"
        radius_fn: Callable | None, # torch.linalg.vector_norm
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        if isinstance(radius_strategy, str): radius_strategy = _RADIUS_STRATEGIES[radius_strategy]
        if defaults is None: defaults = {}

        safe_dict_update_(
            defaults,
            dict(eta=eta, nplus=nplus, nminus=nminus, rho_good=rho_good, rho_bad=rho_bad, init=init,
                 update_freq=update_freq, max_attempts=max_attempts, radius_strategy=radius_strategy,
                 boundary_tol=boundary_tol)
        )

        super().__init__(defaults)

        self._radius_fn = radius_fn
        self.set_child('hess_module', hess_module)

        if inner is not None:
            self.set_child('inner', inner)

    @abstractmethod
    def trust_solve(
        self,
        f: float,
        g: torch.Tensor,
        H: LinearOperator,
        radius: float,
        params: list[torch.Tensor],
        closure: Callable,
        settings: Mapping[str, Any],
    ) -> torch.Tensor:
        """Solve Hx=g with a trust region penalty/bound defined by `radius`"""
        ... # pylint:disable=unnecessary-ellipsis

    def trust_region_update(self, var: Var, H: LinearOperator | None) -> None:
        """updates the state of this module after H or B have been updated, if necessary"""

    def trust_region_apply(self, var: Var, tensors:list[torch.Tensor], H: LinearOperator | None) -> Var:
        """Solves the trust region subproblem and outputs ``Var`` with the solution direction."""
        assert H is not None

        params = TensorList(var.params)
        settings = self.settings[params[0]]
        g = _flatten_tensors(tensors)

        max_attempts = settings['max_attempts']

        # loss at x_0
        loss = var.loss
        closure = var.closure
        if closure is None: raise RuntimeError("Trust region requires closure")
        if loss is None: loss = var.get_loss(False)
        loss = tofloat(loss)

        # trust region step and update
        success = False
        d = None
        while not success:
            max_attempts -= 1
            if max_attempts < 0: break

            trust_radius = self.global_state.get('trust_radius', settings['init'])

            # solve Hx=g
            d = self.trust_solve(f=loss, g=g, H=H, radius=trust_radius, params=params, closure=closure, settings=settings)

            # update trust radius
            radius_strategy: _RadiusStrategy = settings['radius_strategy']
            self.global_state["trust_radius"], success = radius_strategy(
                params=params,
                closure=closure,
                d=d,
                f=loss,
                g=g,
                H=H,
                trust_radius=trust_radius,

                eta=settings["eta"],
                nplus=settings["nplus"],
                nminus=settings["nminus"],
                rho_good=settings["rho_good"],
                rho_bad=settings["rho_bad"],
                boundary_tol=settings["boundary_tol"],
                init=settings["init"],

                state=self.global_state,
                settings=settings,
                radius_fn=self._radius_fn,
            )

        assert d is not None
        if success: var.update = vec_to_tensors(d, params)
        else: var.update = params.zeros_like()

        return var


    @final
    @torch.no_grad
    def update(self, var):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        if step % self.defaults["update_freq"] == 0:

            hessian_module = self.children['hess_module']
            hessian_module.update(var)
            H = hessian_module.get_H(var)
            self.global_state["H"] = H

            self.trust_region_update(var, H=H)


    @final
    @torch.no_grad
    def apply(self, var):
        H = self.global_state.get('H', None)

        # -------------------------------- inner step -------------------------------- #
        update = var.get_update()
        if 'inner' in self.children:
            update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)

        # ----------------------------------- apply ---------------------------------- #
        return self.trust_region_apply(var=var, tensors=update, H=H)

trust_region_apply

trust_region_apply(var: Var, tensors: list[Tensor], H: LinearOperator | None) -> Var

Solves the trust region subproblem and outputs Var with the solution direction.

Source code in torchzero/modules/trust_region/trust_region.py
def trust_region_apply(self, var: Var, tensors:list[torch.Tensor], H: LinearOperator | None) -> Var:
    """Solves the trust region subproblem and outputs ``Var`` with the solution direction."""
    assert H is not None

    params = TensorList(var.params)
    settings = self.settings[params[0]]
    g = _flatten_tensors(tensors)

    max_attempts = settings['max_attempts']

    # loss at x_0
    loss = var.loss
    closure = var.closure
    if closure is None: raise RuntimeError("Trust region requires closure")
    if loss is None: loss = var.get_loss(False)
    loss = tofloat(loss)

    # trust region step and update
    success = False
    d = None
    while not success:
        max_attempts -= 1
        if max_attempts < 0: break

        trust_radius = self.global_state.get('trust_radius', settings['init'])

        # solve Hx=g
        d = self.trust_solve(f=loss, g=g, H=H, radius=trust_radius, params=params, closure=closure, settings=settings)

        # update trust radius
        radius_strategy: _RadiusStrategy = settings['radius_strategy']
        self.global_state["trust_radius"], success = radius_strategy(
            params=params,
            closure=closure,
            d=d,
            f=loss,
            g=g,
            H=H,
            trust_radius=trust_radius,

            eta=settings["eta"],
            nplus=settings["nplus"],
            nminus=settings["nminus"],
            rho_good=settings["rho_good"],
            rho_bad=settings["rho_bad"],
            boundary_tol=settings["boundary_tol"],
            init=settings["init"],

            state=self.global_state,
            settings=settings,
            radius_fn=self._radius_fn,
        )

    assert d is not None
    if success: var.update = vec_to_tensors(d, params)
    else: var.update = params.zeros_like()

    return var

trust_region_update

trust_region_update(var: Var, H: LinearOperator | None) -> None

updates the state of this module after H or B have been updated, if necessary

Source code in torchzero/modules/trust_region/trust_region.py
def trust_region_update(self, var: Var, H: LinearOperator | None) -> None:
    """updates the state of this module after H or B have been updated, if necessary"""

trust_solve

trust_solve(f: float, g: Tensor, H: LinearOperator, radius: float, params: list[Tensor], closure: Callable, settings: Mapping[str, Any]) -> Tensor

Solve Hx=g with a trust region penalty/bound defined by radius

Source code in torchzero/modules/trust_region/trust_region.py
@abstractmethod
def trust_solve(
    self,
    f: float,
    g: torch.Tensor,
    H: LinearOperator,
    radius: float,
    params: list[torch.Tensor],
    closure: Callable,
    settings: Mapping[str, Any],
) -> torch.Tensor:
    """Solve Hx=g with a trust region penalty/bound defined by `radius`"""
    ... # pylint:disable=unnecessary-ellipsis

TwoPointNewton

Bases: torchzero.modules.second_order.multipoint.HigherOrderMethodBase

two-point Newton method with frozen derivative with third order convergence.

Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73.

Source code in torchzero/modules/second_order/multipoint.py
class TwoPointNewton(HigherOrderMethodBase):
    """two-point Newton method with frozen derivative with third order convergence.

    Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73."""
    def __init__(self, lstsq: bool=False, vectorize: bool = True):
        defaults=dict(lstsq=lstsq)
        super().__init__(defaults=defaults, vectorize=vectorize)

    def one_iteration(self, x, evaluate, var):
        settings = self.defaults
        lstsq = settings['lstsq']
        def f(x): return evaluate(x, 1)[1]
        def f_j(x): return evaluate(x, 2)[1:]
        x_star = two_point_newton(x, f, f_j, lstsq)
        return x - x_star

UnaryLambda

Bases: torchzero.core.transform.Transform

Applies :code:fn to input tensors.

:code:fn must accept and return a list of tensors.

Source code in torchzero/modules/ops/unary.py
class UnaryLambda(Transform):
    """Applies :code:`fn` to input tensors.

    :code:`fn` must accept and return a list of tensors.
    """
    def __init__(self, fn, target: "Target" = 'update'):
        defaults = dict(fn=fn)
        super().__init__(defaults=defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        return settings[0]['fn'](tensors)

UnaryParameterwiseLambda

Bases: torchzero.core.transform.TensorwiseTransform

Applies :code:fn to each input tensor.

:code:fn must accept and return a tensor.

Source code in torchzero/modules/ops/unary.py
class UnaryParameterwiseLambda(TensorwiseTransform):
    """Applies :code:`fn` to each input tensor.

    :code:`fn` must accept and return a tensor.
    """
    def __init__(self, fn, target: "Target" = 'update'):
        defaults = dict(fn=fn)
        super().__init__(uses_grad=False, defaults=defaults, target=target)

    @torch.no_grad
    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        return setting['fn'](tensor)

Uniform

Bases: torchzero.core.module.Module

Outputs tensors filled with random numbers from uniform distribution between :code:low and :code:high.

Source code in torchzero/modules/ops/utility.py
class Uniform(Module):
    """Outputs tensors filled with random numbers from uniform distribution between :code:`low` and :code:`high`."""
    def __init__(self, low: float, high: float):
        defaults = dict(low=low, high=high)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        low,high = self.get_settings(var.params, 'low','high')
        var.update = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(var.params, low, high)]
        return var

UpdateGradientSignConsistency

Bases: torchzero.core.transform.Transform

Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.

Parameters:

  • normalize (bool, default: False ) –

    renormalize update after masking. Defaults to False.

  • eps (float, default: 1e-06 ) –

    epsilon for normalization. Defaults to 1e-6.

Source code in torchzero/modules/momentum/cautious.py
class UpdateGradientSignConsistency(Transform):
    """Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.

    Args:
        normalize (bool, optional):
            renormalize update after masking. Defaults to False.
        eps (float, optional): epsilon for normalization. Defaults to 1e-6.
    """
    def __init__(self, normalize = False, eps=1e-6):

        defaults = dict(normalize=normalize, eps=eps)
        super().__init__(defaults, uses_grad=True)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        normalize, eps = itemgetter('normalize', 'eps')(settings[0])

        mask = (TensorList(tensors).mul_(grads)).gt_(0)
        if normalize: mask = mask / mask.global_mean().clip(min = eps) # pyright: ignore[reportOperatorIssue]

        return mask

UpdateSign

Bases: torchzero.core.transform.Transform

Outputs gradient with sign copied from the update.

Source code in torchzero/modules/misc/misc.py
class UpdateSign(Transform):
    """Outputs gradient with sign copied from the update."""
    def __init__(self, target: Target = 'update'):
        super().__init__({}, uses_grad=True, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place

UpdateToNone

Bases: torchzero.core.module.Module

Sets :code:update attribute to None on :code:var.

Source code in torchzero/modules/ops/utility.py
class UpdateToNone(Module):
    """Sets :code:`update` attribute to None on :code:`var`."""
    def __init__(self): super().__init__()
    def step(self, var):
        var.update = None
        return var

VectorProjection

Bases: torchzero.modules.projections.projection.ProjectionBase

projection that concatenates all parameters into a vector

Source code in torchzero/modules/projections/projection.py
class VectorProjection(ProjectionBase):
    """projection that concatenates all parameters into a vector"""
    def __init__(
        self,
        modules: Chainable,
        project_update=True,
        project_params=True,
        project_grad=True,
    ):
        super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)

    @torch.no_grad
    def project(self, tensors, params, grads, loss, states, settings, current):
        return [torch.cat([t.ravel() for t in tensors])]

    @torch.no_grad
    def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
        return vec_to_tensors(vec=projected_tensors[0], reference=params)

ViewAsReal

Bases: torchzero.modules.projections.projection.ProjectionBase

View complex tensors as real tensors. Doesn't affect tensors that are already.

Source code in torchzero/modules/projections/cast.py
class ViewAsReal(ProjectionBase):
    """View complex tensors as real tensors. Doesn't affect tensors that are already."""
    def __init__(self, modules: Chainable):
        super().__init__(modules, project_update=True, project_params=True, project_grad=True, defaults=None)

    @torch.no_grad
    def project(self, tensors, params, grads, loss, states, settings, current):
        views = []
        for tensor, state in zip(tensors,states):
            is_complex = torch.is_complex(tensor)
            state['is_complex'] = is_complex
            if is_complex: tensor = torch.view_as_real(tensor)
            views.append(tensor)
        return views

    @torch.no_grad
    def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
        un_views = []
        for tensor, state in zip(projected_tensors, states):
            if state['is_complex']: tensor = torch.view_as_complex(tensor)
            un_views.append(tensor)
        return un_views

Warmup

Bases: torchzero.core.transform.Transform

Learning rate warmup, linearly increases learning rate multiplier from :code:start_lr to :code:end_lr over :code:steps steps.

Parameters:

  • steps (int, default: 100 ) –

    number of steps to perform warmup for. Defaults to 100.

  • start_lr (_type_, default: 1e-05 ) –

    initial learning rate multiplier on first step. Defaults to 1e-5.

  • end_lr (float, default: 1 ) –

    learning rate multiplier at the end and after warmup. Defaults to 1.

Example

Adam with 1000 steps warmup

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.Adam(),
    tz.m.LR(1e-2),
    tz.m.Warmup(steps=1000)
)
Source code in torchzero/modules/step_size/lr.py
class Warmup(Transform):
    """Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.

    Args:
        steps (int, optional): number of steps to perform warmup for. Defaults to 100.
        start_lr (_type_, optional): initial learning rate multiplier on first step. Defaults to 1e-5.
        end_lr (float, optional): learning rate multiplier at the end and after warmup. Defaults to 1.

    Example:
        Adam with 1000 steps warmup

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.Adam(),
                tz.m.LR(1e-2),
                tz.m.Warmup(steps=1000)
            )

    """
    def __init__(self, steps = 100, start_lr = 1e-5, end_lr:float = 1):
        defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
        num_steps = settings[0]['steps']
        step = self.global_state.get('step', 0)

        tensors = lazy_lr(
            TensorList(tensors),
            lr=_warmup_lr(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
            inplace=True
        )
        self.global_state['step'] = step + 1
        return tensors

WarmupNormClip

Bases: torchzero.core.transform.Transform

Warmup via clipping of the update norm.

Parameters:

  • start_norm (_type_, default: 1e-05 ) –

    maximal norm on the first step. Defaults to 1e-5.

  • end_norm (float, default: 1 ) –

    maximal norm on the last step. After that, norm clipping is disabled. Defaults to 1.

  • steps (int, default: 100 ) –

    number of steps to perform warmup for. Defaults to 100.

Example

Adam with 1000 steps norm clip warmup

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.Adam(),
    tz.m.WarmupNormClip(steps=1000)
    tz.m.LR(1e-2),
)
Source code in torchzero/modules/step_size/lr.py
class WarmupNormClip(Transform):
    """Warmup via clipping of the update norm.

    Args:
        start_norm (_type_, optional): maximal norm on the first step. Defaults to 1e-5.
        end_norm (float, optional): maximal norm on the last step. After that, norm clipping is disabled. Defaults to 1.
        steps (int, optional): number of steps to perform warmup for. Defaults to 100.

    Example:
        Adam with 1000 steps norm clip warmup

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.Adam(),
                tz.m.WarmupNormClip(steps=1000)
                tz.m.LR(1e-2),
            )
    """
    def __init__(self, steps = 100, start_norm = 1e-5, end_norm:float = 1):
        defaults = dict(start_norm=start_norm,end_norm=end_norm, steps=steps)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        start_norm, end_norm = unpack_dicts(settings, 'start_norm', 'end_norm', cls = NumberList)
        num_steps = settings[0]['steps']
        step = self.global_state.get('step', 0)
        if step > num_steps: return tensors

        tensors = TensorList(tensors)
        norm = tensors.global_vector_norm()
        current_max_norm = _warmup_lr(step, start_norm[0], end_norm[0], num_steps)
        if norm > current_max_norm:
            tensors.mul_(current_max_norm / norm)

        self.global_state['step'] = step + 1
        return tensors

WeightDecay

Bases: torchzero.core.transform.Transform

Weight decay.

Parameters:

  • weight_decay (float) –

    weight decay scale.

  • ord (int, default: 2 ) –

    order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.

  • target (Literal, default: 'update' ) –

    what to set on var. Defaults to 'update'.

Examples:

Adam with non-decoupled weight decay

opt = tz.Modular(
    model.parameters(),
    tz.m.WeightDecay(1e-3),
    tz.m.Adam(),
    tz.m.LR(1e-3)
)

Adam with decoupled weight decay that still scales with learning rate

opt = tz.Modular(
    model.parameters(),
    tz.m.Adam(),
    tz.m.WeightDecay(1e-3),
    tz.m.LR(1e-3)
)

Adam with fully decoupled weight decay that doesn't scale with learning rate

opt = tz.Modular(
    model.parameters(),
    tz.m.Adam(),
    tz.m.LR(1e-3),
    tz.m.WeightDecay(1e-6)
)

Source code in torchzero/modules/weight_decay/weight_decay.py
class WeightDecay(Transform):
    """Weight decay.

    Args:
        weight_decay (float): weight decay scale.
        ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
        target (Target, optional): what to set on var. Defaults to 'update'.

    ### Examples:

    Adam with non-decoupled weight decay
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.WeightDecay(1e-3),
        tz.m.Adam(),
        tz.m.LR(1e-3)
    )
    ```

    Adam with decoupled weight decay that still scales with learning rate
    ```python

    opt = tz.Modular(
        model.parameters(),
        tz.m.Adam(),
        tz.m.WeightDecay(1e-3),
        tz.m.LR(1e-3)
    )
    ```

    Adam with fully decoupled weight decay that doesn't scale with learning rate
    ```python
    opt = tz.Modular(
        model.parameters(),
        tz.m.Adam(),
        tz.m.LR(1e-3),
        tz.m.WeightDecay(1e-6)
    )
    ```

    """
    def __init__(self, weight_decay: float, ord: int = 2, target: Target = 'update'):

        defaults = dict(weight_decay=weight_decay, ord=ord)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        weight_decay = NumberList(s['weight_decay'] for s in settings)
        ord = settings[0]['ord']

        return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)

WeightDropout

Bases: torchzero.core.module.Module

Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.

Dropout can be disabled for a parameter by setting :code:use_dropout=False in corresponding parameter group.

Parameters:

  • p (float, default: 0.5 ) –

    probability that any weight is replaced with 0. Defaults to 0.5.

  • graft (bool, default: True ) –

    if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.

Source code in torchzero/modules/misc/regularization.py
class WeightDropout(Module):
    """
    Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.

    Dropout can be disabled for a parameter by setting :code:`use_dropout=False` in corresponding parameter group.

    Args:
        p (float, optional): probability that any weight is replaced with 0. Defaults to 0.5.
        graft (bool, optional):
            if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
    """
    def __init__(self, p: float = 0.5, graft: bool = True):
        defaults = dict(p=p, graft=graft, use_dropout=True)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        closure = var.closure
        if closure is None: raise RuntimeError('WeightDropout requires closure')
        params = TensorList(var.params)
        p = NumberList(self.settings[p]['p'] for p in params)

        # create masks
        mask = []
        for p, m in zip(params, mask):
            prob = self.settings[p]['p']
            use_dropout = self.settings[p]['use_dropout']
            if use_dropout: mask.append(_bernoulli_like(p, prob))
            else: mask.append(torch.ones_like(p))

        @torch.no_grad
        def dropout_closure(backward=True):
            orig_params = params.clone()
            params.mul_(mask)
            if backward:
                with torch.enable_grad(): loss = closure()
            else:
                loss = closure(False)
            params.copy_(orig_params)
            return loss

        var.closure = dropout_closure
        return var

WeightedAveraging

Bases: torchzero.core.transform.TensorwiseTransform

Weighted average of past len(weights) updates.

Parameters:

  • weights (Sequence[float]) –

    a sequence of weights from oldest to newest.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/momentum/averaging.py
class WeightedAveraging(TensorwiseTransform):
    """Weighted average of past ``len(weights)`` updates.

    Args:
        weights (Sequence[float]): a sequence of weights from oldest to newest.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, weights: Sequence[float] | torch.Tensor | Any, target: Target = 'update'):
        defaults = dict(weights = tolist(weights))
        super().__init__(uses_grad=False, defaults=defaults, target=target)

    @torch.no_grad
    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        weights = setting['weights']

        if 'history' not in state:
            state['history'] = deque(maxlen=len(weights))

        history = state['history']
        history.append(tensor)
        if len(history) != len(weights):
            weights = weights[-len(history):]

        average = None
        for i, (h, w) in enumerate(zip(history, weights)):
            if average is None: average = h * (w / len(history))
            else:
                if w == 0: continue
                average += h * (w / len(history))

        assert average is not None
        return average

WeightedMean

Bases: torchzero.modules.ops.reduce.WeightedSum

Outputs weighted mean of :code:inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class WeightedMean(WeightedSum):
    """Outputs weighted mean of :code:`inputs` that can be modules or numbers."""
    USE_MEAN = True

USE_MEAN class-attribute

USE_MEAN = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

WeightedSum

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Source code in torchzero/modules/ops/reduce.py
class WeightedSum(ReduceOperationBase):
    USE_MEAN = False
    def __init__(self, *inputs: Chainable | float, weights: Iterable[float]):
        """Outputs a weighted sum of :code:`inputs` that can be modules or numbers."""
        weights = list(weights)
        if len(inputs) != len(weights):
            raise ValueError(f'Number of inputs {len(inputs)} must match number of weights {len(weights)}')
        defaults = dict(weights=weights)
        super().__init__(defaults=defaults, *inputs)

    @torch.no_grad
    def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        weights = self.defaults['weights']
        sum = cast(list, sorted_inputs[0])
        torch._foreach_mul_(sum, weights[0])
        if len(sorted_inputs) > 1:
            for v, w in zip(sorted_inputs[1:], weights[1:]):
                if isinstance(v, (int, float)): torch._foreach_add_(sum, v*w)
                else: torch._foreach_add_(sum, v, alpha=w)

        if self.USE_MEAN and len(sorted_inputs) > 1: torch._foreach_div_(sum, len(sorted_inputs))
        return sum

USE_MEAN class-attribute

USE_MEAN = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

Wrap

Bases: torchzero.core.module.Module

Wraps a pytorch optimizer to use it as a module.

.. note:: Custom param groups are supported only by set_param_groups, settings passed to Modular will be ignored.

Parameters:

  • opt_fn (Callable[..., Optimizer] | Optimizer) –

    function that takes in parameters and returns the optimizer, for example :code:torch.optim.Adam or :code:lambda parameters: torch.optim.Adam(parameters, lr=1e-3)

  • *args
  • **kwargs

    Extra args to be passed to opt_fn. The function is called as :code:opt_fn(parameters, *args, **kwargs).

Example

wrapping pytorch_optimizer.StableAdamW

.. code-block:: py

from pytorch_optimizer import StableAdamW
opt = tz.Modular(
    model.parameters(),
    tz.m.Wrap(StableAdamW, lr=1),
    tz.m.Cautious(),
    tz.m.LR(1e-2)
)
Source code in torchzero/modules/wrappers/optim_wrapper.py
class Wrap(Module):
    """
    Wraps a pytorch optimizer to use it as a module.

    .. note::
        Custom param groups are supported only by `set_param_groups`, settings passed to Modular will be ignored.

    Args:
        opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
            function that takes in parameters and returns the optimizer, for example :code:`torch.optim.Adam`
            or :code:`lambda parameters: torch.optim.Adam(parameters, lr=1e-3)`
        *args:
        **kwargs:
            Extra args to be passed to opt_fn. The function is called as :code:`opt_fn(parameters, *args, **kwargs)`.

    Example:
        wrapping pytorch_optimizer.StableAdamW

        .. code-block:: py

            from pytorch_optimizer import StableAdamW
            opt = tz.Modular(
                model.parameters(),
                tz.m.Wrap(StableAdamW, lr=1),
                tz.m.Cautious(),
                tz.m.LR(1e-2)
            )


    """
    def __init__(self, opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer, *args, **kwargs):
        super().__init__()
        self._opt_fn = opt_fn
        self._opt_args = args
        self._opt_kwargs = kwargs
        self._custom_param_groups = None

        self.optimizer: torch.optim.Optimizer | None = None
        if isinstance(self._opt_fn, torch.optim.Optimizer) or not callable(self._opt_fn):
            self.optimizer = self._opt_fn

    def set_param_groups(self, param_groups):
        self._custom_param_groups = param_groups
        return super().set_param_groups(param_groups)

    @torch.no_grad
    def step(self, var):
        params = var.params

        # initialize opt on 1st step
        if self.optimizer is None:
            assert callable(self._opt_fn)
            param_groups = params if self._custom_param_groups is None else self._custom_param_groups
            self.optimizer = self._opt_fn(param_groups, *self._opt_args, **self._opt_kwargs)

        # set grad to update
        orig_grad = [p.grad for p in params]
        for p, u in zip(params, var.get_update()):
            p.grad = u

        # if this module is last, can step with _opt directly
        # direct step can't be applied if next module is LR but _opt doesn't support lr,
        # and if there are multiple different per-parameter lrs (would be annoying to support)
        if var.is_last and (
            (var.last_module_lrs is None)
            or
            (('lr' in self.optimizer.defaults) and (len(set(var.last_module_lrs)) == 1))
        ):
            lr = 1 if var.last_module_lrs is None else var.last_module_lrs[0]

            # update optimizer lr with desired lr
            if lr != 1:
                self.optimizer.defaults['__original_lr__'] = self.optimizer.defaults['lr']
                for g in self.optimizer.param_groups:
                    g['__original_lr__'] = g['lr']
                    g['lr'] = g['lr'] * lr

            # step
            self.optimizer.step()

            # restore original lr
            if lr != 1:
                self.optimizer.defaults['lr'] = self.optimizer.defaults.pop('__original_lr__')
                for g in self.optimizer.param_groups:
                    g['lr'] = g.pop('__original_lr__')

            # restore grad
            for p, g in zip(params, orig_grad):
                p.grad = g

            var.stop = True; var.skip_update = True
            return var

        # this is not the last module, meaning update is difference in parameters
        params_before_step = [p.clone() for p in params]
        self.optimizer.step() # step and update params
        for p, g in zip(params, orig_grad):
            p.grad = g
        var.update = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
        for p, o in zip(params, params_before_step):
            p.set_(o) # pyright: ignore[reportArgumentType]

        return var

    def reset(self):
        super().reset()
        assert self.optimizer is not None
        for g in self.optimizer.param_groups:
            for p in g['params']:
                state = self.optimizer.state[p]
                state.clear()

Zeros

Bases: torchzero.core.module.Module

Outputs zeros

Source code in torchzero/modules/ops/utility.py
class Zeros(Module):
    """Outputs zeros"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def step(self, var):
        var.update = [torch.zeros_like(p) for p in var.params]
        return var

clip_grad_norm_

clip_grad_norm_(params: Iterable[Tensor], max_norm: float | None, ord: Union[Literal['mad', 'std', 'var', 'sum', 'l0', 'l1', 'l2', 'l3', 'l4', 'linf'], float, Tensor] = 2, dim: Union[int, Sequence[int], Literal['global'], NoneType] = None, inverse_dims: bool = False, min_size: int = 2, min_norm: float | None = None)

Clips gradient of an iterable of parameters to specified norm value. Gradients are modified in-place.

Parameters:

  • params (Iterable[Tensor]) –

    parameters with gradients to clip.

  • max_norm (float) –

    value to clip norm to.

  • ord (float, default: 2 ) –

    norm order. Defaults to 2.

  • dim (int | Sequence[int] | str | None, default: None ) –

    calculates norm along those dimensions. If list/tuple, tensors are normalized along all dimensios in dim that they have. Can be set to "global" to normalize by global norm of all gradients concatenated to a vector. Defaults to None.

  • min_size (int, default: 2 ) –

    minimal size of a dimension to normalize along it. Defaults to 1.

Source code in torchzero/modules/clipping/clipping.py
def clip_grad_norm_(
    params: Iterable[torch.Tensor],
    max_norm: float | None,
    ord: Metrics = 2,
    dim: int | Sequence[int] | Literal["global"] | None = None,
    inverse_dims: bool = False,
    min_size: int = 2,
    min_norm: float | None = None,
):
    """Clips gradient of an iterable of parameters to specified norm value.
    Gradients are modified in-place.

    Args:
        params (Iterable[torch.Tensor]): parameters with gradients to clip.
        max_norm (float): value to clip norm to.
        ord (float, optional): norm order. Defaults to 2.
        dim (int | Sequence[int] | str | None, optional):
            calculates norm along those dimensions.
            If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
            Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
            Defaults to None.
        min_size (int, optional):
            minimal size of a dimension to normalize along it. Defaults to 1.
    """
    grads = TensorList(p.grad for p in params if p.grad is not None)
    _clip_norm_(grads, min=min_norm, max=max_norm, norm_value=None, ord=ord, dim=dim, inverse_dims=inverse_dims, min_size=min_size)

clip_grad_value_

clip_grad_value_(params: Iterable[Tensor], value: float)

Clips gradient of an iterable of parameters at specified value. Gradients are modified in-place. Args: params (Iterable[Tensor]): iterable of tensors with gradients to clip. value (float or int): maximum allowed value of gradient

Source code in torchzero/modules/clipping/clipping.py
def clip_grad_value_(params: Iterable[torch.Tensor], value: float):
    """Clips gradient of an iterable of parameters at specified value.
    Gradients are modified in-place.
    Args:
        params (Iterable[Tensor]): iterable of tensors with gradients to clip.
        value (float or int): maximum allowed value of gradient
    """
    grads = [p.grad for p in params if p.grad is not None]
    torch._foreach_clamp_min_(grads, -value)
    torch._foreach_clamp_max_(grads, value)

decay_weights_

decay_weights_(params: Iterable[Tensor], weight_decay: float | NumberList, ord: int = 2)

directly decays weights in-place

Source code in torchzero/modules/weight_decay/weight_decay.py
@torch.no_grad
def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberList, ord:int=2):
    """directly decays weights in-place"""
    params = TensorList(params)
    weight_decay_(params, params, -weight_decay, ord)

normalize_grads_

normalize_grads_(params: Iterable[Tensor], norm_value: float, ord: Union[Literal['mad', 'std', 'var', 'sum', 'l0', 'l1', 'l2', 'l3', 'l4', 'linf'], float, Tensor] = 2, dim: Union[int, Sequence[int], Literal['global'], NoneType] = None, inverse_dims: bool = False, min_size: int = 1)

Normalizes gradient of an iterable of parameters to specified norm value. Gradients are modified in-place.

Parameters:

  • params (Iterable[Tensor]) –

    parameters with gradients to clip.

  • norm_value (float) –

    value to clip norm to.

  • ord (float, default: 2 ) –

    norm order. Defaults to 2.

  • dim (int | Sequence[int] | str | None, default: None ) –

    calculates norm along those dimensions. If list/tuple, tensors are normalized along all dimensios in dim that they have. Can be set to "global" to normalize by global norm of all gradients concatenated to a vector. Defaults to None.

  • inverse_dims (bool, default: False ) –

    if True, the dims argument is inverted, and all other dimensions are normalized.

  • min_size (int, default: 1 ) –

    minimal size of a dimension to normalize along it. Defaults to 1.

Source code in torchzero/modules/clipping/clipping.py
def normalize_grads_(
    params: Iterable[torch.Tensor],
    norm_value: float,
    ord: Metrics = 2,
    dim: int | Sequence[int] | Literal["global"] | None = None,
    inverse_dims: bool = False,
    min_size: int = 1,
):
    """Normalizes gradient of an iterable of parameters to specified norm value.
    Gradients are modified in-place.

    Args:
        params (Iterable[torch.Tensor]): parameters with gradients to clip.
        norm_value (float): value to clip norm to.
        ord (float, optional): norm order. Defaults to 2.
        dim (int | Sequence[int] | str | None, optional):
            calculates norm along those dimensions.
            If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
            Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
            Defaults to None.
        inverse_dims (bool, optional):
            if True, the `dims` argument is inverted, and all other dimensions are normalized.
        min_size (int, optional):
            minimal size of a dimension to normalize along it. Defaults to 1.
    """
    grads = TensorList(p.grad for p in params if p.grad is not None)
    _clip_norm_(grads, min=None, max=None, norm_value=norm_value, ord=ord, dim=dim, inverse_dims=inverse_dims, min_size=min_size)

orthogonalize_grads_

orthogonalize_grads_(params: Iterable[Tensor], steps: int = 5, dual_norm_correction=False, method: Literal['newton-schulz', 'svd'] = 'newton-schulz')

Uses newton-Schulz iteration to compute the zeroth power / orthogonalization of gradients of an iterable of parameters.

This sets gradients in-place. Applies along first 2 dims (expected to be out_channels, in_channels).

Note that the Muon page says that embeddings and classifier heads should not be orthogonalized. Args: params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize. steps (int, optional): The number of Newton-Schulz iterations to run. Defaults to 5. dual_norm_correction (bool, optional): enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False. method (str, optional): Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.

Source code in torchzero/modules/adaptive/muon.py
def orthogonalize_grads_(
    params: Iterable[torch.Tensor],
    steps: int = 5,
    dual_norm_correction=False,
    method: Literal["newton-schulz", "svd"] = "newton-schulz",
):
    """Uses newton-Schulz iteration to compute the zeroth power / orthogonalization of gradients of an iterable of parameters.

    This sets gradients in-place. Applies along first 2 dims (expected to be `out_channels, in_channels`).

    Note that the Muon page says that embeddings and classifier heads should not be orthogonalized.
    Args:
        params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
        steps (int, optional):
            The number of Newton-Schulz iterations to run. Defaults to 5.
        dual_norm_correction (bool, optional):
            enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
        method (str, optional):
            Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
    """
    for p in params:
        if (p.grad is not None) and _is_at_least_2d(p.grad):
            X = _orthogonalize_tensor(p.grad, steps, method)
            if dual_norm_correction: X = _dual_norm_correction(X, p.grad, batch_first=False)
            p.grad.set_(X.view_as(p)) # pyright:ignore[reportArgumentType]

orthograd_

orthograd_(params: Iterable[Tensor], eps: float = 1e-30)

Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

Parameters:

  • params (Iterable[Tensor]) –

    parameters that hold gradients to apply ⟂Grad to.

  • eps (float, default: 1e-30 ) –

    epsilon added to the denominator for numerical stability (default: 1e-30)

reference https://arxiv.org/abs/2501.04697

Source code in torchzero/modules/adaptive/orthograd.py
def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
    """Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

    Args:
        params (abc.Iterable[torch.Tensor]): parameters that hold gradients to apply ⟂Grad to.
        eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)

    reference
        https://arxiv.org/abs/2501.04697
    """
    params = as_tensorlist(params).with_grad()
    grad = params.grad
    grad -= (params.dot(grad)/(params.dot(params) + eps)) * params