Skip to content

List of all modules

A somewhat categorized list of modules is also available in Modules

Classes:

  • AEGD

    AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.

  • ASAM

    Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52

  • Abs

    Returns abs(input)

  • AccumulateMaximum

    Accumulates maximum of all past updates.

  • AccumulateMean

    Accumulates mean of all past updates.

  • AccumulateMinimum

    Accumulates minimum of all past updates.

  • AccumulateProduct

    Accumulates product of all past updates.

  • AccumulateSum

    Accumulates sum of all past updates.

  • AdGD

    AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)

  • AdaHessian

    AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)

  • Adagrad

    Adagrad, divides by sum of past squares of gradients.

  • AdagradNorm

    Adagrad-Norm, divides by sum of past means of squares of gradients.

  • Adam

    Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.

  • Adan

    Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677

  • AdaptiveBacktracking

    Adaptive backtracking line search. After each line search procedure, a new initial step size is set

  • AdaptiveBisection

    A line search that evaluates previous step size, if value increased, backtracks until the value stops decreasing,

  • AdaptiveHeavyBall

    Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.

  • Add

    Add other to tensors. other can be a number or a module.

  • Alternate

    Alternates between stepping with :code:modules.

  • Averaging

    Average of past history_size updates.

  • BBStab

    Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).

  • BFGS

    Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.

  • BacktrackOnSignChange

    Negates or undoes update for parameters where where gradient or update sign changes.

  • Backtracking

    Backtracking line search.

  • BarzilaiBorwein

    Barzilai-Borwein step size method.

  • BinaryOperationBase

    Base class for operations that use update as the first operand. This is an abstract class, subclass it and override transform method to use it.

  • BirginMartinezRestart

    the restart criterion for conjugate gradient methods designed by Birgin and Martinez.

  • BoldDriver

    Multiplies step size by nplus if loss decreased compared to last iteration, otherwise multiplies by nminus.

  • BroydenBad

    Broyden's "bad" Quasi-Newton method.

  • BroydenGood

    Broyden's "good" Quasi-Newton method.

  • CD

    Coordinate descent. Proposes a descent direction along a single coordinate.

  • Cautious

    Negates update for parameters where update and gradient sign is inconsistent.

  • CautiousWeightDecay

    Cautious weight decay (https://arxiv.org/pdf/2510.12402).

  • CenteredEMASquared

    Maintains a centered exponential moving average of squared updates. This also maintains an additional

  • CenteredSqrtEMASquared

    Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.

  • Centralize

    Centralizes the update.

  • Clip

    clip tensors to be in (min, max) range. min and `max: can be None, numbers or modules.

  • ClipModules

    Calculates input(tensors).clip(min, max). min and max can be numbers or modules.

  • ClipNorm

    Clips update norm to be no larger than value.

  • ClipNormByEMA

    Clips norm to be no larger than the norm of an exponential moving average of past updates.

  • ClipNormGrowth

    Clips update norm growth.

  • ClipValue

    Clips update magnitude to be within (-value, value) range.

  • ClipValueByEMA

    Clips magnitude of update to be no larger than magnitude of exponential moving average of past (unclipped) updates.

  • ClipValueGrowth

    Clips update value magnitude growth.

  • Clone

    Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations

  • ConjugateDescent

    Conjugate Descent (CD).

  • CopyMagnitude

    Returns other(tensors) with sign copied from tensors.

  • CopySign

    Returns tensors with sign copied from other(tensors).

  • CubicRegularization

    Cubic regularization.

  • CustomUnaryOperation

    Applies getattr(tensor, name) to each tensor

  • DFP

    Davidon–Fletcher–Powell Quasi-Newton method.

  • DNRTR

    Diagonal quasi-newton method.

  • DYHS

    Dai-Yuan - Hestenes–Stiefel hybrid conjugate gradient method.

  • DaiYuan

    Dai–Yuan nonlinear conjugate gradient method.

  • Debias

    Multiplies the update by an Adam debiasing term based first and/or second momentum.

  • Debias2

    Multiplies the update by an Adam debiasing term based on the second momentum.

  • DiagonalBFGS

    Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful.

  • DiagonalQuasiCauchi

    Diagonal quasi-cauchi method.

  • DiagonalSR1

    Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful.

  • DiagonalWeightedQuasiCauchi

    Diagonal quasi-cauchi method.

  • DirectWeightDecay

    Directly applies weight decay to parameters.

  • Div

    Divide tensors by other. other can be a number or a module.

  • DivByLoss

    Divides update by loss times alpha

  • DivModules

    Calculates input / other. input and other can be numbers or modules.

  • Dogleg

    Dogleg trust region algorithm.

  • Dropout

    Applies dropout to the update.

  • DualNormCorrection

    Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon).

  • EMA

    Maintains an exponential moving average of update.

  • EMASquared

    Maintains an exponential moving average of squared updates.

  • ESGD

    Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)

  • EscapeAnnealing

    If parameters stop changing, this runs a backward annealing random search

  • Exp

    Returns exp(input)

  • ExpHomotopy
  • FDM

    Approximate gradients via finite difference method.

  • Fill

    Outputs tensors filled with value

  • FillLoss

    Outputs tensors filled with loss value times alpha

  • FletcherReeves

    Fletcher–Reeves nonlinear conjugate gradient method.

  • FletcherVMM

    Fletcher's variable metric Quasi-Newton method.

  • ForwardGradient

    Forward gradient method.

  • FullMatrixAdagrad

    Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).

  • GGT

    GGT method from https://arxiv.org/pdf/1806.02958

  • GGTBasis

    Run another optimizer in GGT eigenbasis. The eigenbasis is rank-sized, so it is possible to run expensive

  • GaussNewton

    Gauss-newton method.

  • GaussianSmoothing

    Gradient approximation via Gaussian smoothing method.

  • Grad

    Outputs the gradient

  • GradApproximator

    Base class for gradient approximations.

  • GradSign

    Copies gradient sign to update.

  • GradToNone

    Sets grad attribute to None on objective.

  • GradientAccumulation

    Uses n steps to accumulate gradients, after n gradients have been accumulated, they are passed to :code:modules and parameters are updates.

  • GradientCorrection

    Estimates gradient at minima along search direction assuming function is quadratic.

  • GradientSampling

    Samples and aggregates gradients and values at perturbed points.

  • Graft

    Outputs direction output rescaled to have the same norm as magnitude output.

  • GraftGradToUpdate

    Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update.

  • GraftInputToOutput

    Outputs tensors rescaled to have the same norm as magnitude(tensors).

  • GraftOutputToInput

    Outputs magnitude(tensors) rescaled to have the same norm as tensors

  • GraftToGrad

    Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient.

  • GraftToParams

    Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than eps.

  • GramSchimdt

    outputs tensors made orthogonal to other(tensors) via Gram-Schmidt.

  • Greenstadt1

    Greenstadt's first Quasi-Newton method.

  • Greenstadt2

    Greenstadt's second Quasi-Newton method.

  • HagerZhang

    Hager-Zhang nonlinear conjugate gradient method,

  • HeavyBall

    Polyak's momentum (heavy-ball method).

  • HestenesStiefel

    Hestenes–Stiefel nonlinear conjugate gradient method.

  • Horisho

    Horisho's variable metric Quasi-Newton method.

  • HpuEstimate

    returns y/||s||, where y is difference between current and previous update (gradient), s is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update.

  • ICUM

    Inverse Column-updating Quasi-Newton method. This is computationally cheaper than other Quasi-Newton methods

  • Identity

    Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

  • ImprovedNewton

    Improved Newton's Method (INM).

  • IntermoduleCautious

    Negaties update on :code:main module where it's sign doesn't match with output of compare module.

  • InverseFreeNewton

    Inverse-free newton's method

  • LBFGS

    Limited-memory BFGS algorithm. A line search or trust region is recommended.

  • LR

    Learning rate. Adding this module also adds support for LR schedulers.

  • LSR1

    Limited-memory SR1 algorithm. A line search or trust region is recommended.

  • LambdaHomotopy
  • LaplacianSmoothing

    Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.

  • LastAbsoluteRatio

    Outputs ratio between absolute values of past two updates the numerator is determined by numerator argument.

  • LastDifference

    Outputs difference between past two updates.

  • LastGradDifference

    Outputs difference between past two gradients.

  • LastProduct

    Outputs difference between past two updates.

  • LastRatio

    Outputs ratio between past two updates, the numerator is determined by numerator argument.

  • LerpModules

    Does a linear interpolation of input(tensors) and end(tensors) based on a scalar weight.

  • LevenbergMarquardt

    Levenberg-Marquardt trust region algorithm.

  • LineSearchBase

    Base class for line searches.

  • Lion

    Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.

  • LiuStorey

    Liu-Storey nonlinear conjugate gradient method.

  • LogHomotopy
  • MARSCorrection

    MARS variance reduction correction.

  • MSAM

    Momentum-SAM from https://arxiv.org/pdf/2401.12033.

  • MSAMMomentum

    Momentum-SAM from https://arxiv.org/pdf/2401.12033.

  • MatrixMomentum

    Second order momentum method.

  • Maximum

    Outputs maximum(tensors, other(tensors))

  • MaximumModules

    Outputs elementwise maximum of inputs that can be modules or numbers.

  • McCormick

    McCormicks's Quasi-Newton method.

  • MeZO

    Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.

  • Mean

    Outputs a mean of inputs that can be modules or numbers.

  • MedianAveraging

    Median of past history_size updates.

  • Minimum

    Outputs minimum(tensors, other(tensors))

  • MinimumModules

    Outputs elementwise minimum of inputs that can be modules or numbers.

  • Mul

    Multiply tensors by other. other can be a number or a module.

  • MulByLoss

    Multiplies update by loss times alpha

  • MultiOperationBase

    Base class for operations that use operands. This is an abstract class, subclass it and override transform method to use it.

  • Multistep

    Performs steps inner steps with module per each step.

  • MuonAdjustLR

    LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master).

  • NAG

    Nesterov accelerated gradient method (nesterov momentum).

  • NanToNum

    Convert nan, inf and -inf`` to numbers.

  • NaturalGradient

    Natural gradient approximated via empirical fisher information matrix.

  • Negate

    Returns - input

  • NegateOnLossIncrease

    Uses an extra forward pass to evaluate loss at parameters+update,

  • NewDQN

    Diagonal quasi-newton method.

  • NewSSM

    Self-scaling Quasi-Newton method.

  • Newton

    Exact Newton's method via autograd.

  • NewtonCG

    Newton's method with a matrix-free conjugate gradient or minimial-residual solver.

  • NewtonCGSteihaug

    Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.

  • NoiseSign

    Outputs random tensors with sign copied from the update.

  • Noop

    Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

  • Normalize

    Normalizes the update.

  • NormalizeByEMA

    Sets norm of the update to be the same as the norm of an exponential moving average of past updates.

  • NystromPCG

    Newton's method with a Nyström-preconditioned conjugate gradient solver.

  • NystromSketchAndSolve

    Newton's method with a Nyström sketch-and-solve solver.

  • Ones

    Outputs ones

  • Online

    Allows certain modules to be used for mini-batch optimization.

  • OrthoGrad

    Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

  • Orthogonalize

    Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.

  • PSB

    Powell's Symmetric Broyden Quasi-Newton method.

  • PSGDDenseNewton

    Dense hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

  • PSGDKronNewton

    Kron hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

  • PSGDKronWhiten

    Kron whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

  • PSGDLRANewton

    Low rank hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

  • PSGDLRAWhiten

    Low rank whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

  • Params

    Outputs parameters

  • Pearson

    Pearson's Quasi-Newton method.

  • PerturbWeights

    Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.

  • PolakRibiere

    Polak-Ribière-Polyak nonlinear conjugate gradient method.

  • PolyakStepSize

    Polyak's subgradient method with known or unknown f*.

  • Pow

    Take tensors to the power of exponent. exponent can be a number or a module.

  • PowModules

    Calculates input ** exponent. input and other can be numbers or modules.

  • PowellRestart

    Powell's two restarting criterions for conjugate gradient methods.

  • Previous

    Maintains an update from n steps back, for example if n=1, returns previous update

  • PrintLoss

    Prints var.get_loss().

  • PrintParams

    Prints current update.

  • PrintShape

    Prints shapes of the update.

  • PrintUpdate

    Prints current update.

  • Prod

    Outputs product of inputs that can be modules or numbers.

  • ProjectedGradientMethod

    Projected gradient method. Directly projects the gradient onto subspace conjugate to past directions.

  • ProjectedNewtonRaphson

    Projected Newton Raphson method.

  • ProjectionBase

    Base class for projections.

  • RCopySign

    Returns other(tensors) with sign copied from tensors.

  • RDSA

    Gradient approximation via Random-direction stochastic approximation (RDSA) method.

  • RDiv

    Divide other by tensors. other can be a number or a module.

  • RMSprop

    Divides graient by EMA of gradient squares.

  • RPow

    Take other to the power of tensors. other can be a number or a module.

  • RSub

    Subtract tensors from other. other can be a number or a module.

  • Randn

    Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1.

  • RandomHvp

    Returns a hessian-vector product with a random vector, optionally times vector

  • RandomReinitialize

    On each step with probability p_reinit trigger reinitialization,

  • RandomSample

    Outputs tensors filled with random numbers from distribution depending on value of distribution.

  • RandomStepSize

    Uses random global or layer-wise step size from low to high.

  • RandomizedFDM

    Gradient approximation via a randomized finite-difference method.

  • Reciprocal

    Returns 1 / input

  • ReduceOperationBase

    Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override transform method to use it.

  • Relative

    Multiplies update by absolute parameter values to make it relative to their magnitude, min_value is minimum allowed value to avoid getting stuck at 0.

  • RelativeWeightDecay

    Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of norm_input argument.

  • RestartEvery

    Resets the state every n steps

  • RestartOnStuck

    Resets the state when update (difference in parameters) is zero for multiple steps in a row.

  • RestartStrategyBase

    Base class for restart strategies.

  • Rprop

    Resilient propagation. The update magnitude gets multiplied by nplus if gradient didn't change the sign,

  • SAM

    Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412

  • SG2

    second-order stochastic gradient

  • SOAP

    SOAP (ShampoO with Adam in the Preconditioner's eigenbasis from https://arxiv.org/abs/2409.11321).

  • SOAPBasis

    Run another optimizer in Shampoo eigenbases.

  • SPSA

    Gradient approximation via Simultaneous perturbation stochastic approximation (SPSA) method.

  • SPSA1

    One-measurement variant of SPSA. Unlike standard two-measurement SPSA, the estimated

  • SR1

    Symmetric Rank 1. This works best with a trust region:

  • SSVM

    Self-scaling variable metric Quasi-Newton method.

  • SVRG

    Stochastic variance reduced gradient method (SVRG).

  • SaveBest

    Saves best parameters found so far, ones that have lowest loss. Put this as the last module.

  • ScalarProjection

    projetion that splits all parameters into individual scalars

  • ScaleByGradCosineSimilarity

    Multiplies the update by cosine similarity with gradient.

  • ScaleLRBySignChange

    learning rate gets multiplied by nplus if ascent/gradient didn't change the sign,

  • ScaleModulesByCosineSimilarity

    Scales the output of main module by it's cosine similarity to the output

  • ScipyMinimizeScalar

    Line search via :code:scipy.optimize.minimize_scalar which implements brent, golden search and bounded brent methods.

  • Sequential

    On each step, this sequentially steps with modules steps times.

  • Shampoo

    Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).

  • ShorR

    Shor’s r-algorithm.

  • Sign

    Returns sign(input)

  • SignConsistencyLRs

    Outputs per-weight learning rates based on consecutive sign consistency.

  • SignConsistencyMask

    Outputs a mask of sign consistency of current and previous inputs.

  • SixthOrder3P

    Sixth-order iterative method.

  • SixthOrder3PM2

    Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45.

  • SixthOrder5P

    Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139.

  • SophiaH

    SophiaH optimizer from https://arxiv.org/abs/2305.14342

  • Split

    Apply true modules to all parameters filtered by filter, apply false modules to all other parameters.

  • Sqrt

    Returns sqrt(input)

  • SqrtEMASquared

    Maintains an exponential moving average of squared updates, outputs optionally debiased square root.

  • SqrtHomotopy
  • SquareHomotopy
  • StepSize

    this is exactly the same as LR, except the lr parameter can be renamed to any other name to avoid clashes

  • StrongWolfe

    Interpolation line search satisfying Strong Wolfe condition.

  • Sub

    Subtract other from tensors. other can be a number or a module.

  • SubModules

    Calculates input - other. input and other can be numbers or modules.

  • SubspaceNewton

    Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).

  • Sum

    Outputs sum of inputs that can be modules or numbers.

  • SumOfSquares

    Sets loss to be the sum of squares of values returned by the closure.

  • Switch

    After steps steps switches to the next module.

  • TerminateAfterNEvaluations
  • TerminateAfterNSeconds
  • TerminateAfterNSteps
  • TerminateAll
  • TerminateAny
  • TerminateByGradientNorm
  • TerminateByUpdateNorm

    update is calculated as parameter difference

  • TerminateNever
  • TerminateOnLossReached
  • TerminateOnNoImprovement
  • TerminationCriteriaBase
  • ThomasOptimalMethod

    Thomas's "optimal" Quasi-Newton method.

  • Threshold

    Outputs tensors thresholded such that values above threshold are set to value.

  • To

    Cast modules to specified device and dtype

  • TrustCG

    Trust region via Steihaug-Toint Conjugate Gradient method.

  • TrustRegionBase
  • TwoPointNewton

    two-point Newton method with frozen derivative with third order convergence.

  • UnaryLambda

    Applies fn to input tensors.

  • UnaryParameterwiseLambda

    Applies fn to each input tensor.

  • Uniform

    Outputs tensors filled with random numbers from uniform distribution between low and high.

  • UpdateGradientSignConsistency

    Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.

  • UpdateSign

    Outputs gradient with sign copied from the update.

  • UpdateToNone

    Sets update attribute to None on var.

  • VectorProjection

    projection that concatenates all parameters into a vector

  • ViewAsReal

    View complex tensors as real tensors. Doesn't affect tensors that are already.

  • Warmup

    Learning rate warmup, linearly increases learning rate multiplier from start_lr to end_lr over steps steps.

  • WarmupNormClip

    Warmup via clipping of the update norm.

  • WeightDecay

    Weight decay.

  • WeightDropout

    Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.

  • WeightedAveraging

    Weighted average of past len(weights) updates.

  • WeightedMean

    Outputs weighted mean of inputs that can be modules or numbers.

  • WeightedSum

    Outputs a weighted sum of inputs that can be modules or numbers.

  • Wrap

    Wraps a pytorch optimizer to use it as a module.

  • Zeros

    Outputs zeros

Functions:

  • clip_grad_norm_

    Clips gradient of an iterable of parameters to specified norm value.

  • clip_grad_value_

    Clips gradient of an iterable of parameters at specified value.

  • decay_weights_

    directly decays weights in-place

  • normalize_grads_

    Normalizes gradient of an iterable of parameters to specified norm value.

  • orthogonalize_grads_

    Computes the zeroth power / orthogonalization of gradients of an iterable of parameters.

  • orthograd_

    Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

AEGD

Bases: torchzero.core.transform.TensorTransform

AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.

Note

AEGD has a learning rate hyperparameter that can't really be removed from the update rule. To avoid compounding learning rate mofications, remove the tz.m.LR module if you had it.

Parameters:

  • lr (float, default: 0.1 ) –

    learning rate (default: 0.1)

  • c (float, default: 1 ) –

    term added to the original objective function (default: 1)

Reference

Liu, Hailiang, and Xuping Tian. "AEGD: Adaptive gradient descent with energy." arXiv preprint arXiv:2010.05109 (2020).

Source code in torchzero/modules/adaptive/aegd.py
class AEGD(TensorTransform):
    """AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.

    Note:
        AEGD has a learning rate hyperparameter that can't really be removed from the update rule.
        To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.

    Args:
        lr (float, optional): learning rate (default: 0.1)
        c (float, optional): term added to the original objective function (default: 1)

    Reference:
        [Liu, Hailiang, and Xuping Tian. "AEGD: Adaptive gradient descent with energy." arXiv preprint arXiv:2010.05109 (2020).](https://arxiv.org/pdf/2010.05109)
    """
    def __init__(
        self,
        lr: float = 0.1,
        c: float = 1,
    ):
        defaults = dict(c=c, lr=lr)
        super().__init__(defaults, uses_loss=True)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        assert loss is not None
        tensors = TensorList(tensors)

        c, lr = unpack_dicts(settings, 'c', 'lr', cls=NumberList)
        r = unpack_states(states, tensors, 'r', init=lambda t: torch.full_like(t, float(loss + c[0])**0.5), cls=TensorList)

        update = aegd_(
            f=loss,
            g=tensors,
            r_=r,
            c=c,
            eta=lr,
        )

        return update

ASAM

Bases: torchzero.modules.adaptive.sam.SAM

Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52

SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value. It performs two forward and backward passes per step.

This implementation modifies the closure to return loss and calculate gradients of the SAM objective. All modules after this will use the modified objective.

Note

This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients at two points on each step.

Parameters:

  • rho (float, default: 0.5 ) –

    Neighborhood size. Defaults to 0.05.

  • p (float, default: 2 ) –

    norm of the SAM objective. Defaults to 2.

Examples:

ASAM-SGD:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.ASAM(),
    tz.m.LR(1e-2)
)

ASAM-Adam:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.ASAM(),
    tz.m.Adam(),
    tz.m.LR(1e-2)
)
References: Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). ASAM: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR.

Source code in torchzero/modules/adaptive/sam.py
class ASAM(SAM):
    """Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52

    SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
    It performs two forward and backward passes per step.

    This implementation modifies the closure to return loss and calculate gradients
    of the SAM objective. All modules after this will use the modified objective.

    Note:
        This module requires a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients at two points on each step.

    Args:
        rho (float, optional): Neighborhood size. Defaults to 0.05.
        p (float, optional): norm of the SAM objective. Defaults to 2.

    ### Examples:

    ASAM-SGD:

    ```py
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.ASAM(),
        tz.m.LR(1e-2)
    )
    ```

    ASAM-Adam:

    ```
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.ASAM(),
        tz.m.Adam(),
        tz.m.LR(1e-2)
    )
    ```
    References:
        [Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). ASAM: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR.](https://arxiv.org/abs/2102.11600)
    """
    def __init__(self, rho: float = 0.5, p: float = 2, eps=1e-10):
        super().__init__(rho=rho, p=p, eps=eps, asam=True)

Abs

Bases: torchzero.core.transform.TensorTransform

Returns abs(input)

Source code in torchzero/modules/ops/unary.py
class Abs(TensorTransform):
    """Returns ``abs(input)``"""
    def __init__(self): super().__init__()
    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        torch._foreach_abs_(tensors)
        return tensors

AccumulateMaximum

Bases: torchzero.core.transform.TensorTransform

Accumulates maximum of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateMaximum(TensorTransform):
    """Accumulates maximum of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0):
        defaults = dict(decay=decay)
        super().__init__(defaults)
        self.add_projected_keys("grad", "maximum")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return maximum.maximum_(tensors).lazy_mul(decay, clone=True)

AccumulateMean

Bases: torchzero.core.transform.TensorTransform

Accumulates mean of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateMean(TensorTransform):
    """Accumulates mean of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0):
        defaults = dict(decay=decay)
        super().__init__(defaults)
        self.add_projected_keys("grad", "mean")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1
        mean = unpack_states(states, tensors, 'mean', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return mean.add_(tensors).lazy_mul(decay, clone=True).div_(step)

AccumulateMinimum

Bases: torchzero.core.transform.TensorTransform

Accumulates minimum of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateMinimum(TensorTransform):
    """Accumulates minimum of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0):
        defaults = dict(decay=decay)
        super().__init__(defaults)
        self.add_projected_keys("grad", "minimum")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return minimum.minimum_(tensors).lazy_mul(decay, clone=True)

AccumulateProduct

Bases: torchzero.core.transform.TensorTransform

Accumulates product of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Target, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateProduct(TensorTransform):
    """Accumulates product of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0, target = 'update',):
        defaults = dict(decay=decay)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        prod = unpack_states(states, tensors, 'prod', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return prod.mul_(tensors).lazy_mul(decay, clone=True)

AccumulateSum

Bases: torchzero.core.transform.TensorTransform

Accumulates sum of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateSum(TensorTransform):
    """Accumulates sum of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0):
        defaults = dict(decay=decay)
        super().__init__(defaults)
        self.add_projected_keys("grad", "sum")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        sum = unpack_states(states, tensors, 'sum', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return sum.add_(tensors).lazy_mul(decay, clone=True)

AdGD

Bases: torchzero.core.transform.TensorTransform

AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)

Source code in torchzero/modules/step_size/adaptive.py
class AdGD(TensorTransform):
    """AdGD and AdGD-2 (https://arxiv.org/abs/2308.02261)"""
    def __init__(self, variant:Literal[1,2]=2, alpha_0:float = 1e-7, sqrt:bool=True, use_grad=True, inner: Chainable | None = None,):
        defaults = dict(variant=variant, alpha_0=alpha_0, sqrt=sqrt)
        super().__init__(defaults, uses_grad=use_grad, inner=inner,)

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('prev_g')
        self.global_state['reset'] = True

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        variant = settings[0]['variant']
        theta_0 = 0 if variant == 1 else 1/3
        theta = self.global_state.get('theta', theta_0)

        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        p = TensorList(params)
        g = grads if self._uses_grad else tensors
        assert g is not None
        g = TensorList(g)

        prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)

        # online
        if self.global_state.get('reset', False):
            del self.global_state['reset']
            prev_p.copy_(p)
            prev_g.copy_(g)
            return

        if step == 0:
            alpha_0 = settings[0]['alpha_0']
            if alpha_0 is None: alpha_0 = epsilon_step_size(g)
            self.global_state['alpha']  = alpha_0
            prev_p.copy_(p)
            prev_g.copy_(g)
            return

        sqrt = settings[0]['sqrt']
        alpha = self.global_state.get('alpha', math.inf)
        L = (g - prev_g).global_vector_norm() / (p - prev_p).global_vector_norm()
        eps = torch.finfo(L.dtype).tiny * 2

        if variant == 1:
            a1 = math.sqrt(1 + theta)*alpha
            val = math.sqrt(2) if sqrt else 2
            if L > eps: a2 = 1 / (val*L)
            else: a2 = math.inf

        elif variant == 2:
            a1 = math.sqrt(2/3 + theta)*alpha
            a2 = alpha / math.sqrt(max(eps, 2 * alpha**2 * L**2 - 1))

        else:
            raise ValueError(variant)

        alpha_new = min(a1, a2)
        if alpha_new < 0: alpha_new = max(a1, a2)
        if alpha_new > eps:
            self.global_state['theta'] = alpha_new/alpha
            self.global_state['alpha'] = alpha_new

        prev_p.copy_(p)
        prev_g.copy_(g)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        alpha = self.global_state.get('alpha', None)

        if not _acceptable_alpha(alpha, tensors[0]):
            # alpha isn't None on 1st step
            self.state.clear()
            self.global_state.clear()
            alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])

        torch._foreach_mul_(tensors, alpha)
        return tensors

    def get_H(self, objective):
        return _get_scaled_identity_H(self, objective)

AdaHessian

Bases: torchzero.core.transform.Transform

AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)

This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.

Notes
  • In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply AdaHessian preconditioning to another module's output.

  • This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

Parameters:

  • beta1 (float, default: 0.9 ) –

    first momentum. Defaults to 0.9.

  • beta2 (float, default: 0.999 ) –

    second momentum for squared hessian diagonal estimates. Defaults to 0.999.

  • averaging (bool, default: True ) –

    whether to enable block diagonal averaging over 1st dimension on parameters that have 2+ dimensions. This can be set per-parameter in param groups.

  • block_size (int, default: None ) –

    size of block in the block-diagonal averaging.

  • update_freq (int, default: 1 ) –

    frequency of updating hessian diagonal estimate via a hessian-vector product. This value can be increased to reduce computational cost. Defaults to 1.

  • eps (float, default: 1e-08 ) –

    division stability epsilon. Defaults to 1e-8.

  • hvp_method (str, default: 'autograd' ) –

    Determines how hessian-vector products are computed.

    • "batched_autograd" - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to "autograd". Faster than "autograd" but uses more memory.
    • "autograd" - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than "batched_autograd" but uses less memory.
    • "fd_forward" - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
    • "fd_central" - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

    Defaults to "autograd".

  • h (float, default: 0.001 ) –

    The step size for finite difference if hvp_method is "fd_forward" or "fd_central". Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of hessian-vector products with random vectors to evaluate each time when updating the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.

  • seed (int | None, default: None ) –

    seed for random vectors. Defaults to None.

  • inner (Chainable | None) –

    Inner module. If this is specified, operations are performed in the following order. 1. compute hessian diagonal estimate. 2. pass inputs to inner. 3. momentum and preconditioning are applied to the ouputs of inner.

Examples:

Using AdaHessian:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.AdaHessian(),
    tz.m.LR(0.1)
)

AdaHessian preconditioner can be applied to any other module by passing it to the inner argument. Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying AdaHessian preconditioning to nesterov momentum (tz.m.NAG):

opt = tz.Optimizer(
    model.parameters(),
    tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
    tz.m.LR(0.1)
)

Source code in torchzero/modules/adaptive/adahessian.py
class AdaHessian(Transform):
    """AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)

    This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.

    Notes:
        - In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply AdaHessian preconditioning to another module's output.

        - This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        beta1 (float, optional): first momentum. Defaults to 0.9.
        beta2 (float, optional): second momentum for squared hessian diagonal estimates. Defaults to 0.999.
        averaging (bool, optional):
            whether to enable block diagonal averaging over 1st dimension on parameters that have 2+ dimensions.
            This can be set per-parameter in param groups.
        block_size (int, optional):
            size of block in the block-diagonal averaging.
        update_freq (int, optional):
            frequency of updating hessian diagonal estimate via a hessian-vector product.
            This value can be increased to reduce computational cost. Defaults to 1.
        eps (float, optional):
            division stability epsilon. Defaults to 1e-8.
        hvp_method (str, optional):
            Determines how hessian-vector products are computed.

            - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
            - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
            - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

            Defaults to ``"autograd"``.
        h (float, optional):
            The step size for finite difference if ``hvp_method`` is
            ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
        n_samples (int, optional):
            number of hessian-vector products with random vectors to evaluate each time when updating
            the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
        seed (int | None, optional): seed for random vectors. Defaults to None.
        inner (Chainable | None, optional):
            Inner module. If this is specified, operations are performed in the following order.
            1. compute hessian diagonal estimate.
            2. pass inputs to ``inner``.
            3. momentum and preconditioning are applied to the ouputs of ``inner``.

    ## Examples:

    Using AdaHessian:

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.AdaHessian(),
        tz.m.LR(0.1)
    )
    ```

    AdaHessian preconditioner can be applied to any other module by passing it to the ``inner`` argument.
    Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
    AdaHessian preconditioning to nesterov momentum (``tz.m.NAG``):
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
        tz.m.LR(0.1)
    )
    ```

    """
    def __init__(
        self,
        beta1: float = 0.9,
        beta2: float = 0.999,
        averaging: bool = True,
        block_size: int | None = None,
        update_freq: int = 1,
        eps: float = 1e-8,
        hessian_power: float = 1,
        distribution: Distributions = 'rademacher',
        hvp_method: HVPMethod = 'autograd',
        h: float = 1e-3,
        n_samples = 1,
        zHz: bool = True,
        debias: bool = True,
        seed: int | None = None,

        exp_avg_tfm: Chainable | None = None,
        D_exp_avg_sq_tfm: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults["exp_avg_tfm"], defaults["D_exp_avg_sq_tfm"]
        super().__init__(defaults)

        self.set_child('exp_avg', exp_avg_tfm)
        self.set_child('D_exp_avg_sq', D_exp_avg_sq_tfm)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        params = objective.params

        beta1, beta2, averaging, block_size = unpack_dicts(settings, 'beta1', 'beta2', 'averaging', 'block_size', cls=NumberList)

        exp_avg, D_exp_avg_sq = unpack_states(states, params, 'exp_avg', 'D_exp_avg_sq', cls=TensorList)

        # ---------------------------- hutchinson hessian ---------------------------- #
        fs = settings[0]
        step = self.increment_counter("step", start=0) # 0 on 1st update
        update_freq = fs['update_freq']

        if step % update_freq == 0:
            self.increment_counter("num_Ds", start=1)

            D, _ = objective.hutchinson_hessian(
                rgrad = None,
                at_x0 = True,
                n_samples = fs['n_samples'],
                distribution = fs['distribution'],
                hvp_method = fs['hvp_method'],
                h = fs['h'],
                zHz = fs["zHz"],
                generator = self.get_generator(params[0].device, fs["seed"]),
            )

            D = TensorList(D).zipmap_args(_block_average, block_size, averaging)
            D_exp_avg_sq.mul_(beta2).addcmul_(D, D, value=1-beta2)

        # --------------------------------- momentum --------------------------------- #
        tensors = objective.get_updates() # do this after hutchinson to not disturb autograd
        exp_avg.lerp_(tensors, 1-beta1)


    @torch.no_grad
    def apply_states(self, objective, states, settings):
        params = objective.params

        beta1, beta2, eps, hessian_power = unpack_dicts(settings, 'beta1', 'beta2', 'eps', 'hessian_power', cls=NumberList)
        exp_avg, D_exp_avg_sq = unpack_states(states, params, 'exp_avg', 'D_exp_avg_sq', cls=TensorList)

        # ---------------------------------- debias ---------------------------------- #
        if settings[0]["debias"]:
            bias_correction1 = 1.0 - (beta1 ** (self.global_state["step"] + 1))
            bias_correction2 = 1.0 - (beta2 ** self.global_state["num_Ds"])
            exp_avg = exp_avg / bias_correction1
            D_exp_avg_sq = D_exp_avg_sq / bias_correction2


        # -------------------------------- transforms -------------------------------- #
        exp_avg = TensorList(self.inner_step_tensors(
            "exp_avg", tensors=exp_avg, clone=True, objective=objective, must_exist=False))

        D_exp_avg_sq = TensorList(self.inner_step_tensors(
            "D_exp_avg_sq", tensors=D_exp_avg_sq, clone=True, objective=objective, must_exist=False))

        # ------------------------------ compute update ------------------------------ #
        denom = D_exp_avg_sq.lazy_pow(hessian_power / 2) + eps
        objective.updates = exp_avg / denom
        return objective

Adagrad

Bases: torchzero.core.transform.TensorTransform

Adagrad, divides by sum of past squares of gradients.

This implementation is identical to torch.optim.Adagrad.

Parameters:

  • lr_decay (float, default: 0 ) –

    learning rate decay. Defaults to 0.

  • initial_accumulator_value (float, default: 0 ) –

    initial value of the sum of squares of gradients. Defaults to 0.

  • eps (float, default: 1e-10 ) –

    division epsilon. Defaults to 1e-10.

  • alpha (float, default: 1 ) –

    step size. Defaults to 1.

  • pow (float) –

    power for gradients and accumulator root. Defaults to 2.

  • use_sqrt (bool) –

    whether to take the root of the accumulator. Defaults to True.

  • inner (Chainable | None, default: None ) –

    Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.

Source code in torchzero/modules/adaptive/adagrad.py
class Adagrad(TensorTransform):
    """Adagrad, divides by sum of past squares of gradients.

    This implementation is identical to ``torch.optim.Adagrad``.

    Args:
        lr_decay (float, optional): learning rate decay. Defaults to 0.
        initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
        eps (float, optional): division epsilon. Defaults to 1e-10.
        alpha (float, optional): step size. Defaults to 1.
        pow (float, optional): power for gradients and accumulator root. Defaults to 2.
        use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
        inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
    """
    def __init__(
        self,

        # hyperparams
        lr_decay: float = 0,
        initial_accumulator_value: float = 0,
        eps: float = 1e-10,
        alpha: float = 1,

        # tfms
        inner: Chainable | None = None,
        accumulator_tfm: Chainable | None = None
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults["accumulator_tfm"]
        super().__init__(defaults=defaults, inner=inner)

        self.set_child('accumulator', accumulator_tfm)
        self.add_projected_keys("grad", "accumulator")

    @torch.no_grad
    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        state["accumulator"] = torch.full_like(tensor, fill_value=setting["initial_accumulator_value"])

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        torch._foreach_addcmul_([state["accumulator"] for state in states], tensors, tensors)
        self.increment_counter("step", start=0)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        tensors_ = TensorList(tensors)
        step = self.global_state["step"] # 0 on first apply
        eps, alpha, lr_decay = unpack_dicts(settings, "eps", "alpha", "lr_decay", cls=NumberList)

        accumulator = [state["accumulator"] for state in states]
        accumulator = TensorList(self.inner_step_tensors(
            "accumulator", tensors=accumulator, clone=True, params=params, grads=grads, loss=loss, must_exist=False))

        denom = accumulator.sqrt().add_(eps)
        tensors_ /= denom

        clr = alpha / (1 + step * lr_decay)
        tensors_.lazy_mul_(clr)

        return tensors_

AdagradNorm

Bases: torchzero.core.transform.TensorTransform

Adagrad-Norm, divides by sum of past means of squares of gradients.

Parameters:

  • lr_decay (float, default: 0 ) –

    learning rate decay. Defaults to 0.

  • initial_accumulator_value (float, default: 0 ) –

    initial value of the sum of squares of gradients. Defaults to 0.

  • eps (float, default: 1e-10 ) –

    division epsilon. Defaults to 1e-10.

  • alpha (float, default: 1 ) –

    step size. Defaults to 1.

  • use_sqrt (bool, default: True ) –

    whether to take the root of the accumulator. Defaults to True.

  • inner (Chainable | None, default: None ) –

    Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.

Source code in torchzero/modules/adaptive/adagrad.py
class AdagradNorm(TensorTransform):
    """Adagrad-Norm, divides by sum of past means of squares of gradients.

    Args:
        lr_decay (float, optional): learning rate decay. Defaults to 0.
        initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
        eps (float, optional): division epsilon. Defaults to 1e-10.
        alpha (float, optional): step size. Defaults to 1.
        use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
        inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
    """
    def __init__(
        self,
        lr_decay: float = 0,
        initial_accumulator_value: float = 0,
        eps: float = 1e-10,
        beta:float | None = None,
        beta_debias: bool = True,
        layerwise: bool = True,
        use_sqrt: bool = True,
        alpha: float = 1,
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner']
        super().__init__(defaults=defaults, inner=inner)

    @torch.no_grad
    def multi_tensor_initialize(self, tensors, params, grads, loss, states, settings):

        # layerwise initialize in each state
        if settings[0]["layerwise"]:
            for tensor, state, setting in zip(tensors, states, settings):

                initial_accumulator_value = setting["initial_accumulator_value"]
                state["accumulator"] = torch.tensor(initial_accumulator_value, device=tensor.device, dtype=tensor.dtype)

        # global initialize in global state
        else:
            initial_accumulator_value = settings[0]["initial_accumulator_value"]
            tensor = tensors[0]
            self.global_state["accumulator"] = torch.tensor(initial_accumulator_value, device=tensor.device, dtype=tensor.dtype)

    def _get_accumulator(self, states, settings) -> torch.Tensor | TensorList:
        layerwise = settings[0]["layerwise"]
        if layerwise:
            return TensorList(s["accumulator"] for s in states)

        return self.global_state["accumulator"]

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        accumulator = self._get_accumulator(states, settings)
        self.increment_counter("step", start=0)

        # compute squared gradient norm (gg)
        if isinstance(accumulator, TensorList): gg = tensors.tensorwise_dot(tensors)
        else: gg = tensors.dot(tensors)

        # update the accumulator
        beta = settings[0]["beta"]
        if beta is None: accumulator.add_(gg) # pyright:ignore[reportArgumentType]
        else: accumulator.lerp_(gg, weight=1-beta) # pyright:ignore[reportArgumentType, reportCallIssue]

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        accumulator = self._get_accumulator(states, settings)
        eps, alpha, lr_decay = unpack_dicts(settings, "eps", "alpha", "lr_decay", cls=NumberList)
        step = self.global_state["step"] # 0 on 1st step
        fs = settings[0]
        beta = fs["beta"]

        # ------------------------ debias if beta is not None ------------------------ #
        if fs["beta_debias"] and beta is not None:
            accumulator = accumulator / (1 - beta ** (step + 1))


        # ---------------------------- compute denominator --------------------------- #
        if fs["use_sqrt"]:
            denom = accumulator.sqrt().add_(eps) # pyright:ignore[reportArgumentType]
        else:
            denom = accumulator + eps # pyright:ignore[reportOperatorIssue]


        # ---------------------------- compute the update ---------------------------- #
        tensors /= denom
        clr = alpha / (1 + step * lr_decay) # lr decay
        tensors.lazy_mul_(clr)

        return tensors

Adam

Bases: torchzero.core.transform.TensorTransform

Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.

This implementation is identical to :code:torch.optim.Adam.

Parameters:

  • beta1 (float, default: 0.9 ) –

    momentum. Defaults to 0.9.

  • beta2 (float, default: 0.999 ) –

    second momentum. Defaults to 0.999.

  • eps (float, default: 1e-08 ) –

    epsilon. Defaults to 1e-8.

  • alpha (float, default: 1.0 ) –

    learning rate. Defaults to 1.

  • amsgrad (bool, default: False ) –

    Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.

  • pow (float) –

    power used in second momentum power and root. Defaults to 2.

  • debias (bool, default: True ) –

    whether to apply debiasing to momentums based on current step. Defaults to True.

Source code in torchzero/modules/adaptive/adam.py
class Adam(TensorTransform):
    """Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.

    This implementation is identical to :code:`torch.optim.Adam`.

    Args:
        beta1 (float, optional): momentum. Defaults to 0.9.
        beta2 (float, optional): second momentum. Defaults to 0.999.
        eps (float, optional): epsilon. Defaults to 1e-8.
        alpha (float, optional): learning rate. Defaults to 1.
        amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
        pow (float, optional): power used in second momentum power and root. Defaults to 2.
        debias (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
    """
    def __init__(
        self,
        beta1: float = 0.9,
        beta2: float = 0.999,
        eps: float = 1e-8,
        amsgrad: bool = False,
        alpha: float = 1.,
        debias: bool = True,

        exp_avg_tfm: Chainable | None = None,
        exp_avg_sq_tfm: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults["exp_avg_tfm"], defaults["exp_avg_sq_tfm"]
        super().__init__(defaults)

        self.set_child('exp_avg', exp_avg_tfm)
        self.set_child('exp_avg_sq', exp_avg_sq_tfm)

        self.add_projected_keys("grad", "exp_avg")
        self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        self.increment_counter("step", start=0)
        beta1, beta2 = unpack_dicts(settings, 'beta1','beta2', cls=NumberList)

        # ----------------------------- initialize states ---------------------------- #
        if settings[0]["amsgrad"]:
            exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(
                states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        # ------------------------------ update moments ------------------------------ #
        exp_avg.lerp_(tensors, weight=1-beta1)
        exp_avg_sq.mul_(beta2).addcmul_(tensors, tensors, value=1-beta2)

        if max_exp_avg_sq is not None:
            max_exp_avg_sq.maximum_(exp_avg_sq)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state["step"] # 0 on 1st step
        fs = settings[0]

        if fs["amsgrad"]: key = "max_exp_avg_sq"
        else: key = "exp_avg_sq"
        exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', key, cls=TensorList)
        beta1, beta2, alpha, eps = unpack_dicts(settings, 'beta1', 'beta2', 'alpha', 'eps', cls=NumberList)

        # -------------------------------- transforms -------------------------------- #
        exp_avg = TensorList(self.inner_step_tensors(
            "exp_avg", tensors=exp_avg, clone=True, params=params, grads=grads, loss=loss, must_exist=False))

        exp_avg_sq = TensorList(self.inner_step_tensors(
            "exp_avg_sq", tensors=exp_avg_sq, clone=True, params=params, grads=grads, loss=loss, must_exist=False))

        # ---------------------------------- debias ---------------------------------- #
        if fs["debias"]:
            alpha = debiased_step_size((step + 1), beta1=beta1, beta2=beta2, alpha=alpha)
            exp_avg = exp_avg * alpha

        # ---------------------------------- update ---------------------------------- #
        return exp_avg / exp_avg_sq.sqrt().add_(eps)

Adan

Bases: torchzero.core.transform.TensorTransform

Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677

Parameters:

  • beta1 (float, default: 0.98 ) –

    momentum. Defaults to 0.98.

  • beta2 (float, default: 0.92 ) –

    momentum for gradient differences. Defaults to 0.92.

  • beta3 (float, default: 0.99 ) –

    thrid (squared) momentum. Defaults to 0.99.

  • eps (float, default: 1e-08 ) –

    epsilon. Defaults to 1e-8.

Example:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adan(),
    tz.m.LR(1e-3),
)
Reference: Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence.

Source code in torchzero/modules/adaptive/adan.py
class Adan(TensorTransform):
    """Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677

    Args:
        beta1 (float, optional): momentum. Defaults to 0.98.
        beta2 (float, optional): momentum for gradient differences. Defaults to 0.92.
        beta3 (float, optional): thrid (squared) momentum. Defaults to 0.99.
        eps (float, optional): epsilon. Defaults to 1e-8.

    Example:
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Adan(),
        tz.m.LR(1e-3),
    )
    ```
    Reference:
        [Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence](https://arxiv.org/abs/2208.06677).
    """
    def __init__(
        self,
        beta1: float = 0.98,
        beta2: float = 0.92,
        beta3: float = 0.99,
        eps: float = 1e-8,

        m_tfm: Chainable | None = None,
        v_tfm: Chainable | None = None,
        n_tfm: Chainable | None = None,
    ):
        defaults=dict(beta1=beta1, beta2=beta2, beta3=beta3, eps=eps)
        super().__init__(defaults, uses_grad=False)

        self.set_child("m", m_tfm)
        self.set_child("v", v_tfm)
        self.set_child("n", n_tfm)

        self.add_projected_keys("grad_sq", "m", "v", "g_prev")
        self.add_projected_keys("grad", "n")

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        step = self.increment_counter("step", start=0)

        beta1, beta2, beta3 = unpack_dicts(settings, 'beta1','beta2','beta3', cls=NumberList)
        g_prev, m, v, n = unpack_states(states, tensors, 'g_prev', 'm', 'v', 'n', cls=TensorList)

        adan_update_(g=tensors, g_prev_=g_prev, m_=m, v_=v, n_=n, beta1=beta1, beta2=beta2, beta3=beta3, step=step+1)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        step = self.global_state["step"] # 0 on 1st step

        beta1, beta2, beta3, eps = unpack_dicts(settings, 'beta1','beta2','beta3', 'eps', cls=NumberList)
        m, v, n = unpack_states(states, tensors, 'm', 'v', 'n')

        # -------------------------------- transforms -------------------------------- #
        m = TensorList(self.inner_step_tensors("m", m, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
        v = TensorList(self.inner_step_tensors("v", v, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
        n = TensorList(self.inner_step_tensors("n", n, clone=True, params=params, grads=grads, loss=loss, must_exist=False))

        # ---------------------------------- update ---------------------------------- #
        return adan_apply_(m_=m, v_=v, n_=n, beta1=beta1, beta2=beta2, beta3=beta3, eps=eps, step=step+1)

AdaptiveBacktracking

Bases: torchzero.modules.line_search.line_search.LineSearchBase

Adaptive backtracking line search. After each line search procedure, a new initial step size is set such that optimal step size in the procedure would be found on the second line search iteration.

Parameters:

  • init (float, default: 1.0 ) –

    initial step size. Defaults to 1.0.

  • beta (float, default: 0.5 ) –

    multiplies each consecutive step size by this value. Defaults to 0.5.

  • c (float, default: 0.0001 ) –

    sufficient decrease condition. Defaults to 1e-4.

  • condition (Literal, default: 'armijo' ) –

    termination condition, only ones that do not use gradient at f(x+a*d) can be specified. - "armijo" - sufficient decrease condition. - "decrease" - any decrease in objective function value satisfies the condition.

    "goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage. Defaults to 'armijo'.

  • maxiter (int, default: 20 ) –

    maximum number of function evaluations per step. Defaults to 10.

  • target_iters (int, default: 1 ) –

    sets next step size such that this number of iterations are expected to be performed until optimal step size is found. Defaults to 1.

  • nplus (float, default: 2.0 ) –

    if initial step size is optimal, it is multiplied by this value. Defaults to 2.0.

  • scale_beta (float, default: 0.0 ) –

    momentum for initial step size, at 0 disables momentum. Defaults to 0.0.

Source code in torchzero/modules/line_search/backtracking.py
class AdaptiveBacktracking(LineSearchBase):
    """Adaptive backtracking line search. After each line search procedure, a new initial step size is set
    such that optimal step size in the procedure would be found on the second line search iteration.

    Args:
        init (float, optional): initial step size. Defaults to 1.0.
        beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
        c (float, optional): sufficient decrease condition. Defaults to 1e-4.
        condition (TerminationCondition, optional):
            termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
            - "armijo" - sufficient decrease condition.
            - "decrease" - any decrease in objective function value satisfies the condition.

            "goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
            Defaults to 'armijo'.
        maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
        target_iters (int, optional):
            sets next step size such that this number of iterations are expected
            to be performed until optimal step size is found. Defaults to 1.
        nplus (float, optional):
            if initial step size is optimal, it is multiplied by this value. Defaults to 2.0.
        scale_beta (float, optional):
            momentum for initial step size, at 0 disables momentum. Defaults to 0.0.
    """
    def __init__(
        self,
        init: float = 1.0,
        beta: float = 0.5,
        c: float = 1e-4,
        condition: TerminationCondition = 'armijo',
        maxiter: int = 20,
        target_iters = 1,
        nplus = 2.0,
        scale_beta = 0.0,
    ):
        defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta)
        super().__init__(defaults=defaults)

        self.global_state['beta_scale'] = 1.0
        self.global_state['initial_scale'] = 1.0

    def reset(self):
        super().reset()
        self.global_state['beta_scale'] = 1.0
        self.global_state['initial_scale'] = 1.0

    @torch.no_grad
    def search(self, update, var):
        init, beta, c,condition, maxiter, target_iters, nplus, scale_beta=itemgetter(
            'init','beta','c','condition', 'maxiter','target_iters','nplus','scale_beta')(self.defaults)

        objective = self.make_objective(var=var)

        # directional derivative (0 if c = 0 because it is not needed)
        if c == 0: d = 0
        else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grads(), update))

        # scale beta
        beta = beta * self.global_state['beta_scale']

        # scale step size so that decrease is expected at target_iters
        init = init * self.global_state['initial_scale']

        step_size = backtracking_line_search(objective, d, init=init, beta=beta, c=c, condition=condition, maxiter=maxiter)

        # found an alpha that reduces loss
        if step_size is not None:

            # update initial_scale
            # initial step size satisfied conditions, increase initial_scale by nplus
            if step_size == init and target_iters > 0:
                self.global_state['initial_scale'] *= nplus ** target_iters

                # clip by maximum possibel value to avoid overflow exception
                self.global_state['initial_scale'] = min(
                    self.global_state['initial_scale'],
                    torch.finfo(var.params[0].dtype).max / 2,
                )

            else:
                # otherwise make initial_scale such that target_iters iterations will satisfy armijo
                init_target = step_size
                for _ in range(target_iters):
                    init_target = step_size / beta

                self.global_state['initial_scale'] = _lerp(
                    self.global_state['initial_scale'], init_target / init, 1-scale_beta
                )

            # revert beta_scale
            self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))

            return step_size

        # on fail reduce beta scale value
        self.global_state['beta_scale'] /= 1.5
        return 0

AdaptiveBisection

Bases: torchzero.modules.line_search.line_search.LineSearchBase

A line search that evaluates previous step size, if value increased, backtracks until the value stops decreasing, otherwise forward-tracks until value stops decreasing.

Parameters:

  • init (float, default: 1.0 ) –

    initial step size. Defaults to 1.0.

  • nplus (float, default: 2 ) –

    multiplier to step size if initial step size is optimal. Defaults to 2.

  • nminus (float, default: 0.5 ) –

    multiplier to step size if initial step size is too big. Defaults to 0.5.

  • maxiter (int, default: 10 ) –

    maximum number of function evaluations per step. Defaults to 10.

  • adaptive (bool, default: True ) –

    when enabled, if line search failed, step size will continue decreasing on the next step. Otherwise it will restart the line search from init step size. Defaults to True.

Source code in torchzero/modules/line_search/adaptive.py
class AdaptiveBisection(LineSearchBase):
    """A line search that evaluates previous step size, if value increased, backtracks until the value stops decreasing,
    otherwise forward-tracks until value stops decreasing.

    Args:
        init (float, optional): initial step size. Defaults to 1.0.
        nplus (float, optional): multiplier to step size if initial step size is optimal. Defaults to 2.
        nminus (float, optional): multiplier to step size if initial step size is too big. Defaults to 0.5.
        maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
        adaptive (bool, optional):
            when enabled, if line search failed, step size will continue decreasing on the next step.
            Otherwise it will restart the line search from ``init`` step size. Defaults to True.
    """
    def __init__(
        self,
        init: float = 1.0,
        nplus: float = 2,
        nminus: float = 0.5,
        maxiter: int = 10,
        adaptive=True,
    ):
        defaults=dict(init=init,nplus=nplus,nminus=nminus,maxiter=maxiter,adaptive=adaptive)
        super().__init__(defaults=defaults)

    def reset(self):
        super().reset()

    @torch.no_grad
    def search(self, update, var):
        init, nplus, nminus, maxiter, adaptive = itemgetter(
            'init', 'nplus', 'nminus', 'maxiter', 'adaptive')(self.defaults)

        objective = self.make_objective(var=var)

        # scale a_prev
        a_prev = self.global_state.get('a_prev', init)
        if adaptive: a_prev = a_prev * self.global_state.get('init_scale', 1)

        a_init = a_prev
        if a_init < torch.finfo(var.params[0].dtype).tiny * 2:
            a_init = torch.finfo(var.params[0].dtype).max / 2

        step_size, f, niter = adaptive_bisection(
            objective,
            a_init=a_init,
            maxiter=maxiter,
            nplus=nplus,
            nminus=nminus,
        )

        # found an alpha that reduces loss
        if step_size != 0:
            assert (var.loss is None) or (math.isfinite(f) and f < var.loss)
            self.global_state['init_scale'] = 1

            # if niter == 1, forward tracking failed to decrease function value compared to f_a_prev
            if niter == 1 and step_size >= a_init: step_size *= nminus

            self.global_state['a_prev'] = step_size
            return step_size

        # on fail reduce beta scale value
        self.global_state['init_scale'] = self.global_state.get('init_scale', 1) * nminus**maxiter
        self.global_state['a_prev'] = init
        return 0

AdaptiveHeavyBall

Bases: torchzero.core.transform.TensorTransform

Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.

Suitable for quadratic objectives with known f* (loss at minimum).

note

The step size is determined by the algorithm, so learning rate modules shouldn't be used.

Parameters:

  • f_star (int, default: 0 ) –

    (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.

Source code in torchzero/modules/adaptive/adaptive_heavyball.py
class AdaptiveHeavyBall(TensorTransform):
    """Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.

    Suitable for quadratic objectives with known f* (loss at minimum).

    note:
        The step size is determined by the algorithm, so learning rate modules shouldn't be used.

    Args:
        f_star (int, optional):
            (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
    """
    def __init__(self, f_star: float = 0):
        defaults = dict(f_star=f_star)
        super().__init__(defaults, uses_loss=True)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        assert loss is not None
        tensors = TensorList(tensors)
        f_star = settings[0]['f_star']

        f_prev = self.global_state.get('f_prev', None)
        p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', init=[params,tensors], cls=TensorList)

        # -------------------------------- first step -------------------------------- #
        if f_prev is None:
            self.global_state['f_prev'] = loss
            h = 2*(loss - f_star) / tensors.dot(tensors)
            return h * tensors

        # ------------------------------- further steps ------------------------------ #
        update = adaptive_heavy_ball(
            f=loss, f_star=f_star, f_prev=f_prev, g=tensors, g_prev=g_prev, p=TensorList(params), p_prev=p_prev)

        # --------------------------- store previous values -------------------------- #
        self.global_state['f_prev'] = loss
        p_prev.copy_(params)
        g_prev.copy_(tensors)

        return update

Add

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Add other to tensors. other can be a number or a module.

If other is a module, this calculates tensors + other(tensors)

Source code in torchzero/modules/ops/binary.py
class Add(BinaryOperationBase):
    """Add ``other`` to tensors. ``other`` can be a number or a module.

    If ``other`` is a module, this calculates ``tensors + other(tensors)``
    """
    def __init__(self, other: Chainable | float, alpha: float = 1):
        defaults = dict(alpha=alpha)
        super().__init__(defaults, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.defaults['alpha'])
        else: torch._foreach_add_(update, other, alpha=self.defaults['alpha'])
        return update

Alternate

Bases: torchzero.core.module.Module

Alternates between stepping with :code:modules.

That is, first step is performed with 1st module, second step with second module, etc.

Parameters:

  • steps (int | Iterable[int], default: 1 ) –

    number of steps to perform with each module. Defaults to 1.

Examples:

Alternate between Adam, SignSGD and RMSprop

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Alternate(
        tz.m.Adam(),
        [tz.m.SignSGD(), tz.m.Mul(0.5)],
        tz.m.RMSprop(),
    ),
    tz.m.LR(1e-3),
)
Source code in torchzero/modules/misc/switch.py
class Alternate(Module):
    """Alternates between stepping with :code:`modules`.

    That is, first step is performed with 1st module, second step with second module, etc.

    Args:
        steps (int | Iterable[int], optional): number of steps to perform with each module. Defaults to 1.

    ### Examples:
    Alternate between Adam, SignSGD and RMSprop

    ```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Alternate(
            tz.m.Adam(),
            [tz.m.SignSGD(), tz.m.Mul(0.5)],
            tz.m.RMSprop(),
        ),
        tz.m.LR(1e-3),
    )
    ```
    """
    LOOP = True
    def __init__(self, *modules: Chainable, steps: int | Iterable[int] = 1):
        if isinstance(steps, Iterable):
            steps = list(steps)
            if len(steps) != len(modules):
                raise ValueError(f"steps must be the same length as modules, got {len(modules) = }, {len(steps) = }")

        defaults = dict(steps=steps)
        super().__init__(defaults)

        self.set_children_sequence(modules)
        self.global_state['current_module_idx'] = 0
        self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps

    def update(self, objective): raise RuntimeError
    def apply(self, objective): raise RuntimeError

    @torch.no_grad
    def step(self, objective):
        # get current module
        current_module_idx = self.global_state.setdefault('current_module_idx', 0)
        module = self.children[f'module_{current_module_idx}']

        # step
        objective = module.step(objective.clone(clone_updates=False))

        # number of steps until next module
        steps = self.defaults['steps']
        if isinstance(steps, int): steps = [steps]*len(self.children)

        if 'steps_to_next' not in self.global_state:
            self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps

        self.global_state['steps_to_next'] -= 1

        # switch to next module
        if self.global_state['steps_to_next'] == 0:
            self.global_state['current_module_idx'] += 1

            # loop to first module (or keep using last module on Switch)
            if self.global_state['current_module_idx'] > len(self.children) - 1:
                if self.LOOP: self.global_state['current_module_idx'] = 0
                else: self.global_state['current_module_idx'] = len(self.children) - 1

            self.global_state['steps_to_next'] = steps[self.global_state['current_module_idx']]

        return objective

LOOP class-attribute

LOOP = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

Averaging

Bases: torchzero.core.transform.TensorTransform

Average of past history_size updates.

Parameters:

  • history_size (int) –

    Number of past updates to average

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/momentum/averaging.py
class Averaging(TensorTransform):
    """Average of past ``history_size`` updates.

    Args:
        history_size (int): Number of past updates to average
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, history_size: int):
        defaults = dict(history_size=history_size)
        super().__init__(defaults=defaults)

        self.add_projected_keys("grad", "history", "average")

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        history_size = setting['history_size']
        if 'history' not in state:
            state['history'] = deque(maxlen=history_size)
            state['average'] = torch.zeros_like(tensor)

        history = state['history']; average = state['average']
        if len(history) == history_size: average -= history[0]
        history.append(tensor)
        average += tensor

        return average / len(history)

BBStab

Bases: torchzero.core.transform.TensorTransform

Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).

This clips the norm of the Barzilai-Borwein update by delta, where delta can be adaptive if c is specified.

Parameters:

  • c (float, default: 0.2 ) –

    adaptive delta parameter. If delta is set to None, first inf_iters updates are performed with non-stabilized Barzilai-Borwein step size. Then delta is set to norm of the update that had the smallest norm, and multiplied by c. Defaults to 0.2.

  • delta (float | None, default: None ) –

    Barzilai-Borwein update is clipped to this value. Set to None to use an adaptive choice. Defaults to None.

  • type (str, default: 'geom' ) –

    one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long. Defaults to "geom". Note that "long" corresponds to BB1stab and "short" to BB2stab, however I found that "geom" works really well.

  • inner (Chainable | None, default: None ) –

    step size will be applied to outputs of this module. Defaults to None.

Source code in torchzero/modules/step_size/adaptive.py
class BBStab(TensorTransform):
    """Stabilized Barzilai-Borwein method (https://arxiv.org/abs/1907.06409).

    This clips the norm of the Barzilai-Borwein update by ``delta``, where ``delta`` can be adaptive if ``c`` is specified.

    Args:
        c (float, optional):
            adaptive delta parameter. If ``delta`` is set to None, first ``inf_iters`` updates are performed
            with non-stabilized Barzilai-Borwein step size. Then delta is set to norm of
            the update that had the smallest norm, and multiplied by ``c``. Defaults to 0.2.
        delta (float | None, optional):
            Barzilai-Borwein update is clipped to this value. Set to ``None`` to use an adaptive choice. Defaults to None.
        type (str, optional):
            one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
            Defaults to "geom". Note that "long" corresponds to BB1stab and "short" to BB2stab,
            however I found that "geom" works really well.
        inner (Chainable | None, optional):
            step size will be applied to outputs of this module. Defaults to None.

    """
    def __init__(
        self,
        c=0.2,
        delta:float | None = None,
        type: Literal["long", "short", "geom", "geom-fallback"] = "geom",
        alpha_0: float = 1e-7,
        use_grad=True,
        inf_iters: int = 3,
        inner: Chainable | None = None,
    ):
        defaults = dict(type=type,alpha_0=alpha_0, c=c, delta=delta, inf_iters=inf_iters)
        super().__init__(defaults, uses_grad=use_grad, inner=inner)

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('prev_g')
        self.global_state['reset'] = True

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
        type = self.defaults['type']
        c = self.defaults['c']
        delta = self.defaults['delta']
        inf_iters = self.defaults['inf_iters']

        g = grads if self._uses_grad else tensors
        assert g is not None
        g = TensorList(g)

        reset = self.global_state.get('reset', False)
        self.global_state.pop('reset', None)

        if step != 0 and not reset:
            s = params-prev_p
            y = g-prev_g
            sy = s.dot(y)
            eps = torch.finfo(sy.dtype).tiny

            if type == 'short': alpha = _bb_short(s, y, sy, eps)
            elif type == 'long': alpha = _bb_long(s, y, sy, eps)
            elif type == 'geom': alpha = _bb_geom(s, y, sy, eps, fallback=False)
            elif type == 'geom-fallback': alpha = _bb_geom(s, y, sy, eps, fallback=True)
            else: raise ValueError(type)

            if alpha is not None:

                # adaptive delta
                if delta is None:
                    niters = self.global_state.get('niters', 0) # this accounts for skipped negative curvature steps
                    self.global_state['niters'] = niters + 1


                    if niters == 0: pass # 1st iteration is scaled GD step, shouldn't be used to find s_norm_min
                    elif niters <= inf_iters:
                        s_norm_min = self.global_state.get('s_norm_min', None)
                        if s_norm_min is None: s_norm_min = s.global_vector_norm()
                        else: s_norm_min = min(s_norm_min, s.global_vector_norm())
                        self.global_state['s_norm_min'] = s_norm_min
                        # first few steps use delta=inf, so delta remains None

                    else:
                        delta = c * self.global_state['s_norm_min']

                if delta is None: # delta is inf for first few steps
                    self.global_state['alpha'] = alpha

                # BBStab step size
                else:
                    a_stab = delta / g.global_vector_norm()
                    self.global_state['alpha'] = min(alpha, a_stab)

        prev_p.copy_(params)
        prev_g.copy_(g)

    def get_H(self, objective):
        return _get_scaled_identity_H(self, objective)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        alpha = self.global_state.get('alpha', None)

        if not _acceptable_alpha(alpha, tensors[0]):
            alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])

        torch._foreach_mul_(tensors, alpha)
        return tensors

BFGS

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.

Note

a line search or a trust region is recommended

Warning

this uses at least O(N^2) memory.

Parameters:

  • init_scale (float | Literal['auto'], default: 'auto' ) –

    initial hessian matrix is set to identity times this.

    "auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.

    Defaults to "auto".

  • tol (float, default: 1e-32 ) –

    tolerance on curvature condition. Defaults to 1e-32.

  • ptol (float | None, default: 1e-32 ) –

    skips update if maximum difference between current and previous gradients is less than this, to avoid instability. Defaults to 1e-32.

  • ptol_restart (bool, default: False ) –

    whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.

  • restart_interval (int | None | Literal['auto'], default: None ) –

    interval between resetting the hessian approximation.

    "auto" corresponds to number of decision variables + 1.

    None - no resets.

    Defaults to None.

  • beta (float | None, default: None ) –

    momentum on H or B. Defaults to None.

  • update_freq (int, default: 1 ) –

    frequency of updating H or B. Defaults to 1.

  • scale_first (bool, default: False ) –

    whether to downscale first step before hessian approximation becomes available. Defaults to True.

  • scale_second (bool) –

    whether to downscale second step. Defaults to False.

  • concat_params (bool, default: True ) –

    If true, all parameters are treated as a single vector. If False, the update rule is applied to each parameter separately. Defaults to True.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to the output of this module. Defaults to None.

Examples:

BFGS with backtracking line search:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.BFGS(),
    tz.m.Backtracking()
)

BFGS with trust region

opt = tz.Optimizer(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.BFGS(inverse=False)),
)

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class BFGS(_InverseHessianUpdateStrategyDefaults):
    """Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.

    Note:
        a line search or a trust region is recommended

    Warning:
        this uses at least O(N^2) memory.

    Args:
        init_scale (float | Literal["auto"], optional):
            initial hessian matrix is set to identity times this.

            "auto" corresponds to a heuristic from Nocedal. Stephen J. Wright. Numerical Optimization p.142-143.

            Defaults to "auto".
        tol (float, optional):
            tolerance on curvature condition. Defaults to 1e-32.
        ptol (float | None, optional):
            skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
            Defaults to 1e-32.
        ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
        restart_interval (int | None | Literal["auto"], optional):
            interval between resetting the hessian approximation.

            "auto" corresponds to number of decision variables + 1.

            None - no resets.

            Defaults to None.
        beta (float | None, optional): momentum on H or B. Defaults to None.
        update_freq (int, optional): frequency of updating H or B. Defaults to 1.
        scale_first (bool, optional):
            whether to downscale first step before hessian approximation becomes available. Defaults to True.
        scale_second (bool, optional): whether to downscale second step. Defaults to False.
        concat_params (bool, optional):
            If true, all parameters are treated as a single vector.
            If False, the update rule is applied to each parameter separately. Defaults to True.
        inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.

    ## Examples:

    BFGS with backtracking line search:

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.BFGS(),
        tz.m.Backtracking()
    )
    ```

    BFGS with trust region
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.LevenbergMarquardt(tz.m.BFGS(inverse=False)),
    )
    ```
    """

    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return bfgs_B_(B=B, s=s, y=y, tol=setting['tol'])

BacktrackOnSignChange

Bases: torchzero.core.transform.TensorTransform

Negates or undoes update for parameters where where gradient or update sign changes.

This is part of RProp update rule.

Parameters:

  • use_grad (bool, default: False ) –

    if True, tracks sign change of the gradient, otherwise track sign change of the update. Defaults to True.

  • backtrack (bool, default: True ) –

    if True, undoes the update when sign changes, otherwise negates it. Defaults to True.

Source code in torchzero/modules/adaptive/rprop.py
class BacktrackOnSignChange(TensorTransform):
    """Negates or undoes update for parameters where where gradient or update sign changes.

    This is part of RProp update rule.

    Args:
        use_grad (bool, optional):
            if True, tracks sign change of the gradient,
            otherwise track sign change of the update. Defaults to True.
        backtrack (bool, optional):
            if True, undoes the update when sign changes, otherwise negates it.
            Defaults to True.

    """
    def __init__(self, use_grad = False, backtrack = True):
        defaults = dict(use_grad=use_grad, backtrack=backtrack)
        super().__init__(defaults, uses_grad=use_grad)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        tensors = TensorList(tensors)
        backtrack = settings[0]['backtrack']

        if self._uses_grad:
            assert grads is not None
            cur = TensorList(grads)
        else: cur = tensors

        tensors = backtrack_on_sign_change_(
            tensors_ = tensors,
            cur = cur,
            prev_ = unpack_states(states, tensors, 'prev', cls=TensorList),
            backtrack = backtrack,
            step = step,
        )

        return tensors

Backtracking

Bases: torchzero.modules.line_search.line_search.LineSearchBase

Backtracking line search.

Parameters:

  • init (float, default: 1.0 ) –

    initial step size. Defaults to 1.0.

  • beta (float, default: 0.5 ) –

    multiplies each consecutive step size by this value. Defaults to 0.5.

  • c (float, default: 0.0001 ) –

    sufficient decrease condition. Defaults to 1e-4.

  • condition (Literal, default: 'armijo' ) –

    termination condition, only ones that do not use gradient at f(x+a*d) can be specified. - "armijo" - sufficient decrease condition. - "decrease" - any decrease in objective function value satisfies the condition.

    "goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage. Defaults to 'armijo'.

  • maxiter (int, default: 10 ) –

    maximum number of function evaluations per step. Defaults to 10.

  • adaptive (bool, default: True ) –

    when enabled, if line search failed, step size will continue decreasing on the next step. Otherwise it will restart the line search from init step size. Defaults to True.

Examples: Gradient descent with backtracking line search:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Backtracking()
)

L-BFGS with backtracking line search:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.LBFGS(),
    tz.m.Backtracking()
)

Source code in torchzero/modules/line_search/backtracking.py
class Backtracking(LineSearchBase):
    """Backtracking line search.

    Args:
        init (float, optional): initial step size. Defaults to 1.0.
        beta (float, optional): multiplies each consecutive step size by this value. Defaults to 0.5.
        c (float, optional): sufficient decrease condition. Defaults to 1e-4.
        condition (TerminationCondition, optional):
            termination condition, only ones that do not use gradient at f(x+a*d) can be specified.
            - "armijo" - sufficient decrease condition.
            - "decrease" - any decrease in objective function value satisfies the condition.

            "goldstein" can techincally be specified but it doesn't make sense because there is not zoom stage.
            Defaults to 'armijo'.
        maxiter (int, optional): maximum number of function evaluations per step. Defaults to 10.
        adaptive (bool, optional):
            when enabled, if line search failed, step size will continue decreasing on the next step.
            Otherwise it will restart the line search from ``init`` step size. Defaults to True.

    Examples:
    Gradient descent with backtracking line search:

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Backtracking()
    )
    ```

    L-BFGS with backtracking line search:
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.LBFGS(),
        tz.m.Backtracking()
    )
    ```

    """
    def __init__(
        self,
        init: float = 1.0,
        beta: float = 0.5,
        c: float = 1e-4,
        condition: TerminationCondition = 'armijo',
        maxiter: int = 10,
        adaptive=True,
    ):
        defaults=dict(init=init,beta=beta,c=c,condition=condition,maxiter=maxiter,adaptive=adaptive)
        super().__init__(defaults=defaults)

    def reset(self):
        super().reset()

    @torch.no_grad
    def search(self, update, var):
        init, beta, c, condition, maxiter, adaptive = itemgetter(
            'init', 'beta', 'c', 'condition', 'maxiter', 'adaptive')(self.defaults)

        objective = self.make_objective(var=var)

        # # directional derivative
        if c == 0: d = 0
        else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grads(), var.get_updates()))

        # scale init
        init_scale = self.global_state.get('init_scale', 1)
        if adaptive: init = init * init_scale

        step_size = backtracking_line_search(objective, d, init=init, beta=beta,c=c, condition=condition, maxiter=maxiter)

        # found an alpha that reduces loss
        if step_size is not None:
            #self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
            self.global_state['init_scale'] = 1
            return step_size

        # on fail set init_scale to continue decreasing the step size
        # or set to large step size when it becomes too small
        if adaptive:
            finfo = torch.finfo(var.params[0].dtype)
            if init_scale <= finfo.tiny * 2:
                self.global_state["init_scale"] = init * 2
            else:
                self.global_state['init_scale'] = init_scale * beta**maxiter
        return 0

BarzilaiBorwein

Bases: torchzero.core.transform.TensorTransform

Barzilai-Borwein step size method.

Parameters:

  • type (str, default: 'geom' ) –

    one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long. Defaults to "geom".

  • fallback (float) –

    step size when denominator is less than 0 (will happen on negative curvature). Defaults to 1e-3.

  • inner (Chainable | None, default: None ) –

    step size will be applied to outputs of this module. Defaults to None.

Source code in torchzero/modules/step_size/adaptive.py
class BarzilaiBorwein(TensorTransform):
    """Barzilai-Borwein step size method.

    Args:
        type (str, optional):
            one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
            Defaults to "geom".
        fallback (float, optional): step size when denominator is less than 0 (will happen on negative curvature). Defaults to 1e-3.
        inner (Chainable | None, optional):
            step size will be applied to outputs of this module. Defaults to None.
    """

    def __init__(
        self,
        type: Literal["long", "short", "geom", "geom-fallback"] = "geom",
        alpha_0: float = 1e-7,
        use_grad=True,
        inner: Chainable | None = None,
    ):
        defaults = dict(type=type, alpha_0=alpha_0)
        super().__init__(defaults, uses_grad=use_grad, inner=inner)

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('prev_g')
        self.global_state['reset'] = True

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
        type = self.defaults['type']

        g = grads if self._uses_grad else tensors
        assert g is not None

        reset = self.global_state.get('reset', False)
        self.global_state.pop('reset', None)

        if step != 0 and not reset:
            s = params-prev_p
            y = g-prev_g
            sy = s.dot(y)
            eps = torch.finfo(sy.dtype).tiny * 2

            if type == 'short': alpha = _bb_short(s, y, sy, eps)
            elif type == 'long': alpha = _bb_long(s, y, sy, eps)
            elif type == 'geom': alpha = _bb_geom(s, y, sy, eps, fallback=False)
            elif type == 'geom-fallback': alpha = _bb_geom(s, y, sy, eps, fallback=True)
            else: raise ValueError(type)

            # if alpha is not None:
            self.global_state['alpha'] = alpha

        prev_p.copy_(params)
        prev_g.copy_(g)

    def get_H(self, objective):
        return _get_scaled_identity_H(self, objective)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        alpha = self.global_state.get('alpha', None)

        if not _acceptable_alpha(alpha, tensors[0]):
            alpha = epsilon_step_size(TensorList(tensors), settings[0]['alpha_0'])

        torch._foreach_mul_(tensors, alpha)
        return tensors

BinaryOperationBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for operations that use update as the first operand. This is an abstract class, subclass it and override transform method to use it.

Methods:

  • transform

    applies the operation to operands

Source code in torchzero/modules/ops/binary.py
class BinaryOperationBase(Module, ABC):
    """Base class for operations that use update as the first operand. This is an abstract class, subclass it and override `transform` method to use it."""
    def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
        super().__init__(defaults=defaults)

        self.operands = {}
        for k,v in operands.items():

            if isinstance(v, (Module, Sequence)):
                self.set_child(k, v)
                self.operands[k] = self.children[k]
            else:
                self.operands[k] = v

    @abstractmethod
    def transform(self, objective: Objective, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
        """applies the operation to operands"""
        raise NotImplementedError

    def update(self, objective): raise RuntimeError
    def apply(self, objective): raise RuntimeError

    @torch.no_grad
    def step(self, objective: Objective) -> Objective:
        # pass cloned update to all module operands
        processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()

        for k,v in self.operands.items():
            if k in self.children:
                v: Module
                updated_obj = v.step(objective.clone(clone_updates=True))
                processed_operands[k] = updated_obj.get_updates()
                objective.update_attrs_from_clone_(updated_obj) # update loss, grad, etc if this module calculated them

        transformed = self.transform(objective, update=objective.get_updates(), **processed_operands)
        objective.updates = list(transformed)
        return objective

transform

transform(objective: Objective, update: list[Tensor], **operands: Any | list[Tensor]) -> Iterable[Tensor]

applies the operation to operands

Source code in torchzero/modules/ops/binary.py
@abstractmethod
def transform(self, objective: Objective, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
    """applies the operation to operands"""
    raise NotImplementedError

BirginMartinezRestart

Bases: torchzero.core.module.Module

the restart criterion for conjugate gradient methods designed by Birgin and Martinez.

This criterion restarts when when the angle between dk+1 and −gk+1 is not acute enough.

The restart clears all states of module.

Parameters:

  • module (Module) –

    module to restart, should be a conjugate gradient or possibly a quasi-newton method.

  • cond (float, default: 0.001 ) –

    Restart is performed whenevr d^Tg > -cond||d||||g||. The default condition value of 1e-3 is suggested by Birgin and Martinez.

Reference

Birgin, Ernesto G., and José Mario Martínez. "A spectral conjugate gradient method for unconstrained optimization." Applied Mathematics & Optimization 43.2 (2001): 117-128.

Source code in torchzero/modules/restarts/restars.py
class BirginMartinezRestart(Module):
    """the restart criterion for conjugate gradient methods designed by Birgin and Martinez.

    This criterion restarts when when the angle between dk+1 and −gk+1 is not acute enough.

    The restart clears all states of ``module``.

    Args:
        module (Module):
            module to restart, should be a conjugate gradient or possibly a quasi-newton method.
        cond (float, optional):
            Restart is performed whenevr d^Tg > -cond*||d||*||g||.
            The default condition value of 1e-3 is suggested by Birgin and Martinez.

    Reference:
        Birgin, Ernesto G., and José Mario Martínez. "A spectral conjugate gradient method for unconstrained optimization." Applied Mathematics & Optimization 43.2 (2001): 117-128.
    """
    def __init__(self, module: Module, cond:float = 1e-3):
        defaults=dict(cond=cond)
        super().__init__(defaults)

        self.set_child("module", module)

    def update(self, objective):
        module = self.children['module']
        module.update(objective)

    def apply(self, objective):
        module = self.children['module']
        objective = module.apply(objective.clone(clone_updates=False))

        cond = self.defaults['cond']
        g = TensorList(objective.get_grads())
        d = TensorList(objective.get_updates())
        d_g = d.dot(g)
        d_norm = d.global_vector_norm()
        g_norm = g.global_vector_norm()

        # d in our case is same direction as g so it has a minus sign
        if -d_g > -cond * d_norm * g_norm:
            module.reset()
            objective.updates = g.clone()
            return objective

        return objective

BoldDriver

Bases: torchzero.core.transform.TensorTransform

Multiplies step size by nplus if loss decreased compared to last iteration, otherwise multiplies by nminus.

Source code in torchzero/modules/step_size/adaptive.py
class BoldDriver(TensorTransform):
    """Multiplies step size by ``nplus`` if loss decreased compared to last iteration, otherwise multiplies by ``nminus``."""
    def __init__(self, a_init=1e-3, nplus=1.1, nminus=0.1, inner: Chainable | None = None):
        defaults = dict(a_init=a_init, nplus=nplus, nminus=nminus)
        super().__init__(defaults, uses_loss=True, inner=inner)
        self.global_state["alpha"] = a_init

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('f_prev')

    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        fs = settings[0]
        if "f_prev" not in self.global_state:
            self.global_state["f_prev"] = tofloat(loss)
            return

        if self.global_state["f_prev"] <= loss:
            self.global_state["alpha"] *= fs["nminus"]

        else:
            self.global_state["alpha"] *= fs["nplus"]

        self.global_state["f_prev"] = tofloat(loss)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        alpha = self.global_state.get('alpha', None)

        if not _acceptable_alpha(alpha, tensors[0]):
            self.state.clear()
            self.global_state.clear()
            self.global_state["alpha"] = settings[0]["a_init"]
            alpha = epsilon_step_size(TensorList(tensors), 1e-7)

        torch._foreach_mul_(tensors, alpha)
        return tensors

    def get_H(self, objective):
        return _get_scaled_identity_H(self, objective)

BroydenBad

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Broyden's "bad" Quasi-Newton method.

Note

a trust region or an accurate line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class BroydenBad(_InverseHessianUpdateStrategyDefaults):
    """Broyden's "bad" Quasi-Newton method.

    Note:
        a trust region or an accurate line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return broyden_bad_H_(H=H, s=s, y=y)
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return broyden_bad_B_(B=B, s=s, y=y)

BroydenGood

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Broyden's "good" Quasi-Newton method.

Note

a trust region or an accurate line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class BroydenGood(_InverseHessianUpdateStrategyDefaults):
    """Broyden's "good" Quasi-Newton method.

    Note:
        a trust region or an accurate line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return broyden_good_H_(H=H, s=s, y=y)
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return broyden_good_B_(B=B, s=s, y=y)

CD

Bases: torchzero.core.module.Module

Coordinate descent. Proposes a descent direction along a single coordinate. A line search such as tz.m.ScipyMinimizeScalar(maxiter=8) or a fixed step size can be used after this.

Parameters:

  • h (float, default: 0.001 ) –

    finite difference step size. Defaults to 1e-3.

  • grad (bool, default: False ) –

    if True, scales direction by gradient estimate. If False, the scale is fixed to 1. Defaults to True.

  • adaptive (bool, default: True ) –

    whether to adapt finite difference step size, this requires an additional buffer. Defaults to True.

  • index (str, default: 'cyclic2' ) –

    index selection strategy. - "cyclic" - repeatedly cycles through each coordinate, e.g. 1,2,3,1,2,3,.... - "cyclic2" - cycles forward and then backward, e.g 1,2,3,3,2,1,1,2,3,... (default). - "random" - picks coordinate randomly.

  • threepoint (bool, default: True ) –

    whether to use three points (three function evaluatins) to determine descent direction. if False, uses two points, but then adaptive can't be used. Defaults to True.

Source code in torchzero/modules/zeroth_order/cd.py
class CD(Module):
    """Coordinate descent. Proposes a descent direction along a single coordinate.
    A line search such as ``tz.m.ScipyMinimizeScalar(maxiter=8)`` or a fixed step size can be used after this.

    Args:
        h (float, optional): finite difference step size. Defaults to 1e-3.
        grad (bool, optional):
            if True, scales direction by gradient estimate. If False, the scale is fixed to 1. Defaults to True.
        adaptive (bool, optional):
            whether to adapt finite difference step size, this requires an additional buffer. Defaults to True.
        index (str, optional):
            index selection strategy.
            - "cyclic" - repeatedly cycles through each coordinate, e.g. ``1,2,3,1,2,3,...``.
            - "cyclic2" - cycles forward and then backward, e.g ``1,2,3,3,2,1,1,2,3,...`` (default).
            - "random" - picks coordinate randomly.
        threepoint (bool, optional):
            whether to use three points (three function evaluatins) to determine descent direction.
            if False, uses two points, but then ``adaptive`` can't be used. Defaults to True.
    """
    def __init__(self, h:float=1e-3, grad:bool=False, adaptive:bool=True, index:Literal['cyclic', 'cyclic2', 'random']="cyclic2", threepoint:bool=True,):
        defaults = dict(h=h, grad=grad, adaptive=adaptive, index=index, threepoint=threepoint)
        super().__init__(defaults)

    def update(self, objective): raise RuntimeError
    def apply(self, objective): raise RuntimeError

    @torch.no_grad
    def step(self, objective):
        closure = objective.closure
        if closure is None:
            raise RuntimeError("CD requires closure")

        params = TensorList(objective.params)
        ndim = params.global_numel()

        grad_step_size = self.defaults['grad']
        adaptive = self.defaults['adaptive']
        index_strategy = self.defaults['index']
        h = self.defaults['h']
        threepoint = self.defaults['threepoint']

        # ------------------------------ determine index ----------------------------- #
        if index_strategy == 'cyclic':
            idx = self.global_state.get('idx', 0) % ndim
            self.global_state['idx'] = idx + 1

        elif index_strategy == 'cyclic2':
            idx = self.global_state.get('idx', 0)
            self.global_state['idx'] = idx + 1
            if idx >= ndim * 2:
                idx = self.global_state['idx'] = 0
            if idx >= ndim:
                idx  = (2*ndim - idx) - 1

        elif index_strategy == 'random':
            if 'generator' not in self.global_state:
                self.global_state['generator'] = random.Random(0)
            generator = self.global_state['generator']
            idx = generator.randrange(0, ndim)

        else:
            raise ValueError(index_strategy)

        # -------------------------- find descent direction -------------------------- #
        h_vec = None
        if adaptive:
            if threepoint:
                h_vec = self.get_state(params, 'h_vec', init=lambda x: torch.full_like(x, h), cls=TensorList)
                h = float(h_vec.flat_get(idx))
            else:
                warnings.warn("CD adaptive=True only works with threepoint=True")

        f_0 = objective.get_loss(False)
        params.flat_set_lambda_(idx, lambda x: x + h)
        f_p = closure(False)

        # -------------------------------- threepoint -------------------------------- #
        if threepoint:
            params.flat_set_lambda_(idx, lambda x: x - 2*h)
            f_n = closure(False)
            params.flat_set_lambda_(idx, lambda x: x + h)

            if adaptive:
                assert h_vec is not None
                if f_0 <= f_p and f_0 <= f_n:
                    h_vec.flat_set_lambda_(idx, lambda x: max(x/2, 1e-10))
                else:
                    if abs(f_0 - f_n) < 1e-12 or abs((f_p - f_0) / (f_0 - f_n) - 1) < 1e-2:
                        h_vec.flat_set_lambda_(idx, lambda x: min(x*2, 1e10))

            if grad_step_size:
                alpha = (f_p - f_n) / (2*h)

            else:
                if f_0 < f_p and f_0 < f_n: alpha = 0
                elif f_p < f_n: alpha = -1
                else: alpha = 1

        # --------------------------------- twopoint --------------------------------- #
        else:
            params.flat_set_lambda_(idx, lambda x: x - h)
            if grad_step_size:
                alpha = (f_p - f_0) / h
            else:
                if f_p < f_0: alpha = -1
                else: alpha = 1

        # ----------------------------- create the update ---------------------------- #
        update = params.zeros_like()
        update.flat_set_(idx, alpha)
        objective.updates = update
        return objective

Cautious

Bases: torchzero.core.transform.TensorTransform

Negates update for parameters where update and gradient sign is inconsistent. Optionally normalizes the update by the number of parameters that are not masked. This is meant to be used after any momentum-based modules.

Parameters:

  • normalize (bool, default: False ) –

    renormalize update after masking. only has effect when mode is 'zero'. Defaults to False.

  • eps (float, default: 1e-06 ) –

    epsilon for normalization. Defaults to 1e-6.

  • mode (str, default: 'zero' ) –

    what to do with updates with inconsistent signs. - "zero" - set them to zero (as in paper) - "grad" - set them to the gradient (same as using update magnitude and gradient sign) - "backtrack" - negate them

Examples:

Cautious Adam

opt = tz.Optimizer(
    bench.parameters(),
    tz.m.Adam(),
    tz.m.Cautious(),
    tz.m.LR(1e-2)
)
References

Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu

Source code in torchzero/modules/momentum/cautious.py
class Cautious(TensorTransform):
    """Negates update for parameters where update and gradient sign is inconsistent.
    Optionally normalizes the update by the number of parameters that are not masked.
    This is meant to be used after any momentum-based modules.

    Args:
        normalize (bool, optional):
            renormalize update after masking.
            only has effect when mode is 'zero'. Defaults to False.
        eps (float, optional): epsilon for normalization. Defaults to 1e-6.
        mode (str, optional):
            what to do with updates with inconsistent signs.
            - "zero" - set them to zero (as in paper)
            - "grad" - set them to the gradient (same as using update magnitude and gradient sign)
            - "backtrack" - negate them

    ## Examples:

    Cautious Adam

    ```python
    opt = tz.Optimizer(
        bench.parameters(),
        tz.m.Adam(),
        tz.m.Cautious(),
        tz.m.LR(1e-2)
    )
    ```

    References:
        Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu
    """

    def __init__(
        self,
        normalize=False,
        eps=1e-6,
        mode: Literal["zero", "grad", "backtrack"] = "zero",
    ):
        defaults = dict(normalize=normalize, eps=eps, mode=mode)
        super().__init__(defaults, uses_grad=True)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(settings[0])
        return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)

CautiousWeightDecay

Bases: torchzero.core.transform.TensorTransform

Cautious weight decay (https://arxiv.org/pdf/2510.12402).

Weight decay but only applied to updates where update sign matches weight decay sign.

Parameters:

  • weight_decay (float) –

    weight decay scale.

  • ord (int, default: 2 ) –

    order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.

  • target (Target) –

    what to set on var. Defaults to 'update'.

Examples:

Adam with non-decoupled cautious weight decay

opt = tz.Optimizer(
    model.parameters(),
    tz.m.CautiousWeightDecay(1e-3),
    tz.m.Adam(),
    tz.m.LR(1e-3)
)

Adam with decoupled cautious weight decay that still scales with learning rate

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adam(),
    tz.m.CautiousWeightDecay(1e-3),
    tz.m.LR(1e-3)
)

Adam with fully decoupled cautious weight decay that doesn't scale with learning rate

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adam(),
    tz.m.LR(1e-3),
    tz.m.CautiousWeightDecay(1e-6)
)

Source code in torchzero/modules/weight_decay/weight_decay.py
class CautiousWeightDecay(TensorTransform):
    """Cautious weight decay (https://arxiv.org/pdf/2510.12402).

    Weight decay but only applied to updates where update sign matches weight decay sign.

    Args:
        weight_decay (float): weight decay scale.
        ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
        target (Target, optional): what to set on var. Defaults to 'update'.

    ### Examples:

    Adam with non-decoupled cautious weight decay
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.CautiousWeightDecay(1e-3),
        tz.m.Adam(),
        tz.m.LR(1e-3)
    )
    ```

    Adam with decoupled cautious weight decay that still scales with learning rate
    ```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Adam(),
        tz.m.CautiousWeightDecay(1e-3),
        tz.m.LR(1e-3)
    )
    ```

    Adam with fully decoupled cautious weight decay that doesn't scale with learning rate
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Adam(),
        tz.m.LR(1e-3),
        tz.m.CautiousWeightDecay(1e-6)
    )
    ```

    """
    def __init__(self, weight_decay: float, ord: int = 2):

        defaults = dict(weight_decay=weight_decay, ord=ord)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        weight_decay = NumberList(s['weight_decay'] for s in settings)
        ord = settings[0]['ord']

        return cautious_weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)

CenteredEMASquared

Bases: torchzero.core.transform.TensorTransform

Maintains a centered exponential moving average of squared updates. This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.

Parameters:

  • beta (float, default: 0.99 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Source code in torchzero/modules/ops/higher_level.py
class CenteredEMASquared(TensorTransform):
    """Maintains a centered exponential moving average of squared updates. This also maintains an additional
    exponential moving average of un-squared updates, square of which is subtracted from the EMA.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2):
        defaults = dict(beta=beta, amsgrad=amsgrad, pow=pow)
        super().__init__(defaults, uses_grad=False)
        self.add_projected_keys("grad", "exp_avg")
        self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        amsgrad, pow = itemgetter('amsgrad', 'pow')(settings[0])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return centered_ema_sq_(
            TensorList(tensors),
            exp_avg_=exp_avg,
            exp_avg_sq_=exp_avg_sq,
            beta=beta,
            max_exp_avg_sq_=max_exp_avg_sq,
            pow=pow,
        ).clone()

CenteredSqrtEMASquared

Bases: torchzero.core.transform.TensorTransform

Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root. This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.

Parameters:

  • beta (float, default: 0.99 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • debiased (bool, default: False ) –

    whether to multiply the output by a debiasing term from the Adam method. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Source code in torchzero/modules/ops/higher_level.py
class CenteredSqrtEMASquared(TensorTransform):
    """Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.
    This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2):
        defaults = dict(beta=beta, amsgrad=amsgrad, debiased=debiased, pow=pow)
        super().__init__(defaults, uses_grad=False)
        self.add_projected_keys("grad", "exp_avg")
        self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return sqrt_centered_ema_sq_(
            TensorList(tensors),
            exp_avg_=exp_avg,
            exp_avg_sq_=exp_avg_sq,
            beta=beta,
            debiased=debiased,
            step=step,
            max_exp_avg_sq_=max_exp_avg_sq,
            pow=pow,
        )

Centralize

Bases: torchzero.core.transform.TensorTransform

Centralizes the update.

Parameters:

  • dim (int | Sequence[int] | str | None, default: None ) –

    calculates norm along those dimensions. If list/tuple, tensors are centralized along all dimensios in dim that they have. Can be set to "global" to centralize by global mean of all gradients concatenated to a vector. Defaults to None.

  • inverse_dims (bool, default: False ) –

    if True, the dims argument is inverted, and all other dimensions are centralized.

  • min_size (int, default: 2 ) –

    minimal size of a dimension to normalize along it. Defaults to 1.

Examples:

Standard gradient centralization:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Centralize(dim=0),
    tz.m.LR(1e-2),
)

References: - Yong, H., Huang, J., Hua, X., & Zhang, L. (2020). Gradient centralization: A new optimization technique for deep neural networks. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing. https://arxiv.org/abs/2004.01461

Source code in torchzero/modules/clipping/clipping.py
class Centralize(TensorTransform):
    """Centralizes the update.

    Args:
        dim (int | Sequence[int] | str | None, optional):
            calculates norm along those dimensions.
            If list/tuple, tensors are centralized along all dimensios in `dim` that they have.
            Can be set to "global" to centralize by global mean of all gradients concatenated to a vector.
            Defaults to None.
        inverse_dims (bool, optional):
            if True, the `dims` argument is inverted, and all other dimensions are centralized.
        min_size (int, optional):
            minimal size of a dimension to normalize along it. Defaults to 1.

    Examples:

    Standard gradient centralization:
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Centralize(dim=0),
        tz.m.LR(1e-2),
    )
    ```

    References:
    - Yong, H., Huang, J., Hua, X., & Zhang, L. (2020). Gradient centralization: A new optimization technique for deep neural networks. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing. https://arxiv.org/abs/2004.01461
    """
    def __init__(
        self,
        dim: int | Sequence[int] | Literal["global"] | None = None,
        inverse_dims: bool = False,
        min_size: int = 2,
    ):
        defaults = dict(dim=dim,min_size=min_size,inverse_dims=inverse_dims)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(settings[0])

        _centralize_(tensors_ = TensorList(tensors), dim=dim, inverse_dims=inverse_dims, min_size=min_size)

        return tensors

Clip

Bases: torchzero.modules.ops.binary.BinaryOperationBase

clip tensors to be in (min, max) range. min and `max: can be None, numbers or modules.

If min and max are modules, this calculates tensors.clip(min(tensors), max(tensors)).

Source code in torchzero/modules/ops/binary.py
class Clip(BinaryOperationBase):
    """clip tensors to be in  ``(min, max)`` range. ``min`` and ``max`: can be None, numbers or modules.

    If ``min`` and ``max``  are modules, this calculates ``tensors.clip(min(tensors), max(tensors))``.
    """
    def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
        super().__init__({}, min=min, max=max)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
        return TensorList(update).clamp_(min=min,  max=max)

ClipModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates input(tensors).clip(min, max). min and max can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class ClipModules(MultiOperationBase):
    """Calculates ``input(tensors).clip(min, max)``. ``min`` and ``max`` can be numbers or modules."""
    def __init__(self, input: Chainable, min: float | Chainable | None = None, max: float | Chainable | None = None):
        defaults = {}
        super().__init__(defaults, input=input, min=min, max=max)

    @torch.no_grad
    def transform(self, objective: Objective, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
        return TensorList(input).clamp_(min=min, max=max)

ClipNorm

Bases: torchzero.core.transform.TensorTransform

Clips update norm to be no larger than value.

Parameters:

  • max_norm (float) –

    value to clip norm to.

  • ord (float, default: 2 ) –

    norm order. Defaults to 2.

  • dim (int | Sequence[int] | str | None, default: None ) –

    calculates norm along those dimensions. If list/tuple, tensors are normalized along all dimensios in dim that they have. Can be set to "global" to normalize by global norm of all gradients concatenated to a vector. Defaults to None.

  • inverse_dims (bool, default: False ) –

    if True, the dims argument is inverted, and all other dimensions are normalized.

  • min_size (int, default: 1 ) –

    minimal numer of elements in a parameter or slice to clip norm. Defaults to 1.

  • target (str) –

    what this affects.

Examples:

Gradient norm clipping:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.ClipNorm(1),
    tz.m.Adam(),
    tz.m.LR(1e-2),
)

Update norm clipping:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adam(),
    tz.m.ClipNorm(1),
    tz.m.LR(1e-2),
)

Source code in torchzero/modules/clipping/clipping.py
class ClipNorm(TensorTransform):
    """Clips update norm to be no larger than ``value``.

    Args:
        max_norm (float): value to clip norm to.
        ord (float, optional): norm order. Defaults to 2.
        dim (int | Sequence[int] | str | None, optional):
            calculates norm along those dimensions.
            If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
            Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
            Defaults to None.
        inverse_dims (bool, optional):
            if True, the `dims` argument is inverted, and all other dimensions are normalized.
        min_size (int, optional):
            minimal numer of elements in a parameter or slice to clip norm. Defaults to 1.
        target (str, optional):
            what this affects.

    Examples:

    Gradient norm clipping:
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.ClipNorm(1),
        tz.m.Adam(),
        tz.m.LR(1e-2),
    )
    ```

    Update norm clipping:
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Adam(),
        tz.m.ClipNorm(1),
        tz.m.LR(1e-2),
    )
    ```
    """
    def __init__(
        self,
        max_norm: float,
        ord: Metrics = 2,
        dim: int | Sequence[int] | Literal["global"] | None = None,
        inverse_dims: bool = False,
        min_size: int = 1,
    ):
        defaults = dict(max_norm=max_norm,ord=ord,dim=dim,min_size=min_size,inverse_dims=inverse_dims)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        max_norm = NumberList(s['max_norm'] for s in settings)
        ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
        _clip_norm_(
            tensors_ = TensorList(tensors),
            min = 0,
            max = max_norm,
            norm_value = None,
            ord = ord,
            dim = dim,
            inverse_dims=inverse_dims,
            min_size = min_size,
        )
        return tensors

ClipNormByEMA

Bases: torchzero.core.transform.TensorTransform

Clips norm to be no larger than the norm of an exponential moving average of past updates.

Parameters:

  • beta (float, default: 0.99 ) –

    beta for the exponential moving average. Defaults to 0.99.

  • ord (float, default: 2 ) –

    order of the norm. Defaults to 2.

  • eps (float) –

    epsilon for division. Defaults to 1e-6.

  • tensorwise (bool, default: True ) –

    if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.

  • max_ema_growth (float | None, default: 1.5 ) –

    if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.

  • ema_init (str) –

    How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.

Source code in torchzero/modules/clipping/ema_clipping.py
class ClipNormByEMA(TensorTransform):
    """Clips norm to be no larger than the norm of an exponential moving average of past updates.

    Args:
        beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
        ord (float, optional): order of the norm. Defaults to 2.
        eps (float, optional): epsilon for division. Defaults to 1e-6.
        tensorwise (bool, optional):
            if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
        max_ema_growth (float | None, optional):
            if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
        ema_init (str, optional):
            How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
    """
    NORMALIZE = False
    def __init__(
        self,
        beta=0.99,
        ord: Metrics = 2,
        tensorwise:bool=True,
        max_ema_growth: float | None = 1.5,
        init: float = 0.0,
        min_norm: float = 1e-6,

        inner: Chainable | None = None,
    ):
        defaults = dict(beta=beta, ord=ord, tensorwise=tensorwise, init=init, min_norm=min_norm, max_ema_growth=max_ema_growth)
        super().__init__(defaults, inner=inner)
        self.add_projected_keys("grad", "exp_avg")

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        eps = torch.finfo(tensors[0].dtype).tiny * 2
        ord, tensorwise, init, max_ema_growth = itemgetter('ord', 'tensorwise', 'init', 'max_ema_growth')(settings[0])

        beta, min_norm = unpack_dicts(settings, 'beta', 'min_norm', cls=NumberList)

        exp_avg = unpack_states(states, tensors, 'exp_avg', init = lambda x: torch.full_like(x, init), cls=TensorList)

        exp_avg.lerp_(tensors, 1-beta)

        # ----------------------------- tensorwise update ---------------------------- #
        if tensorwise:
            tensors_norm = tensors.norm(ord)
            ema_norm = exp_avg.metric(ord)

            # clip ema norm growth
            if max_ema_growth is not None:
                prev_ema_norm = unpack_states(states, tensors, 'prev_ema_norm', init=ema_norm, cls=TensorList)
                allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=min_norm)

                ema_denom = (ema_norm / allowed_norm).clip(min=1)
                exp_avg.div_(ema_denom)
                ema_norm.div_(ema_denom)

                prev_ema_norm.set_(ema_norm)


        # ------------------------------- global update ------------------------------ #
        else:
            tensors_norm = tensors.global_metric(ord)
            ema_norm = exp_avg.global_metric(ord)

            # clip ema norm growth
            if max_ema_growth is not None:
                prev_ema_norm = self.global_state.setdefault('prev_ema_norm', ema_norm)
                allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=min_norm[0])

                if ema_norm > allowed_norm:
                    exp_avg.div_(ema_norm / allowed_norm)
                    ema_norm = allowed_norm

                prev_ema_norm.set_(ema_norm)


        # ------------------- compute denominator to clip/normalize ------------------ #
        denom = tensors_norm / ema_norm.clip(min=eps)
        if self.NORMALIZE: denom.clip_(min=eps)
        else: denom.clip_(min=1)
        self.global_state['denom'] = denom

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        denom = self.global_state.pop('denom')
        torch._foreach_div_(tensors, denom)
        return tensors

NORMALIZE class-attribute

NORMALIZE = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

ClipNormGrowth

Bases: torchzero.core.transform.TensorTransform

Clips update norm growth.

Parameters:

  • add (float | None, default: None ) –

    additive clipping, next update norm is at most previous norm + add. Defaults to None.

  • mul (float | None, default: 1.5 ) –

    multiplicative clipping, next update norm is at most previous norm * mul. Defaults to 1.5.

  • min_value (float | None, default: 0.0001 ) –

    minimum value for multiplicative clipping to prevent collapse to 0. Next norm is at most :code:max(prev_norm, min_value) * mul. Defaults to 1e-4.

  • max_decay (float | None, default: 2 ) –

    bounds the tracked multiplicative clipping decay to prevent collapse to 0. Next norm is at most :code:max(previous norm * mul, max_decay). Defaults to 2.

  • ord (float, default: 2 ) –

    norm order. Defaults to 2.

  • tensorwise (bool, default: True ) –

    if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.

  • target (Target) –

    what to set on var. Defaults to "update".

Source code in torchzero/modules/clipping/growth_clipping.py
class ClipNormGrowth(TensorTransform):
    """Clips update norm growth.

    Args:
        add (float | None, optional): additive clipping, next update norm is at most `previous norm + add`. Defaults to None.
        mul (float | None, optional):
            multiplicative clipping, next update norm is at most `previous norm * mul`. Defaults to 1.5.
        min_value (float | None, optional):
            minimum value for multiplicative clipping to prevent collapse to 0.
            Next norm is at most :code:`max(prev_norm, min_value) * mul`. Defaults to 1e-4.
        max_decay (float | None, optional):
            bounds the tracked multiplicative clipping decay to prevent collapse to 0.
            Next norm is at most :code:`max(previous norm * mul, max_decay)`.
            Defaults to 2.
        ord (float, optional): norm order. Defaults to 2.
        tensorwise (bool, optional):
            if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
        target (Target, optional): what to set on var. Defaults to "update".
    """
    def __init__(
        self,
        add: float | None = None,
        mul: float | None = 1.5,
        min_value: float | None = 1e-4,
        max_decay: float | None = 2,
        ord: float = 2,
        tensorwise=True,
    ):
        defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay, ord=ord, tensorwise=tensorwise)
        super().__init__(defaults)


    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        tensorwise = settings[0]['tensorwise']
        tensors = TensorList(tensors)

        if tensorwise:
            ts = tensors
            stts = states
            stns = settings

        else:
            ts = [tensors.to_vec()]
            stts = [self.global_state]
            stns = [settings[0]]


        for t, state, setting in zip(ts, stts, stns):
            if 'prev_norm' not in state:
                state['prev_norm'] = torch.linalg.vector_norm(t, ord=setting['ord']) # pylint:disable=not-callable
                state['prev_denom'] = 1
                continue

            _,  state['prev_norm'], state['prev_denom'] = norm_growth_clip_(
                tensor_ = t,
                prev_norm = state['prev_norm'],
                add = setting['add'],
                mul = setting['mul'],
                min_value = setting['min_value'],
                max_decay = setting['max_decay'],
                ord = setting['ord'],
            )

        if not tensorwise:
            tensors.from_vec_(ts[0])

        return tensors

ClipValue

Bases: torchzero.core.transform.TensorTransform

Clips update magnitude to be within (-value, value) range.

Parameters:

  • value (float) –

    value to clip to.

  • target (str) –

    refer to target argument in documentation.

Examples:

Gradient clipping:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.ClipValue(1),
    tz.m.Adam(),
    tz.m.LR(1e-2),
)

Update clipping:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adam(),
    tz.m.ClipValue(1),
    tz.m.LR(1e-2),
)

Source code in torchzero/modules/clipping/clipping.py
class ClipValue(TensorTransform):
    """Clips update magnitude to be within ``(-value, value)`` range.

    Args:
        value (float): value to clip to.
        target (str): refer to ``target argument`` in documentation.

    Examples:

    Gradient clipping:
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.ClipValue(1),
        tz.m.Adam(),
        tz.m.LR(1e-2),
    )
    ```

    Update clipping:
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Adam(),
        tz.m.ClipValue(1),
        tz.m.LR(1e-2),
    )
    ```

    """
    def __init__(self, value: float):
        defaults = dict(value=value)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        value = [s['value'] for s in settings]
        return TensorList(tensors).clip_([-v for v in value], value)

ClipValueByEMA

Bases: torchzero.core.transform.TensorTransform

Clips magnitude of update to be no larger than magnitude of exponential moving average of past (unclipped) updates.

Parameters:

  • beta (float, default: 0.99 ) –

    beta for the exponential moving average. Defaults to 0.99.

  • ema_init (str) –

    How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.

  • exp_avg_tfm (Chainable | None, default: None ) –

    optional modules applied to exponential moving average before clipping by it. Defaults to None.

Source code in torchzero/modules/clipping/ema_clipping.py
class ClipValueByEMA(TensorTransform):
    """Clips magnitude of update to be no larger than magnitude of exponential moving average of past (unclipped) updates.

    Args:
        beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
        ema_init (str, optional):
            How to initialize exponential moving average on first step,
            "update" to use the first update or "zeros". Defaults to 'zeros'.
        exp_avg_tfm (Chainable | None, optional):
            optional modules applied to exponential moving average before clipping by it. Defaults to None.
    """
    def __init__(
        self,
        beta=0.99,
        init: float = 0,

        inner: Chainable | None = None,
        exp_avg_tfm:Chainable | None=None,
    ):
        defaults = dict(beta=beta, init=init)
        super().__init__(defaults, inner=inner)

        self.set_child('exp_avg', exp_avg_tfm)
        self.add_projected_keys("grad", "exp_avg")

    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        state["exp_avg"] = tensor.abs() * setting["init"]

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        beta = unpack_dicts(settings, 'beta', cls=NumberList)

        exp_avg = unpack_states(states, tensors, 'exp_avg', must_exist=True, cls=TensorList)
        exp_avg.lerp_(tensors.abs(), 1-beta)

    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        exp_avg = unpack_states(states, tensors, 'exp_avg')

        exp_avg = TensorList(
            self.inner_step_tensors("exp_avg", exp_avg, clone=True, params=params, grads=grads, loss=loss, must_exist=False))

        tensors.clip_(-exp_avg, exp_avg)
        return tensors

ClipValueGrowth

Bases: torchzero.core.transform.TensorTransform

Clips update value magnitude growth.

Parameters:

  • add (float | None, default: None ) –

    additive clipping, next update is at most previous update + add. Defaults to None.

  • mul (float | None, default: 1.5 ) –

    multiplicative clipping, next update is at most previous update * mul. Defaults to 1.5.

  • min_value (float | None, default: 0.0001 ) –

    minimum value for multiplicative clipping to prevent collapse to 0. Next update is at most :code:max(prev_update, min_value) * mul. Defaults to 1e-4.

  • max_decay (float | None, default: 2 ) –

    bounds the tracked multiplicative clipping decay to prevent collapse to 0. Next update is at most :code:max(previous update * mul, max_decay). Defaults to 2.

  • target (Target) –

    what to set on var. Defaults to "update".

Source code in torchzero/modules/clipping/growth_clipping.py
class ClipValueGrowth(TensorTransform):
    """Clips update value magnitude growth.

    Args:
        add (float | None, optional): additive clipping, next update is at most `previous update + add`. Defaults to None.
        mul (float | None, optional): multiplicative clipping, next update is at most `previous update * mul`. Defaults to 1.5.
        min_value (float | None, optional):
            minimum value for multiplicative clipping to prevent collapse to 0.
            Next update is at most :code:`max(prev_update, min_value) * mul`. Defaults to 1e-4.
        max_decay (float | None, optional):
            bounds the tracked multiplicative clipping decay to prevent collapse to 0.
            Next update is at most :code:`max(previous update * mul, max_decay)`.
            Defaults to 2.
        target (Target, optional): what to set on var. Defaults to "update".
    """
    def __init__(
        self,
        add: float | None = None,
        mul: float | None = 1.5,
        min_value: float | None = 1e-4,
        max_decay: float | None = 2,
    ):
        defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay)
        super().__init__(defaults)
        self.add_projected_keys("grad", "prev")


    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(setting)
        add: float | None

        if add is None and mul is None:
            return tensor

        if 'prev' not in state:
            state['prev'] = tensor.clone()
            return tensor

        prev: torch.Tensor = state['prev']

        # additive bound
        if add is not None:
            growth = (tensor.abs() - prev.abs()).clip(min=0)
            tensor.sub_(torch.where(growth > add, (growth-add).copysign_(tensor), 0))

        # multiplicative bound
        growth = None
        if mul is not None:
            prev_magn = prev.abs()
            if min_value is not None: prev_magn.clip_(min=min_value)
            growth = (tensor.abs() / prev_magn).clamp_(min=1e-8)

            denom = torch.where(growth > mul, growth/mul, 1)

            tensor.div_(denom)

        # limit max growth decay
        if max_decay is not None:
            if growth is None:
                prev_magn = prev.abs()
                if min_value is not None: prev_magn.clip_(min=min_value)
                growth = (tensor.abs() / prev_magn).clamp_(min=1e-8)

            new_prev = torch.where(growth < (1/max_decay), prev/max_decay, tensor)
        else:
            new_prev = tensor.clone()

        state['prev'] = new_prev
        return tensor

Clone

Bases: torchzero.core.module.Module

Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations

Source code in torchzero/modules/ops/utility.py
class Clone(Module):
    """Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def apply(self, objective):
        objective.updates = [u.clone() for u in objective.get_updates()]
        return objective

ConjugateDescent

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Conjugate Descent (CD).

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class ConjugateDescent(ConguateGradientBase):
    """Conjugate Descent (CD).

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
    """
    def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def get_beta(self, p, g, prev_g, prev_d):
        return conjugate_descent_beta(g, prev_d, prev_g)

CopyMagnitude

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Returns other(tensors) with sign copied from tensors.

Source code in torchzero/modules/ops/binary.py
class RCopySign(BinaryOperationBase):
    """Returns ``other(tensors)`` with sign copied from tensors."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
        return [o.copysign_(u) for u, o in zip(update, other)]

CopySign

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Returns tensors with sign copied from other(tensors).

Source code in torchzero/modules/ops/binary.py
class CopySign(BinaryOperationBase):
    """Returns tensors with sign copied from ``other(tensors)``."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
        return [u.copysign_(o) for u, o in zip(update, other)]

CubicRegularization

Bases: torchzero.modules.trust_region.trust_region.TrustRegionBase

Cubic regularization.

Parameters:

  • hess_module (Module | None) –

    A module that maintains a hessian approximation (not hessian inverse!). This includes all full-matrix quasi-newton methods, tz.m.Newton and tz.m.GaussNewton. When using quasi-newton methods, set inverse=False when constructing them.

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. When :code:hess_module is GaussNewton, this can be set to 0. Defaults to 0.15.

  • nplus (float, default: 3.5 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.99 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.0001 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • maxiter (float, default: 100 ) –

    maximum iterations when solving cubic subproblem, defaults to 1e-7.

  • eps (float, default: 1e-08 ) –

    epsilon for the solver, defaults to 1e-8.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • max_attempts (max_attempts, default: 10 ) –

    maximum number of trust region size size reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • fallback (bool) –

    if True, when hess_module maintains hessian inverse which can't be inverted efficiently, it will be inverted anyway. When False (default), a RuntimeError will be raised instead.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to output of thise module. Defaults to None.

Examples:

Cubic regularized newton

.. code-block:: python

opt = tz.Optimizer(
    model.parameters(),
    tz.m.CubicRegularization(tz.m.Newton()),
)
Source code in torchzero/modules/trust_region/cubic_regularization.py
class CubicRegularization(TrustRegionBase):
    """Cubic regularization.

    Args:
        hess_module (Module | None, optional):
            A module that maintains a hessian approximation (not hessian inverse!).
            This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
            When using quasi-newton methods, set `inverse=False` when constructing them.
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted.
            When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        maxiter (float, optional): maximum iterations when solving cubic subproblem, defaults to 1e-7.
        eps (float, optional): epsilon for the solver, defaults to 1e-8.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        max_attempts (max_attempts, optional):
            maximum number of trust region size size reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        fallback (bool, optional):
            if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
            be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
        inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.


    Examples:
        Cubic regularized newton

        .. code-block:: python

            opt = tz.Optimizer(
                model.parameters(),
                tz.m.CubicRegularization(tz.m.Newton()),
            )

    """
    def __init__(
        self,
        hess_module: Chainable,
        eta: float= 0.0,
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        init: float = 1,
        max_attempts: int = 10,
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
        maxiter: int = 100,
        eps: float = 1e-8,
        check_decrease:bool=False,
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict(maxiter=maxiter, eps=eps, check_decrease=check_decrease)
        super().__init__(
            defaults=defaults,
            hess_module=hess_module,
            eta=eta,
            nplus=nplus,
            nminus=nminus,
            rho_good=rho_good,
            rho_bad=rho_bad,
            init=init,
            max_attempts=max_attempts,
            radius_strategy=radius_strategy,
            update_freq=update_freq,
            inner=inner,

            boundary_tol=None,
            radius_fn=None,
        )

    def trust_solve(self, f, g, H, radius, params, closure, settings):
        params = TensorList(params)

        loss_at_params_plus_x_fn = None
        if settings['check_decrease']:
            def closure_plus_x(x):
                x_unflat = vec_to_tensors(x, params)
                params.add_(x_unflat)
                loss_x = closure(False)
                params.sub_(x_unflat)
                return loss_x
            loss_at_params_plus_x_fn = closure_plus_x


        d, _ = ls_cubic_solver(f=f, g=g, H=H, M=1/radius, loss_at_params_plus_x_fn=loss_at_params_plus_x_fn,
                               it_max=settings['maxiter'], epsilon=settings['eps'])
        return d.neg_()

CustomUnaryOperation

Bases: torchzero.core.transform.TensorTransform

Applies getattr(tensor, name) to each tensor

Source code in torchzero/modules/ops/unary.py
class CustomUnaryOperation(TensorTransform):
    """Applies ``getattr(tensor, name)`` to each tensor
    """
    def __init__(self, name: str):
        defaults = dict(name=name)
        super().__init__(defaults=defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        return getattr(tensors, settings[0]['name'])()

DFP

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Davidon–Fletcher–Powell Quasi-Newton method.

Note

a trust region or an accurate line search is recommended.

Warning

this uses at least O(N^2) memory.

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class DFP(_InverseHessianUpdateStrategyDefaults):
    """Davidon–Fletcher–Powell Quasi-Newton method.

    Note:
        a trust region or an accurate line search is recommended.

    Warning:
        this uses at least O(N^2) memory.
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return dfp_H_(H=H, s=s, y=y, tol=setting['tol'])
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return dfp_B(B=B, s=s, y=y, tol=setting['tol'])

DNRTR

Bases: torchzero.modules.quasi_newton.quasi_newton.HessianUpdateStrategy

Diagonal quasi-newton method.

Reference

Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.

Source code in torchzero/modules/quasi_newton/diagonal_quasi_newton.py
class DNRTR(HessianUpdateStrategy):
    """Diagonal quasi-newton method.

    Reference:
        Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
    """
    def __init__(
        self,
        lb: float = 1e-2,
        ub: float = 1e5,
        init_scale: float | Literal["auto"] = "auto",
        tol: float = 1e-32,
        ptol: float | None = 1e-32,
        ptol_restart: bool = False,
        gtol: float | None = 1e-32,
        restart_interval: int | None | Literal['auto'] = None,
        beta: float | None = None,
        update_freq: int = 1,
        scale_first: bool = False,
        concat_params: bool = True,
        inner: Chainable | None = None,
    ):
        defaults = dict(lb=lb, ub=ub)
        super().__init__(
            defaults=defaults,
            init_scale=init_scale,
            tol=tol,
            ptol=ptol,
            ptol_restart=ptol_restart,
            gtol=gtol,
            restart_interval=restart_interval,
            beta=beta,
            update_freq=update_freq,
            scale_first=scale_first,
            concat_params=concat_params,
            inverse=False,
            inner=inner,
        )

    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return diagonal_wqc_B_(B=B, s=s, y=y)

    def modify_B(self, B, state, setting):
        return _truncate(B, setting['lb'], setting['ub'])

    def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)

DYHS

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Dai-Yuan - Hestenes–Stiefel hybrid conjugate gradient method.

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class DYHS(ConguateGradientBase):
    """Dai-Yuan - Hestenes–Stiefel hybrid conjugate gradient method.

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
    """
    def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def get_beta(self, p, g, prev_g, prev_d):
        return dyhs_beta(g, prev_d, prev_g)

DaiYuan

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Dai–Yuan nonlinear conjugate gradient method.

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1) after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class DaiYuan(ConguateGradientBase):
    """Dai–Yuan nonlinear conjugate gradient method.

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1)`` after this.
    """
    def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def get_beta(self, p, g, prev_g, prev_d):
        return dai_yuan_beta(g, prev_d, prev_g)

Debias

Bases: torchzero.core.transform.TensorTransform

Multiplies the update by an Adam debiasing term based first and/or second momentum.

Parameters:

  • beta1 (float | None, default: None ) –

    first momentum, should be the same as first momentum used in modules before. Defaults to None.

  • beta2 (float | None, default: None ) –

    second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.

  • alpha (float, default: 1 ) –

    learning rate. Defaults to 1.

  • pow (float, default: 2 ) –

    power, assumes absolute value is used. Defaults to 2.

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/higher_level.py
class Debias(TensorTransform):
    """Multiplies the update by an Adam debiasing term based first and/or second momentum.

    Args:
        beta1 (float | None, optional):
            first momentum, should be the same as first momentum used in modules before. Defaults to None.
        beta2 (float | None, optional):
            second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
        alpha (float, optional): learning rate. Defaults to 1.
        pow (float, optional): power, assumes absolute value is used. Defaults to 2.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, beta1: float | None = None, beta2: float | None = None, alpha: float = 1, pow:float=2):
        defaults = dict(beta1=beta1, beta2=beta2, alpha=alpha, pow=pow)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        pow = settings[0]['pow']
        alpha, beta1, beta2 = unpack_dicts(settings, 'alpha', 'beta1', 'beta2', cls=NumberList)

        return debias(TensorList(tensors), step=step, beta1=beta1, beta2=beta2, alpha=alpha, pow=pow, inplace=True)

Debias2

Bases: torchzero.core.transform.TensorTransform

Multiplies the update by an Adam debiasing term based on the second momentum.

Parameters:

  • beta (float | None, default: 0.999 ) –

    second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.

  • pow (float, default: 2 ) –

    power, assumes absolute value is used. Defaults to 2.

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/higher_level.py
class Debias2(TensorTransform):
    """Multiplies the update by an Adam debiasing term based on the second momentum.

    Args:
        beta (float | None, optional):
            second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
        pow (float, optional): power, assumes absolute value is used. Defaults to 2.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, beta: float = 0.999, pow: float = 2,):
        defaults = dict(beta=beta, pow=pow)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        pow = settings[0]['pow']
        beta = NumberList(s['beta'] for s in settings)
        return debias_second_momentum(TensorList(tensors), step=step, beta=beta, pow=pow, inplace=True)

DiagonalBFGS

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful.

Source code in torchzero/modules/quasi_newton/diagonal_quasi_newton.py
class DiagonalBFGS(_InverseHessianUpdateStrategyDefaults):
    """Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return diagonal_bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])

    def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)

DiagonalQuasiCauchi

Bases: torchzero.modules.quasi_newton.quasi_newton._HessianUpdateStrategyDefaults

Diagonal quasi-cauchi method.

Reference

Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.

Source code in torchzero/modules/quasi_newton/diagonal_quasi_newton.py
class DiagonalQuasiCauchi(_HessianUpdateStrategyDefaults):
    """Diagonal quasi-cauchi method.

    Reference:
        Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
    """
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return diagonal_qc_B_(B=B, s=s, y=y)

    def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)

DiagonalSR1

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful.

Source code in torchzero/modules/quasi_newton/diagonal_quasi_newton.py
class DiagonalSR1(_InverseHessianUpdateStrategyDefaults):
    """Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return diagonal_sr1_(H=H, s=s, y=y, tol=setting['tol'])
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return diagonal_sr1_(H=B, s=y, y=s, tol=setting['tol'])

    def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)

DiagonalWeightedQuasiCauchi

Bases: torchzero.modules.quasi_newton.quasi_newton._HessianUpdateStrategyDefaults

Diagonal quasi-cauchi method.

Reference

Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.

Source code in torchzero/modules/quasi_newton/diagonal_quasi_newton.py
class DiagonalWeightedQuasiCauchi(_HessianUpdateStrategyDefaults):
    """Diagonal quasi-cauchi method.

    Reference:
        Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
    """
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return diagonal_wqc_B_(B=B, s=s, y=y)

    def initialize_P(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)

DirectWeightDecay

Bases: torchzero.core.module.Module

Directly applies weight decay to parameters.

Parameters:

  • weight_decay (float) –

    weight decay scale.

  • ord (int, default: 2 ) –

    order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.

Source code in torchzero/modules/weight_decay/weight_decay.py
class DirectWeightDecay(Module):
    """Directly applies weight decay to parameters.

    Args:
        weight_decay (float): weight decay scale.
        ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
    """
    def __init__(self, weight_decay: float, ord: int = 2,):
        defaults = dict(weight_decay=weight_decay, ord=ord)
        super().__init__(defaults)

    @torch.no_grad
    def apply(self, objective):
        weight_decay = self.get_settings(objective.params, 'weight_decay', cls=NumberList)
        ord = self.defaults['ord']

        decay_weights_(objective.params, weight_decay, ord)
        return objective

Div

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Divide tensors by other. other can be a number or a module.

If other is a module, this calculates tensors / other(tensors)

Source code in torchzero/modules/ops/binary.py
class Div(BinaryOperationBase):
    """Divide tensors by ``other``. ``other`` can be a number or a module.

    If ``other`` is a module, this calculates ``tensors / other(tensors)``
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        torch._foreach_div_(update, other)
        return update

DivByLoss

Bases: torchzero.core.transform.TensorTransform

Divides update by loss times alpha

Source code in torchzero/modules/misc/misc.py
class DivByLoss(TensorTransform):
    """Divides update by loss times ``alpha``"""
    def __init__(self, alpha: float = 1, min_value:float = 1e-16, backward: bool = True):
        defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
        super().__init__(defaults, uses_loss=True)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        assert loss is not None
        alpha, min_value = unpack_dicts(settings, 'alpha', 'min_value')
        denom = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
        torch._foreach_div_(tensors, denom)
        return tensors

DivModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates input / other. input and other can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class DivModules(MultiOperationBase):
    """Calculates ``input / other``. ``input`` and ``other`` can be numbers or modules."""
    def __init__(self, input: Chainable | float, other: Chainable | float, other_first:bool=False):
        defaults = {}
        if other_first: super().__init__(defaults, other=other, input=input)
        else: super().__init__(defaults, input=input, other=other)

    @torch.no_grad
    def transform(self, objective: Objective, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
        if isinstance(input, (int,float)):
            assert isinstance(other, list)
            return input / TensorList(other)

        torch._foreach_div_(input, other)
        return input

Dogleg

Bases: torchzero.modules.trust_region.trust_region.TrustRegionBase

Dogleg trust region algorithm.

Parameters:

  • hess_module (Module | None) –

    A module that maintains a hessian approximation (not hessian inverse!). This includes all full-matrix quasi-newton methods, tz.m.Newton and tz.m.GaussNewton. When using quasi-newton methods, set inverse=False when constructing them.

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. When :code:hess_module is GaussNewton, this can be set to 0. Defaults to 0.15.

  • nplus (float, default: 2 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.75 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.25 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • max_attempts (max_attempts, default: 10 ) –

    maximum number of trust region size size reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to output of thise module. Defaults to None.

Source code in torchzero/modules/trust_region/dogleg.py
class Dogleg(TrustRegionBase):
    """Dogleg trust region algorithm.


    Args:
        hess_module (Module | None, optional):
            A module that maintains a hessian approximation (not hessian inverse!).
            This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
            When using quasi-newton methods, set `inverse=False` when constructing them.
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted.
            When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        max_attempts (max_attempts, optional):
            maximum number of trust region size size reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.

    """
    def __init__(
        self,
        hess_module: Chainable,
        eta: float= 0.0,
        nplus: float = 2,
        nminus: float = 0.25,
        rho_good: float = 0.75,
        rho_bad: float = 0.25,
        boundary_tol: float | None = None,
        init: float = 1,
        max_attempts: int = 10,
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict()
        super().__init__(
            defaults=defaults,
            hess_module=hess_module,
            eta=eta,
            nplus=nplus,
            nminus=nminus,
            rho_good=rho_good,
            rho_bad=rho_bad,
            boundary_tol=boundary_tol,
            init=init,
            max_attempts=max_attempts,
            radius_strategy=radius_strategy,
            update_freq=update_freq,
            inner=inner,

            radius_fn=torch.linalg.vector_norm,
        )

    def trust_solve(self, f, g, H, radius, params, closure, settings):
        if radius > 2: radius = self.global_state['radius'] = 2
        eps = torch.finfo(g.dtype).tiny * 2

        gHg = g.dot(H.matvec(g))
        if gHg <= eps:
            return (radius / torch.linalg.vector_norm(g)) * g # pylint:disable=not-callable

        p_cauchy = (g.dot(g) / gHg) * g
        p_newton = H.solve(g)

        a = p_newton - p_cauchy
        b = p_cauchy

        aa = a.dot(a)
        if aa < eps:
            return (radius / torch.linalg.vector_norm(g)) * g # pylint:disable=not-callable

        ab = a.dot(b)
        bb = b.dot(b)
        c = bb - radius**2
        discriminant = (2*ab)**2 - 4*aa*c
        beta = (-2*ab + torch.sqrt(discriminant.clip(min=0))) / (2 * aa)
        return p_cauchy + beta * (p_newton - p_cauchy)

Dropout

Bases: torchzero.core.transform.Transform

Applies dropout to the update.

For each weight the update to that weight has p probability to be set to 0. This can be used to implement gradient dropout or update dropout depending on placement.

Parameters:

  • p (float, default: 0.5 ) –

    probability that update for a weight is replaced with 0. Defaults to 0.5.

  • graft (bool, default: False ) –

    if True, update after dropout is rescaled to have the same norm as before dropout. Defaults to False.

  • target (Target) –

    what to set on var, refer to documentation. Defaults to 'update'.

Examples:

Gradient dropout.

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Dropout(0.5),
    tz.m.Adam(),
    tz.m.LR(1e-3)
)

Update dropout.

``python opt = tz.Optimizer( model.parameters(), tz.m.Adam(), tz.m.Dropout(0.5), tz.m.LR(1e-3) ) ```

Source code in torchzero/modules/misc/regularization.py
class Dropout(Transform):
    """Applies dropout to the update.

    For each weight the update to that weight has ``p`` probability to be set to 0.
    This can be used to implement gradient dropout or update dropout depending on placement.

    Args:
        p (float, optional): probability that update for a weight is replaced with 0. Defaults to 0.5.
        graft (bool, optional):
            if True, update after dropout is rescaled to have the same norm as before dropout. Defaults to False.
        target (Target, optional): what to set on var, refer to documentation. Defaults to 'update'.


    ### Examples:

    Gradient dropout.

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Dropout(0.5),
        tz.m.Adam(),
        tz.m.LR(1e-3)
    )
    ```

    Update dropout.

    ``python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Adam(),
        tz.m.Dropout(0.5),
        tz.m.LR(1e-3)
    )
    ```

    """
    def __init__(self, p: float = 0.5, graft: bool=False):
        defaults = dict(p=p, graft=graft)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        p = NumberList(s['p'] for s in settings)
        graft = settings[0]['graft']

        if graft:
            target_norm = tensors.global_vector_norm()
            tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
            return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft

        return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))

DualNormCorrection

Bases: torchzero.core.transform.TensorTransform

Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon). Orthogonalize already has this built in with the dual_norm_correction setting.

Source code in torchzero/modules/adaptive/muon.py
class DualNormCorrection(TensorTransform):
    """Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon).
    Orthogonalize already has this built in with the `dual_norm_correction` setting."""
    def __init__(self, channel_first: bool = True):
        defaults = dict(channel_first=channel_first)
        super().__init__(defaults)

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        assert grad is not None
        if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
            return _dual_norm_correction(tensor, grad, channel_first=setting["channel_first"])
        return tensor

EMA

Bases: torchzero.core.transform.TensorTransform

Maintains an exponential moving average of update.

Parameters:

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • dampening (float, default: 0 ) –

    momentum dampening. Defaults to 0.

  • debias (bool, default: False ) –

    whether to debias the EMA like in Adam. Defaults to False.

  • lerp (bool, default: True ) –

    whether to use linear interpolation. Defaults to True.

  • ema_init (str, default: 'zeros' ) –

    initial values for the EMA, "zeros" or "update".

  • target (Target) –

    target to apply EMA to. Defaults to 'update'.

Source code in torchzero/modules/momentum/momentum.py
class EMA(TensorTransform):
    """Maintains an exponential moving average of update.

    Args:
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        dampening (float, optional): momentum dampening. Defaults to 0.
        debias (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
        lerp (bool, optional): whether to use linear interpolation. Defaults to True.
        ema_init (str, optional): initial values for the EMA, "zeros" or "update".
        target (Target, optional): target to apply EMA to. Defaults to 'update'.
    """
    def __init__(self, momentum:float=0.9, dampening:float=0, debias: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros'):
        defaults = dict(momentum=momentum,dampening=dampening,debias=debias,lerp=lerp,ema_init=ema_init)
        super().__init__(defaults, uses_grad=False)

        self.add_projected_keys("grad", "exp_avg")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        debias, lerp, ema_init = itemgetter('debias','lerp','ema_init')(settings[0])

        exp_avg = unpack_states(states, tensors, 'exp_avg',
                                init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
        momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)

        exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)

        if debias: return _debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
        else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned

EMASquared

Bases: torchzero.core.transform.TensorTransform

Maintains an exponential moving average of squared updates.

Parameters:

  • beta (float, default: 0.999 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Methods:

  • EMA_SQ_FN

    Updates exp_avg_sq_ with EMA of squared tensors, if max_exp_avg_sq_ is not None, updates it with maximum of EMA.

Source code in torchzero/modules/ops/higher_level.py
class EMASquared(TensorTransform):
    """Maintains an exponential moving average of squared updates.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    EMA_SQ_FN: staticmethod = staticmethod(ema_sq_)

    def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2):
        defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
        super().__init__(defaults)
        self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return self.EMA_SQ_FN(TensorList(tensors), exp_avg_sq_=exp_avg_sq, beta=beta, max_exp_avg_sq_=max_exp_avg_sq, pow=pow).clone()

EMA_SQ_FN

EMA_SQ_FN(tensors: TensorList, exp_avg_sq_: TensorList, beta: float | NumberList, max_exp_avg_sq_: TensorList | None, pow: float = 2)

Updates exp_avg_sq_ with EMA of squared tensors, if max_exp_avg_sq_ is not None, updates it with maximum of EMA.

Returns exp_avg_sq_ or max_exp_avg_sq_.

Source code in torchzero/modules/opt_utils.py
def ema_sq_(
    tensors: TensorList,
    exp_avg_sq_: TensorList,
    beta: float | NumberList,
    max_exp_avg_sq_: TensorList | None,
    pow: float = 2,
):
    """
    Updates `exp_avg_sq_` with EMA of squared `tensors`, if `max_exp_avg_sq_` is not None, updates it with maximum of EMA.

    Returns `exp_avg_sq_` or `max_exp_avg_sq_`.
    """
    lerp_power_(tensors=tensors, exp_avg_pow_=exp_avg_sq_,beta=beta,pow=pow)

    # AMSGrad
    if max_exp_avg_sq_ is not None:
        max_exp_avg_sq_.maximum_(exp_avg_sq_)
        exp_avg_sq_ = max_exp_avg_sq_

    return exp_avg_sq_

ESGD

Bases: torchzero.core.transform.Transform

Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)

This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.

Notes:
    - In most cases ESGD should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply ESGD preconditioning to another module's output.

    - This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

Args:
    damping (float, optional): added to denominator for stability. Defaults to 1e-4.
    update_freq (int, optional):
        frequency of updating hessian diagonal estimate via a hessian-vector product.
        This value can be increased to reduce computational cost. Defaults to 20.
    hvp_method (str, optional):
        Determines how hessian-vector products are computed.

        - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
        - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
        - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
        - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

        Defaults to ``"autograd"``.
    h (float, optional):
        The step size for finite difference if ``hvp_method`` is
        ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
    n_samples (int, optional):
        number of hessian-vector products with random vectors to evaluate each time when updating
        the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
    seed (int | None, optional): seed for random vectors. Defaults to None.
    inner (Chainable | None, optional):
        Inner module. If this is specified, operations are performed in the following order.
        1. compute hessian diagonal estimate.
        2. pass inputs to :code:`inner`.
        3. momentum and preconditioning are applied to the ouputs of :code:`inner`.

### Examples:

Using ESGD:

```python

opt = tz.Optimizer(
    model.parameters(),
    tz.m.ESGD(),
    tz.m.LR(0.1)
)
```

ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):

```python
opt = tz.Optimizer(
    model.parameters(),
    tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
    tz.m.LR(0.1)
)
```
Source code in torchzero/modules/adaptive/esgd.py
class ESGD(Transform):
    """Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)

    This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.

    Notes:
        - In most cases ESGD should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply ESGD preconditioning to another module's output.

        - This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        damping (float, optional): added to denominator for stability. Defaults to 1e-4.
        update_freq (int, optional):
            frequency of updating hessian diagonal estimate via a hessian-vector product.
            This value can be increased to reduce computational cost. Defaults to 20.
        hvp_method (str, optional):
            Determines how hessian-vector products are computed.

            - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
            - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
            - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

            Defaults to ``"autograd"``.
        h (float, optional):
            The step size for finite difference if ``hvp_method`` is
            ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
        n_samples (int, optional):
            number of hessian-vector products with random vectors to evaluate each time when updating
            the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
        seed (int | None, optional): seed for random vectors. Defaults to None.
        inner (Chainable | None, optional):
            Inner module. If this is specified, operations are performed in the following order.
            1. compute hessian diagonal estimate.
            2. pass inputs to :code:`inner`.
            3. momentum and preconditioning are applied to the ouputs of :code:`inner`.

    ### Examples:

    Using ESGD:
```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.ESGD(),
        tz.m.LR(0.1)
    )
    ```

    ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
    ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
        tz.m.LR(0.1)
    )
    ```

    """
    def __init__(
        self,
        damping: float = 1e-4,
        update_freq: int = 20,
        distribution: Distributions = 'gaussian',
        hvp_method: HVPMethod = 'autograd',
        h: float = 1e-3,
        n_samples = 1,
        zHz: bool = False,
        seed: int | None = None,
        beta: float | None = None,
        beta_debias: bool = True,

        inner: Chainable | None = None,
        Hz_sq_acc_tfm: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults["Hz_sq_acc_tfm"]
        super().__init__(defaults, inner=inner)

        self.set_child("Hz_sq_acc", Hz_sq_acc_tfm)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        params = objective.params

        fs = settings[0]
        update_freq = fs['update_freq']

        # ------------------------------- accumulate Hz ------------------------------ #
        step = self.increment_counter("step", start=0)

        if step % update_freq == 0:
            self.increment_counter("num_Hzs", start=1)

            Hz, _ = objective.hutchinson_hessian(
                rgrad = None,
                at_x0 = True,
                n_samples = fs['n_samples'],
                distribution = fs['distribution'],
                hvp_method = fs['hvp_method'],
                h = fs['h'],
                zHz = fs["zHz"], # default is False, so it returns Hz, not z⊙Hz
                generator = self.get_generator(params[0].device, fs["seed"]),
            )

            Hz = TensorList(Hz)
            Hz_sq_acc = unpack_states(states, params, 'Hz_sq_acc', cls=TensorList)

            beta = fs["beta"]
            if beta is None:
                Hz_sq_acc.addcmul_(Hz, Hz)

            else:
                Hz_sq_acc.mul_(beta).addcmul_(Hz, Hz, value=1-beta)

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        tensors = TensorList(objective.get_updates())
        Hz_sq_acc = unpack_states(states, tensors, 'Hz_sq_acc', cls=TensorList)
        num_Hzs = self.global_state["num_Hzs"]
        fs = settings[0]

        # ---------------------------------- debias ---------------------------------- #
        beta = fs["beta"]
        beta_debias = fs["beta_debias"]

        if beta_debias and beta is not None:
            bias_correction = 1.0 - beta ** num_Hzs
            Hz_sq_acc = Hz_sq_acc / bias_correction

        else:
            Hz_sq_acc = Hz_sq_acc / num_Hzs

        # ---------------------------------- update ---------------------------------- #
        damping = [s["damping"] for s in settings]

        denom = (Hz_sq_acc / num_Hzs).sqrt_().add_(damping)

        objective.updates = tensors.div_(denom)
        return objective

EscapeAnnealing

Bases: torchzero.core.module.Module

If parameters stop changing, this runs a backward annealing random search

Source code in torchzero/modules/misc/escape.py
class EscapeAnnealing(Module):
    """If parameters stop changing, this runs a backward annealing random search"""
    def __init__(self, max_region:float = 1, max_iter:int = 1000, tol=1e-6, n_tol: int = 10):
        defaults = dict(max_region=max_region, max_iter=max_iter, tol=tol, n_tol=n_tol)
        super().__init__(defaults)


    @torch.no_grad
    def apply(self, objective):
        closure = objective.closure
        if closure is None: raise RuntimeError("Escape requries closure")

        params = TensorList(objective.params)
        settings = self.settings[params[0]]
        max_region = self.get_settings(params, 'max_region', cls=NumberList)
        max_iter = settings['max_iter']
        tol = settings['tol']
        n_tol = settings['n_tol']

        n_bad = self.global_state.get('n_bad', 0)

        prev_params = self.get_state(params, 'prev_params', cls=TensorList)
        diff = params-prev_params
        prev_params.copy_(params)

        if diff.abs().global_max() <= tol:
            n_bad += 1

        else:
            n_bad = 0

        self.global_state['n_bad'] = n_bad

        # no progress
        f_0 = objective.get_loss(False)
        if n_bad >= n_tol:
            for i in range(1, max_iter+1):
                alpha = max_region * (i / max_iter)
                pert = params.sphere_like(radius=alpha)

                params.add_(pert)
                f_star = closure(False)

                if math.isfinite(f_star) and f_star < f_0-1e-12:
                    objective.updates = None
                    objective.stop = True
                    objective.skip_update = True
                    return objective

                params.sub_(pert)

            self.global_state['n_bad'] = 0
        return objective

Exp

Bases: torchzero.core.transform.TensorTransform

Returns exp(input)

Source code in torchzero/modules/ops/unary.py
class Exp(TensorTransform):
    """Returns ``exp(input)``"""
    def __init__(self): super().__init__()
    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        torch._foreach_exp_(tensors)
        return tensors

ExpHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class ExpHomotopy(HomotopyBase):
    def __init__(self): super().__init__()
    def loss_transform(self, loss): return loss.exp()

FDM

Bases: torchzero.modules.grad_approximation.grad_approximator.GradApproximator

Approximate gradients via finite difference method.

Note

This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

Parameters:

  • h (float, default: 0.001 ) –

    magnitude of parameter perturbation. Defaults to 1e-3.

  • formula (Literal, default: 'central' ) –

    finite difference formula. Defaults to 'central2'.

  • target (Literal, default: 'closure' ) –

    what to set on var. Defaults to 'closure'.

Examples: plain FDM:

fdm = tz.Optimizer(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))

Any gradient-based method can use FDM-estimated gradients.

fdm_ncg = tz.Optimizer(
    model.parameters(),
    tz.m.FDM(),
    # set hvp_method to "forward" so that it
    # uses gradient difference instead of autograd
    tz.m.NewtonCG(hvp_method="forward"),
    tz.m.Backtracking()
)

Source code in torchzero/modules/grad_approximation/fdm.py
class FDM(GradApproximator):
    """Approximate gradients via finite difference method.

    Note:
        This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
        and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

    Args:
        h (float, optional): magnitude of parameter perturbation. Defaults to 1e-3.
        formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
        target (GradTarget, optional): what to set on var. Defaults to 'closure'.

    Examples:
    plain FDM:

    ```python
    fdm = tz.Optimizer(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
    ```

    Any gradient-based method can use FDM-estimated gradients.
    ```python
    fdm_ncg = tz.Optimizer(
        model.parameters(),
        tz.m.FDM(),
        # set hvp_method to "forward" so that it
        # uses gradient difference instead of autograd
        tz.m.NewtonCG(hvp_method="forward"),
        tz.m.Backtracking()
    )
    ```
    """
    def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central', target: GradTarget = 'closure'):
        defaults = dict(h=h, formula=formula)
        super().__init__(defaults, target=target)

    @torch.no_grad
    def approximate(self, closure, params, loss):
        grads = []
        loss_approx = None

        for p in params:
            g = torch.zeros_like(p)
            grads.append(g)

            settings = self.settings[p]
            h = settings['h']
            fd_fn = _FD_FUNCS[settings['formula']]

            p_flat = p.ravel(); g_flat = g.ravel()
            for i in range(len(p_flat)):
                loss, loss_approx, d = fd_fn(closure=closure, param=p_flat, idx=i, h=h, v_0=loss)
                g_flat[i] = d

        return grads, loss, loss_approx

Fill

Bases: torchzero.core.module.Module

Outputs tensors filled with value

Source code in torchzero/modules/ops/utility.py
class Fill(Module):
    """Outputs tensors filled with ``value``"""
    def __init__(self, value: float):
        defaults = dict(value=value)
        super().__init__(defaults)

    @torch.no_grad
    def apply(self, objective):
        objective.updates = [torch.full_like(p, self.settings[p]['value']) for p in objective.params]
        return objective

FillLoss

Bases: torchzero.core.module.Module

Outputs tensors filled with loss value times alpha

Source code in torchzero/modules/misc/misc.py
class FillLoss(Module):
    """Outputs tensors filled with loss value times ``alpha``"""
    def __init__(self, alpha: float = 1, backward: bool = True):
        defaults = dict(alpha=alpha, backward=backward)
        super().__init__(defaults)

    @torch.no_grad
    def apply(self, objective):
        alpha = self.get_settings(objective.params, 'alpha')
        loss = objective.get_loss(backward=self.defaults['backward'])
        objective.updates = [torch.full_like(p, loss*a) for p,a in zip(objective.params, alpha)]
        return objective

FletcherReeves

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Fletcher–Reeves nonlinear conjugate gradient method.

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class FletcherReeves(ConguateGradientBase):
    """Fletcher–Reeves nonlinear conjugate gradient method.

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
    """
    def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def initialize(self, p, g):
        self.global_state['prev_gg'] = g.dot(g)

    def get_beta(self, p, g, prev_g, prev_d):
        gg = g.dot(g)
        beta = fletcher_reeves_beta(gg, self.global_state['prev_gg'])
        self.global_state['prev_gg'] = gg
        return beta

FletcherVMM

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Fletcher's variable metric Quasi-Newton method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class FletcherVMM(_InverseHessianUpdateStrategyDefaults):
    """
    Fletcher's variable metric Quasi-Newton method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return fletcher_vmm_H_(H=H, s=s, y=y, tol=setting['tol'])

ForwardGradient

Bases: torchzero.modules.grad_approximation.rfdm.RandomizedFDM

Forward gradient method.

This method samples one or more directional derivatives evaluated via autograd jacobian-vector products. This is very similar to randomized finite difference.

Note

This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

Parameters:

  • n_samples (int, default: 1 ) –

    number of random gradient samples. Defaults to 1.

  • distribution (Literal, default: 'gaussian' ) –

    distribution for random gradient samples. Defaults to "gaussian".

  • pre_generate (bool, default: True ) –

    whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.

  • jvp_method (str, default: 'autograd' ) –

    how to calculate jacobian vector product, note that with forward and 'central' this is equivalent to randomized finite difference. Defaults to 'autograd'.

  • h (float, default: 0.001 ) –

    finite difference step size of jvp_method is set to forward or central. Defaults to 1e-3.

  • target (Literal, default: 'closure' ) –

    what to set on var. Defaults to "closure".

References

Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022). Gradients without backpropagation. arXiv preprint arXiv:2202.08587.

Source code in torchzero/modules/grad_approximation/forward_gradient.py
class ForwardGradient(RandomizedFDM):
    """Forward gradient method.

    This method samples one or more directional derivatives evaluated via autograd jacobian-vector products. This is very similar to randomized finite difference.

    Note:
        This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
        and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.


    Args:
        n_samples (int, optional): number of random gradient samples. Defaults to 1.
        distribution (Distributions, optional): distribution for random gradient samples. Defaults to "gaussian".
        pre_generate (bool, optional):
            whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
        jvp_method (str, optional):
            how to calculate jacobian vector product, note that with `forward` and 'central' this is equivalent to randomized finite difference. Defaults to 'autograd'.
        h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
        target (GradTarget, optional): what to set on var. Defaults to "closure".

    References:
        Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022). Gradients without backpropagation. arXiv preprint arXiv:2202.08587.
    """
    PRE_MULTIPLY_BY_H = False
    def __init__(
        self,
        n_samples: int = 1,
        distribution: Distributions = "gaussian",
        pre_generate = True,
        jvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
        h: float = 1e-3,
        target: GradTarget = "closure",
        seed: int | None | torch.Generator = None,
    ):
        super().__init__(h=h, n_samples=n_samples, distribution=distribution, target=target, pre_generate=pre_generate, seed=seed)
        self.defaults['jvp_method'] = jvp_method

    @torch.no_grad
    def approximate(self, closure, params, loss):
        params = TensorList(params)
        loss_approx = None

        fs = self.settings[params[0]]
        n_samples = fs['n_samples']
        jvp_method = fs['jvp_method']
        h = fs['h']
        distribution = fs['distribution']
        default = [None]*n_samples
        perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
        generator = self.get_generator(params[0].device, self.defaults['seed'])

        grad = None
        for i in range(n_samples):
            prt = perturbations[i]
            if prt[0] is None:
                prt = params.sample_like(distribution=distribution, variance=1, generator=generator)

            else: prt = TensorList(prt)

            if jvp_method == 'autograd':
                with torch.enable_grad():
                    loss, d = jvp(partial(closure, False), params=params, tangent=prt)

            elif jvp_method == 'forward':
                loss, d = jvp_fd_forward(partial(closure, False), params=params, tangent=prt, v_0=loss, h=h)

            elif jvp_method == 'central':
                loss_approx, d = jvp_fd_central(partial(closure, False), params=params, tangent=prt, h=h)

            else: raise ValueError(jvp_method)

            if grad is None: grad = prt * d
            else: grad += prt * d

        assert grad is not None
        if n_samples > 1: grad.div_(n_samples)
        return grad, loss, loss_approx

PRE_MULTIPLY_BY_H class-attribute

PRE_MULTIPLY_BY_H = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

FullMatrixAdagrad

Bases: torchzero.core.transform.TensorTransform

Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).

Note

A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in tz.m.GGT.

Parameters:

  • reg (float, default: 1e-12 ) –

    regularization, scale of identity matrix added to accumulator. Defaults to 1e-12.

  • precond_freq (int, default: 1 ) –

    frequency of updating the inverse square root of the accumulator. Defaults to 1.

  • beta (float | None, default: None ) –

    momentum for gradient outer product accumulators. if None, uses sum. Defaults to None.

  • beta_debias (bool, default: True ) –

    whether to use debiasing, only has effect when beta is not None. Defaults to True.

  • init (Literal[str], default: 'identity' ) –

    how to initialize the accumulator. - "identity" - with identity matrix (default). - "zeros" - with zero matrix. - "ones" - with matrix of ones. -"GGT" - with the first outer product

  • matrix_power (float, default: -0.5 ) –

    accumulator matrix power. Defaults to -1/2.

  • concat_params (bool, default: True ) –

    if False, each parameter will have it's own accumulator. Defaults to True.

  • inner (Chainable | None, default: None ) –

    inner modules to apply preconditioning to. Defaults to None.

Examples:

Plain full-matrix adagrad

opt = tz.Optimizer(
    model.parameters(),
    tz.m.FullMatrixAdagrd(),
    tz.m.LR(1e-2),
)

Full-matrix RMSprop

opt = tz.Optimizer(
    model.parameters(),
    tz.m.FullMatrixAdagrad(beta=0.99),
    tz.m.LR(1e-2),
)

Full-matrix Adam

opt = tz.Optimizer(
    model.parameters(),
    tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9)),
    tz.m.Debias(0.9, 0.999),
    tz.m.LR(1e-2),
)

Source code in torchzero/modules/adaptive/adagrad.py
class FullMatrixAdagrad(TensorTransform):
    """Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).

    Note:
        A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in ``tz.m.GGT``.

    Args:
        reg (float, optional): regularization, scale of identity matrix added to accumulator. Defaults to 1e-12.
        precond_freq (int, optional): frequency of updating the inverse square root of the accumulator. Defaults to 1.
        beta (float | None, optional): momentum for gradient outer product accumulators. if None, uses sum. Defaults to None.
        beta_debias (bool, optional): whether to use debiasing, only has effect when ``beta`` is not ``None``. Defaults to True.
        init (Literal[str], optional):
            how to initialize the accumulator.
            - "identity" - with identity matrix (default).
            - "zeros" - with zero matrix.
            - "ones" - with matrix of ones.
             -"GGT" - with the first outer product
        matrix_power (float, optional): accumulator matrix power. Defaults to -1/2.
        concat_params (bool, optional): if False, each parameter will have it's own accumulator. Defaults to True.
        inner (Chainable | None, optional): inner modules to apply preconditioning to. Defaults to None.

    ## Examples:

    Plain full-matrix adagrad
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.FullMatrixAdagrd(),
        tz.m.LR(1e-2),
    )
    ```

    Full-matrix RMSprop
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.FullMatrixAdagrad(beta=0.99),
        tz.m.LR(1e-2),
    )
    ```

    Full-matrix Adam
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9)),
        tz.m.Debias(0.9, 0.999),
        tz.m.LR(1e-2),
    )
    ```
    """
    def __init__(
        self,
        reg: float = 1e-12,
        precond_freq: int = 1,
        beta: float | None = None,
        beta_debias: bool=True,
        init: Literal["identity", "zeros", "GGT"] = "identity",
        matrix_power: float = -1/2,
        matrix_power_method: MatrixPowerMethod = "eigh_abs",
        concat_params=True,

        inner: Chainable | None = None,
        accumulator_tfm: Chainable | None = None
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults["concat_params"], defaults["accumulator_tfm"]
        super().__init__(defaults=defaults, inner=inner, concat_params=concat_params)

        self.set_child("accumulator", accumulator_tfm)
        self.add_projected_keys("covariance", "accumulator")

    @torch.no_grad
    def single_tensor_update(self, tensor, param, grad, loss, state, setting):

        G = tensor.ravel()
        GGT = torch.outer(G, G)

        # initialize
        if "accumulator" not in state:
            init = setting['init']
            if init == 'identity': state['accumulator'] = torch.eye(GGT.size(0), device=GGT.device, dtype=GGT.dtype)
            elif init == 'zeros': state['accumulator'] =  torch.zeros_like(GGT)
            elif init == 'GGT': state['accumulator'] = GGT.clone()
            else: raise ValueError(init)

        # update
        beta = setting['beta']
        accumulator: torch.Tensor = state["accumulator"]

        if beta is None: accumulator.add_(GGT)
        else: accumulator.lerp_(GGT, 1-beta)

        # update number of GGᵀ in accumulator for divide
        state['num_GGTs'] = state.get('num_GGTs', 0) + 1

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        step = state.get('step', 0)
        state['step'] = step + 1

        accumulator: torch.Tensor = state['accumulator']
        accumulator = self.inner_step_tensors("accumulator", [accumulator], clone=True, must_exist=False)[0]

        precond_freq = setting['precond_freq']
        reg = setting['reg']
        beta = setting["beta"]

        # add regularizer
        if reg != 0:
            device = accumulator.device; dtype = accumulator.dtype
            accumulator = accumulator + torch.eye(accumulator.size(0), device=device, dtype=dtype).mul_(reg)

        # for single value use sqrt
        if tensor.numel() == 1:
            dir = tensor.mul_(accumulator.squeeze() ** setting["matrix_power"])

        # otherwise use matrix inverse square root
        else:

            # compute inverse square root and store to state
            try:
                if "B" not in state or step % precond_freq == 0:
                    B = state["B"] = _matrix_power(accumulator, setting["matrix_power"], method=setting["matrix_power_method"])
                else:
                    B = state["B"]

                dir = (B @ tensor.ravel()).view_as(tensor)

            # fallback to diagonal Adagrad on fail
            except torch.linalg.LinAlgError:
                dir = tensor.mul_(accumulator.diagonal() ** setting["matrix_power"])

        # debias
        if setting["beta_debias"] and beta is not None:
            num_GGTs = state.get('num_GGTs', 1)
            bias_correction = 1 - beta ** num_GGTs
            dir *= bias_correction ** 0.5

        return dir

GGT

Bases: torchzero.core.transform.TensorTransform

GGT method from https://arxiv.org/pdf/1806.02958

The update rule is to stack recent gradients into M and compute eigendecomposition of M M^T via eigendecomposition of M^T M.

This is equivalent to full-matrix Adagrad on recent gradients.

Parameters:

  • history_size (int, default: 100 ) –

    number of past gradients to store. Defaults to 10.

  • update_freq (int, default: 1 ) –

    frequency of updating the preconditioner (U and S). Defaults to 1.

  • eig_tol (float, default: 1e-07 ) –

    removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.

  • truncate (int, default: None ) –

    number of larges eigenvalues to keep. None to disable. Defaults to None.

  • damping (float, default: 0.0001 ) –

    damping value. Defaults to 1e-4.

  • rdamping (float, default: 0 ) –

    value of damping relative to largest eigenvalue. Defaults to 0.

  • concat_params (bool, default: True ) –

    if True, treats all parameters as a single vector. Defaults to True.

  • inner (Chainable | None, default: None ) –

    preconditioner will be applied to output of this module. Defaults to None.

Examples:

Limited-memory Adagrad

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.GGT(),
    tz.m.LR(0.1)
)
Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.GGT(inner=tz.m.EMA()),
    tz.m.Debias(0.9, 0.999),
    tz.m.LR(0.01)
)

Stable Adam with L-Adagrad preconditioner (this is what I would recommend)

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.GGT(inner=tz.m.EMA()),
    tz.m.Debias(0.9, 0.999),
    tz.m.ClipNormByEMA(max_ema_growth=1.2),
    tz.m.LR(0.01)
)
Reference: Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.

Source code in torchzero/modules/adaptive/ggt.py
class GGT(TensorTransform):
    """
    GGT method from https://arxiv.org/pdf/1806.02958

    The update rule is to stack recent gradients into M and
    compute eigendecomposition of M M^T via eigendecomposition of M^T M.

    This is equivalent to full-matrix Adagrad on recent gradients.

    Args:
        history_size (int, optional): number of past gradients to store. Defaults to 10.
        update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
        eig_tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
        truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
        damping (float, optional): damping value. Defaults to 1e-4.
        rdamping (float, optional): value of damping relative to largest eigenvalue. Defaults to 0.
        concat_params (bool, optional): if True, treats all parameters as a single vector. Defaults to True.
        inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.

    ## Examples:

    Limited-memory Adagrad

    ```python
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.GGT(),
        tz.m.LR(0.1)
    )
    ```
    Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)

    ```python
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.GGT(inner=tz.m.EMA()),
        tz.m.Debias(0.9, 0.999),
        tz.m.LR(0.01)
    )
    ```

    Stable Adam with L-Adagrad preconditioner (this is what I would recommend)

    ```python
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.GGT(inner=tz.m.EMA()),
        tz.m.Debias(0.9, 0.999),
        tz.m.ClipNormByEMA(max_ema_growth=1.2),
        tz.m.LR(0.01)
    )
    ```
    Reference:
        Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
    """

    def __init__(
        self,
        history_size: int = 100,
        update_freq: int = 1,
        eig_tol: float = 1e-7,
        truncate: int | None = None,
        damping: float = 1e-4,
        rdamping: float = 0,
        matrix_power: float = -1/2,
        basis_optimizer: LREOptimizerBase | None = None,
        concat_params: bool = True,

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults['concat_params']

        super().__init__(defaults, concat_params=concat_params, inner=inner)
        self.add_projected_keys("grad", "history")

    @torch.no_grad
    def single_tensor_update(self, tensor, param, grad, loss, state, setting):
        history_size = setting['history_size']
        update_freq = setting['update_freq']

        if 'history' not in state: state['history'] = deque(maxlen=history_size)
        history = state['history']

        t = tensor.clone().view(-1)
        history.append(t)

        step = state.get('step', 0)
        state['step'] = step + 1

        if step % update_freq == 0 :

            # compute new factors
            L = state.get("L", None)
            U = state.get("U", None)

            L_new, U_new = ggt_update(
                history,
                damping=setting["damping"],
                rdamping=setting["rdamping"],
                truncate=setting["truncate"],
                eig_tol=setting["eig_tol"],
                matrix_power=setting["matrix_power"],
            )

            # reproject basis optimizer
            basis_optimizer: LREOptimizerBase | None = setting["basis_optimizer"]
            if basis_optimizer is not None:
                if (L is not None) and (U is not None) and (L_new is not None) and (U_new is not None):
                    basis_state = state["basis_state"]
                    basis_optimizer.reproject(L_old=L, Q_old=U, L_new=L_new, Q_new=U_new, state=basis_state)


            # store new factors
            if L_new is not None: state["L"] = L_new
            if U_new is not None: state["U"] = U_new


    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        g = tensor.view(-1)
        U = state.get('U', None)

        if U is None:
            # fallback to element-wise preconditioning
            history = torch.stack(tuple(state["history"]), 0)
            g /= history.square().mean(0).sqrt().add(1e-8)
            return g.view_as(tensor)

        L = state['L']

        # step with basis optimizer
        basis_optimizer: LREOptimizerBase | None = setting["basis_optimizer"]
        if basis_optimizer is not None:

            if "basis_state" not in state: state["basis_state"] = {}
            basis_state = state["basis_state"]

            update = basis_optimizer.step(g, L=L, Q=U, state=basis_state)
            return update.view_as(tensor)

        # or just whiten
        z = U.T @ g
        update = (U * L.pow(setting["matrix_power"])) @ z
        return update.view_as(tensor)

GGTBasis

Bases: torchzero.core.transform.TensorTransform

Run another optimizer in GGT eigenbasis. The eigenbasis is rank-sized, so it is possible to run expensive methods such as Full-matrix Adagrad/Adam.

The update rule is to stack recent gradients into M and compute eigendecomposition of M M^T via eigendecomposition of M^T M.

This is equivalent to full-matrix Adagrad on recent gradients.

Note

the buffers of the basis_opt are re-projected whenever basis changes. The reprojection logic is not implemented on all modules. Some supported modules are:

Adagrad, FullMatrixAdagrad, Adam, Adan, Lion, MARSCorrection, MSAMMomentum, RMSprop, GGT, EMA, HeavyBall, NAG, ClipNormByEMA, ClipValueByEMA, NormalizeByEMA, ClipValueGrowth, CoordinateMomentum, CubicAdam.

Additionally most modules with no internal buffers are supported, e.g. Cautious, Sign, ClipNorm, Orthogonalize, etc. However modules that use weight values, such as WeighDecay can't be supported, as weights can't be projected.

Also, if you say use EMA on output of Pow(2), the exponential average will be reprojected as gradient and not as squared gradients. Use modules like EMASquared, SqrtEMASquared to get correct reprojections.

Parameters:

  • basis_opt (Chainable) –

    module or modules to run in GGT eigenbasis.

  • history_size (int, default: 100 ) –

    number of past gradients to store, and rank of preconditioner. Defaults to 10.

  • update_freq (int, default: 1 ) –

    frequency of updating the preconditioner (U and S). Defaults to 1.

  • eig_tol (float, default: 1e-07 ) –

    removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.

  • truncate (int, default: None ) –

    number of larges eigenvalues to keep. None to disable. Defaults to None.

  • damping (float, default: 0.0001 ) –

    damping value. Defaults to 1e-4.

  • rdamping (float, default: 0 ) –

    value of damping relative to largest eigenvalue. Defaults to 0.

  • concat_params (bool) –

    if True, treats all parameters as a single vector. Defaults to True.

  • inner (Chainable | None, default: None ) –

    output of this module is projected and basis_opt will run on it, but preconditioners are updated from original gradients.

Examples:

Examples: Adam in GGT eigenbasis:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.GGTBasis(tz.m.Adam(beta2=0.99)),
    tz.m.LR(1e-3)
)

Full-matrix Adam in GGT eigenbasis. We can define full-matrix Adam through FullMatrixAdagrad.

opt = tz.Optimizer(
    model.parameters(),
    tz.m.GGTBasis(
        [tz.m.FullMatrixAdagrad(beta=0.99, inner=tz.m.EMA(0.9, debias=True))]
    ),
    tz.m.LR(1e-3)
)

LaProp in GGT eigenbasis:

# we define LaProp through other modules, moved it out for brevity
laprop = (
    tz.m.RMSprop(0.95),
    tz.m.Debias(beta1=None, beta2=0.95),
    tz.m.EMA(0.95),
    tz.m.Debias(beta1=0.95, beta2=None),
)

opt = tz.Optimizer(
    model.parameters(),
    tz.m.GGTBasis(laprop),
    tz.m.LR(1e-3)
)

Reference

Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.

Source code in torchzero/modules/basis/ggt_basis.py
class GGTBasis(TensorTransform):
    """
    Run another optimizer in GGT eigenbasis. The eigenbasis is ``rank``-sized, so it is possible to run expensive
    methods such as Full-matrix Adagrad/Adam.

    The update rule is to stack recent gradients into M and
    compute eigendecomposition of M M^T via eigendecomposition of M^T M.

    This is equivalent to full-matrix Adagrad on recent gradients.

    Note:
        the buffers of the ``basis_opt`` are re-projected whenever basis changes. The reprojection logic is not implemented on all modules. Some supported modules are:

        ``Adagrad``, ``FullMatrixAdagrad``, ``Adam``, ``Adan``, ``Lion``, ``MARSCorrection``, ``MSAMMomentum``, ``RMSprop``, ``GGT``, ``EMA``, ``HeavyBall``, ``NAG``, ``ClipNormByEMA``, ``ClipValueByEMA``, ``NormalizeByEMA``, ``ClipValueGrowth``, ``CoordinateMomentum``, ``CubicAdam``.

        Additionally most modules with no internal buffers are supported, e.g. ``Cautious``, ``Sign``, ``ClipNorm``, ``Orthogonalize``, etc. However modules that use weight values, such as ``WeighDecay`` can't be supported, as weights can't be projected.

        Also, if you say use ``EMA`` on output of ``Pow(2)``, the exponential average will be reprojected as gradient and not as squared gradients. Use modules like ``EMASquared``, ``SqrtEMASquared`` to get correct reprojections.


    Args:
        basis_opt (Chainable): module or modules to run in GGT eigenbasis.
        history_size (int, optional): number of past gradients to store, and rank of preconditioner. Defaults to 10.
        update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
        eig_tol (float, optional): removes eigenvalues this much smaller than largest eigenvalue. Defaults to 1e-7.
        truncate (int, optional): number of larges eigenvalues to keep. None to disable. Defaults to None.
        damping (float, optional): damping value. Defaults to 1e-4.
        rdamping (float, optional): value of damping relative to largest eigenvalue. Defaults to 0.
        concat_params (bool, optional): if True, treats all parameters as a single vector. Defaults to True.
        inner (Chainable | None, optional):
            output of this module is projected and ``basis_opt`` will run on it, but preconditioners are updated
            from original gradients.

    ## Examples:

    Examples:
    Adam in GGT eigenbasis:
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.GGTBasis(tz.m.Adam(beta2=0.99)),
        tz.m.LR(1e-3)
    )
    ```

    Full-matrix Adam in GGT eigenbasis. We can define full-matrix Adam through ``FullMatrixAdagrad``.
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.GGTBasis(
            [tz.m.FullMatrixAdagrad(beta=0.99, inner=tz.m.EMA(0.9, debias=True))]
        ),
        tz.m.LR(1e-3)
    )
    ```

    LaProp in GGT eigenbasis:
    ```python

    # we define LaProp through other modules, moved it out for brevity
    laprop = (
        tz.m.RMSprop(0.95),
        tz.m.Debias(beta1=None, beta2=0.95),
        tz.m.EMA(0.95),
        tz.m.Debias(beta1=0.95, beta2=None),
    )

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.GGTBasis(laprop),
        tz.m.LR(1e-3)
    )
    ```

    Reference:
        Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
    """

    def __init__(
        self,
        basis_opt: Chainable,
        history_size: int = 100,
        update_freq: int = 1,
        eig_tol: float = 1e-7,
        truncate: int | None = None,
        damping: float = 1e-4,
        rdamping: float = 0,
        matrix_power: float = -1/2,
        approx_sq_reproject:bool = False,
        approx_cu_reproject:bool = False,

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults["basis_opt"]

        super().__init__(defaults, concat_params=True, inner=inner)
        self.set_child("basis_opt", basis_opt)

    @torch.no_grad
    def single_tensor_update(self, tensor, param, grad, loss, state, setting):
        history_size = setting['history_size']
        update_freq = setting['update_freq']

        if 'history' not in state: state['history'] = deque(maxlen=history_size)
        history = state['history']

        t = tensor.clone().view(-1)
        history.append(t)

        step = state.get('step', 0)
        state['step'] = step + 1

        if step % update_freq == 0 :

            # compute new factors
            L = state.get("L", None)
            U = state.get("U", None)

            L_new, U_new = ggt_update(
                history,
                damping=setting["damping"],
                rdamping=setting["rdamping"],
                truncate=setting["truncate"],
                eig_tol=setting["eig_tol"],
                matrix_power=setting["matrix_power"],
            )

            if (L is not None) and (U is not None) and (L_new is not None) and (U_new is not None):
                # reproject basis optimizer
                # this happens after first step, so basis opt is initialized by then
                # note that because we concatenate parameters, each buffer will a single rank-length vector
                C = U_new.T @ U # change of basis matrix

                # reproject gradient-like buffers
                for (buff,) in self.get_child_projected_buffers("basis_opt", "grad"):
                    set_storage_(buff, C @ buff)

                # reproject covariance diagonal-like buffers
                for (buff,) in self.get_child_projected_buffers("basis_opt", "grad_sq"):
                    if setting["approx_sq_reproject"]: set_storage_(buff, C.pow(2) @ buff)
                    else: set_storage_(buff, (C @ buff.diag_embed() @ C.T).diagonal())

                # reproject third order diagonal-like buffers
                for (buff,) in self.get_child_projected_buffers("basis_opt", "grad_cu"):
                    buff_r = _cubic_reproject(C, buff, setting["approx_cu_reproject"])
                    set_storage_(buff, buff_r)

                # reproject covariance-like buffers
                for (buff,) in self.get_child_projected_buffers("basis_opt", "covariance"):
                    set_storage_(buff, C @ buff @ C.T)

            # store new factors
            if L_new is not None: state["L"] = L_new
            if U_new is not None: state["U"] = U_new


    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        g = tensor.view(-1)
        U = state.get('U', None)

        if U is None:
            # fallback to element-wise preconditioning
            history = torch.stack(tuple(state["history"]), 0)
            g /= history.square().mean(0).sqrt().add(1e-8)
            return g.view_as(tensor)

        # project
        g_proj = U.T @ g

        # step
        dir_proj = self.inner_step_tensors("basis_opt", tensors=[g_proj], clone=False, grads=[g_proj])[0]

        # unproject
        update = U @ dir_proj

        # update = (U * L.pow(setting["matrix_power"])) @ z
        return update.view_as(tensor)

GaussNewton

Bases: torchzero.core.transform.Transform

Gauss-newton method.

To use this, the closure should return a vector of values to minimize sum of squares of. Please add the backward argument, it will always be False but it is required. Gradients will be calculated via batched autograd within this module, you don't need to implement the backward pass. Please see below for an example.

Note

This method requires ndim^2 memory, however, if it is used within tz.m.TrustCG trust region, the memory requirement is ndim*m, where m is number of values in the output.

Parameters:

  • reg (float, default: 1e-08 ) –

    regularization parameter. Defaults to 1e-8.

  • update_freq (int, default: 1 ) –

    frequency of computing the jacobian. When jacobian is not computed, only residuals are computed and updated. Defaults to 1.

  • batched (bool, default: True ) –

    whether to use vmapping. Defaults to True.

Examples:

minimizing the rosenbrock function:

def rosenbrock(X):
    x1, x2 = X
    return torch.stack([(1 - x1), 100 * (x2 - x1**2)])

X = torch.tensor([-1.1, 2.5], requires_grad=True)
opt = tz.Optimizer([X], tz.m.GaussNewton(), tz.m.Backtracking())

# define the closure for line search
def closure(backward=True):
    return rosenbrock(X)

# minimize
for iter in range(10):
    loss = opt.step(closure)
    print(f'{loss = }')

training a neural network with a matrix-free GN trust region:

X = torch.randn(64, 20)
y = torch.randn(64, 10)

model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
opt = tz.Optimizer(
    model.parameters(),
    tz.m.TrustCG(tz.m.GaussNewton()),
)

def closure(backward=True):
    y_hat = model(X) # (64, 10)
    return (y_hat - y).pow(2).mean(0) # (10, )

for i in range(100):
    losses = opt.step(closure)
    if i % 10 == 0:
        print(f'{losses.mean() = }')

Source code in torchzero/modules/least_squares/gn.py
class GaussNewton(Transform):
    """Gauss-newton method.

    To use this, the closure should return a vector of values to minimize sum of squares of.
    Please add the ``backward`` argument, it will always be False but it is required.
    Gradients will be calculated via batched autograd within this module, you don't need to
    implement the backward pass. Please see below for an example.

    Note:
        This method requires ``ndim^2`` memory, however, if it is used within ``tz.m.TrustCG`` trust region,
        the memory requirement is ``ndim*m``, where ``m`` is number of values in the output.

    Args:
        reg (float, optional): regularization parameter. Defaults to 1e-8.
        update_freq (int, optional):
            frequency of computing the jacobian. When jacobian is not computed, only residuals are computed and updated.
            Defaults to 1.
        batched (bool, optional): whether to use vmapping. Defaults to True.

    Examples:

    minimizing the rosenbrock function:
    ```python
    def rosenbrock(X):
        x1, x2 = X
        return torch.stack([(1 - x1), 100 * (x2 - x1**2)])

    X = torch.tensor([-1.1, 2.5], requires_grad=True)
    opt = tz.Optimizer([X], tz.m.GaussNewton(), tz.m.Backtracking())

    # define the closure for line search
    def closure(backward=True):
        return rosenbrock(X)

    # minimize
    for iter in range(10):
        loss = opt.step(closure)
        print(f'{loss = }')
    ```

    training a neural network with a matrix-free GN trust region:
    ```python
    X = torch.randn(64, 20)
    y = torch.randn(64, 10)

    model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.TrustCG(tz.m.GaussNewton()),
    )

    def closure(backward=True):
        y_hat = model(X) # (64, 10)
        return (y_hat - y).pow(2).mean(0) # (10, )

    for i in range(100):
        losses = opt.step(closure)
        if i % 10 == 0:
            print(f'{losses.mean() = }')
    ```
    """
    def __init__(self, reg:float = 1e-8, update_freq: int= 1, batched:bool=True, inner: Chainable | None = None):
        defaults=dict(update_freq=update_freq,batched=batched, reg=reg)
        super().__init__(defaults=defaults)
        if inner is not None: self.set_child('inner', inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]
        params = objective.params
        closure = objective.closure
        batched = fs['batched']
        update_freq = fs['update_freq']

        # compute residuals
        r = objective.loss
        if r is None:
            assert closure is not None
            with torch.enable_grad():
                r = objective.get_loss(backward=False) # n_residuals
                assert isinstance(r, torch.Tensor)

        if r.numel() == 1:
            r = r.view(1,1)
            warnings.warn("Gauss-newton got a single residual. Make sure objective function returns a vector of residuals.")

        # set sum of squares scalar loss and it's gradient to objective
        objective.loss = r.pow(2).sum()

        step = self.increment_counter("step", start=0)

        if step % update_freq == 0:

            # compute jacobian
            with torch.enable_grad():
                J_list = jacobian_wrt([r.ravel()], params, batched=batched)

            J = self.global_state["J"] = flatten_jacobian(J_list) # (n_residuals, ndim)

        else:
            J = self.global_state["J"]

        Jr = J.T @ r.detach() # (ndim)

        # if there are more residuals, solve (J^T J)x = J^T r, so we need Jr
        # otherwise solve (J J^T)z = r and set x = J^T z, so we need r
        n_residuals, ndim = J.shape
        if n_residuals >= ndim or "inner" in self.children:
            self.global_state["Jr"] = Jr

        else:
            self.global_state["r"] = r

        objective.grads = vec_to_tensors(Jr, objective.params)

        # set closure to calculate sum of squares for line searches etc
        if closure is not None:
            def sos_closure(backward=True):

                if backward:
                    objective.zero_grad()
                    with torch.enable_grad():
                        loss = closure(False).pow(2).sum()
                        loss.backward()
                    return loss

                loss = closure(False).pow(2).sum()
                return loss

            objective.closure = sos_closure

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        fs = settings[0]
        reg = fs['reg']

        J: torch.Tensor = self.global_state['J']
        nresiduals, ndim = J.shape
        if nresiduals >= ndim or "inner" in self.children:

            # (J^T J)v = J^T r
            Jr: torch.Tensor = self.global_state['Jr']

            # inner step
            if "inner" in self.children:

                # var.grad is set to unflattened Jr
                assert objective.grads is not None
                objective = self.inner_step("inner", objective, must_exist=True)
                Jr_list = objective.get_updates()
                Jr = torch.cat([t.ravel() for t in Jr_list])

            JtJ = J.T @ J # (ndim, ndim)
            if reg != 0:
                JtJ.add_(torch.eye(JtJ.size(0), device=JtJ.device, dtype=JtJ.dtype).mul_(reg))

            if nresiduals >= ndim:
                v, info = torch.linalg.solve_ex(JtJ, Jr) # pylint:disable=not-callable
            else:
                v = torch.linalg.lstsq(JtJ, Jr).solution # pylint:disable=not-callable

            objective.updates = vec_to_tensors(v, objective.params)
            return objective

        # else:
        # solve (J J^T)z = r and set v = J^T z
        # we need (J^T J)v = J^T r
        # if z is solution to (G G^T)z = r, and v = J^T z
        # then (J^T J)v = (J^T J) (J^T z) = J^T (J J^T) z = J^T r
        # therefore (J^T J)v = J^T r
        # also this gives a minimum norm solution

        r = self.global_state['r']

        JJT = J @ J.T # (nresiduals, nresiduals)
        if reg != 0:
            JJT.add_(torch.eye(JJT.size(0), device=JJT.device, dtype=JJT.dtype).mul_(reg))

        z, info = torch.linalg.solve_ex(JJT, r) # pylint:disable=not-callable
        v = J.T @ z

        objective.updates = vec_to_tensors(v, objective.params)
        return objective

    def get_H(self, objective=...):
        J = self.global_state['J']
        return linear_operator.AtA(J)

GaussianSmoothing

Bases: torchzero.modules.grad_approximation.rfdm.RandomizedFDM

Gradient approximation via Gaussian smoothing method.

Note

This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

Parameters:

  • h (float, default: 0.01 ) –

    finite difference step size of jvp_method is set to forward or central. Defaults to 1e-2.

  • n_samples (int, default: 100 ) –

    number of random gradient samples. Defaults to 100.

  • formula (Literal, default: 'forward2' ) –

    finite difference formula. Defaults to 'forward2'.

  • distribution (Literal, default: 'gaussian' ) –

    distribution. Defaults to "gaussian".

  • pre_generate (bool, default: True ) –

    whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.

  • seed (int | None | Generator, default: None ) –

    Seed for random generator. Defaults to None.

  • target (Literal, default: 'closure' ) –

    what to set on var. Defaults to "closure".

References

Yurii Nesterov, Vladimir Spokoiny. (2015). Random Gradient-Free Minimization of Convex Functions. https://gwern.net/doc/math/2015-nesterov.pdf

Source code in torchzero/modules/grad_approximation/rfdm.py
class GaussianSmoothing(RandomizedFDM):
    """
    Gradient approximation via Gaussian smoothing method.

    Note:
        This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
        and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

    Args:
        h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-2.
        n_samples (int, optional): number of random gradient samples. Defaults to 100.
        formula (_FD_Formula, optional): finite difference formula. Defaults to 'forward2'.
        distribution (Distributions, optional): distribution. Defaults to "gaussian".
        pre_generate (bool, optional):
            whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
        seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
        target (GradTarget, optional): what to set on var. Defaults to "closure".


    References:
        Yurii Nesterov, Vladimir Spokoiny. (2015). Random Gradient-Free Minimization of Convex Functions. https://gwern.net/doc/math/2015-nesterov.pdf
    """
    def __init__(
        self,
        h: float = 1e-2,
        n_samples: int = 100,
        formula: _FD_Formula = "forward2",
        distribution: Distributions = "gaussian",
        pre_generate: bool = True,
        return_approx_loss: bool = False,
        target: GradTarget = "closure",
        seed: int | None | torch.Generator = None,
    ):
        super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed, return_approx_loss=return_approx_loss)

Grad

Bases: torchzero.core.module.Module

Outputs the gradient

Source code in torchzero/modules/ops/utility.py
class Grad(Module):
    """Outputs the gradient"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def apply(self, objective):
        objective.updates = [g.clone() for g in objective.get_grads()]
        return objective

GradApproximator

Bases: torchzero.core.module.Module, abc.ABC

Base class for gradient approximations. This is an abstract class, to use it, subclass it and override approximate.

GradientApproximator modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure.

Parameters:

  • defaults (dict[str, Any] | None, default: None ) –

    dict with defaults. Defaults to None.

  • target (str, default: 'closure' ) –

    whether to set var.grad, var.update or 'var.closure`. Defaults to 'closure'.

Example:

Basic SPSA method implementation.

class SPSA(GradApproximator):
    def __init__(self, h=1e-3):
        defaults = dict(h=h)
        super().__init__(defaults)

    @torch.no_grad
    def approximate(self, closure, params, loss):
        perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]

        # evaluate params + perturbation
        torch._foreach_add_(params, perturbation)
        loss_plus = closure(False)

        # evaluate params - perturbation
        torch._foreach_sub_(params, perturbation)
        torch._foreach_sub_(params, perturbation)
        loss_minus = closure(False)

        # restore original params
        torch._foreach_add_(params, perturbation)

        # calculate SPSA gradients
        spsa_grads = []
        for p, pert in zip(params, perturbation):
            settings = self.settings[p]
            h = settings['h']
            d = (loss_plus - loss_minus) / (2*(h**2))
            spsa_grads.append(pert * d)

        # returns tuple: (grads, loss, loss_approx)
        # loss must be with initial parameters
        # since we only evaluated loss with perturbed parameters
        # we only have loss_approx
        return spsa_grads, None, loss_plus

Methods:

  • approximate

    Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!

  • pre_step

    This runs once before each step, whereas approximate may run multiple times per step if further modules

Source code in torchzero/modules/grad_approximation/grad_approximator.py
class GradApproximator(Module, ABC):
    """Base class for gradient approximations.
    This is an abstract class, to use it, subclass it and override `approximate`.

    GradientApproximator modifies the closure to evaluate the estimated gradients,
    and further closure-based modules will use the modified closure.

    Args:
        defaults (dict[str, Any] | None, optional): dict with defaults. Defaults to None.
        target (str, optional):
            whether to set `var.grad`, `var.update` or 'var.closure`. Defaults to 'closure'.

    Example:

    Basic SPSA method implementation.
    ```python
    class SPSA(GradApproximator):
        def __init__(self, h=1e-3):
            defaults = dict(h=h)
            super().__init__(defaults)

        @torch.no_grad
        def approximate(self, closure, params, loss):
            perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]

            # evaluate params + perturbation
            torch._foreach_add_(params, perturbation)
            loss_plus = closure(False)

            # evaluate params - perturbation
            torch._foreach_sub_(params, perturbation)
            torch._foreach_sub_(params, perturbation)
            loss_minus = closure(False)

            # restore original params
            torch._foreach_add_(params, perturbation)

            # calculate SPSA gradients
            spsa_grads = []
            for p, pert in zip(params, perturbation):
                settings = self.settings[p]
                h = settings['h']
                d = (loss_plus - loss_minus) / (2*(h**2))
                spsa_grads.append(pert * d)

            # returns tuple: (grads, loss, loss_approx)
            # loss must be with initial parameters
            # since we only evaluated loss with perturbed parameters
            # we only have loss_approx
            return spsa_grads, None, loss_plus
    ```
    """
    def __init__(self, defaults: dict[str, Any] | None = None, return_approx_loss:bool=False, target: GradTarget = 'closure'):
        super().__init__(defaults)
        self._target: GradTarget = target
        self._return_approx_loss = return_approx_loss

    @abstractmethod
    def approximate(self, closure: Callable, params: list[torch.Tensor], loss: torch.Tensor | None) -> tuple[Iterable[torch.Tensor], torch.Tensor | None, torch.Tensor | None]:
        """Returns a tuple: ``(grad, loss, loss_approx)``, make sure this resets parameters to their original values!"""

    def pre_step(self, objective: Objective) -> None:
        """This runs once before each step, whereas `approximate` may run multiple times per step if further modules
        evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""

    @torch.no_grad
    def update(self, objective):
        self.pre_step(objective)

        if objective.closure is None: raise RuntimeError("Gradient approximation requires closure")
        params, closure, loss = objective.params, objective.closure, objective.loss

        if self._target == 'closure':

            def approx_closure(backward=True):
                if backward:
                    # set loss to None because closure might be evaluated at different points
                    grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None)
                    for p, g in zip(params, grad): p.grad = g
                    if l is not None: return l
                    if self._return_approx_loss and l_approx is not None: return l_approx
                    return closure(False)

                return closure(False)

            objective.closure = approx_closure
            return

        # if var.grad is not None:
        #     warnings.warn('Using grad approximator when `var.grad` is already set.')
        grad, loss, loss_approx = self.approximate(closure=closure, params=params, loss=loss)
        if loss_approx is not None: objective.loss_approx = loss_approx
        if loss is not None: objective.loss = objective.loss_approx = loss
        if self._target == 'grad': objective.grads = list(grad)
        elif self._target == 'update': objective.updates = list(grad)
        else: raise ValueError(self._target)
        return

    def apply(self, objective):
        return objective

approximate

approximate(closure: Callable, params: list[Tensor], loss: Tensor | None) -> tuple[Iterable[Tensor], Tensor | None, Tensor | None]

Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!

Source code in torchzero/modules/grad_approximation/grad_approximator.py
@abstractmethod
def approximate(self, closure: Callable, params: list[torch.Tensor], loss: torch.Tensor | None) -> tuple[Iterable[torch.Tensor], torch.Tensor | None, torch.Tensor | None]:
    """Returns a tuple: ``(grad, loss, loss_approx)``, make sure this resets parameters to their original values!"""

pre_step

pre_step(objective: Objective) -> None

This runs once before each step, whereas approximate may run multiple times per step if further modules evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations.

Source code in torchzero/modules/grad_approximation/grad_approximator.py
def pre_step(self, objective: Objective) -> None:
    """This runs once before each step, whereas `approximate` may run multiple times per step if further modules
    evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""

GradSign

Bases: torchzero.core.transform.TensorTransform

Copies gradient sign to update.

Source code in torchzero/modules/misc/misc.py
class GradSign(TensorTransform):
    """Copies gradient sign to update."""
    def __init__(self):
        super().__init__(uses_grad=True)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        return [t.copysign_(g) for t,g in zip(tensors, grads)]

GradToNone

Bases: torchzero.core.module.Module

Sets grad attribute to None on objective.

Source code in torchzero/modules/ops/utility.py
class GradToNone(Module):
    """Sets ``grad`` attribute to None on ``objective``."""
    def __init__(self): super().__init__()
    def apply(self, objective):
        objective.grads = None
        return objective

GradientAccumulation

Bases: torchzero.core.module.Module

Uses n steps to accumulate gradients, after n gradients have been accumulated, they are passed to :code:modules and parameters are updates.

Accumulating gradients for n steps is equivalent to increasing batch size by n. Increasing the batch size is more computationally efficient, but sometimes it is not feasible due to memory constraints.

Note

Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.

Parameters:

  • n (int) –

    number of gradients to accumulate.

  • mean (bool, default: True ) –

    if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.

  • stop (bool, default: True ) –

    this module prevents next modules from stepping unless n gradients have been accumulate. Setting this argument to False disables that. Defaults to True.

Examples:

Adam with gradients accumulated for 16 batches.

opt = tz.Optimizer(
    model.parameters(),
    tz.m.GradientAccumulation(),
    tz.m.Adam(),
    tz.m.LR(1e-2),
)
Source code in torchzero/modules/misc/gradient_accumulation.py
class GradientAccumulation(Module):
    """Uses ``n`` steps to accumulate gradients, after ``n`` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.

    Accumulating gradients for ``n`` steps is equivalent to increasing batch size by ``n``. Increasing the batch size
    is more computationally efficient, but sometimes it is not feasible due to memory constraints.

    Note:
        Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.

    Args:
        n (int): number of gradients to accumulate.
        mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
        stop (bool, optional):
            this module prevents next modules from stepping unless ``n`` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.

    ## Examples:

    Adam with gradients accumulated for 16 batches.

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.GradientAccumulation(),
        tz.m.Adam(),
        tz.m.LR(1e-2),
    )
    ```
    """
    def __init__(self, n: int, mean=True, stop=True):
        defaults = dict(n=n, mean=mean, stop=stop)
        super().__init__(defaults)
        self.add_projected_keys("grad", "accumulator")


    @torch.no_grad
    def apply(self, objective):
        accumulator = self.get_state(objective.params, 'accumulator')
        settings = self.defaults
        n = settings['n']; mean = settings['mean']; stop = settings['stop']
        step = self.increment_counter("step", 0)

        # add update to accumulator
        torch._foreach_add_(accumulator, objective.get_updates())

        # step with accumulated updates
        if (step + 1) % n == 0:
            if mean:
                torch._foreach_div_(accumulator, n)

            objective.updates = accumulator

            # zero accumulator
            self.clear_state_keys('accumulator')

        else:
            # prevent update
            if stop:
                objective.updates = None
                objective.stop=True
                objective.skip_update=True

        return objective

GradientCorrection

Bases: torchzero.core.transform.TensorTransform

Estimates gradient at minima along search direction assuming function is quadratic.

This can useful as inner module for second order methods with inexact line search.

Example:

L-BFGS with gradient correction

opt = tz.Optimizer(
    model.parameters(),
    tz.m.LBFGS(inner=tz.m.GradientCorrection()),
    tz.m.Backtracking()
)
Reference

HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class GradientCorrection(TensorTransform):
    """
    Estimates gradient at minima along search direction assuming function is quadratic.

    This can useful as inner module for second order methods with inexact line search.

    ## Example:
    L-BFGS with gradient correction

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.LBFGS(inner=tz.m.GradientCorrection()),
        tz.m.Backtracking()
    )
    ```

    Reference:
        HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394

    """
    def __init__(self):
        super().__init__()

    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        if 'p_prev' not in states[0]:
            p_prev = unpack_states(states, tensors, 'p_prev', init=params)
            g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
            return tensors

        p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
        g_hat = gradient_correction(TensorList(tensors), params-p_prev, tensors-g_prev)

        p_prev.copy_(params)
        g_prev.copy_(tensors)
        return g_hat

GradientSampling

Bases: torchzero.core.reformulation.Reformulation

Samples and aggregates gradients and values at perturbed points.

This module can be used for gaussian homotopy and gradient sampling methods.

Parameters:

  • modules (Chainable | None, default: None ) –

    modules that will be optimizing the modified objective. if None, returns gradient of the modified objective as the update. Defaults to None.

  • sigma (float, default: 1.0 ) –

    initial magnitude of the perturbations. Defaults to 1.

  • n (int, default: 100 ) –

    number of perturbations per step. Defaults to 100.

  • aggregate (str, default: 'mean' ) –

    how to aggregate values and gradients - "mean" - uses mean of the gradients, as in gaussian homotopy. - "max" - uses element-wise maximum of the gradients. - "min" - uses element-wise minimum of the gradients. - "min-norm" - picks gradient with the lowest norm.

    Defaults to 'mean'.

  • distribution (Literal, default: 'gaussian' ) –

    distribution for random perturbations. Defaults to 'gaussian'.

  • include_x0 (bool, default: True ) –

    whether to include gradient at un-perturbed point. Defaults to True.

  • fixed (bool, default: True ) –

    if True, perturbations do not get replaced by new random perturbations until termination criteria is satisfied. Defaults to True.

  • pre_generate (bool, default: True ) –

    if True, perturbations are pre-generated before each step. This requires more memory to store all of them, but ensures they do not change when closure is evaluated multiple times. Defaults to True.

  • termination (TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None, default: None ) –

    a termination criteria module, sigma will be multiplied by decay when termination criteria is satisfied, and new perturbations will be generated if fixed. Defaults to None.

  • decay (float, default: 0.6666666666666666 ) –

    sigma multiplier on termination criteria. Defaults to 2/3.

  • reset_on_termination (bool, default: True ) –

    whether to reset states of all other modules on termination. Defaults to True.

  • sigma_strategy (str | None, default: None ) –

    strategy for adapting sigma. If condition is satisfied, sigma is multiplied by sigma_nplus, otherwise it is multiplied by sigma_nminus. - "grad-norm" - at least sigma_target gradients should have lower norm than at un-perturbed point. - "value" - at least sigma_target values (losses) should be lower than at un-perturbed point. - None - doesn't use adaptive sigma.

    This introduces a side-effect to the closure, so it should be left at None of you use trust region or line search to optimize the modified objective. Defaults to None.

  • sigma_target (int, default: 0.2 ) –

    number of elements to satisfy the condition in sigma_strategy. Defaults to 1.

  • sigma_nplus (float, default: 1.3333333333333333 ) –

    sigma multiplier when sigma_strategy condition is satisfied. Defaults to 4/3.

  • sigma_nminus (float, default: 0.6666666666666666 ) –

    sigma multiplier when sigma_strategy condition is not satisfied. Defaults to 2/3.

  • seed (int | None, default: None ) –

    seed. Defaults to None.

Source code in torchzero/modules/smoothing/sampling.py
class GradientSampling(Reformulation):
    """Samples and aggregates gradients and values at perturbed points.

    This module can be used for gaussian homotopy and gradient sampling methods.

    Args:
        modules (Chainable | None, optional):
            modules that will be optimizing the modified objective.
            if None, returns gradient of the modified objective as the update. Defaults to None.
        sigma (float, optional): initial magnitude of the perturbations. Defaults to 1.
        n (int, optional): number of perturbations per step. Defaults to 100.
        aggregate (str, optional):
            how to aggregate values and gradients
            - "mean" - uses mean of the gradients, as in gaussian homotopy.
            - "max" - uses element-wise maximum of the gradients.
            - "min" - uses element-wise minimum of the gradients.
            - "min-norm" - picks gradient with the lowest norm.

            Defaults to 'mean'.
        distribution (Distributions, optional): distribution for random perturbations. Defaults to 'gaussian'.
        include_x0 (bool, optional): whether to include gradient at un-perturbed point. Defaults to True.
        fixed (bool, optional):
            if True, perturbations do not get replaced by new random perturbations until termination criteria is satisfied. Defaults to True.
        pre_generate (bool, optional):
            if True, perturbations are pre-generated before each step.
            This requires more memory to store all of them,
            but ensures they do not change when closure is evaluated multiple times.
            Defaults to True.
        termination (TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None, optional):
            a termination criteria module, sigma will be multiplied by ``decay`` when termination criteria is satisfied,
            and new perturbations will be generated if ``fixed``. Defaults to None.
        decay (float, optional): sigma multiplier on termination criteria. Defaults to 2/3.
        reset_on_termination (bool, optional): whether to reset states of all other modules on termination. Defaults to True.
        sigma_strategy (str | None, optional):
            strategy for adapting sigma. If condition is satisfied, sigma is multiplied by ``sigma_nplus``,
            otherwise it is multiplied by ``sigma_nminus``.
            - "grad-norm" - at least ``sigma_target`` gradients should have lower norm than at un-perturbed point.
            - "value" - at least ``sigma_target`` values (losses) should be lower than at un-perturbed point.
            - None - doesn't use adaptive sigma.

            This introduces a side-effect to the closure, so it should be left at None of you use
            trust region or line search to optimize the modified objective.
            Defaults to None.
        sigma_target (int, optional):
            number of elements to satisfy the condition in ``sigma_strategy``. Defaults to 1.
        sigma_nplus (float, optional): sigma multiplier when ``sigma_strategy`` condition is satisfied. Defaults to 4/3.
        sigma_nminus (float, optional): sigma multiplier when ``sigma_strategy`` condition is not satisfied. Defaults to 2/3.
        seed (int | None, optional): seed. Defaults to None.
    """
    def __init__(
        self,
        modules: Chainable | None = None,
        sigma: float = 1.,
        n:int = 100,
        aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = 'mean',
        distribution: Distributions = 'gaussian',
        include_x0: bool = True,

        fixed: bool=True,
        pre_generate: bool = True,
        termination: TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None = None,
        decay: float = 2/3,
        reset_on_termination: bool = True,

        sigma_strategy: Literal['grad-norm', 'value'] | None = None,
        sigma_target: int | float = 0.2,
        sigma_nplus: float = 4/3,
        sigma_nminus: float = 2/3,

        seed: int | None = None,
    ):

        defaults = dict(sigma=sigma, n=n, aggregate=aggregate, distribution=distribution, seed=seed, include_x0=include_x0, fixed=fixed, decay=decay, reset_on_termination=reset_on_termination, sigma_strategy=sigma_strategy, sigma_target=sigma_target, sigma_nplus=sigma_nplus, sigma_nminus=sigma_nminus, pre_generate=pre_generate)
        super().__init__(defaults, modules)

        if termination is not None:
            self.set_child('termination', make_termination_criteria(extra=termination))

    @torch.no_grad
    def pre_step(self, objective):
        params = TensorList(objective.params)

        fixed = self.defaults['fixed']

        # check termination criteria
        if 'termination' in self.children:
            termination = cast(TerminationCriteriaBase, self.children['termination'])
            if termination.should_terminate(objective):

                # decay sigmas
                states = [self.state[p] for p in params]
                settings = [self.settings[p] for p in params]

                for state, setting in zip(states, settings):
                    if 'sigma' not in state: state['sigma'] = setting['sigma']
                    state['sigma'] *= setting['decay']

                # reset on sigmas decay
                if self.defaults['reset_on_termination']:
                    objective.post_step_hooks.append(partial(_reset_except_self, self=self))

                # clear perturbations
                self.global_state.pop('perts', None)

        # pre-generate perturbations if not already pre-generated or not fixed
        if self.defaults['pre_generate'] and (('perts' not in self.global_state) or (not fixed)):
            states = [self.state[p] for p in params]
            settings = [self.settings[p] for p in params]

            n = self.defaults['n'] - self.defaults['include_x0']
            generator = self.get_generator(params[0].device, self.defaults['seed'])

            perts = [params.sample_like(self.defaults['distribution'], generator=generator) for _ in range(n)]

            self.global_state['perts'] = perts

    @torch.no_grad
    def closure(self, backward, closure, params, objective):
        params = TensorList(params)
        loss_agg = None
        grad_agg = None

        states = [self.state[p] for p in params]
        settings = [self.settings[p] for p in params]
        sigma_inits = [s['sigma'] for s in settings]
        sigmas = [s.setdefault('sigma', si) for s, si in zip(states, sigma_inits)]

        include_x0 = self.defaults['include_x0']
        pre_generate = self.defaults['pre_generate']
        aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = self.defaults['aggregate']
        sigma_strategy: Literal['grad-norm', 'value'] | None = self.defaults['sigma_strategy']
        distribution = self.defaults['distribution']
        generator = self.get_generator(params[0].device, self.defaults['seed'])


        n_finite = 0
        n_good = 0
        f_0 = None; g_0 = None

        # evaluate at x_0
        if include_x0:
            f_0 = objective.get_loss(backward=backward)

            isfinite = math.isfinite(f_0)
            if isfinite:
                n_finite += 1
                loss_agg = f_0

            if backward:
                g_0 = objective.get_grads()
                if isfinite: grad_agg = g_0

        # evaluate at x_0 + p for each perturbation
        if pre_generate:
            perts = self.global_state['perts']
        else:
            perts = [None] * (self.defaults['n'] - include_x0)

        x_0 = [p.clone() for p in params]

        for pert in perts:
            loss = None; grad = None

            # generate if not pre-generated
            if pert is None:
                pert = params.sample_like(distribution, generator=generator)

            # add perturbation and evaluate
            pert = pert * sigmas
            torch._foreach_add_(params, pert)

            with torch.enable_grad() if backward else nullcontext():
                loss = closure(backward)

            if math.isfinite(loss):
                n_finite += 1

                # add loss
                if loss_agg is None:
                    loss_agg = loss
                else:
                    if aggregate == 'mean':
                        loss_agg += loss

                    elif (aggregate=='min') or (aggregate=='min-value') or (aggregate=='min-norm' and not backward):
                        loss_agg = loss_agg.clamp(max=loss)

                    elif aggregate == 'max':
                        loss_agg = loss_agg.clamp(min=loss)

                # add grad
                if backward:
                    grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
                    if grad_agg is None:
                        grad_agg = grad
                    else:
                        if aggregate == 'mean':
                            torch._foreach_add_(grad_agg, grad)

                        elif aggregate == 'min':
                            grad_agg_abs = torch._foreach_abs(grad_agg)
                            torch._foreach_minimum_(grad_agg_abs, torch._foreach_abs(grad))
                            grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]

                        elif aggregate == 'max':
                            grad_agg_abs = torch._foreach_abs(grad_agg)
                            torch._foreach_maximum_(grad_agg_abs, torch._foreach_abs(grad))
                            grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]

                        elif aggregate == 'min-norm':
                            if TensorList(grad).global_vector_norm() < TensorList(grad_agg).global_vector_norm():
                                grad_agg = grad
                                loss_agg = loss

                        elif aggregate == 'min-value':
                            if loss < loss_agg:
                                grad_agg = grad
                                loss_agg = loss

            # undo perturbation
            torch._foreach_copy_(params, x_0)

            # adaptive sigma
            # by value
            if sigma_strategy == 'value':
                if f_0 is None:
                    with torch.enable_grad() if backward else nullcontext():
                        f_0 = closure(False)

                if loss < f_0:
                    n_good += 1

            # by gradient norm
            elif sigma_strategy == 'grad-norm' and backward and math.isfinite(loss):
                assert grad is not None
                if g_0 is None:
                    with torch.enable_grad() if backward else nullcontext():
                        closure()
                        g_0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]

                if TensorList(grad).global_vector_norm() < TensorList(g_0).global_vector_norm():
                    n_good += 1

        # update sigma if strategy is enabled
        if sigma_strategy is not None:

            sigma_target = self.defaults['sigma_target']
            if isinstance(sigma_target, float):
                sigma_target = int(max(1, n_finite * sigma_target))

            if n_good >= sigma_target:
                key = 'sigma_nplus'
            else:
                key = 'sigma_nminus'

            for p in params:
                self.state[p]['sigma'] *= self.settings[p][key]

        # if no finite losses, just return inf
        if n_finite == 0:
            assert loss_agg is None and grad_agg is None
            loss = torch.tensor(torch.inf, dtype=params[0].dtype, device=params[0].device)
            grad = [torch.full_like(p, torch.inf) for p in params]
            return loss, grad

        assert loss_agg is not None

        # no post processing needed when aggregate is 'max', 'min', 'min-norm', 'min-value'
        if aggregate != 'mean':
            return loss_agg, grad_agg

        # on mean divide by number of evals
        loss_agg /= n_finite

        if backward:
            assert grad_agg is not None
            torch._foreach_div_(grad_agg, n_finite)

        return loss_agg, grad_agg

Graft

Bases: torchzero.modules.ops.multi.MultiOperationBase

Outputs direction output rescaled to have the same norm as magnitude output.

Parameters:

  • direction (Chainable) –

    module to use the direction from

  • magnitude (Chainable) –

    module to use the magnitude from

  • tensorwise (bool, default: True ) –

    whether to calculate norm per-tensor or globally. Defaults to True.

  • ord (float, default: 2 ) –

    norm order. Defaults to 2.

  • eps (float, default: 1e-06 ) –

    clips denominator to be no less than this value. Defaults to 1e-6.

  • strength (float, default: 1 ) –

    strength of grafting. Defaults to 1.

Example:

Shampoo grafted to Adam

opt = tz.Optimizer(
    model.parameters(),
    tz.m.GraftModules(
        direction = tz.m.Shampoo(),
        magnitude = tz.m.Adam(),
    ),
    tz.m.LR(1e-3)
)

Reference

Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803.

Source code in torchzero/modules/ops/multi.py
class Graft(MultiOperationBase):
    """Outputs ``direction`` output rescaled to have the same norm as ``magnitude`` output.

    Args:
        direction (Chainable): module to use the direction from
        magnitude (Chainable): module to use the magnitude from
        tensorwise (bool, optional): whether to calculate norm per-tensor or globally. Defaults to True.
        ord (float, optional): norm order. Defaults to 2.
        eps (float, optional): clips denominator to be no less than this value. Defaults to 1e-6.
        strength (float, optional): strength of grafting. Defaults to 1.

    ### Example:

    Shampoo grafted to Adam
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.GraftModules(
            direction = tz.m.Shampoo(),
            magnitude = tz.m.Adam(),
        ),
        tz.m.LR(1e-3)
    )
    ```

    Reference:
        [Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803.](https://arxiv.org/pdf/2002.11803)
    """
    def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:Metrics=2, eps:float = 1e-6, strength:float=1):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
        super().__init__(defaults, direction=direction, magnitude=magnitude)

    @torch.no_grad
    def transform(self, objective, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
        tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.defaults)
        return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)

GraftGradToUpdate

Bases: torchzero.core.transform.TensorTransform

Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update.

Source code in torchzero/modules/misc/misc.py
class GraftGradToUpdate(TensorTransform):
    """Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update."""
    def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, uses_grad=True)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
        return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)

GraftInputToOutput

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs tensors rescaled to have the same norm as magnitude(tensors).

Source code in torchzero/modules/ops/binary.py
class GraftInputToOutput(BinaryOperationBase):
    """Outputs ``tensors`` rescaled to have the same norm as ``magnitude(tensors)``."""
    def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, magnitude=magnitude)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
        return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)

GraftOutputToInput

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs magnitude(tensors) rescaled to have the same norm as tensors

Source code in torchzero/modules/ops/binary.py
class GraftOutputToInput(BinaryOperationBase):
    """Outputs ``magnitude(tensors)`` rescaled to have the same norm as ``tensors``"""

    def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, direction=direction)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], direction: list[torch.Tensor]):
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
        return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)

GraftToGrad

Bases: torchzero.core.transform.TensorTransform

Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient.

Source code in torchzero/modules/misc/misc.py
class GraftToGrad(TensorTransform):
    """Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient."""
    def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, uses_grad=True)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
        return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)

GraftToParams

Bases: torchzero.core.transform.TensorTransform

Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than eps.

Source code in torchzero/modules/misc/misc.py
class GraftToParams(TensorTransform):
    """Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than ``eps``."""
    def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-4):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
        return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)

GramSchimdt

Bases: torchzero.modules.ops.binary.BinaryOperationBase

outputs tensors made orthogonal to other(tensors) via Gram-Schmidt.

Source code in torchzero/modules/ops/binary.py
class GramSchimdt(BinaryOperationBase):
    """outputs tensors made orthogonal to ``other(tensors)`` via Gram-Schmidt."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
        update = TensorList(update); other = TensorList(other)
        min = torch.finfo(update[0].dtype).tiny * 2
        return update - (other*update) / (other*other).clip(min=min)

Greenstadt1

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Greenstadt's first Quasi-Newton method.

Note

a trust region or an accurate line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class Greenstadt1(_InverseHessianUpdateStrategyDefaults):
    """Greenstadt's first Quasi-Newton method.

    Note:
        a trust region or an accurate line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev)

Greenstadt2

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Greenstadt's second Quasi-Newton method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class Greenstadt2(_InverseHessianUpdateStrategyDefaults):
    """Greenstadt's second Quasi-Newton method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return greenstadt2_H_(H=H, s=s, y=y)

HagerZhang

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Hager-Zhang nonlinear conjugate gradient method,

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class HagerZhang(ConguateGradientBase):
    """Hager-Zhang nonlinear conjugate gradient method,

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
    """
    def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def get_beta(self, p, g, prev_g, prev_d):
        return hager_zhang_beta(g, prev_d, prev_g)

HeavyBall

Bases: torchzero.modules.momentum.momentum.EMA

Polyak's momentum (heavy-ball method).

Parameters:

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • dampening (float, default: 0 ) –

    momentum dampening. Defaults to 0.

  • debias (bool, default: False ) –

    whether to debias the EMA like in Adam. Defaults to False.

  • lerp (bool, default: False ) –

    whether to use linear interpolation, if True, this becomes exponential moving average. Defaults to False.

  • ema_init (str, default: 'update' ) –

    initial values for the EMA, "zeros" or "update".

  • target (Target) –

    target to apply EMA to. Defaults to 'update'.

Source code in torchzero/modules/momentum/momentum.py
class HeavyBall(EMA):
    """Polyak's momentum (heavy-ball method).

    Args:
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        dampening (float, optional): momentum dampening. Defaults to 0.
        debias (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
        lerp (bool, optional):
            whether to use linear interpolation, if True, this becomes exponential moving average. Defaults to False.
        ema_init (str, optional): initial values for the EMA, "zeros" or "update".
        target (Target, optional): target to apply EMA to. Defaults to 'update'.
    """
    def __init__(self, momentum:float=0.9, dampening:float=0, debias: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update'):
        super().__init__(momentum=momentum, dampening=dampening, debias=debias, lerp=lerp, ema_init=ema_init)

HestenesStiefel

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Hestenes–Stiefel nonlinear conjugate gradient method.

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class HestenesStiefel(ConguateGradientBase):
    """Hestenes–Stiefel nonlinear conjugate gradient method.

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
    """
    def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def get_beta(self, p, g, prev_g, prev_d):
        return hestenes_stiefel_beta(g, prev_d, prev_g)

Horisho

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Horisho's variable metric Quasi-Newton method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class Horisho(_InverseHessianUpdateStrategyDefaults):
    """
    Horisho's variable metric Quasi-Newton method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
    """

    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return hoshino_H_(H=H, s=s, y=y, tol=setting['tol'])

HpuEstimate

Bases: torchzero.core.transform.TensorTransform

returns y/||s||, where y is difference between current and previous update (gradient), s is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update.

Source code in torchzero/modules/misc/misc.py
class HpuEstimate(TensorTransform):
    """returns ``y/||s||``, where ``y`` is difference between current and previous update (gradient), ``s`` is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update."""
    def __init__(self):
        defaults = dict()
        super().__init__(defaults)

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('prev_params', 'prev_update')

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        prev_params, prev_update = self.get_state(params, 'prev_params', 'prev_update') # initialized to 0
        s = torch._foreach_sub(params, prev_params)
        y = torch._foreach_sub(tensors, prev_update)
        for p, c in zip(prev_params, params): p.copy_(c)
        for p, c in zip(prev_update, tensors): p.copy_(c)
        torch._foreach_div_(y, torch.linalg.norm(torch.cat([t.ravel() for t in s])).clip(min=1e-8)) # pylint:disable=not-callable
        self.store(params, 'y', y)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        return [self.state[p]['y'] for p in params]

ICUM

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Inverse Column-updating Quasi-Newton method. This is computationally cheaper than other Quasi-Newton methods due to only updating one column of the inverse hessian approximation per step.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class ICUM(_InverseHessianUpdateStrategyDefaults):
    """
    Inverse Column-updating Quasi-Newton method. This is computationally cheaper than other Quasi-Newton methods
    due to only updating one column of the inverse hessian approximation per step.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return icum_H_(H=H, s=s, y=y)

Identity

Bases: torchzero.core.module.Module

Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

Source code in torchzero/modules/ops/utility.py
class Identity(Module):
    """Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
    def __init__(self, *args, **kwargs): super().__init__()
    def update(self, objective): pass
    def apply(self, objective): return objective
    def get_H(self, objective):
        n = sum(p.numel() for p in objective.params)
        p = objective.params[0]
        return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)

ImprovedNewton

Bases: torchzero.core.transform.Transform

Improved Newton's Method (INM).

Reference

Saheya, B., et al. "A new Newton-like method for solving nonlinear equations." SpringerPlus 5.1 (2016): 1269.

Source code in torchzero/modules/second_order/inm.py
class ImprovedNewton(Transform):
    """Improved Newton's Method (INM).

    Reference:
        [Saheya, B., et al. "A new Newton-like method for solving nonlinear equations." SpringerPlus 5.1 (2016): 1269.](https://d-nb.info/1112813721/34)
    """

    def __init__(
        self,
        damping: float = 0,
        eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
        eigv_tol: float | None = None,
        truncate: int | None = None,
        update_freq: int = 1,
        precompute_inverse: bool | None = None,
        use_lstsq: bool = False,
        hessian_method: HessianMethod = "batched_autograd",
        h: float = 1e-3,
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults["update_freq"]
        super().__init__(defaults, update_freq=update_freq, inner=inner, )

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]

        _, f_list, J = objective.hessian(
            hessian_method=fs['hessian_method'],
            h=fs['h'],
            at_x0=True
        )
        if f_list is None: f_list = objective.get_grads()

        f = torch.cat([t.ravel() for t in f_list])
        J = _eigval_fn(J, fs["eigval_fn"])

        x_list = TensorList(objective.params)
        f_list = TensorList(objective.get_grads())
        x_prev, f_prev = unpack_states(states, objective.params, "x_prev", "f_prev", cls=TensorList)

        # initialize on 1st step, do Newton step
        if "H" not in self.global_state:
            x_prev.copy_(x_list)
            f_prev.copy_(f_list)
            P = J

        # INM update
        else:
            s_list = x_list - x_prev
            y_list = f_list - f_prev
            x_prev.copy_(x_list)
            f_prev.copy_(f_list)

            P = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())

        # update state
        precompute_inverse = fs["precompute_inverse"]
        if precompute_inverse is None:
            precompute_inverse = fs["__update_freq"] >= 10

        _newton_update_state_(
            H=P,
            state = self.global_state,
            damping = fs["damping"],
            eigval_fn = fs["eigval_fn"],
            eigv_tol = fs["eigv_tol"],
            truncate = fs["truncate"],
            precompute_inverse = precompute_inverse,
            use_lstsq = fs["use_lstsq"]
        )

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        updates = objective.get_updates()
        fs = settings[0]

        b = torch.cat([t.ravel() for t in updates])
        sol = _newton_solve(b=b, state=self.global_state, use_lstsq=fs["use_lstsq"])

        vec_to_tensors_(sol, updates)
        return objective


    def get_H(self,objective=...):
        return _newton_get_H(self.global_state)

IntermoduleCautious

Bases: torchzero.core.module.Module

Negaties update on :code:main module where it's sign doesn't match with output of compare module.

Parameters:

  • main (Chainable) –

    main module or sequence of modules whose update will be cautioned.

  • compare (Chainable) –

    modules or sequence of modules to compare the sign to.

  • normalize (bool, default: False ) –

    renormalize update after masking. Defaults to False.

  • eps (float, default: 1e-06 ) –

    epsilon for normalization. Defaults to 1e-6.

  • mode (str, default: 'zero' ) –

    what to do with updates with inconsistent signs. - "zero" - set them to zero (as in paper) - "grad" - set them to the gradient (same as using update magnitude and gradient sign) - "backtrack" - negate them

Source code in torchzero/modules/momentum/cautious.py
class IntermoduleCautious(Module):
    """Negaties update on :code:`main` module where it's sign doesn't match with output of ``compare`` module.

    Args:
        main (Chainable): main module or sequence of modules whose update will be cautioned.
        compare (Chainable): modules or sequence of modules to compare the sign to.
        normalize (bool, optional):
            renormalize update after masking. Defaults to False.
        eps (float, optional): epsilon for normalization. Defaults to 1e-6.
        mode (str, optional):
            what to do with updates with inconsistent signs.
            - "zero" - set them to zero (as in paper)
            - "grad" - set them to the gradient (same as using update magnitude and gradient sign)
            - "backtrack" - negate them
    """
    def __init__(
        self,
        main: Chainable,
        compare: Chainable,
        normalize=False,
        eps=1e-6,
        mode: Literal["zero", "grad", "backtrack"] = "zero",
    ):

        defaults = dict(normalize=normalize, eps=eps, mode=mode)
        super().__init__(defaults)

        self.set_child('main', main)
        self.set_child('compare', compare)

    def update(self, objective): raise RuntimeError
    def apply(self, objective): raise RuntimeError

    @torch.no_grad
    def step(self, objective):
        main = self.children['main']
        compare = self.children['compare']

        main_var = main.step(objective.clone(clone_updates=True))
        objective.update_attrs_from_clone_(main_var)

        compare_var = compare.step(objective.clone(clone_updates=True))
        objective.update_attrs_from_clone_(compare_var)

        mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.defaults)
        objective.updates = cautious_(
            TensorList(main_var.get_updates()),
            TensorList(compare_var.get_updates()),
            normalize=normalize,
            mode=mode,
            eps=eps,
        )

        return objective

InverseFreeNewton

Bases: torchzero.core.transform.Transform

Inverse-free newton's method

Reference Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.

Source code in torchzero/modules/second_order/ifn.py
class InverseFreeNewton(Transform):
    """Inverse-free newton's method

    Reference
        [Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.](https://www.jaac-online.com/article/doi/10.11948/20240428)
    """
    def __init__(
        self,
        update_freq: int = 1,
        hessian_method: HessianMethod = "batched_autograd",
        h: float = 1e-3,
        inner: Chainable | None = None,
    ):
        defaults = dict(hessian_method=hessian_method, h=h)
        super().__init__(defaults, update_freq=update_freq, inner=inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]

        _, _, H = objective.hessian(
            hessian_method=fs['hessian_method'],
            h=fs['h'],
            at_x0=True
        )

        self.global_state["H"] = H

        # inverse free part
        if 'Y' not in self.global_state:
            num = H.T
            denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable

            finfo = torch.finfo(H.dtype)
            self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))

        else:
            Y = self.global_state['Y']
            I2 = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
            I2 -= H @ Y
            self.global_state['Y'] = Y @ I2


    def apply_states(self, objective, states, settings):
        Y = self.global_state["Y"]
        g = torch.cat([t.ravel() for t in objective.get_updates()])
        objective.updates = vec_to_tensors(Y@g, objective.params)
        return objective

    def get_H(self,objective=...):
        return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])

LBFGS

Bases: torchzero.core.transform.TensorTransform

Limited-memory BFGS algorithm. A line search or trust region is recommended.

Parameters:

  • history_size (int, default: 10 ) –

    number of past parameter differences and gradient differences to store. Defaults to 10.

  • ptol (float | None, default: 1e-32 ) –

    skips updating the history if maximum absolute value of parameter difference is less than this value. Defaults to 1e-10.

  • ptol_restart (bool, default: False ) –

    If true, whenever parameter difference is less then ptol, L-BFGS state will be reset. Defaults to None.

  • gtol (float | None, default: 1e-32 ) –

    skips updating the history if if maximum absolute value of gradient difference is less than this value. Defaults to 1e-10.

  • ptol_restart (bool, default: False ) –

    If true, whenever gradient difference is less then gtol, L-BFGS state will be reset. Defaults to None.

  • sy_tol (float | None, default: 1e-32 ) –

    history will not be updated whenever s⋅y is less than this value (negative s⋅y means negative curvature)

  • scale_first (bool, default: True ) –

    makes first step, when hessian approximation is not available, small to reduce number of line search iterations. Defaults to True.

  • update_freq (int, default: 1 ) –

    how often to update L-BFGS history. Larger values may be better for stochastic optimization. Defaults to 1.

  • damping (Union, default: None ) –

    damping to use, can be "powell" or "double". Defaults to None.

  • inner (Chainable | None, default: None ) –

    optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.

Examples:

L-BFGS with line search

opt = tz.Optimizer(
    model.parameters(),
    tz.m.LBFGS(100),
    tz.m.Backtracking()
)

L-BFGS with trust region

opt = tz.Optimizer(
    model.parameters(),
    tz.m.TrustCG(tz.m.LBFGS())
)

Source code in torchzero/modules/quasi_newton/lbfgs.py
class LBFGS(TensorTransform):
    """Limited-memory BFGS algorithm. A line search or trust region is recommended.

    Args:
        history_size (int, optional):
            number of past parameter differences and gradient differences to store. Defaults to 10.
        ptol (float | None, optional):
            skips updating the history if maximum absolute value of
            parameter difference is less than this value. Defaults to 1e-10.
        ptol_restart (bool, optional):
            If true, whenever parameter difference is less then ``ptol``,
            L-BFGS state will be reset. Defaults to None.
        gtol (float | None, optional):
            skips updating the history if if maximum absolute value of
            gradient difference is less than this value. Defaults to 1e-10.
        ptol_restart (bool, optional):
            If true, whenever gradient difference is less then ``gtol``,
            L-BFGS state will be reset. Defaults to None.
        sy_tol (float | None, optional):
            history will not be updated whenever s⋅y is less than this value (negative s⋅y means negative curvature)
        scale_first (bool, optional):
            makes first step, when hessian approximation is not available,
            small to reduce number of line search iterations. Defaults to True.
        update_freq (int, optional):
            how often to update L-BFGS history. Larger values may be better for stochastic optimization. Defaults to 1.
        damping (DampingStrategyType, optional):
            damping to use, can be "powell" or "double". Defaults to None.
        inner (Chainable | None, optional):
            optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.

    ## Examples:

    L-BFGS with line search
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.LBFGS(100),
        tz.m.Backtracking()
    )
    ```

    L-BFGS with trust region
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.TrustCG(tz.m.LBFGS())
    )
    ```
    """
    def __init__(
        self,
        history_size=10,
        ptol: float | None = 1e-32,
        ptol_restart: bool = False,
        gtol: float | None = 1e-32,
        gtol_restart: bool = False,
        sy_tol: float = 1e-32,
        scale_first:bool=True,
        update_freq = 1,
        damping: DampingStrategyType = None,
        inner: Chainable | None = None,
    ):
        defaults = dict(
            history_size=history_size,
            scale_first=scale_first,
            ptol=ptol,
            gtol=gtol,
            ptol_restart=ptol_restart,
            gtol_restart=gtol_restart,
            sy_tol=sy_tol,
            damping = damping,
        )
        super().__init__(defaults, inner=inner, update_freq=update_freq)

        self.global_state['s_history'] = deque(maxlen=history_size)
        self.global_state['y_history'] = deque(maxlen=history_size)
        self.global_state['sy_history'] = deque(maxlen=history_size)

    def _reset_self(self):
        self.state.clear()
        self.global_state['step'] = 0
        self.global_state['s_history'].clear()
        self.global_state['y_history'].clear()
        self.global_state['sy_history'].clear()

    def reset(self):
        self._reset_self()
        for c in self.children.values(): c.reset()

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('p_prev', 'g_prev')
        self.global_state.pop('step', None)

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        p = as_tensorlist(params)
        g = as_tensorlist(tensors)
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        # history of s and k
        s_history: deque[TensorList] = self.global_state['s_history']
        y_history: deque[TensorList] = self.global_state['y_history']
        sy_history: deque[torch.Tensor] = self.global_state['sy_history']

        ptol = self.defaults['ptol']
        gtol = self.defaults['gtol']
        ptol_restart = self.defaults['ptol_restart']
        gtol_restart = self.defaults['gtol_restart']
        sy_tol = self.defaults['sy_tol']
        damping = self.defaults['damping']

        p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)

        # 1st step - there are no previous params and grads, lbfgs will do normalized SGD step
        if step == 0:
            s = None; y = None; sy = None
        else:
            s = p - p_prev
            y = g - g_prev

            if damping is not None:
                s, y = apply_damping(damping, s=s, y=y, g=g, H=self.get_H())

            sy = s.dot(y)
            # damping to be added here

        below_tol = False
        # tolerance on parameter difference to avoid exploding after converging
        if ptol is not None:
            if s is not None and s.abs().global_max() <= ptol:
                if ptol_restart:
                    self._reset_self()
                sy = None
                below_tol = True

        # tolerance on gradient difference to avoid exploding when there is no curvature
        if gtol is not None:
            if y is not None and y.abs().global_max() <= gtol:
                if gtol_restart: self._reset_self()
                sy = None
                below_tol = True

        # store previous params and grads
        if not below_tol:
            p_prev.copy_(p)
            g_prev.copy_(g)

        # update effective preconditioning state
        if sy is not None and sy > sy_tol:
            assert s is not None and y is not None and sy is not None

            s_history.append(s)
            y_history.append(y)
            sy_history.append(sy)

    def get_H(self, objective=...):
        s_history = [tl.to_vec() for tl in self.global_state['s_history']]
        y_history = [tl.to_vec() for tl in self.global_state['y_history']]
        sy_history = self.global_state['sy_history']
        return LBFGSLinearOperator(s_history, y_history, sy_history)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        scale_first = self.defaults['scale_first']

        tensors = as_tensorlist(tensors)

        s_history = self.global_state['s_history']
        y_history = self.global_state['y_history']
        sy_history = self.global_state['sy_history']

        # precondition
        dir = lbfgs_Hx(
            x=tensors,
            s_history=s_history,
            y_history=y_history,
            sy_history=sy_history,
        )

        # scale 1st step
        if scale_first and self.global_state.get('step', 1) == 1:
            dir *= initial_step_size(dir, eps=1e-7)

        return dir

LR

Bases: torchzero.core.transform.TensorTransform

Learning rate. Adding this module also adds support for LR schedulers.

Source code in torchzero/modules/step_size/lr.py
class LR(TensorTransform):
    """Learning rate. Adding this module also adds support for LR schedulers."""
    def __init__(self, lr: float):
        defaults=dict(lr=lr)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)

LSR1

Bases: torchzero.core.transform.TensorTransform

Limited-memory SR1 algorithm. A line search or trust region is recommended.

Parameters:

  • history_size (int, default: 10 ) –

    number of past parameter differences and gradient differences to store. Defaults to 10.

  • ptol (float | None, default: None ) –

    skips updating the history if maximum absolute value of parameter difference is less than this value. Defaults to None.

  • ptol_restart (bool, default: False ) –

    If true, whenever parameter difference is less then ptol, L-SR1 state will be reset. Defaults to None.

  • gtol (float | None, default: None ) –

    skips updating the history if if maximum absolute value of gradient difference is less than this value. Defaults to None.

  • ptol_restart (bool, default: False ) –

    If true, whenever gradient difference is less then gtol, L-SR1 state will be reset. Defaults to None.

  • scale_first (bool, default: False ) –

    makes first step, when hessian approximation is not available, small to reduce number of line search iterations. Defaults to False.

  • update_freq (int, default: 1 ) –

    how often to update L-SR1 history. Larger values may be better for stochastic optimization. Defaults to 1.

  • damping (Union, default: None ) –

    damping to use, can be "powell" or "double". Defaults to None.

  • compact (bool) –

    if True, uses a compact representation verstion of L-SR1. It is much faster computationally, but less stable.

  • inner (Chainable | None, default: None ) –

    optional inner modules applied after updating L-SR1 history and before preconditioning. Defaults to None.

Examples:

L-SR1 with line search

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SR1(),
    tz.m.StrongWolfe(c2=0.1, fallback=True)
)

L-SR1 with trust region

opt = tz.Optimizer(
    model.parameters(),
    tz.m.TrustCG(tz.m.LSR1())
)

Source code in torchzero/modules/quasi_newton/lsr1.py
class LSR1(TensorTransform):
    """Limited-memory SR1 algorithm. A line search or trust region is recommended.

    Args:
        history_size (int, optional):
            number of past parameter differences and gradient differences to store. Defaults to 10.
        ptol (float | None, optional):
            skips updating the history if maximum absolute value of
            parameter difference is less than this value. Defaults to None.
        ptol_restart (bool, optional):
            If true, whenever parameter difference is less then ``ptol``,
            L-SR1 state will be reset. Defaults to None.
        gtol (float | None, optional):
            skips updating the history if if maximum absolute value of
            gradient difference is less than this value. Defaults to None.
        ptol_restart (bool, optional):
            If true, whenever gradient difference is less then ``gtol``,
            L-SR1 state will be reset. Defaults to None.
        scale_first (bool, optional):
            makes first step, when hessian approximation is not available,
            small to reduce number of line search iterations. Defaults to False.
        update_freq (int, optional):
            how often to update L-SR1 history. Larger values may be better for stochastic optimization. Defaults to 1.
        damping (DampingStrategyType, optional):
            damping to use, can be "powell" or "double". Defaults to None.
        compact (bool, optional):
            if True, uses a compact representation verstion of L-SR1. It is much faster computationally, but less stable.
        inner (Chainable | None, optional):
            optional inner modules applied after updating L-SR1 history and before preconditioning. Defaults to None.

    ## Examples:

    L-SR1 with line search
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SR1(),
        tz.m.StrongWolfe(c2=0.1, fallback=True)
    )
    ```

    L-SR1 with trust region
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.TrustCG(tz.m.LSR1())
    )
    ```
    """
    def __init__(
        self,
        history_size=10,
        ptol: float | None = None,
        ptol_restart: bool = False,
        gtol: float | None = None,
        gtol_restart: bool = False,
        scale_first:bool=False,
        update_freq = 1,
        damping: DampingStrategyType = None,
        inner: Chainable | None = None,
    ):
        defaults = dict(
            history_size=history_size,
            scale_first=scale_first,
            ptol=ptol,
            gtol=gtol,
            ptol_restart=ptol_restart,
            gtol_restart=gtol_restart,
            damping = damping,
        )
        super().__init__(defaults, inner=inner, update_freq=update_freq)

        self.global_state['s_history'] = deque(maxlen=history_size)
        self.global_state['y_history'] = deque(maxlen=history_size)

    def _reset_self(self):
        self.state.clear()
        self.global_state['step'] = 0
        self.global_state['s_history'].clear()
        self.global_state['y_history'].clear()

    def reset(self):
        self._reset_self()
        for c in self.children.values(): c.reset()

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('p_prev', 'g_prev')
        self.global_state.pop('step', None)

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        p = as_tensorlist(params)
        g = as_tensorlist(tensors)
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        # history of s and k
        s_history: deque = self.global_state['s_history']
        y_history: deque = self.global_state['y_history']

        ptol = self.defaults['ptol']
        gtol = self.defaults['gtol']
        ptol_restart = self.defaults['ptol_restart']
        gtol_restart = self.defaults['gtol_restart']
        damping = self.defaults['damping']

        p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)

        # 1st step - there are no previous params and grads, lsr1 will do normalized SGD step
        if step == 0:
            s = None; y = None; sy = None
        else:
            s = p - p_prev
            y = g - g_prev

            if damping is not None:
                s, y = apply_damping(damping, s=s, y=y, g=g, H=self.get_H())

            sy = s.dot(y)
            # damping to be added here

        below_tol = False
        # tolerance on parameter difference to avoid exploding after converging
        if ptol is not None:
            if s is not None and s.abs().global_max() <= ptol:
                if ptol_restart: self._reset_self()
                sy = None
                below_tol = True

        # tolerance on gradient difference to avoid exploding when there is no curvature
        if gtol is not None:
            if y is not None and y.abs().global_max() <= gtol:
                if gtol_restart: self._reset_self()
                sy = None
                below_tol = True

        # store previous params and grads
        if not below_tol:
            p_prev.copy_(p)
            g_prev.copy_(g)

        # update effective preconditioning state
        if sy is not None:
            assert s is not None and y is not None and sy is not None

            s_history.append(s)
            y_history.append(y)

    def get_H(self, objective=...):
        s_history = [tl.to_vec() for tl in self.global_state['s_history']]
        y_history = [tl.to_vec() for tl in self.global_state['y_history']]
        return LSR1LinearOperator(s_history, y_history)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        scale_first = self.defaults['scale_first']

        tensors = as_tensorlist(tensors)

        s_history = self.global_state['s_history']
        y_history = self.global_state['y_history']

        # precondition
        dir = lsr1_Hx(
            x=tensors,
            s_history=s_history,
            y_history=y_history,
        )

        # scale 1st step
        if scale_first and self.global_state.get('step', 1) == 1:
            dir *= initial_step_size(dir, eps=1e-7)

        return dir

LambdaHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class LambdaHomotopy(HomotopyBase):
    def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
        defaults = dict(fn=fn)
        super().__init__(defaults)

    def loss_transform(self, loss): return self.defaults['fn'](loss)

LaplacianSmoothing

Bases: torchzero.core.transform.TensorTransform

Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.

Parameters:

  • sigma (float, default: 1 ) –

    controls the amount of smoothing. Defaults to 1.

  • layerwise (bool, default: True ) –

    If True, applies smoothing to each parameter's gradient separately, Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.

  • min_numel (int, default: 4 ) –

    minimum number of elements in a parameter to apply laplacian smoothing to. Only has effect if layerwise is True. Defaults to 4.

  • target (str) –

    what to set on var.

Examples: Laplacian Smoothing Gradient Descent optimizer as in the paper

opt = tz.Optimizer(
    model.parameters(),
    tz.m.LaplacianSmoothing(),
    tz.m.LR(1e-2),
)
Reference

Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.

Source code in torchzero/modules/smoothing/laplacian.py
class LaplacianSmoothing(TensorTransform):
    """Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.

    Args:
        sigma (float, optional): controls the amount of smoothing. Defaults to 1.
        layerwise (bool, optional):
            If True, applies smoothing to each parameter's gradient separately,
            Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
        min_numel (int, optional):
            minimum number of elements in a parameter to apply laplacian smoothing to.
            Only has effect if `layerwise` is True. Defaults to 4.
        target (str, optional):
            what to set on var.

    Examples:
    Laplacian Smoothing Gradient Descent optimizer as in the paper

    ```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.LaplacianSmoothing(),
        tz.m.LR(1e-2),
    )
    ```

    Reference:
        Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.

    """
    def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4):
        defaults = dict(sigma = sigma, layerwise=layerwise, min_numel=min_numel)
        super().__init__(defaults)
        # precomputed denominator for when layerwise=False
        self.global_state['full_denominator'] = None


    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        layerwise = settings[0]['layerwise']

        # layerwise laplacian smoothing
        if layerwise:

            # precompute the denominator for each layer and store it in each parameters state
            smoothed_target = TensorList()
            for p, t, state, setting in zip(params, tensors, states, settings):
                if p.numel() > setting['min_numel']:
                    if 'denominator' not in state: state['denominator'] = _precompute_denominator(p, setting['sigma'])
                    smoothed_target.append(torch.fft.ifft(torch.fft.fft(t.view(-1)) / state['denominator']).real.view_as(t)) #pylint:disable=not-callable
                else:
                    smoothed_target.append(t)

            return smoothed_target

        # else
        # full laplacian smoothing
        # precompute full denominator
        tensors = TensorList(tensors)
        if self.global_state.get('full_denominator', None) is None:
            self.global_state['full_denominator'] = _precompute_denominator(tensors.to_vec(), settings[0]['sigma'])

        # apply the smoothing
        vec = tensors.to_vec()
        return tensors.from_vec(torch.fft.ifft(torch.fft.fft(vec) / self.global_state['full_denominator']).real)#pylint:disable=not-callable

LastAbsoluteRatio

Bases: torchzero.core.transform.TensorTransform

Outputs ratio between absolute values of past two updates the numerator is determined by numerator argument.

Source code in torchzero/modules/misc/misc.py
class LastAbsoluteRatio(TensorTransform):
    """Outputs ratio between absolute values of past two updates the numerator is determined by ``numerator`` argument."""
    def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8):
        defaults = dict(numerator=numerator, eps=eps)
        super().__init__(defaults)
        self.add_projected_keys("grad", "prev")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
        numerator = settings[0]['numerator']
        eps = NumberList(s['eps'] for s in settings)

        torch._foreach_abs_(tensors)
        torch._foreach_clamp_min_(prev, eps)

        if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
        else: ratio = torch._foreach_div(prev, tensors)
        for p, c in zip(prev, tensors): p.set_(c)
        return ratio

LastDifference

Bases: torchzero.core.transform.TensorTransform

Outputs difference between past two updates.

Source code in torchzero/modules/misc/misc.py
class LastDifference(TensorTransform):
    """Outputs difference between past two updates."""
    def __init__(self,):
        super().__init__()
        self.add_projected_keys("grad", "prev_tensors")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        prev_tensors = unpack_states(states, tensors, 'prev_tensors') # initialized to 0
        difference = torch._foreach_sub(tensors, prev_tensors)
        for p, c in zip(prev_tensors, tensors): p.set_(c)
        return difference

LastGradDifference

Bases: torchzero.core.module.Module

Outputs difference between past two gradients.

Source code in torchzero/modules/misc/misc.py
class LastGradDifference(Module):
    """Outputs difference between past two gradients."""
    def __init__(self):
        super().__init__()
        self.add_projected_keys("grad", "prev_grad")

    @torch.no_grad
    def apply(self, objective):
        grad = objective.get_grads()
        prev_grad = self.get_state(objective.params, 'prev_grad') # initialized to 0
        difference = torch._foreach_sub(grad, prev_grad)
        for p, c in zip(prev_grad, grad): p.copy_(c)
        objective.updates = list(difference)
        return objective

LastProduct

Bases: torchzero.core.transform.TensorTransform

Outputs difference between past two updates.

Source code in torchzero/modules/misc/misc.py
class LastProduct(TensorTransform):
    """Outputs difference between past two updates."""
    def __init__(self):
        super().__init__()
        self.add_projected_keys("grad", "prev")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
        prod = torch._foreach_mul(tensors, prev)
        for p, c in zip(prev, tensors): p.set_(c)
        return prod

LastRatio

Bases: torchzero.core.transform.TensorTransform

Outputs ratio between past two updates, the numerator is determined by numerator argument.

Source code in torchzero/modules/misc/misc.py
class LastRatio(TensorTransform):
    """Outputs ratio between past two updates, the numerator is determined by ``numerator`` argument."""
    def __init__(self, numerator: Literal['cur', 'prev'] = 'cur'):
        defaults = dict(numerator=numerator)
        super().__init__(defaults)
        self.add_projected_keys("grad", "prev")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
        numerator = settings[0]['numerator']
        if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
        else: ratio = torch._foreach_div(prev, tensors)
        for p, c in zip(prev, tensors): p.set_(c)
        return ratio

LerpModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Does a linear interpolation of input(tensors) and end(tensors) based on a scalar weight.

The output is given by output = input(tensors) + weight * (end(tensors) - input(tensors))

Source code in torchzero/modules/ops/multi.py
class LerpModules(MultiOperationBase):
    """Does a linear interpolation of ``input(tensors)`` and ``end(tensors)`` based on a scalar ``weight``.

    The output is given by ``output = input(tensors) + weight * (end(tensors) - input(tensors))``
    """
    def __init__(self, input: Chainable, end: Chainable, weight: float):
        defaults = dict(weight=weight)
        super().__init__(defaults, input=input, end=end)

    @torch.no_grad
    def transform(self, objective: Objective, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
        torch._foreach_lerp_(input, end, weight=self.defaults['weight'])
        return input

LevenbergMarquardt

Bases: torchzero.modules.trust_region.trust_region.TrustRegionBase

Levenberg-Marquardt trust region algorithm.

Parameters:

  • hess_module (Module | None) –

    A module that maintains a hessian approximation (not hessian inverse!). This includes all full-matrix quasi-newton methods, tz.m.Newton and tz.m.GaussNewton. When using quasi-newton methods, set inverse=False when constructing them.

  • y (float, default: 0 ) –

    when y=0, identity matrix is added to hessian, when y=1, diagonal of the hessian approximation is added. Values between interpolate. This should only be used with Gauss-Newton. Defaults to 0.

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. When hess_module is Newton or GaussNewton, this can be set to 0. Defaults to 0.15.

  • nplus (float, default: 3.5 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.99 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.0001 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • max_attempts (max_attempts, default: 10 ) –

    maximum number of trust region size size reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • adaptive (bool, default: False ) –

    if True, trust radius is multiplied by square root of gradient norm.

  • fallback (bool, default: False ) –

    if True, when hess_module maintains hessian inverse which can't be inverted efficiently, it will be inverted anyway. When False (default), a RuntimeError will be raised instead.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to output of thise module. Defaults to None.

Examples:

Gauss-Newton with Levenberg-Marquardt trust-region

opt = tz.Optimizer(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
)

LM-SR1

opt = tz.Optimizer(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
)

Source code in torchzero/modules/trust_region/levenberg_marquardt.py
class LevenbergMarquardt(TrustRegionBase):
    """Levenberg-Marquardt trust region algorithm.


    Args:
        hess_module (Module | None, optional):
            A module that maintains a hessian approximation (not hessian inverse!).
            This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
            When using quasi-newton methods, set ``inverse=False`` when constructing them.
        y (float, optional):
            when ``y=0``, identity matrix is added to hessian, when ``y=1``, diagonal of the hessian approximation
            is added. Values between interpolate. This should only be used with Gauss-Newton. Defaults to 0.
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted.
            When ``hess_module`` is ``Newton`` or ``GaussNewton``, this can be set to 0. Defaults to 0.15.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        max_attempts (max_attempts, optional):
            maximum number of trust region size size reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        adaptive (bool, optional):
            if True, trust radius is multiplied by square root of gradient norm.
        fallback (bool, optional):
            if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
            be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
        inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.

    ### Examples:

    Gauss-Newton with Levenberg-Marquardt trust-region

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
    )
    ```

    LM-SR1
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
    )
    ```

    """
    def __init__(
        self,
        hess_module: Chainable,
        eta: float= 0.0,
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        init: float = 1,
        max_attempts: int = 10,
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
        y: float = 0,
        adaptive: bool = False,
        fallback: bool = False,
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict(y=y, fallback=fallback, adaptive=adaptive)
        super().__init__(
            defaults=defaults,
            hess_module=hess_module,
            eta=eta,
            nplus=nplus,
            nminus=nminus,
            rho_good=rho_good,
            rho_bad=rho_bad,
            init=init,
            max_attempts=max_attempts,
            radius_strategy=radius_strategy,
            update_freq=update_freq,
            inner=inner,

            boundary_tol=None,
            radius_fn=None,
        )

    def trust_solve(self, f, g, H, radius, params, closure, settings):
        y = settings['y']
        adaptive = settings["adaptive"]

        if isinstance(H, linear_operator.DenseInverse):
            if settings['fallback']:
                H = H.to_dense()
            else:
                raise RuntimeError(
                    f"{self.children['hess_module']} maintains a hessian inverse. "
                    "LevenbergMarquardt requires the hessian, not the inverse. "
                    "If that module is a quasi-newton module, pass `inverse=False` on initialization. "
                    "Or pass `fallback=True` to LevenbergMarquardt to allow inverting the hessian inverse, "
                    "however that can be inefficient and unstable."
                )

        reg = 1/radius
        if adaptive: reg = reg * torch.linalg.vector_norm(g).sqrt()

        if y == 0:
            return H.solve_plus_diag(g, reg) # pyright:ignore[reportAttributeAccessIssue]

        diag = H.diagonal()
        diag = torch.where(diag < torch.finfo(diag.dtype).tiny * 2, 1, diag)
        if y != 1: diag = (diag*y) + (1-y)
        return H.solve_plus_diag(g, diag*reg)

LineSearchBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for line searches.

This is an abstract class, to use it, subclass it and override search.

Parameters:

  • defaults (dict[str, Any] | None) –

    dictionary with defaults.

  • maxiter (int | None, default: None ) –

    if this is specified, the search method will terminate upon evaluating the objective this many times, and step size with the lowest loss value will be used. This is useful when passing make_objective to an external library which doesn't have a maxiter option. Defaults to None.

Other useful methods
  • evaluate_f - returns loss with a given scalar step size
  • evaluate_f_d - returns loss and directional derivative with a given scalar step size
  • make_objective - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
  • make_objective_with_derivative - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.

Examples:

This evaluates all step sizes in a range by using the :code:self.evaluate_step_size method.

class GridLineSearch(LineSearch):
    def __init__(self, start, end, num):
        defaults = dict(start=start,end=end,num=num)
        super().__init__(defaults)

    @torch.no_grad
    def search(self, update, var):

        start = self.defaults["start"]
        end = self.defaults["end"]
        num = self.defaults["num"]

        lowest_loss = float("inf")
        best_step_size = best_step_size

        for step_size in torch.linspace(start,end,num):
            loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
            if loss < lowest_loss:
                lowest_loss = loss
                best_step_size = step_size

        return best_step_size

Using external solver via self.make_objective

Here we let :code:scipy.optimize.minimize_scalar solver find the best step size via :code:self.make_objective

class ScipyMinimizeScalar(LineSearch):
    def __init__(self, method: str | None = None):
        defaults = dict(method=method)
        super().__init__(defaults)

    @torch.no_grad
    def search(self, update, var):
        objective = self.make_objective(var=var)
        method = self.defaults["method"]

        res = self.scopt.minimize_scalar(objective, method=method)
        return res.x

Methods:

  • evaluate_f

    evaluate function value at alpha step_size.

  • evaluate_f_d

    evaluate function value and directional derivative in the direction of the update at step size step_size.

  • evaluate_f_d_g

    evaluate function value, directional derivative, and gradient list at step size step_size.

  • search

    Finds the step size to use

Source code in torchzero/modules/line_search/line_search.py
class LineSearchBase(Module, ABC):
    """Base class for line searches.

    This is an abstract class, to use it, subclass it and override `search`.

    Args:
        defaults (dict[str, Any] | None): dictionary with defaults.
        maxiter (int | None, optional):
            if this is specified, the search method will terminate upon evaluating
            the objective this many times, and step size with the lowest loss value will be used.
            This is useful when passing `make_objective` to an external library which
            doesn't have a maxiter option. Defaults to None.

    Other useful methods:
        * ``evaluate_f`` - returns loss with a given scalar step size
        * ``evaluate_f_d`` - returns loss and directional derivative with a given scalar step size
        * ``make_objective`` - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
        * ``make_objective_with_derivative`` - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.

    Examples:

    #### Basic line search

    This evaluates all step sizes in a range by using the :code:`self.evaluate_step_size` method.
    ```python
    class GridLineSearch(LineSearch):
        def __init__(self, start, end, num):
            defaults = dict(start=start,end=end,num=num)
            super().__init__(defaults)

        @torch.no_grad
        def search(self, update, var):

            start = self.defaults["start"]
            end = self.defaults["end"]
            num = self.defaults["num"]

            lowest_loss = float("inf")
            best_step_size = best_step_size

            for step_size in torch.linspace(start,end,num):
                loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
                if loss < lowest_loss:
                    lowest_loss = loss
                    best_step_size = step_size

            return best_step_size
    ```

    #### Using external solver via self.make_objective

    Here we let :code:`scipy.optimize.minimize_scalar` solver find the best step size via :code:`self.make_objective`

    ```python
    class ScipyMinimizeScalar(LineSearch):
        def __init__(self, method: str | None = None):
            defaults = dict(method=method)
            super().__init__(defaults)

        @torch.no_grad
        def search(self, update, var):
            objective = self.make_objective(var=var)
            method = self.defaults["method"]

            res = self.scopt.minimize_scalar(objective, method=method)
            return res.x
    ```
    """
    def __init__(self, defaults: dict[str, Any] | None, maxiter: int | None = None):
        super().__init__(defaults)
        self._maxiter = maxiter
        self._reset()

    def _reset(self):
        self._current_step_size: float = 0
        self._lowest_loss = float('inf')
        self._best_step_size: float = 0
        self._current_iter = 0
        self._initial_params = None

    def set_step_size_(
        self,
        step_size: float,
        params: list[torch.Tensor],
        update: list[torch.Tensor],
    ):
        if not math.isfinite(step_size): return

         # avoid overflow error
        step_size = clip_by_finfo(tofloat(step_size), torch.finfo(update[0].dtype))

        # skip is parameters are already at suggested step size
        if self._current_step_size == step_size: return

        assert self._initial_params is not None
        if step_size == 0:
            new_params = [p.clone() for p in self._initial_params]
        else:
            new_params = torch._foreach_sub(self._initial_params, update, alpha=step_size)

        for c, n in zip(params, new_params):
            set_storage_(c, n)

        self._current_step_size = step_size

    def _set_per_parameter_step_size_(
        self,
        step_size: Sequence[float],
        params: list[torch.Tensor],
        update: list[torch.Tensor],
    ):

        assert self._initial_params is not None
        if not np.isfinite(step_size).all(): step_size = [0 for _ in step_size]

        if any(s!=0 for s in step_size):
            new_params = torch._foreach_sub(self._initial_params, torch._foreach_mul(update, step_size))
        else:
            new_params = [p.clone() for p in self._initial_params]

        for c, n in zip(params, new_params):
            set_storage_(c, n)

    def _loss(self, step_size: float, var: Objective, closure, params: list[torch.Tensor],
              update: list[torch.Tensor], backward:bool=False) -> float:

        # if step_size is 0, we might already know the loss
        if (var.loss is not None) and (step_size == 0):
            return tofloat(var.loss)

        # check max iter
        if self._maxiter is not None and self._current_iter >= self._maxiter: raise MaxLineSearchItersReached
        self._current_iter += 1

        # set new lr and evaluate loss with it
        self.set_step_size_(step_size, params=params, update=update)
        if backward:
            with torch.enable_grad(): loss = closure()
        else:
            loss = closure(False)

        # if it is the best so far, record it
        if loss < self._lowest_loss:
            self._lowest_loss = tofloat(loss)
            self._best_step_size = step_size

        # if evaluated loss at step size 0, set it to var.loss
        if step_size == 0:
            var.loss = loss
            if backward: var.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]

        return tofloat(loss)

    def _loss_derivative_gradient(self, step_size: float, var: Objective, closure,
                         params: list[torch.Tensor], update: list[torch.Tensor]):
        # if step_size is 0, we might already know the derivative
        if (var.grads is not None) and (step_size == 0):
            loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=False)
            derivative = - sum(t.sum() for t in torch._foreach_mul(var.grads, update))

        else:
            # loss with a backward pass sets params.grad
            loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=True)

            # directional derivative
            derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
                                                                    else torch.zeros_like(p) for p in params], update))

        assert var.grads is not None
        return loss, tofloat(derivative), var.grads

    def _loss_derivative(self, step_size: float, var: Objective, closure,
                         params: list[torch.Tensor], update: list[torch.Tensor]):
        return self._loss_derivative_gradient(step_size=step_size, var=var,closure=closure,params=params,update=update)[:2]

    def evaluate_f(self, step_size: float, var: Objective, backward:bool=False):
        """evaluate function value at alpha `step_size`."""
        closure = var.closure
        if closure is None: raise RuntimeError('line search requires closure')
        return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates(),backward=backward)

    def evaluate_f_d(self, step_size: float, var: Objective):
        """evaluate function value and directional derivative in the direction of the update at step size `step_size`."""
        closure = var.closure
        if closure is None: raise RuntimeError('line search requires closure')
        return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates())

    def evaluate_f_d_g(self, step_size: float, var: Objective):
        """evaluate function value, directional derivative, and gradient list at step size `step_size`."""
        closure = var.closure
        if closure is None: raise RuntimeError('line search requires closure')
        return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates())

    def make_objective(self, var: Objective, backward:bool=False):
        closure = var.closure
        if closure is None: raise RuntimeError('line search requires closure')
        return partial(self._loss, var=var, closure=closure, params=var.params, update=var.get_updates(), backward=backward)

    def make_objective_with_derivative(self, var: Objective):
        closure = var.closure
        if closure is None: raise RuntimeError('line search requires closure')
        return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_updates())

    def make_objective_with_derivative_and_gradient(self, var: Objective):
        closure = var.closure
        if closure is None: raise RuntimeError('line search requires closure')
        return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.get_updates())

    @abstractmethod
    def search(self, update: list[torch.Tensor], var: Objective) -> float:
        """Finds the step size to use"""

    @torch.no_grad
    def apply(self, objective: Objective) -> Objective:
        self._reset()

        params = objective.params
        self._initial_params = [p.clone() for p in params]
        update = objective.get_updates()

        try:
            step_size = self.search(update=update, var=objective)
        except MaxLineSearchItersReached:
            step_size = self._best_step_size

        step_size = clip_by_finfo(step_size, torch.finfo(update[0].dtype))

        # set loss_approx
        if objective.loss_approx is None: objective.loss_approx = self._lowest_loss

        # if this is last module, directly update parameters to avoid redundant operations
        if objective.modular is not None and self is objective.modular.modules[-1]:
            self.set_step_size_(step_size, params=params, update=update)

            objective.stop = True; objective.skip_update = True
            return objective

        # revert parameters and multiply update by step size
        self.set_step_size_(0, params=params, update=update)
        torch._foreach_mul_(objective.updates, step_size)
        return objective

evaluate_f

evaluate_f(step_size: float, var: Objective, backward: bool = False)

evaluate function value at alpha step_size.

Source code in torchzero/modules/line_search/line_search.py
def evaluate_f(self, step_size: float, var: Objective, backward:bool=False):
    """evaluate function value at alpha `step_size`."""
    closure = var.closure
    if closure is None: raise RuntimeError('line search requires closure')
    return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates(),backward=backward)

evaluate_f_d

evaluate_f_d(step_size: float, var: Objective)

evaluate function value and directional derivative in the direction of the update at step size step_size.

Source code in torchzero/modules/line_search/line_search.py
def evaluate_f_d(self, step_size: float, var: Objective):
    """evaluate function value and directional derivative in the direction of the update at step size `step_size`."""
    closure = var.closure
    if closure is None: raise RuntimeError('line search requires closure')
    return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates())

evaluate_f_d_g

evaluate_f_d_g(step_size: float, var: Objective)

evaluate function value, directional derivative, and gradient list at step size step_size.

Source code in torchzero/modules/line_search/line_search.py
def evaluate_f_d_g(self, step_size: float, var: Objective):
    """evaluate function value, directional derivative, and gradient list at step size `step_size`."""
    closure = var.closure
    if closure is None: raise RuntimeError('line search requires closure')
    return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates())

search

search(update: list[Tensor], var: Objective) -> float

Finds the step size to use

Source code in torchzero/modules/line_search/line_search.py
@abstractmethod
def search(self, update: list[torch.Tensor], var: Objective) -> float:
    """Finds the step size to use"""

Lion

Bases: torchzero.core.transform.TensorTransform

Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.

Parameters:

  • beta1 (float, default: 0.9 ) –

    dampening for momentum. Defaults to 0.9.

  • beta2 (float, default: 0.99 ) –

    momentum factor. Defaults to 0.99.

Source code in torchzero/modules/adaptive/lion.py
class Lion(TensorTransform):
    """Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.

    Args:
        beta1 (float, optional): dampening for momentum. Defaults to 0.9.
        beta2 (float, optional): momentum factor. Defaults to 0.99.
    """

    def __init__(self, beta1: float = 0.9, beta2: float = 0.99):
        defaults = dict(beta1=beta1, beta2=beta2)
        super().__init__(defaults)

        self.add_projected_keys("grad", "exp_avg")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
        exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList)
        return lion_(TensorList(tensors), exp_avg, beta1, beta2)

LiuStorey

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Liu-Storey nonlinear conjugate gradient method.

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class LiuStorey(ConguateGradientBase):
    """Liu-Storey nonlinear conjugate gradient method.

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
    """
    def __init__(self, restart_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def get_beta(self, p, g, prev_g, prev_d):
        return liu_storey_beta(g, prev_d, prev_g)

LogHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class LogHomotopy(HomotopyBase):
    def __init__(self): super().__init__()
    def loss_transform(self, loss): return (loss+1e-12).log()

MARSCorrection

Bases: torchzero.core.transform.TensorTransform

MARS variance reduction correction.

Place any other momentum-based optimizer after this, make sure beta parameter matches with momentum in the optimizer.

Parameters:

  • beta (float, default: 0.9 ) –

    use the same beta as you use in the momentum module. Defaults to 0.9.

  • scaling (float, default: 0.025 ) –

    controls the scale of gradient correction in variance reduction. Defaults to 0.025.

  • max_norm (float, default: 1 ) –

    clips norm of corrected gradients, None to disable. Defaults to 1.

Examples:

Mars-AdamW

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.MARSCorrection(beta=0.95),
    tz.m.Adam(beta1=0.95, beta2=0.99),
    tz.m.WeightDecay(1e-3),
    tz.m.LR(0.1)
)

Mars-Lion

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.MARSCorrection(beta=0.9),
    tz.m.Lion(beta1=0.9),
    tz.m.LR(0.1)
)

Source code in torchzero/modules/adaptive/mars.py
class MARSCorrection(TensorTransform):
    """MARS variance reduction correction.

    Place any other momentum-based optimizer after this,
    make sure ``beta`` parameter matches with momentum in the optimizer.

    Args:
        beta (float, optional): use the same beta as you use in the momentum module. Defaults to 0.9.
        scaling (float, optional): controls the scale of gradient correction in variance reduction. Defaults to 0.025.
        max_norm (float, optional): clips norm of corrected gradients, None to disable. Defaults to 1.

    ## Examples:

    Mars-AdamW
    ```python
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.MARSCorrection(beta=0.95),
        tz.m.Adam(beta1=0.95, beta2=0.99),
        tz.m.WeightDecay(1e-3),
        tz.m.LR(0.1)
    )
    ```

    Mars-Lion
    ```python
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.MARSCorrection(beta=0.9),
        tz.m.Lion(beta1=0.9),
        tz.m.LR(0.1)
    )
    ```

    """
    def __init__(
        self,
        beta: float = 0.9,
        scaling: float = 0.025,
        max_norm: float | None = 1,
    ):
        defaults = dict(beta=beta, scaling=scaling, max_norm=max_norm)
        super().__init__(defaults)
        self.add_projected_keys("grad", "g_prev")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        g_prev = unpack_states(states, tensors, 'g_prev', init=tensors, cls=TensorList)
        beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
        max_norm = settings[0]['max_norm']

        return mars_correction_(
            tensors_=TensorList(tensors),
            g_prev_=g_prev,
            beta=beta,
            scaling=scaling,
            max_norm=max_norm,
        )

MSAM

Bases: torchzero.core.transform.Transform

Momentum-SAM from https://arxiv.org/pdf/2401.12033.

Note

Please make sure to place tz.m.LR inside the modules argument. For example, tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)]). Putting LR after MSAM will lead to an incorrect update rule.

Parameters:

  • modules (Chainable) –

    modules that will optimize the MSAM objective. Make sure tz.m.LR is one of them.

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • rho (float, default: 0.3 ) –

    perturbation strength. Defaults to 0.3.

  • nesterov (bool, default: False ) –

    whether to use nesterov momentum formula. Defaults to False.

  • lerp (bool, default: False ) –

    whether to use linear interpolation, if True, MSAM momentum becomes similar to exponential moving average. Defaults to False.

Examples: AdamW-MSAM

opt = tz.Optimizer(
    bench.parameters(),
    tz.m.MSAMObjective(
        [tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
        rho=1.
    )
)
Source code in torchzero/modules/adaptive/msam.py
class MSAM(Transform):
    """Momentum-SAM from https://arxiv.org/pdf/2401.12033.

    Note:
        Please make sure to place ``tz.m.LR`` inside the ``modules`` argument. For example,
        ``tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)])``. Putting LR after MSAM will lead
        to an incorrect update rule.

    Args:
        modules (Chainable): modules that will optimize the MSAM objective. Make sure ``tz.m.LR`` is one of them.
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        rho (float, optional): perturbation strength. Defaults to 0.3.
        nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
        lerp (bool, optional):
            whether to use linear interpolation, if True, MSAM momentum becomes similar to exponential moving average.
            Defaults to False.

    Examples:
    AdamW-MSAM

    ```py
    opt = tz.Optimizer(
        bench.parameters(),
        tz.m.MSAMObjective(
            [tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
            rho=1.
        )
    )
    ```
    """
    def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
        defaults = dict(momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
        super().__init__(defaults)

        self.set_child('modules', modules)
        self.add_projected_keys("grad", "velocity")


    @torch.no_grad
    def apply_states(self, objective, states, settings):
        velocity = unpack_states(states, objective.params, 'velocity', cls=TensorList)
        fs = settings[0]

        momentum, rho, weight_decay = unpack_dicts(settings, 'momentum', 'rho', 'weight_decay', cls=NumberList)

        return msam_(
            TensorList(objective.get_updates()),
            params=TensorList(objective.params),
            velocity_=velocity,
            momentum=momentum,
            lr=None,
            rho=rho,
            weight_decay=weight_decay,
            nesterov=fs['nesterov'],
            lerp=fs['lerp'],

            # inner args
            inner=self.children["modules"],
            objective=objective,
        )

MSAMMomentum

Bases: torchzero.core.transform.TensorTransform

Momentum-SAM from https://arxiv.org/pdf/2401.12033.

This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in replacement for momentum strategies in other optimizers.

To combine MSAM with other optimizers in the way done in the official implementation, e.g. to make Adam_MSAM, use tz.m.MSAMObjective module.

Note MSAM has a learning rate hyperparameter that can't really be removed from the update rule. To avoid compounding learning rate mofications, remove the tz.m.LR module if you had it.

Parameters:

  • lr (float) –

    learning rate. Adding this module adds support for learning rate schedulers.

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • rho (float, default: 0.3 ) –

    perturbation strength. Defaults to 0.3.

  • weight_decay (float, default: 0 ) –

    weight decay. It is applied to perturbed parameters, so it is differnet from applying :code:tz.m.WeightDecay after MSAM. Defaults to 0.

  • nesterov (bool, default: False ) –

    whether to use nesterov momentum formula. Defaults to False.

  • lerp (bool, default: False ) –

    whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.

Examples:

MSAM

opt = tz.Optimizer(
    model.parameters(),
    tz.m.MSAM(1e-3)
)

Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM. To make Adam_MSAM and such, use the tz.m.MSAMObjective module.

opt = tz.Optimizer(
    model.parameters(),
    tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
    tz.m.Debias(0.9, 0.999),
)
Source code in torchzero/modules/adaptive/msam.py
class MSAMMomentum(TensorTransform):
    """Momentum-SAM from https://arxiv.org/pdf/2401.12033.

    This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in
    replacement for momentum strategies in other optimizers.

    To combine MSAM with other optimizers in the way done in the official implementation,
    e.g. to make Adam_MSAM, use ``tz.m.MSAMObjective`` module.

    Note
        MSAM has a learning rate hyperparameter that can't really be removed from the update rule.
        To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.

    Args:
        lr (float): learning rate. Adding this module adds support for learning rate schedulers.
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        rho (float, optional): perturbation strength. Defaults to 0.3.
        weight_decay (float, optional):
            weight decay. It is applied to perturbed parameters, so it is differnet
            from applying :code:`tz.m.WeightDecay` after MSAM. Defaults to 0.
        nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
        lerp (bool, optional):
            whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.

    ### Examples:

    MSAM

    ```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.MSAM(1e-3)
    )
    ```

    Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
    To make Adam_MSAM and such, use the ``tz.m.MSAMObjective`` module.

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
        tz.m.Debias(0.9, 0.999),
    )
    ```
    """

    def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3,  weight_decay:float=0, nesterov=False, lerp=False,):
        defaults = dict(lr = lr, momentum=momentum, rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
        super().__init__(defaults, uses_grad=False)

        self.add_projected_keys("grad", "velocity")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
        fs = settings[0]

        lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)

        return msam_(
            TensorList(tensors),
            params=TensorList(params),
            velocity_=velocity,
            momentum=momentum,
            lr=lr,
            rho=rho,
            weight_decay=weight_decay,
            nesterov=fs['nesterov'],
            lerp=fs['lerp'],

            # inner args
            inner=None,
            objective=None,
        )

MatrixMomentum

Bases: torchzero.core.transform.Transform

Second order momentum method.

Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.

Notes
  • mu needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable. I have devised an adaptive version of this - tz.m.AdaptiveMatrixMomentum, and it works well without having to tune mu, however the adaptive version doesn't work on stochastic objectives.

  • In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.

  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument.

Parameters:

  • mu (float, default: 0.1 ) –

    this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.

  • hvp_method (str, default: 'autograd' ) –

    Determines how hessian-vector products are computed.

    • "batched_autograd" - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to "autograd". Faster than "autograd" but uses more memory.
    • "autograd" - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than "batched_autograd" but uses less memory.
    • "fd_forward" - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
    • "fd_central" - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

    Defaults to "autograd".

  • h (float, default: 0.001 ) –

    The step size for finite difference if hvp_method is "fd_forward" or "fd_central". Defaults to 1e-3.

  • hvp_tfm (Chainable | None) –

    optional module applied to hessian-vector products. Defaults to None.

Reference

Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).

Source code in torchzero/modules/adaptive/matrix_momentum.py
class MatrixMomentum(Transform):
    """Second order momentum method.

    Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.

    Notes:
        - ``mu`` needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable. I have devised an adaptive version of this - ``tz.m.AdaptiveMatrixMomentum``, and it works well without having to tune ``mu``, however the adaptive version doesn't work on stochastic objectives.

        - In most cases ``MatrixMomentum`` should be the first module in the chain because it relies on autograd.

        - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument.

    Args:
        mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
        hvp_method (str, optional):
            Determines how hessian-vector products are computed.

            - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
            - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
            - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

            Defaults to ``"autograd"``.
        h (float, optional):
            The step size for finite difference if ``hvp_method`` is
            ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
        hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.

    Reference:
        Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
    """

    def __init__(
        self,
        lr:float,
        mu=0.1,
        hvp_method: HVPMethod = "autograd",
        h: float = 1e-3,
        adaptive:bool = False,
        adapt_freq: int | None = None,

        inner: Chainable | None = None,
    ):
        defaults = dict(lr=lr, mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
        super().__init__(defaults, inner=inner)

    def reset_for_online(self):
        super().reset_for_online()
        self.clear_state_keys('p_prev')

    @torch.no_grad
    def update_states(self, objective, states, settings):
        step = self.increment_counter("step", 0)
        p = TensorList(objective.params)
        p_prev = unpack_states(states, p, 'p_prev', init=p)

        fs = settings[0]
        hvp_method = fs['hvp_method']
        h = fs['h']

        if step > 0:
            s = p - p_prev

            Hs, _ = objective.hessian_vector_product(s, at_x0=True, rgrad=None, hvp_method=hvp_method, h=h, retain_graph=False)
            Hs = [t.detach() for t in Hs]

            self.store(p, ("Hs", "s"), (Hs, s))

            # -------------------------------- adaptive mu ------------------------------- #
            if fs["adaptive"]:
                g = TensorList(objective.get_grads())

                if fs["adapt_freq"] is None:
                    # ---------------------------- deterministic case ---------------------------- #
                    g_prev = unpack_states(states, p, "g_prev", cls=TensorList)
                    y = g - g_prev
                    g_prev.copy_(g)
                    denom = y.global_vector_norm()
                    denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
                    self.global_state["mu_mul"] = s.global_vector_norm() / denom

                else:
                    # -------------------------------- stochastic -------------------------------- #
                    adapt_freq = self.defaults["adapt_freq"]

                    # we start on 1nd step, and want to adapt when we start, so use (step - 1)
                    if (step - 1) % adapt_freq == 0:
                        assert objective.closure is not None
                        params = TensorList(objective.params)
                        p_cur = params.clone()

                        # move to previous params and evaluate p_prev with current mini-batch
                        params.copy_(unpack_states(states, p, 'p_prev'))
                        with torch.enable_grad():
                            objective.closure()
                        g_prev = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
                        y = g - g_prev

                        # move back to current params
                        params.copy_(p_cur)

                        denom = y.global_vector_norm()
                        denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
                        self.global_state["mu_mul"] = s.global_vector_norm() / denom

        torch._foreach_copy_(p_prev, objective.params)

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        update = TensorList(objective.get_updates())
        lr, mu = unpack_dicts(settings, "lr", 'mu', cls=NumberList)

        if "mu_mul" in self.global_state:
            mu = mu * self.global_state["mu_mul"]

        # --------------------------------- 1st step --------------------------------- #
        # p_prev is not available so make a small step
        step = self.global_state["step"]
        if step == 1:
            if self.defaults["adaptive"]:
                # initialize
                unpack_states(states, objective.params, "g_prev", init=objective.get_grads())

            update.mul_(lr) # separate so that initial_step_size can clip correctly
            update.mul_(initial_step_size(update, 1e-7))
            return objective

        # -------------------------- matrix momentum update -------------------------- #
        s, Hs = unpack_states(states, objective.params, 's', 'Hs', cls=TensorList)

        update.mul_(lr).sub_(s).add_(Hs*mu)
        objective.updates = update
        return objective

Maximum

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs maximum(tensors, other(tensors))

Source code in torchzero/modules/ops/binary.py
class Maximum(BinaryOperationBase):
    """Outputs ``maximum(tensors, other(tensors))``"""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
        torch._foreach_maximum_(update, other)
        return update

MaximumModules

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs elementwise maximum of inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class MaximumModules(ReduceOperationBase):
    """Outputs elementwise maximum of ``inputs`` that can be modules or numbers."""
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        maximum = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_maximum_(maximum, v)

        return maximum

McCormick

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

McCormicks's Quasi-Newton method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.

This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method in other sources.

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class McCormick(_InverseHessianUpdateStrategyDefaults):
    """McCormicks's Quasi-Newton method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.

        This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method in other sources.
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return mccormick_H_(H=H, s=s, y=y)

MeZO

Bases: torchzero.modules.grad_approximation.grad_approximator.GradApproximator

Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.

Note

This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

Parameters:

  • h (float, default: 0.001 ) –

    finite difference step size of jvp_method is set to forward or central. Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of random gradient samples. Defaults to 1.

  • formula (Literal, default: 'central2' ) –

    finite difference formula. Defaults to 'central2'.

  • distribution (Literal, default: 'rademacher' ) –

    distribution. Defaults to "rademacher". If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.

  • target (Literal, default: 'closure' ) –

    what to set on var. Defaults to "closure".

References

Malladi, S., Gao, T., Nichani, E., Damian, A., Lee, J. D., Chen, D., & Arora, S. (2023). Fine-tuning language models with just forward passes. Advances in Neural Information Processing Systems, 36, 53038-53075. https://arxiv.org/abs/2305.17333

Source code in torchzero/modules/grad_approximation/rfdm.py
class MeZO(GradApproximator):
    """Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.

    Note:
        This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
        and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

    Args:
        h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
        n_samples (int, optional): number of random gradient samples. Defaults to 1.
        formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
        distribution (Distributions, optional): distribution. Defaults to "rademacher".
            If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
        target (GradTarget, optional): what to set on var. Defaults to "closure".

    References:
        Malladi, S., Gao, T., Nichani, E., Damian, A., Lee, J. D., Chen, D., & Arora, S. (2023). Fine-tuning language models with just forward passes. Advances in Neural Information Processing Systems, 36, 53038-53075. https://arxiv.org/abs/2305.17333
    """

    def __init__(self, h: float=1e-3, n_samples: int = 1, formula: _FD_Formula = 'central2',
                 distribution: Distributions = 'rademacher', return_approx_loss: bool = False, target: GradTarget = 'closure'):

        defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution)
        super().__init__(defaults, return_approx_loss=return_approx_loss, target=target)

    def _seeded_perturbation(self, params: list[torch.Tensor], distribution, seed, h):
        prt = TensorList(params).sample_like(
            distribution=distribution,
            variance=h,
            generator=torch.Generator(params[0].device).manual_seed(seed)
        )
        return prt

    def pre_step(self, objective):
        h = NumberList(self.settings[p]['h'] for p in objective.params)

        n_samples = self.defaults['n_samples']
        distribution = self.defaults['distribution']

        step = objective.current_step

        # create functions that generate a deterministic perturbation from seed based on current step
        prt_fns = []
        for i in range(n_samples):

            prt_fn = partial(self._seeded_perturbation, params=objective.params, distribution=distribution, seed=1_000_000*step + i, h=h)
            prt_fns.append(prt_fn)

        self.global_state['prt_fns'] = prt_fns

    @torch.no_grad
    def approximate(self, closure, params, loss):
        params = TensorList(params)
        loss_approx = None

        h = NumberList(self.settings[p]['h'] for p in params)
        n_samples = self.defaults['n_samples']
        fd_fn = _RFD_FUNCS[self.defaults['formula']]

        prt_fns = self.global_state['prt_fns']

        grad = None
        for i in range(n_samples):
            loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=prt_fns[i], h=h, f_0=loss)
            if grad is None: grad = prt_fns[i]().mul_(d)
            else: grad += prt_fns[i]().mul_(d)

        assert grad is not None
        if n_samples > 1: grad.div_(n_samples)
        return grad, loss, loss_approx

Mean

Bases: torchzero.modules.ops.reduce.Sum

Outputs a mean of inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class Mean(Sum):
    """Outputs a mean of ``inputs`` that can be modules or numbers."""
    USE_MEAN = True

USE_MEAN class-attribute

USE_MEAN = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

MedianAveraging

Bases: torchzero.core.transform.TensorTransform

Median of past history_size updates.

Parameters:

  • history_size (int) –

    Number of past updates to average

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/momentum/averaging.py
class MedianAveraging(TensorTransform):
    """Median of past ``history_size`` updates.

    Args:
        history_size (int): Number of past updates to average
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, history_size: int,):
        defaults = dict(history_size = history_size)
        super().__init__(defaults=defaults)

        self.add_projected_keys("grad", "history")

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        history_size = setting['history_size']

        if 'history' not in state:
            state['history'] = deque(maxlen=history_size)

        history = state['history']
        history.append(tensor)

        stacked = torch.stack(tuple(history), 0)
        return torch.quantile(stacked, 0.5, dim = 0)

Minimum

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs minimum(tensors, other(tensors))

Source code in torchzero/modules/ops/binary.py
class Minimum(BinaryOperationBase):
    """Outputs ``minimum(tensors, other(tensors))``"""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
        torch._foreach_minimum_(update, other)
        return update

MinimumModules

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs elementwise minimum of inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class MinimumModules(ReduceOperationBase):
    """Outputs elementwise minimum of ``inputs`` that can be modules or numbers."""
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        minimum = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_minimum_(minimum, v)

        return minimum

Mul

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Multiply tensors by other. other can be a number or a module.

If other is a module, this calculates tensors * other(tensors)

Source code in torchzero/modules/ops/binary.py
class Mul(BinaryOperationBase):
    """Multiply tensors by ``other``. ``other`` can be a number or a module.

    If ``other`` is a module, this calculates ``tensors * other(tensors)``
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        torch._foreach_mul_(update, other)
        return update

MulByLoss

Bases: torchzero.core.transform.TensorTransform

Multiplies update by loss times alpha

Source code in torchzero/modules/misc/misc.py
class MulByLoss(TensorTransform):
    """Multiplies update by loss times ``alpha``"""
    def __init__(self, alpha: float = 1, min_value:float = 1e-16, backward: bool = True):
        defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
        super().__init__(defaults, uses_loss=True)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        assert loss is not None
        alpha, min_value = unpack_dicts(settings, 'alpha', 'min_value')
        mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
        torch._foreach_mul_(tensors, mul)
        return tensors

MultiOperationBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for operations that use operands. This is an abstract class, subclass it and override transform method to use it.

Methods:

  • transform

    applies the operation to operands

Source code in torchzero/modules/ops/multi.py
class MultiOperationBase(Module, ABC):
    """Base class for operations that use operands. This is an abstract class, subclass it and override `transform` method to use it."""
    def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
        super().__init__(defaults=defaults)

        self.operands = {}
        for k,v in operands.items():

            if isinstance(v, (Module, Sequence)):
                self.set_child(k, v)
                self.operands[k] = self.children[k]
            else:
                self.operands[k] = v

        if not self.children:
            raise ValueError('At least one operand must be a module')

    @abstractmethod
    def transform(self, objective: Objective, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
        """applies the operation to operands"""
        raise NotImplementedError

    def update(self, objective): raise RuntimeError
    def apply(self, objective): raise RuntimeError

    @torch.no_grad
    def step(self, objective: Objective) -> Objective:
        # pass cloned update to all module operands
        processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()

        for k,v in self.operands.items():
            if k in self.children:
                v: Module
                updated_obj = v.step(objective.clone(clone_updates=True))
                processed_operands[k] = updated_obj.get_updates()
                objective.update_attrs_from_clone_(updated_obj) # update loss, grad, etc if this module calculated them

        transformed = self.transform(objective, **processed_operands)
        objective.updates = transformed
        return objective

transform

transform(objective: Objective, **operands: Any | list[Tensor]) -> list[Tensor]

applies the operation to operands

Source code in torchzero/modules/ops/multi.py
@abstractmethod
def transform(self, objective: Objective, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
    """applies the operation to operands"""
    raise NotImplementedError

Multistep

Bases: torchzero.core.module.Module

Performs steps inner steps with module per each step.

The update is taken to be the parameter difference between parameters before and after the inner loop.

Source code in torchzero/modules/misc/multistep.py
class Multistep(Module):
    """Performs ``steps`` inner steps with ``module`` per each step.

    The update is taken to be the parameter difference between parameters before and after the inner loop."""
    def __init__(self, module: Chainable, steps: int):
        defaults = dict(steps=steps)
        super().__init__(defaults)
        self.set_child('module', module)

    @torch.no_grad
    def apply(self, objective):
        return _sequential_step(self, objective, sequential=False)

MuonAdjustLR

Bases: torchzero.core.transform.Transform

LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master). Orthogonalize already has this built in with the adjust_lr setting, however you might want to move this to be later in the chain.

Source code in torchzero/modules/adaptive/muon.py
class MuonAdjustLR(Transform):
    """LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master).
    Orthogonalize already has this built in with the ``adjust_lr`` setting, however you might want to move this to be later in the chain."""
    def __init__(self, channel_first: bool = True, alpha: float = 1):
        defaults = dict(channel_first=channel_first, alpha=alpha)
        super().__init__(defaults=defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        alphas = [s['alpha'] for s in settings]
        channel_first = [s["channel_first=channel_first"] for s in settings]
        tensors_alphas = [
            (t, adjust_lr_for_muon(a, t.shape, cf)) for t, a, cf in zip(tensors, alphas, channel_first) if _is_at_least_2d(t, channel_first=cf)
        ]
        tensors = [i[0] for i in tensors_alphas]
        a = [i[1] for i in alphas]
        torch._foreach_mul_(tensors, a)
        return tensors

NAG

Bases: torchzero.core.transform.TensorTransform

Nesterov accelerated gradient method (nesterov momentum).

Parameters:

  • momentum (float, default: 0.9 ) –

    momentum (beta). Defaults to 0.9.

  • dampening (float, default: 0 ) –

    momentum dampening. Defaults to 0.

  • lerp (bool, default: False ) –

    whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.

  • target (Target) –

    target to apply EMA to. Defaults to 'update'.

Source code in torchzero/modules/momentum/momentum.py
class NAG(TensorTransform):
    """Nesterov accelerated gradient method (nesterov momentum).

    Args:
        momentum (float, optional): momentum (beta). Defaults to 0.9.
        dampening (float, optional): momentum dampening. Defaults to 0.
        lerp (bool, optional):
            whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
        target (Target, optional): target to apply EMA to. Defaults to 'update'.
    """
    def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False):
        defaults = dict(momentum=momentum,dampening=dampening, lerp=lerp)
        super().__init__(defaults, uses_grad=False)

        self.add_projected_keys("grad", "velocity")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
        lerp = self.settings[params[0]]['lerp']

        momentum,dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
        return nag_(TensorList(tensors), velocity_=velocity,momentum=momentum,dampening=dampening,lerp=lerp)

NanToNum

Bases: torchzero.core.transform.TensorTransform

Convert nan, inf and -inf`` to numbers.

Parameters:

  • nan (optional, default: None ) –

    the value to replace NaNs with. Default is zero.

  • posinf (optional, default: None ) –

    if a Number, the value to replace positive infinity values with. If None, positive infinity values are replaced with the greatest finite value representable by input's dtype. Default is None.

  • neginf (optional, default: None ) –

    if a Number, the value to replace negative infinity values with. If None, negative infinity values are replaced with the lowest finite value representable by input's dtype. Default is None.

Source code in torchzero/modules/ops/unary.py
class NanToNum(TensorTransform):
    """Convert ``nan``, ``inf`` and `-`inf`` to numbers.

    Args:
        nan (optional): the value to replace NaNs with. Default is zero.
        posinf (optional): if a Number, the value to replace positive infinity values with.
            If None, positive infinity values are replaced with the greatest finite value
            representable by input's dtype. Default is None.
        neginf (optional): if a Number, the value to replace negative infinity values with.
            If None, negative infinity values are replaced with the lowest finite value
            representable by input's dtype. Default is None.
    """
    def __init__(self, nan=None, posinf=None, neginf=None):
        defaults = dict(nan=nan, posinf=posinf, neginf=neginf)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        nan, posinf, neginf = unpack_dicts(settings, 'nan', 'posinf', 'neginf')
        return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]

NaturalGradient

Bases: torchzero.core.transform.Transform

Natural gradient approximated via empirical fisher information matrix.

To use this, either pass vector of per-sample losses to the step method, or make sure the closure returns it. Gradients will be calculated via batched autograd within this module, you don't need to implement the backward pass. When using closure, please add the backward argument, it will always be False but it is required. See below for an example.

Note

Empirical fisher information matrix may give a really bad approximation in some cases. If that is the case, set sqrt to True to perform whitening instead, which is way more robust.

Parameters:

  • reg (float, default: 1e-08 ) –

    regularization parameter. Defaults to 1e-8.

  • sqrt (bool, default: False ) –

    if True, uses square root of empirical fisher information matrix. Both EFIM and it's square root can be calculated and stored efficiently without ndim^2 memory. Square root whitens the gradient and often performs much better, especially when you try to use NGD with a vector that isn't strictly per-sample gradients, but rather for example different losses.

  • gn_grad (bool, default: False ) –

    if True, uses Gauss-Newton G^T @ f as the gradient, which is effectively sum weighted by value and is equivalent to squaring the values. That makes the kernel trick solver incorrect, but for some reason it still works. If False, uses sum of per-sample gradients. This has an effect when sqrt=False, and affects the grad attribute. Defaults to False.

  • batched (bool, default: True ) –

    whether to use vmapping. Defaults to True.

Examples:

training a neural network:

X = torch.randn(64, 20)
y = torch.randn(64, 10)

model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
opt = tz.Optimizer(
    model.parameters(),
    tz.m.NaturalGradient(),
    tz.m.LR(3e-2)
)

for i in range(100):
    y_hat = model(X) # (64, 10)
    losses = (y_hat - y).pow(2).mean(0) # (10, )
    opt.step(loss=losses)
    if i % 10 == 0:
        print(f'{losses.mean() = }')

training a neural network - closure version

X = torch.randn(64, 20)
y = torch.randn(64, 10)

model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
opt = tz.Optimizer(
    model.parameters(),
    tz.m.NaturalGradient(),
    tz.m.LR(3e-2)
)

def closure(backward=True):
    y_hat = model(X) # (64, 10)
    return (y_hat - y).pow(2).mean(0) # (10, )

for i in range(100):
    losses = opt.step(closure)
    if i % 10 == 0:
    print(f'{losses.mean() = }')

minimizing the rosenbrock function with a mix of natural gradient, whitening and gauss-newton:

def rosenbrock(X):
    x1, x2 = X
    return torch.stack([(1 - x1).abs(), (10 * (x2 - x1**2).abs())])

X = torch.tensor([-1.1, 2.5], requires_grad=True)
opt = tz.Optimizer([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))

for iter in range(200):
    losses = rosenbrock(X)
    opt.step(loss=losses)
    if iter % 20 == 0:
        print(f'{losses.mean() = }')

Source code in torchzero/modules/adaptive/natural_gradient.py
class NaturalGradient(Transform):
    """Natural gradient approximated via empirical fisher information matrix.

    To use this, either pass vector of per-sample losses to the step method, or make sure
    the closure returns it. Gradients will be calculated via batched autograd within this module,
    you don't need to implement the backward pass. When using closure, please add the ``backward`` argument,
    it will always be False but it is required. See below for an example.

    Note:
        Empirical fisher information matrix may give a really bad approximation in some cases.
        If that is the case, set ``sqrt`` to True to perform whitening instead, which is way more robust.

    Args:
        reg (float, optional): regularization parameter. Defaults to 1e-8.
        sqrt (bool, optional):
            if True, uses square root of empirical fisher information matrix. Both EFIM and it's square
            root can be calculated and stored efficiently without ndim^2 memory. Square root
            whitens the gradient and often performs much better, especially when you try to use NGD
            with a vector that isn't strictly per-sample gradients, but rather for example different losses.
        gn_grad (bool, optional):
            if True, uses Gauss-Newton G^T @ f as the gradient, which is effectively sum weighted by value
            and is equivalent to squaring the values. That makes the kernel trick solver incorrect, but for
            some reason it still works. If False, uses sum of per-sample gradients.
            This has an effect when ``sqrt=False``, and affects the ``grad`` attribute.
            Defaults to False.
        batched (bool, optional): whether to use vmapping. Defaults to True.

    Examples:

    training a neural network:
    ```python
    X = torch.randn(64, 20)
    y = torch.randn(64, 10)

    model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NaturalGradient(),
        tz.m.LR(3e-2)
    )

    for i in range(100):
        y_hat = model(X) # (64, 10)
        losses = (y_hat - y).pow(2).mean(0) # (10, )
        opt.step(loss=losses)
        if i % 10 == 0:
            print(f'{losses.mean() = }')
    ```

    training a neural network - closure version
    ```python
    X = torch.randn(64, 20)
    y = torch.randn(64, 10)

    model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NaturalGradient(),
        tz.m.LR(3e-2)
    )

    def closure(backward=True):
        y_hat = model(X) # (64, 10)
        return (y_hat - y).pow(2).mean(0) # (10, )

    for i in range(100):
        losses = opt.step(closure)
        if i % 10 == 0:
        print(f'{losses.mean() = }')
    ```

    minimizing the rosenbrock function with a mix of natural gradient, whitening and gauss-newton:
    ```python
    def rosenbrock(X):
        x1, x2 = X
        return torch.stack([(1 - x1).abs(), (10 * (x2 - x1**2).abs())])

    X = torch.tensor([-1.1, 2.5], requires_grad=True)
    opt = tz.Optimizer([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))

    for iter in range(200):
        losses = rosenbrock(X)
        opt.step(loss=losses)
        if iter % 20 == 0:
            print(f'{losses.mean() = }')
    ```
    """
    def __init__(self, reg:float = 1e-8, sqrt:bool=False, gn_grad:bool=False, batched:bool=True, ):
        super().__init__(defaults=dict(batched=batched, reg=reg, sqrt=sqrt, gn_grad=gn_grad))

    @torch.no_grad
    def update_states(self, objective, states, settings):
        params = objective.params
        closure = objective.closure
        fs = settings[0]
        batched = fs['batched']
        gn_grad = fs['gn_grad']

        # compute per-sample losses
        f = objective.loss
        if f is None:
            assert closure is not None
            with torch.enable_grad():
                f = objective.get_loss(backward=False) # n_out
                assert isinstance(f, torch.Tensor)

        # compute per-sample gradients
        with torch.enable_grad():
            G_list = jacobian_wrt([f.ravel()], params, batched=batched)

        # set scalar loss and it's grad to objective
        objective.loss = f.sum()
        G = self.global_state["G"] = flatten_jacobian(G_list) # (n_samples, ndim)

        if gn_grad:
            g = self.global_state["g"] = G.H @ f.detach()

        else:
            g = self.global_state["g"] = G.sum(0)

        objective.grads = vec_to_tensors(g, params)

        # set closure to calculate scalar value for line searches etc
        if closure is not None:

            def ngd_closure(backward=True):

                if backward:
                    objective.zero_grad()
                    with torch.enable_grad():
                        loss = closure(False)
                        if gn_grad: loss = loss.pow(2)
                        loss = loss.sum()
                        loss.backward()
                    return loss

                loss = closure(False)
                if gn_grad: loss = loss.pow(2)
                return loss.sum()

            objective.closure = ngd_closure

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        params = objective.params
        fs = settings[0]
        reg = fs['reg']
        sqrt = fs['sqrt']

        G: torch.Tensor = self.global_state['G'] # (n_samples, n_dim)

        if sqrt:
            # this computes U, S <- SVD(M), then calculate update as U S^-1 Uᵀg,
            # but it computes it through eigendecompotision
            L, U = ggt_update(G.H, damping=reg, rdamping=1e-16, truncate=0, eig_tol=1e-12)

            if U is None or L is None:

                # fallback to element-wise
                g = self.global_state["g"]
                g /= G.square().mean(0).sqrt().add(reg)
                objective.updates = vec_to_tensors(g, params)
                return objective

            # whiten
            z = U.T @ self.global_state["g"]
            v = (U * L.rsqrt()) @ z
            objective.updates = vec_to_tensors(v, params)
            return objective

        # we need (G^T G)v = g
        # where g = G^T
        # so we need to solve (G^T G)v = G^T
        GGt = G @ G.H # (n_samples, n_samples)

        if reg != 0:
            GGt.add_(torch.eye(GGt.size(0), device=GGt.device, dtype=GGt.dtype).mul_(reg))

        z, _ = torch.linalg.solve_ex(GGt, torch.ones_like(GGt[0])) # pylint:disable=not-callable
        v = G.H @ z

        objective.updates = vec_to_tensors(v, params)
        return objective


    def get_H(self, objective=...):
        if "G" not in self.global_state: return linear_operator.ScaledIdentity()
        G = self.global_state['G']
        return linear_operator.AtA(G)

Negate

Bases: torchzero.core.transform.TensorTransform

Returns - input

Source code in torchzero/modules/ops/unary.py
class Negate(TensorTransform):
    """Returns ``- input``"""
    def __init__(self): super().__init__()
    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        torch._foreach_neg_(tensors)
        return tensors

NegateOnLossIncrease

Bases: torchzero.core.module.Module

Uses an extra forward pass to evaluate loss at parameters+update, if loss is larger than at parameters, the update is set to 0 if backtrack=False and to -update otherwise

Source code in torchzero/modules/misc/multistep.py
class NegateOnLossIncrease(Module):
    """Uses an extra forward pass to evaluate loss at ``parameters+update``,
    if loss is larger than at ``parameters``,
    the update is set to 0 if ``backtrack=False`` and to ``-update`` otherwise"""
    def __init__(self, backtrack=False):
        defaults = dict(backtrack=backtrack)
        super().__init__(defaults=defaults)

    @torch.no_grad
    def apply(self, objective):
        closure = objective.closure
        if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
        backtrack = self.defaults['backtrack']

        update = objective.get_updates()
        f_0 = objective.get_loss(backward=False)

        torch._foreach_sub_(objective.params, update)
        f_1 = closure(False)

        if f_1 <= f_0:
            # if var.is_last and var.last_module_lrs is None:
            #     var.stop = True
            #     var.skip_update = True
            #     return var

            torch._foreach_add_(objective.params, update)
            return objective

        torch._foreach_add_(objective.params, update)
        if backtrack:
            torch._foreach_neg_(objective.updates)
        else:
            torch._foreach_zero_(objective.updates)
        return objective

NewDQN

Bases: torchzero.modules.quasi_newton.diagonal_quasi_newton.DNRTR

Diagonal quasi-newton method.

Reference

Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.

Source code in torchzero/modules/quasi_newton/diagonal_quasi_newton.py
class NewDQN(DNRTR):
    """Diagonal quasi-newton method.

    Reference:
        Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
    """
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return new_dqn_B_(B=B, s=s, y=y)

NewSSM

Bases: torchzero.modules.quasi_newton.quasi_newton.HessianUpdateStrategy

Self-scaling Quasi-Newton method.

Note

a line search such as tz.m.StrongWolfe() is required.

Warning

this uses roughly O(N^2) memory.

Reference

Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class NewSSM(HessianUpdateStrategy):
    """Self-scaling Quasi-Newton method.

    Note:
        a line search such as ``tz.m.StrongWolfe()`` is required.

    Warning:
        this uses roughly O(N^2) memory.

    Reference:
        Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
    """
    def __init__(
        self,
        type: Literal[1, 2] = 1,
        init_scale: float | Literal["auto"] = "auto",
        tol: float = 1e-32,
        ptol: float | None = 1e-32,
        ptol_restart: bool = False,
        gtol: float | None = 1e-32,
        restart_interval: int | None = None,
        beta: float | None = None,
        update_freq: int = 1,
        scale_first: bool = False,
        concat_params: bool = True,
        inner: Chainable | None = None,
    ):
        super().__init__(
            defaults=dict(type=type),
            init_scale=init_scale,
            tol=tol,
            ptol=ptol,
            ptol_restart=ptol_restart,
            gtol=gtol,
            restart_interval=restart_interval,
            beta=beta,
            update_freq=update_freq,
            scale_first=scale_first,
            concat_params=concat_params,
            inverse=True,
            uses_loss=True,
            inner=inner,
        )
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        f = state['f']
        f_prev = state['f_prev']
        return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=setting['type'], tol=setting['tol'])

Newton

Bases: torchzero.core.transform.Transform

Exact Newton's method via autograd.

Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function. The update rule is given by (H + yI)⁻¹g, where H is the hessian and g is the gradient, y is the damping parameter.

g can be output of another module, if it is specifed in inner argument.

Note

In most cases Newton should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply Newton preconditioning to another module's output.

Note

This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating the hessian. The closure must accept a backward argument (refer to documentation).

Parameters:

  • damping (float, default: 0 ) –

    tikhonov regularizer value. Defaults to 0.

  • eigval_fn (Callable | None, default: None ) –

    function to apply to eigenvalues, for example torch.abs or lambda L: torch.clip(L, min=1e-8). If this is specified, eigendecomposition will be used to invert the hessian.

  • update_freq (int, default: 1 ) –

    updates hessian every update_freq steps.

  • precompute_inverse (bool, default: None ) –

    if True, whenever hessian is computed, also computes the inverse. This is more efficient when update_freq is large. If None, this is True if update_freq >= 10.

  • use_lstsq ((bool, Optional), default: False ) –

    if True, least squares will be used to solve the linear system, this can prevent it from exploding when hessian is indefinite. If False, tries cholesky, if it fails tries LU, and then least squares. If eigval_fn is specified, eigendecomposition is always used and this argument is ignored.

  • hessian_method (str, default: 'batched_autograd' ) –

    Determines how hessian is computed.

    • "batched_autograd" - uses autograd to compute ndim batched hessian-vector products. Faster than "autograd" but uses more memory.
    • "autograd" - uses autograd to compute ndim hessian-vector products using for loop. Slower than "batched_autograd" but uses less memory.
    • "functional_revrev" - uses torch.autograd.functional with "reverse-over-reverse" strategy and a for-loop. This is generally equivalent to "autograd".
    • "functional_fwdrev" - uses torch.autograd.functional with vectorized "forward-over-reverse" strategy. Faster than "functional_fwdrev" but uses more memory ("batched_autograd" seems to be faster)
    • "func" - uses torch.func.hessian which uses "forward-over-reverse" strategy. This method is the fastest and is recommended, however it is more restrictive and fails with some operators which is why it isn't the default.
    • "gfd_forward" - computes ndim hessian-vector products via gradient finite difference using a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
    • "gfd_central" - computes ndim hessian-vector products via gradient finite difference using a more accurate central formula which requires two gradient evaluations per hessian-vector product.
    • "fd" - uses function values to estimate gradient and hessian via finite difference. This uses less evaluations than chaining "gfd_*" after tz.m.FDM.
    • "thoad" - uses thoad library, can be significantly faster than pytorch but limited operator coverage.

    Defaults to "batched_autograd".

  • h (float, default: 0.001 ) –

    finite difference step size if hessian is compute via finite-difference.

  • inner (Chainable | None, default: None ) –

    modules to apply hessian preconditioner to. Defaults to None.

See also

  • tz.m.NewtonCG: uses a matrix-free conjugate gradient solver and hessian-vector products. useful for large scale problems as it doesn't form the full hessian.
  • tz.m.NewtonCGSteihaug: trust region version of tz.m.NewtonCG.
  • tz.m.ImprovedNewton: Newton with additional rank one correction to the hessian, can be faster than Newton.
  • tz.m.InverseFreeNewton: an inverse-free variant of Newton's method.
  • tz.m.quasi_newton: large collection of quasi-newton methods that estimate the hessian.

Notes

Implementation details

(H + yI)⁻¹g is calculated by solving the linear system (H + yI)x = g. The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares. Least squares can be forced by setting use_lstsq=True.

Additionally, if eigval_fn is specified, eigendecomposition of the hessian is computed, eigval_fn is applied to the eigenvalues, and (H + yI)⁻¹ is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive but not by much.

Handling non-convexity

Standard Newton's method does not handle non-convexity well without some modifications. This is because it jumps to the stationary point, which may be the maxima of the quadratic approximation.

A modification to handle non-convexity is to modify the eignevalues to be positive, for example by setting eigval_fn = lambda L: L.abs().clip(min=1e-4).

Examples:

Newton's method with backtracking line search

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Newton(),
    tz.m.Backtracking()
)

Newton's method for non-convex optimization.

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Newton(eigval_fn = lambda L: L.abs().clip(min=1e-4)),
    tz.m.Backtracking()
)

Newton preconditioning applied to momentum

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Newton(inner=tz.m.EMA(0.9)),
    tz.m.LR(0.1)
)
Source code in torchzero/modules/second_order/newton.py
class Newton(Transform):
    """Exact Newton's method via autograd.

    Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
    The update rule is given by ``(H + yI)⁻¹g``, where ``H`` is the hessian and ``g`` is the gradient, ``y`` is the ``damping`` parameter.

    ``g`` can be output of another module, if it is specifed in ``inner`` argument.

    Note:
        In most cases Newton should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.

    Note:
        This module requires the a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients for calculating the hessian.
        The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        damping (float, optional): tikhonov regularizer value. Defaults to 0.
        eigval_fn (Callable | None, optional):
            function to apply to eigenvalues, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
            If this is specified, eigendecomposition will be used to invert the hessian.
        update_freq (int, optional):
            updates hessian every ``update_freq`` steps.
        precompute_inverse (bool, optional):
            if ``True``, whenever hessian is computed, also computes the inverse. This is more efficient
            when ``update_freq`` is large. If ``None``, this is ``True`` if ``update_freq >= 10``.
        use_lstsq (bool, Optional):
            if True, least squares will be used to solve the linear system, this can prevent it from exploding
            when hessian is indefinite. If False, tries cholesky, if it fails tries LU, and then least squares.
            If ``eigval_fn`` is specified, eigendecomposition is always used and this argument is ignored.
        hessian_method (str):
            Determines how hessian is computed.

            - ``"batched_autograd"`` - uses autograd to compute ``ndim`` batched hessian-vector products. Faster than ``"autograd"`` but uses more memory.
            - ``"autograd"`` - uses autograd to compute ``ndim`` hessian-vector products using for loop. Slower than ``"batched_autograd"`` but uses less memory.
            - ``"functional_revrev"`` - uses ``torch.autograd.functional`` with "reverse-over-reverse" strategy and a for-loop. This is generally equivalent to ``"autograd"``.
            - ``"functional_fwdrev"`` - uses ``torch.autograd.functional`` with vectorized "forward-over-reverse" strategy. Faster than ``"functional_fwdrev"`` but uses more memory (``"batched_autograd"`` seems to be faster)
            - ``"func"`` - uses ``torch.func.hessian`` which uses "forward-over-reverse" strategy. This method is the fastest and is recommended, however it is more restrictive and fails with some operators which is why it isn't the default.
            - ``"gfd_forward"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"gfd_central"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a more accurate central formula which requires two gradient evaluations per hessian-vector product.
            - ``"fd"`` - uses function values to estimate gradient and hessian via finite difference. This uses less evaluations than chaining ``"gfd_*"`` after ``tz.m.FDM``.
            - ``"thoad"`` - uses ``thoad`` library, can be significantly faster than pytorch but limited operator coverage.

            Defaults to ``"batched_autograd"``.
        h (float, optional):
            finite difference step size if hessian is compute via finite-difference.
        inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.

    # See also

    * ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products.
    useful for large scale problems as it doesn't form the full hessian.
    * ``tz.m.NewtonCGSteihaug``: trust region version of ``tz.m.NewtonCG``.
    * ``tz.m.ImprovedNewton``: Newton with additional rank one correction to the hessian, can be faster than Newton.
    * ``tz.m.InverseFreeNewton``: an inverse-free variant of Newton's method.
    * ``tz.m.quasi_newton``: large collection of quasi-newton methods that estimate the hessian.

    # Notes

    ## Implementation details

    ``(H + yI)⁻¹g`` is calculated by solving the linear system ``(H + yI)x = g``.
    The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares. Least squares can be forced by setting ``use_lstsq=True``.

    Additionally, if ``eigval_fn`` is specified, eigendecomposition of the hessian is computed,
    ``eigval_fn`` is applied to the eigenvalues, and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive but not by much.

    ## Handling non-convexity

    Standard Newton's method does not handle non-convexity well without some modifications.
    This is because it jumps to the stationary point, which may be the maxima of the quadratic approximation.

    A modification to handle non-convexity is to modify the eignevalues to be positive,
    for example by setting ``eigval_fn = lambda L: L.abs().clip(min=1e-4)``.

    # Examples:

    Newton's method with backtracking line search

    ```py
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Newton(),
        tz.m.Backtracking()
    )
    ```

    Newton's method for non-convex optimization.

    ```py
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Newton(eigval_fn = lambda L: L.abs().clip(min=1e-4)),
        tz.m.Backtracking()
    )
    ```

    Newton preconditioning applied to momentum

    ```py
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Newton(inner=tz.m.EMA(0.9)),
        tz.m.LR(0.1)
    )
    ```

    """
    def __init__(
        self,
        damping: float = 0,
        eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
        eigv_tol: float | None = None,
        truncate: int | None = None,
        update_freq: int = 1,
        precompute_inverse: bool | None = None,
        use_lstsq: bool = False,
        hessian_method: HessianMethod = "batched_autograd",
        h: float = 1e-3,
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['update_freq'], defaults["inner"]
        super().__init__(defaults, update_freq=update_freq, inner=inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]

        precompute_inverse = fs["precompute_inverse"]
        if precompute_inverse is None:
            precompute_inverse = fs["__update_freq"] >= 10

        __, _, H = objective.hessian(hessian_method=fs["hessian_method"], h=fs["h"], at_x0=True)

        _newton_update_state_(
            state = self.global_state,
            H=H,
            damping = fs["damping"],
            eigval_fn = fs["eigval_fn"],
            eigv_tol = fs["eigv_tol"],
            truncate = fs["truncate"],
            precompute_inverse = precompute_inverse,
            use_lstsq = fs["use_lstsq"]
        )

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        updates = objective.get_updates()
        fs = settings[0]

        b = torch.cat([t.ravel() for t in updates])
        sol = _newton_solve(b=b, state=self.global_state, use_lstsq=fs["use_lstsq"])

        vec_to_tensors_(sol, updates)
        return objective

    def get_H(self,objective=...):
        return _newton_get_H(self.global_state)

NewtonCG

Bases: torchzero.core.transform.Transform

Newton's method with a matrix-free conjugate gradient or minimial-residual solver.

Notes
  • In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply Newton preconditioning to another module's output.

  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

Warning

CG may fail if hessian is not positive-definite.

Parameters:

  • maxiter (int | None, default: None ) –

    Maximum number of iterations for the conjugate gradient solver. By default, this is set to the number of dimensions in the objective function, which is the theoretical upper bound for CG convergence. Setting this to a smaller value (truncated Newton) can still generate good search directions. Defaults to None.

  • tol (float, default: 1e-08 ) –

    Relative tolerance for the conjugate gradient solver to determine convergence. Defaults to 1e-4.

  • reg (float, default: 1e-08 ) –

    Regularization parameter (damping) added to the Hessian diagonal. This helps ensure the system is positive-definite. Defaults to 1e-8.

  • hvp_method (str, default: 'autograd' ) –

    Determines how Hessian-vector products are evaluated.

    • "autograd" - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop.
    • "fd_forward" - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
    • "fd_central" - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

    For NewtonCG "batched_autograd" is equivalent to "autograd". Defaults to "autograd".

  • h (float, default: 0.001 ) –

    The step size for finite difference if hvp_method is "fd_forward" or "fd_central". Defaults to 1e-3.

  • warm_start (bool, default: False ) –

    If True, the conjugate gradient solver is initialized with the solution from the previous optimization step. This can accelerate convergence, especially in truncated Newton methods. Defaults to False.

  • inner (Chainable | None, default: None ) –

    NewtonCG will attempt to apply preconditioning to the output of this module.

Examples: Newton-CG with a backtracking line search:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.NewtonCG(),
    tz.m.Backtracking()
)

Truncated Newton method (useful for large-scale problems):

opt = tz.Optimizer(
    model.parameters(),
    tz.m.NewtonCG(maxiter=10),
    tz.m.Backtracking()
)

Source code in torchzero/modules/second_order/newton_cg.py
class NewtonCG(Transform):
    """Newton's method with a matrix-free conjugate gradient or minimial-residual solver.

    Notes:
        * In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.

        * This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

    Warning:
        CG may fail if hessian is not positive-definite.

    Args:
        maxiter (int | None, optional):
            Maximum number of iterations for the conjugate gradient solver.
            By default, this is set to the number of dimensions in the
            objective function, which is the theoretical upper bound for CG
            convergence. Setting this to a smaller value (truncated Newton)
            can still generate good search directions. Defaults to None.
        tol (float, optional):
            Relative tolerance for the conjugate gradient solver to determine
            convergence. Defaults to 1e-4.
        reg (float, optional):
            Regularization parameter (damping) added to the Hessian diagonal.
            This helps ensure the system is positive-definite. Defaults to 1e-8.
        hvp_method (str, optional):
            Determines how Hessian-vector products are evaluated.

            - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop.
            - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

            For NewtonCG ``"batched_autograd"`` is equivalent to ``"autograd"``. Defaults to ``"autograd"``.
        h (float, optional):
            The step size for finite difference if ``hvp_method`` is
            ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
        warm_start (bool, optional):
            If ``True``, the conjugate gradient solver is initialized with the
            solution from the previous optimization step. This can accelerate
            convergence, especially in truncated Newton methods.
            Defaults to False.
        inner (Chainable | None, optional):
            NewtonCG will attempt to apply preconditioning to the output of this module.

    Examples:
    Newton-CG with a backtracking line search:

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NewtonCG(),
        tz.m.Backtracking()
    )
    ```

    Truncated Newton method (useful for large-scale problems):
    ```
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NewtonCG(maxiter=10),
        tz.m.Backtracking()
    )
    ```

    """
    def __init__(
        self,
        maxiter: int | None = None,
        tol: float = 1e-8,
        reg: float = 1e-8,
        hvp_method: HVPMethod = "autograd",
        solver: Literal['cg', 'minres'] = 'cg',
        npc_terminate: bool = False,
        h: float = 1e-3, # tuned 1e-4 or 1e-3
        miniter:int = 1,
        warm_start=False,
        warm_beta:float=0,
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner']
        super().__init__(defaults, inner=inner)

        self._num_hvps = 0
        self._num_hvps_last_step = 0

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]
        hvp_method = fs['hvp_method']
        h = fs['h']

        # ---------------------- Hessian vector product function --------------------- #
        _, H_mv = objective.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
        objective.temp = H_mv

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        self._num_hvps_last_step = 0
        H_mv = objective.poptemp()

        fs = settings[0]
        tol = fs['tol']
        reg = fs['reg']
        maxiter = fs['maxiter']
        solver = fs['solver'].lower().strip()
        warm_start = fs['warm_start']
        npc_terminate = fs["npc_terminate"]

        # ---------------------------------- run cg ---------------------------------- #
        x0 = None
        if warm_start:
            x0 = unpack_states(states, objective.params, 'prev_x', cls=TensorList)

        b = TensorList(objective.get_updates())

        if solver == 'cg':
            d, _ = cg(A_mv=H_mv, b=b, x0=x0, tol=tol, maxiter=maxiter,
                      miniter=fs["miniter"], reg=reg, npc_terminate=npc_terminate)

        elif solver == 'minres':
            d = minres(A_mv=H_mv, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)

        else:
            raise ValueError(f"Unknown solver {solver}")

        if warm_start:
            assert x0 is not None
            x0.lerp_(d, weight = 1-fs["warm_beta"])

        objective.updates = d
        self._num_hvps += self._num_hvps_last_step
        return objective

NewtonCGSteihaug

Bases: torchzero.core.transform.Transform

Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.

Notes
  • In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply Newton preconditioning to another module's output.

  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

Parameters:

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. Defaults to 0.0.

  • nplus (float, default: 3.5 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.99 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.0001 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • max_attempts (max_attempts, default: 100 ) –

    maximum number of trust radius reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • max_history (int, default: 100 ) –

    CG will store this many intermediate solutions, reusing them when trust radius is reduced instead of re-running CG. Each solution storage requires 2N memory. Defaults to 100.

  • boundary_tol (float | None, default: 1e-06 ) –

    The trust region only increases when suggested step's norm is at least (1-boundary_tol)*trust_region. This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.

  • maxiter (int | None, default: None ) –

    maximum number of CG iterations per step. Each iteration requies one backward pass if hvp_method="forward", two otherwise. Defaults to None.

  • miniter (int, default: 1 ) –

    minimal number of CG iterations. This prevents making no progress

  • tol (float, default: 1e-08 ) –

    terminates CG when norm of the residual is less than this value. Defaults to 1e-8. when initial guess is below tolerance. Defaults to 1.

  • reg (float, default: 1e-08 ) –

    hessian regularization. Defaults to 1e-8.

  • solver (str, default: 'cg' ) –

    solver, "cg" or "minres". "cg" is recommended. Defaults to 'cg'.

  • adapt_tol (bool, default: False ) –

    if True, whenever trust radius collapses to smallest representable number, the tolerance is multiplied by 0.1. Defaults to True.

  • npc_terminate (bool, default: False ) –

    whether to terminate CG/MINRES whenever negative curvature is detected. Defaults to False.

  • hvp_method (str, default: 'fd_central' ) –

    either "fd_forward" to use forward formula which requires one backward pass per hessian-vector product, or "fd_central" to use a more accurate central formula which requires two backward passes. "fd_forward" is usually accurate enough. Defaults to "fd_forward".

  • h (float, default: 0.001 ) –

    finite difference step size. Defaults to 1e-3.

  • inner (Chainable | None, default: None ) –

    applies preconditioning to output of this module. Defaults to None.

Examples:

Trust-region Newton-CG:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.NewtonCGSteihaug(),
)
Reference:
Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
Source code in torchzero/modules/second_order/newton_cg.py
class NewtonCGSteihaug(Transform):
    """Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.

    Notes:
        * In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.

        * This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted. Defaults to 0.0.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        max_attempts (max_attempts, optional):
            maximum number of trust radius reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        max_history (int, optional):
            CG will store this many intermediate solutions, reusing them when trust radius is reduced
            instead of re-running CG. Each solution storage requires 2N memory. Defaults to 100.
        boundary_tol (float | None, optional):
            The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
            This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.

        maxiter (int | None, optional):
            maximum number of CG iterations per step. Each iteration requies one backward pass if `hvp_method="forward"`, two otherwise. Defaults to None.
        miniter (int, optional):
            minimal number of CG iterations. This prevents making no progress
        tol (float, optional):
            terminates CG when norm of the residual is less than this value. Defaults to 1e-8.
            when initial guess is below tolerance. Defaults to 1.
        reg (float, optional): hessian regularization. Defaults to 1e-8.
        solver (str, optional): solver, "cg" or "minres". "cg" is recommended. Defaults to 'cg'.
        adapt_tol (bool, optional):
            if True, whenever trust radius collapses to smallest representable number,
            the tolerance is multiplied by 0.1. Defaults to True.
        npc_terminate (bool, optional):
            whether to terminate CG/MINRES whenever negative curvature is detected. Defaults to False.

        hvp_method (str, optional):
            either ``"fd_forward"`` to use forward formula which requires one backward pass per hessian-vector product, or ``"fd_central"`` to use a more accurate central formula which requires two backward passes. ``"fd_forward"`` is usually accurate enough. Defaults to ``"fd_forward"``.
        h (float, optional): finite difference step size. Defaults to 1e-3.

        inner (Chainable | None, optional):
            applies preconditioning to output of this module. Defaults to None.

    ### Examples:
    Trust-region Newton-CG:

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NewtonCGSteihaug(),
    )
    ```

    ### Reference:
        Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
    """
    def __init__(
        self,
        # trust region settings
        eta: float= 0.0,
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        init: float = 1,
        max_attempts: int = 100,
        max_history: int = 100,
        boundary_tol: float = 1e-6, # tuned

        # cg settings
        maxiter: int | None = None,
        miniter: int = 1,
        tol: float = 1e-8,
        reg: float = 1e-8,
        solver: Literal['cg', "minres"] = 'cg',
        adapt_tol: bool = False,
        terminate_on_tr: bool = True,
        npc_terminate: bool = False,

        # hvp settings
        hvp_method: Literal["fd_forward", "fd_central"] = "fd_central",
        h: float = 1e-3, # tuned 1e-4 or 1e-3

        # inner
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner']
        super().__init__(defaults, inner=inner)

        self._num_hvps = 0
        self._num_hvps_last_step = 0


    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]
        hvp_method = fs['hvp_method']
        h = fs['h']

        # ---------------------- Hessian vector product function --------------------- #
        _, H_mv = objective.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
        objective.temp = H_mv

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        self._num_hvps_last_step = 0

        H_mv = objective.poptemp()
        params = TensorList(objective.params)
        fs = settings[0]

        tol = fs['tol'] * self.global_state.get('tol_mul', 1)
        solver = fs['solver'].lower().strip()

        reg=fs["reg"]
        maxiter=fs["maxiter"]
        max_attempts=fs["max_attempts"]
        init=fs["init"]
        npc_terminate=fs["npc_terminate"]
        miniter=fs["miniter"]
        max_history=fs["max_history"]


        # ------------------------------- trust region ------------------------------- #
        success = False
        d = None
        orig_params = [p.clone() for p in params]
        b = TensorList(objective.get_updates())
        solution = None
        closure = objective.closure
        assert closure is not None

        while not success:
            max_attempts -= 1
            if max_attempts < 0: break

            trust_radius = self.global_state.get('trust_radius', init)

            # -------------- make sure trust radius isn't too small or large ------------- #
            finfo = torch.finfo(orig_params[0].dtype)
            if trust_radius < finfo.tiny * 2:
                trust_radius = self.global_state['trust_radius'] = init

                if fs["adapt_tol"]:
                    self.global_state["tol_mul"] = self.global_state.get("tol_mul", 1) * 0.1

                if fs["terminate_on_tr"]:
                    objective.should_terminate = True

            elif trust_radius > finfo.max / 2:
                trust_radius = self.global_state['trust_radius'] = init

            # ----------------------------------- solve ---------------------------------- #
            d = None
            if solution is not None and solution.history is not None:
                d = find_within_trust_radius(solution.history, trust_radius)

            if d is None:
                if solver == 'cg':
                    d, solution = cg(
                        A_mv=H_mv,
                        b=b,
                        tol=tol,
                        maxiter=maxiter,
                        reg=reg,
                        trust_radius=trust_radius,
                        miniter=miniter,
                        npc_terminate=npc_terminate,
                        history_size=max_history,
                    )

                elif solver == 'minres':
                    d = minres(A_mv=H_mv, b=b, trust_radius=trust_radius, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)

                else:
                    raise ValueError(f"unknown solver {solver}")

            # ---------------------------- update trust radius --------------------------- #
            self.global_state["trust_radius"], success = default_radius(
                params = params,
                closure = closure,
                f = tofloat(objective.get_loss(False)),
                g = b,
                H = H_mv,
                d = d,
                trust_radius = trust_radius,
                eta = fs["eta"],
                nplus = fs["nplus"],
                nminus = fs["nminus"],
                rho_good = fs["rho_good"],
                rho_bad = fs["rho_bad"],
                boundary_tol = fs["boundary_tol"],

                init = cast(int, None), # init isn't used because check_overflow=False
                state = cast(dict, None), # not used
                settings = cast(dict, None), # not used
                check_overflow = False, # this is checked manually to adapt tolerance
            )

        # --------------------------- assign new direction --------------------------- #
        assert d is not None
        if success:
            objective.updates = d

        else:
            objective.updates = params.zeros_like()

        self._num_hvps += self._num_hvps_last_step
        return objective

NoiseSign

Bases: torchzero.core.transform.TensorTransform

Outputs random tensors with sign copied from the update.

Source code in torchzero/modules/misc/misc.py
class NoiseSign(TensorTransform):
    """Outputs random tensors with sign copied from the update."""
    def __init__(self, distribution:Distributions = 'normal', variance:float | None = None):
        defaults = dict(distribution=distribution, variance=variance)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        variance = unpack_dicts(settings, 'variance')
        return TensorList(tensors).sample_like(settings[0]['distribution'], variance=variance).copysign_(tensors)

Noop

Bases: torchzero.core.module.Module

Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

Source code in torchzero/modules/ops/utility.py
class Identity(Module):
    """Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
    def __init__(self, *args, **kwargs): super().__init__()
    def update(self, objective): pass
    def apply(self, objective): return objective
    def get_H(self, objective):
        n = sum(p.numel() for p in objective.params)
        p = objective.params[0]
        return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)

Normalize

Bases: torchzero.core.transform.TensorTransform

Normalizes the update.

Parameters:

  • norm_value (float, default: 1 ) –

    desired norm value.

  • ord (float, default: 2 ) –

    norm order. Defaults to 2.

  • dim (int | Sequence[int] | str | None, default: None ) –

    calculates norm along those dimensions. If list/tuple, tensors are normalized along all dimensios in dim that they have. Can be set to "global" to normalize by global norm of all gradients concatenated to a vector. Defaults to None.

  • inverse_dims (bool, default: False ) –

    if True, the dims argument is inverted, and all other dimensions are normalized.

  • min_size (int, default: 1 ) –

    minimal size of a dimension to normalize along it. Defaults to 1.

  • target (str) –

    what this affects.

Examples: Gradient normalization:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Normalize(1),
    tz.m.Adam(),
    tz.m.LR(1e-2),
)

Update normalization:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adam(),
    tz.m.Normalize(1),
    tz.m.LR(1e-2),
)
Source code in torchzero/modules/clipping/clipping.py
class Normalize(TensorTransform):
    """Normalizes the update.

    Args:
        norm_value (float): desired norm value.
        ord (float, optional): norm order. Defaults to 2.
        dim (int | Sequence[int] | str | None, optional):
            calculates norm along those dimensions.
            If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
            Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
            Defaults to None.
        inverse_dims (bool, optional):
            if True, the `dims` argument is inverted, and all other dimensions are normalized.
        min_size (int, optional):
            minimal size of a dimension to normalize along it. Defaults to 1.
        target (str, optional):
            what this affects.

    Examples:
    Gradient normalization:
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Normalize(1),
        tz.m.Adam(),
        tz.m.LR(1e-2),
    )
    ```

    Update normalization:

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Adam(),
        tz.m.Normalize(1),
        tz.m.LR(1e-2),
    )
    ```
    """
    def __init__(
        self,
        norm_value: float = 1,
        ord: Metrics = 2,
        dim: int | Sequence[int] | Literal["global"] | None = None,
        inverse_dims: bool = False,
        min_size: int = 1,
    ):
        defaults = dict(norm_value=norm_value,ord=ord,dim=dim,min_size=min_size, inverse_dims=inverse_dims)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        norm_value = NumberList(s['norm_value'] for s in settings)
        ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])

        _clip_norm_(
            tensors_ = TensorList(tensors),
            min = None,
            max = None,
            norm_value = norm_value,
            ord = ord,
            dim = dim,
            inverse_dims=inverse_dims,
            min_size = min_size,
        )

        return tensors

NormalizeByEMA

Bases: torchzero.modules.clipping.ema_clipping.ClipNormByEMA

Sets norm of the update to be the same as the norm of an exponential moving average of past updates.

Parameters:

  • beta (float, default: 0.99 ) –

    beta for the exponential moving average. Defaults to 0.99.

  • ord (float, default: 2 ) –

    order of the norm. Defaults to 2.

  • eps (float) –

    epsilon for division. Defaults to 1e-6.

  • tensorwise (bool, default: True ) –

    if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.

  • max_ema_growth (float | None, default: 1.5 ) –

    if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.

  • ema_init (str) –

    How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.

Source code in torchzero/modules/clipping/ema_clipping.py
class NormalizeByEMA(ClipNormByEMA):
    """Sets norm of the update to be the same as the norm of an exponential moving average of past updates.

    Args:
        beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
        ord (float, optional): order of the norm. Defaults to 2.
        eps (float, optional): epsilon for division. Defaults to 1e-6.
        tensorwise (bool, optional):
            if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
        max_ema_growth (float | None, optional):
            if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
        ema_init (str, optional):
            How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
    """
    NORMALIZE = True

NORMALIZE class-attribute

NORMALIZE = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

NystromPCG

Bases: torchzero.core.transform.Transform

Newton's method with a Nyström-preconditioned conjugate gradient solver.

Notes
  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

  • In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply Newton preconditioning to another module's output.

Parameters:

  • rank (int, default: 100 ) –

    size of the sketch for preconditioning, this many hessian-vector products will be evaluated before running the conjugate gradient solver. Larger value improves the preconditioning and speeds up conjugate gradient.

  • maxiter (int | None, default: None ) –

    maximum number of iterations. By default this is set to the number of dimensions in the objective function, which is supposed to be enough for conjugate gradient to have guaranteed convergence. Setting this to a small value can still generate good enough directions. Defaults to None.

  • tol (float, default: 1e-08 ) –

    relative tolerance for conjugate gradient solver. Defaults to 1e-4.

  • reg (float, default: 1e-06 ) –

    regularization parameter. Defaults to 1e-8.

  • hvp_method (str, default: 'batched_autograd' ) –

    Determines how Hessian-vector products are computed.

    • "batched_autograd" - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than "autograd" but uses more memory.
    • "autograd" - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than "batched_autograd" but uses less memory.
    • "fd_forward" - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
    • "fd_central" - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

    Defaults to "autograd".

  • h (float, default: 0.001 ) –

    The step size for finite difference if hvp_method is "fd_forward" or "fd_central". Defaults to 1e-3.

  • inner (Chainable | None, default: None ) –

    modules to apply hessian preconditioner to. Defaults to None.

  • seed (int | None, default: None ) –

    seed for random generator. Defaults to None.

Examples:

NystromPCG with backtracking line search

opt = tz.Optimizer(
    model.parameters(),
    tz.m.NystromPCG(10),
    tz.m.Backtracking()
)
Reference

Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820

Source code in torchzero/modules/second_order/nystrom.py
class NystromPCG(Transform):
    """Newton's method with a Nyström-preconditioned conjugate gradient solver.

    Notes:
        - This module requires the a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients for calculating HVPs.
        The closure must accept a ``backward`` argument (refer to documentation).

        - In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.

    Args:
        rank (int):
            size of the sketch for preconditioning, this many hessian-vector products will be evaluated before
            running the conjugate gradient solver. Larger value improves the preconditioning and speeds up
            conjugate gradient.
        maxiter (int | None, optional):
            maximum number of iterations. By default this is set to the number of dimensions
            in the objective function, which is supposed to be enough for conjugate gradient
            to have guaranteed convergence. Setting this to a small value can still generate good enough directions.
            Defaults to None.
        tol (float, optional): relative tolerance for conjugate gradient solver. Defaults to 1e-4.
        reg (float, optional): regularization parameter. Defaults to 1e-8.
        hvp_method (str, optional):
            Determines how Hessian-vector products are computed.

            - ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
            - ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
            - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

            Defaults to ``"autograd"``.
        h (float, optional):
            The step size for finite difference if ``hvp_method`` is
            ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
        inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
        seed (int | None, optional): seed for random generator. Defaults to None.

    Examples:

    NystromPCG with backtracking line search

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NystromPCG(10),
        tz.m.Backtracking()
    )
    ```

    Reference:
        Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820

    """
    def __init__(
        self,
        rank: int = 100,
        maxiter=None,
        tol=1e-8,
        reg: float = 1e-6,
        update_freq: int = 1, # here update_freq is within update_states
        eigv_tol: float = 0,
        orthogonalize_method: OrthogonalizeMethod = 'qr',
        hvp_method: HVPMethod = "batched_autograd",
        h=1e-3,
        inner: Chainable | None = None,
        seed: int | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner']
        super().__init__(defaults, inner=inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]

        # ---------------------- Hessian vector product function --------------------- #
        # this should run on every update_states
        _, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=fs['hvp_method'], h=fs['h'], at_x0=True)
        objective.temp = H_mv

        # --------------------------- update preconditioner -------------------------- #
        step = self.increment_counter("step", 0)
        if step % fs["update_freq"] == 0:

            ndim = sum(t.numel() for t in objective.params)
            device = objective.params[0].device
            dtype = objective.params[0].dtype
            generator = self.get_generator(device, seed=fs['seed'])

            try:
                Omega = torch.randn(ndim, min(fs["rank"], ndim), device=device, dtype=dtype, generator=generator)
                HOmega = H_mm(orthogonalize(Omega, fs["orthogonalize_method"]))
                # compute the approximation
                L, Q = nystrom_approximation(
                    Omega=Omega,
                    AOmega=HOmega,
                    eigv_tol=fs["eigv_tol"],
                )

                self.global_state["L"] = L
                self.global_state["Q"] = Q

            except torch.linalg.LinAlgError as e:
                warnings.warn(f"Nystrom approximation failed with: {e}")

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        b = objective.get_updates()
        H_mv = objective.poptemp()
        fs = self.settings[objective.params[0]]

        # ----------------------------------- solve ---------------------------------- #
        if "L" not in self.global_state:
            # fallback on cg
            sol = cg(A_mv=H_mv, b=TensorList(b), tol=fs["tol"], reg=fs["reg"], maxiter=fs["maxiter"])
            objective.updates = sol.x
            return objective

        L = self.global_state["L"]
        Q = self.global_state["Q"]

        x = nystrom_pcg(L=L, Q=Q, A_mv=H_mv, b=torch.cat([t.ravel() for t in b]),
                        reg=fs['reg'], tol=fs["tol"], maxiter=fs["maxiter"])

        # -------------------------------- set update -------------------------------- #
        objective.updates = vec_to_tensors(x, reference=objective.params)
        return objective

NystromSketchAndSolve

Bases: torchzero.core.transform.Transform

Newton's method with a Nyström sketch-and-solve solver.

Notes
  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

  • In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply Newton preconditioning to another module's output.

  • If this is unstable, increase the reg parameter and tune the rank.

Parameters:

  • rank (int, default: 100 ) –

    size of the sketch, this many hessian-vector products will be evaluated per step.

  • reg (float | None, default: 0.01 ) –

    scale of identity matrix added to hessian. Note that if this is specified, nystrom sketch-and-solve is used to compute (Q diag(L) Q.T + reg*I)x = b. It is very unstable when reg is small, i.e. smaller than 1e-4. If this is None,(Q diag(L) Q.T)x = b is computed by simply taking reciprocal of eigenvalues. Defaults to 1e-3.

  • eigv_tol (float, default: 0 ) –

    all eigenvalues smaller than largest eigenvalue times eigv_tol are removed. Defaults to None.

  • truncate (int | None, default: None ) –

    keeps top truncate eigenvalues. Defaults to None.

  • damping (float, default: 0 ) –

    scalar added to eigenvalues. Defaults to 0.

  • rdamping (float, default: 0 ) –

    scalar multiplied by largest eigenvalue and added to eigenvalues. Defaults to 0.

  • update_freq (int, default: 1 ) –

    frequency of updating preconditioner. Defaults to 1.

  • hvp_method (str, default: 'batched_autograd' ) –

    Determines how Hessian-vector products are computed.

    • "batched_autograd" - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than "autograd" but uses more memory.
    • "autograd" - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than "batched_autograd" but uses less memory.
    • "fd_forward" - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
    • "fd_central" - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

    Defaults to "autograd".

  • h (float, default: 0.001 ) –

    The step size for finite difference if hvp_method is "fd_forward" or "fd_central". Defaults to 1e-3.

  • inner (Chainable | None, default: None ) –

    modules to apply hessian preconditioner to. Defaults to None.

  • seed (int | None, default: None ) –

    seed for random generator. Defaults to None.

Examples: NystromSketchAndSolve with backtracking line search

opt = tz.Optimizer(
    model.parameters(),
    tz.m.NystromSketchAndSolve(100),
    tz.m.Backtracking()
)

Trust region NystromSketchAndSolve

opt = tz.Optimizer(
    model.parameters(),
    tz.m.LevenbergMarquadt(tz.m.NystromSketchAndSolve(100)),
)

References: - Frangella, Z., Rathore, P., Zhao, S., & Udell, M. (2024). SketchySGD: Reliable Stochastic Optimization via Randomized Curvature Estimates. SIAM Journal on Mathematics of Data Science, 6(4), 1173-1204. - Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752

Source code in torchzero/modules/second_order/nystrom.py
class NystromSketchAndSolve(Transform):
    """Newton's method with a Nyström sketch-and-solve solver.

    Notes:
        - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

        - In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.

        - If this is unstable, increase the ``reg`` parameter and tune the rank.

    Args:
        rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
        reg (float | None, optional):
            scale of identity matrix added to hessian. Note that if this is specified, nystrom sketch-and-solve
            is used to compute ``(Q diag(L) Q.T + reg*I)x = b``. It is very unstable when ``reg`` is small,
            i.e. smaller than 1e-4. If this is None,``(Q diag(L) Q.T)x = b`` is computed by simply taking
            reciprocal of eigenvalues. Defaults to 1e-3.
        eigv_tol (float, optional):
            all eigenvalues smaller than largest eigenvalue times ``eigv_tol`` are removed. Defaults to None.
        truncate (int | None, optional):
            keeps top ``truncate`` eigenvalues. Defaults to None.
        damping (float, optional): scalar added to eigenvalues. Defaults to 0.
        rdamping (float, optional): scalar multiplied by largest eigenvalue and added to eigenvalues. Defaults to 0.
        update_freq (int, optional): frequency of updating preconditioner. Defaults to 1.
        hvp_method (str, optional):
            Determines how Hessian-vector products are computed.

            - ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
            - ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
            - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

            Defaults to ``"autograd"``.
        h (float, optional):
            The step size for finite difference if ``hvp_method`` is
            ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
        inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
        seed (int | None, optional): seed for random generator. Defaults to None.


    Examples:
    NystromSketchAndSolve with backtracking line search

    ```py
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NystromSketchAndSolve(100),
        tz.m.Backtracking()
    )
    ```

    Trust region NystromSketchAndSolve

    ```py
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.LevenbergMarquadt(tz.m.NystromSketchAndSolve(100)),
    )
    ```

    References:
    - [Frangella, Z., Rathore, P., Zhao, S., & Udell, M. (2024). SketchySGD: Reliable Stochastic Optimization via Randomized Curvature Estimates. SIAM Journal on Mathematics of Data Science, 6(4), 1173-1204.](https://arxiv.org/pdf/2211.08597)
    - [Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752](https://arxiv.org/abs/2110.02820)

    """
    def __init__(
        self,
        rank: int = 100,
        reg: float | None = 1e-2,
        eigv_tol: float = 0,
        truncate: int | None = None,
        damping: float = 0,
        rdamping: float = 0,
        update_freq: int = 1,
        orthogonalize_method: OrthogonalizeMethod = 'qr',
        hvp_method: HVPMethod = "batched_autograd",
        h: float = 1e-3,
        inner: Chainable | None = None,
        seed: int | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults["update_freq"]
        super().__init__(defaults, update_freq=update_freq, inner=inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        params = TensorList(objective.params)
        fs = settings[0]

        # ---------------------- Hessian vector product function --------------------- #
        hvp_method = fs['hvp_method']
        h = fs['h']
        _, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)

        # ---------------------------------- sketch ---------------------------------- #
        ndim = sum(t.numel() for t in objective.params)
        device = params[0].device
        dtype = params[0].dtype

        generator = self.get_generator(params[0].device, seed=fs['seed'])
        try:
            Omega = torch.randn([ndim, min(fs["rank"], ndim)], device=device, dtype=dtype, generator=generator)
            Omega = orthogonalize(Omega, fs["orthogonalize_method"])
            HOmega = H_mm(Omega)

            # compute the approximation
            L, Q = nystrom_approximation(
                Omega=Omega,
                AOmega=HOmega,
                eigv_tol=fs["eigv_tol"],
            )

            # regularize
            L, Q = regularize_eigh(
                L=L,
                Q=Q,
                truncate=fs["truncate"],
                tol=fs["eigv_tol"],
                damping=fs["damping"],
                rdamping=fs["rdamping"],
            )

            # store
            if L is not None:
                self.global_state["L"] = L
                self.global_state["Q"] = Q

        except torch.linalg.LinAlgError as e:
            warnings.warn(f"Nystrom approximation failed with: {e}")

    def apply_states(self, objective, states, settings):
        if "L" not in self.global_state:
            return objective

        fs = settings[0]
        updates = objective.get_updates()
        b=torch.cat([t.ravel() for t in updates])

        # ----------------------------------- solve ---------------------------------- #
        L = self.global_state["L"]
        Q = self.global_state["Q"]

        if fs["reg"] is None:
            x = Q @ ((Q.mH @ b) / L)
        else:
            x = nystrom_sketch_and_solve(L=L, Q=Q, b=b, reg=fs["reg"])

        # -------------------------------- set update -------------------------------- #
        objective.updates = vec_to_tensors(x, reference=objective.params)
        return objective

    def get_H(self, objective=...):
        if "L" not in self.global_state:
            return ScaledIdentity()

        L = self.global_state["L"]
        Q = self.global_state["Q"]
        return Eigendecomposition(L, Q)

Ones

Bases: torchzero.core.module.Module

Outputs ones

Source code in torchzero/modules/ops/utility.py
class Ones(Module):
    """Outputs ones"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def apply(self, objective):
        objective.updates = [torch.ones_like(p) for p in objective.params]
        return objective

Online

Bases: torchzero.core.module.Module

Allows certain modules to be used for mini-batch optimization.

Examples:

Online L-BFGS with Backtracking line search

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Online(tz.m.LBFGS()),
    tz.m.Backtracking()
)

Online L-BFGS trust region

opt = tz.Optimizer(
    model.parameters(),
    tz.m.TrustCG(tz.m.Online(tz.m.LBFGS()))
)

Source code in torchzero/modules/misc/multistep.py
class Online(Module):
    """Allows certain modules to be used for mini-batch optimization.

    Examples:

    Online L-BFGS with Backtracking line search
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Online(tz.m.LBFGS()),
        tz.m.Backtracking()
    )
    ```

    Online L-BFGS trust region
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.TrustCG(tz.m.Online(tz.m.LBFGS()))
    )
    ```

    """
    def __init__(self, module: Module,):
        super().__init__()
        self.set_child('module', module)

    @torch.no_grad
    def update(self, objective):
        closure = objective.closure
        if closure is None: raise ValueError("Closure must be passed for Online")

        step = self.increment_counter("step", start = 0)

        params = TensorList(objective.params)
        p_cur = params.clone()
        p_prev = self.get_state(params, 'p_prev', cls=TensorList)

        module = self.children['module']
        var_c = objective.clone(clone_updates=False)

        # on 1st step just step and store previous params
        if step == 0:
            p_prev.copy_(params)

            module.update(var_c)
            objective.update_attrs_from_clone_(var_c)
            return

        # restore previous params and update
        prev_objective = Objective(params=params, closure=closure, model=objective.model, current_step=objective.current_step)
        params.set_(p_prev)
        module.reset_for_online()
        module.update(prev_objective)

        # restore current params and update
        params.set_(p_cur)
        p_prev.copy_(params)
        module.update(var_c)
        objective.update_attrs_from_clone_(var_c)

    @torch.no_grad
    def apply(self, objective):
        module = self.children['module']
        return module.apply(objective.clone(clone_updates=False))

    def get_H(self, objective):
        return self.children['module'].get_H(objective)

OrthoGrad

Bases: torchzero.core.transform.TensorTransform

Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

Parameters:

  • eps (float, default: 1e-08 ) –

    epsilon added to the denominator for numerical stability (default: 1e-30)

  • renormalize (bool, default: True ) –

    whether to graft projected gradient to original gradient norm. Defaults to True.

Source code in torchzero/modules/adaptive/orthograd.py
class OrthoGrad(TensorTransform):
    """Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

    Args:
        eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
        renormalize (bool, optional): whether to graft projected gradient to original gradient norm. Defaults to True.
    """
    def __init__(self, eps: float = 1e-8, renormalize=True):
        defaults = dict(eps=eps, renormalize=renormalize)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        eps = settings[0]['eps']
        renormalize = settings[0]['renormalize']

        params = TensorList(params)
        target = TensorList(tensors)

        scale = params.dot(target)/(params.dot(params) + eps)
        if renormalize:
            norm = target.global_vector_norm()
            target -= params * scale
            target *= (norm / target.global_vector_norm())
            return target

        target -= params * scale
        return target

Orthogonalize

Bases: torchzero.core.transform.TensorTransform

Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.

To disable orthogonalization for a parameter, put it into a parameter group with "orthogonalize" = False. The Muon page says that embeddings and classifier heads should not be orthogonalized. Usually only matrix parameters that are directly used in matmuls should be orthogonalized.

To make Muon, use Split with Adam on 1d params

Parameters:

  • adjust_lr (bool, default: False ) –

    Enables LR adjustment based on parameter size from "Muon is Scalable for LLM Training". Defaults to False.

  • dual_norm_correction (bool, default: False ) –

    enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.

  • method (str, default: 'newtonschulz' ) –

    Newton-Schulz is very fast, SVD is slow but can be more precise.

  • channel_first (bool, default: True ) –

    if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions are considered batch dimensions.

Examples:

standard Muon with Adam fallback

opt = tz.Optimizer(
    model.head.parameters(),
    tz.m.Split(
        # apply muon only to 2D+ parameters
        filter = lambda t: t.ndim >= 2,
        true = [
            tz.m.HeavyBall(),
            tz.m.Orthogonalize(),
            tz.m.LR(1e-2),
        ],
        false = tz.m.Adam()
    ),
    tz.m.LR(1e-2)
)

Reference

Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon

Source code in torchzero/modules/adaptive/muon.py
class Orthogonalize(TensorTransform):
    """Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.

    To disable orthogonalization for a parameter, put it into a parameter group with "orthogonalize" = False.
    The Muon page says that embeddings and classifier heads should not be orthogonalized.
    Usually only matrix parameters that are directly used in matmuls should be orthogonalized.

    To make Muon, use Split with Adam on 1d params

    Args:
        adjust_lr (bool, optional):
            Enables LR adjustment based on parameter size from "Muon is Scalable for LLM Training". Defaults to False.
        dual_norm_correction (bool, optional):
            enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
        method (str, optional):
            Newton-Schulz is very fast, SVD is slow but can be more precise.
        channel_first (bool, optional):
            if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions
            are considered batch dimensions.

    ## Examples:

    standard Muon with Adam fallback
    ```py
    opt = tz.Optimizer(
        model.head.parameters(),
        tz.m.Split(
            # apply muon only to 2D+ parameters
            filter = lambda t: t.ndim >= 2,
            true = [
                tz.m.HeavyBall(),
                tz.m.Orthogonalize(),
                tz.m.LR(1e-2),
            ],
            false = tz.m.Adam()
        ),
        tz.m.LR(1e-2)
    )
    ```

    Reference:
        Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
    """
    def __init__(self, adjust_lr=False, dual_norm_correction=False,
                 method: OrthogonalizeMethod = 'newtonschulz', channel_first:bool=True):
        defaults = dict(orthogonalize=True, dual_norm_correction=dual_norm_correction, adjust_lr=adjust_lr, method=method.lower(), channel_first=channel_first)
        super().__init__(defaults=defaults)

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        orthogonalize, dual_norm_correction, adjust_lr, method, channel_first = itemgetter(
            'orthogonalize', 'dual_norm_correction', 'adjust_lr', 'method', 'channel_first')(setting)

        if not orthogonalize: return tensor

        if _is_at_least_2d(tensor, channel_first=channel_first):

            X = _orthogonalize_format(tensor, method, channel_first=channel_first)

            if dual_norm_correction:
                X = _dual_norm_correction(X, tensor, channel_first=channel_first)

            if adjust_lr:
                X.mul_(adjust_lr_for_muon(1, param.shape, channel_first=channel_first))

            return X.view_as(param)

        return tensor

PSB

Bases: torchzero.modules.quasi_newton.quasi_newton._HessianUpdateStrategyDefaults

Powell's Symmetric Broyden Quasi-Newton method.

Note

a line search or a trust region is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class PSB(_HessianUpdateStrategyDefaults):
    """Powell's Symmetric Broyden Quasi-Newton method.

    Note:
        a line search or a trust region is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
    """
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return psb_B_(B=B, s=s, y=y)

PSGDDenseNewton

Bases: torchzero.core.transform.Transform

Dense hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

Parameters:

  • init_scale (float | None, default: None ) –

    initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.

  • lr_preconditioner (float, default: 0.1 ) –

    learning rate of the preconditioner. Defaults to 0.1.

  • betaL (float, default: 0.9 ) –

    EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.

  • damping (float, default: 1e-09 ) –

    adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.

  • grad_clip_max_norm (float, default: inf ) –

    clips norm of the update. Defaults to float("inf").

  • update_probability (float, default: 1.0 ) –

    probability of updating preconditioner on each step. Defaults to 1.0.

  • dQ (str, default: 'Q0.5EQ1.5' ) –

    geometry for preconditioner update. Defaults to "Q0.5EQ1.5".

  • hvp_method (Literal, default: 'autograd' ) –

    how to compute hessian-vector products. Defaults to 'autograd'.

  • h (float, default: 0.001 ) –

    if hvp_method is "fd_central" or "fd_forward", controls finite difference step size. Defaults to 1e-3.

  • distribution (Literal, default: 'normal' ) –

    distribution for random vectors for hessian-vector products. Defaults to 'normal'.

  • inner (Chainable | None, default: None ) –

    preconditioning will be applied to output of this module. Defaults to None.

Examples:

Pure Dense Newton PSGD:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.DenseNewton(),
    tz.m.LR(1e-3),
)

Applying preconditioner to momentum:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.DenseNewton(inner=tz.m.EMA(0.9)),
    tz.m.LR(1e-3),
)

Source code in torchzero/modules/adaptive/psgd/psgd_dense_newton.py
class PSGDDenseNewton(Transform):
    """Dense hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

    Args:
        init_scale (float | None, optional):
            initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
        lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
        betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
        damping (float, optional):
            adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.
        grad_clip_max_norm (float, optional): clips norm of the update. Defaults to float("inf").
        update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
        dQ (str, optional): geometry for preconditioner update. Defaults to "Q0.5EQ1.5".
        hvp_method (HVPMethod, optional): how to compute hessian-vector products. Defaults to 'autograd'.
        h (float, optional):
            if ``hvp_method`` is ``"fd_central"`` or ``"fd_forward"``, controls finite difference step size.
            Defaults to 1e-3.
        distribution (Distributions, optional):
            distribution for random vectors for hessian-vector products. Defaults to 'normal'.

        inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.

    ###Examples:

    Pure Dense Newton PSGD:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.DenseNewton(),
        tz.m.LR(1e-3),
    )
    ```

    Applying preconditioner to momentum:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.DenseNewton(inner=tz.m.EMA(0.9)),
        tz.m.LR(1e-3),
    )
    ```
    """
    def __init__(
        self,
        init_scale: float | None = None,
        lr_preconditioner=0.1,
        betaL=0.9,
        damping=1e-9,
        grad_clip_max_norm=float("inf"),
        update_probability=1.0,
        dQ: Literal["QUAD4P", "QUAD", "QEP", "EQ", "QEQ", "Q0p5EQ1p5", "Q0.5EQ1.5"] = "Q0.5EQ1.5",

        hvp_method: HVPMethod = 'autograd',
        h: float = 1e-3,
        distribution: Distributions = 'normal',

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults["inner"], defaults["self"]
        super().__init__(defaults, inner=inner)


    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]

        # -------------------------------- initialize -------------------------------- #
        if "Q" not in self.global_state:

            p = objective.params[0]
            dQ = fs["dQ"]
            init_scale = fs["init_scale"]

            if init_scale is None:
                self.global_state["Q"] = None

            else:
                n = sum(p.numel() for p in objective.params)
                if dQ == "QUAD4P":
                    init_scale *= init_scale
                self.global_state["Q"] = torch.eye(n, dtype=p.dtype, device=p.device) * init_scale

            self.global_state["L"] = lift2single(torch.zeros([], dtype=p.dtype, device=p.device)) # Lipschitz smoothness constant estimation for the psgd criterion

            if dQ == "QUAD4P":
                self.global_state["update_precond"] = update_precond_dense_quad4p
                self.global_state["precond_grad"] = lambda Q, g: Q @ g
                assert torch.finfo(p.dtype).eps < 1e-6, "Directly fitting P needs at least single precision"

            elif dQ == "QUAD":
                self.global_state["update_precond"] = update_precond_dense_quad
                self.global_state["precond_grad"] = lambda Q, g: Q @ (Q @ g) # Q is symmetric; just save one transpose

            else:
                self.global_state["precond_grad"] = lambda Q, g: Q.T @ (Q @ g)
                if dQ == "QEP":
                    self.global_state["update_precond"] = update_precond_dense_qep
                elif dQ == "EQ":
                    self.global_state["update_precond"] = update_precond_dense_eq
                elif dQ == "QEQ":
                    self.global_state["update_precond"] = update_precond_dense_qeq
                else:
                    assert (dQ == "Q0p5EQ1p5") or (dQ == "Q0.5EQ1.5"), f"Invalid choice for dQ: '{dQ}'"
                    self.global_state["update_precond"] = update_precond_dense_q0p5eq1p5

        # ---------------------------------- update ---------------------------------- #
        Q = self.global_state["Q"]
        if (torch.rand([]) < fs["update_probability"]) or Q is None:

            # hessian-vector product
            vs = TensorList(objective.params).sample_like(distribution=fs["distribution"])
            Hvs, _ = objective.hessian_vector_product(z=vs, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])

            v = torch.cat([t.ravel() for t in vs]).unsqueeze(1)
            h = torch.cat([t.ravel() for t in Hvs]).unsqueeze(1)

            # initialize on the fly
            if Q is None:
                scale = (torch.mean(v*v))**(1/4) * (torch.mean(h**4) + fs["damping"]**4)**(-1/8)
                if fs["dQ"] == "QUAD4P": # Q actually is P in this case
                    scale *= scale
                Q = self.global_state["Q"] = torch.eye(len(v), dtype=v.dtype, device=v.device) * scale

            # update preconditioner
            self.global_state["update_precond"](
                Q=Q,
                L=self.global_state["L"],
                v=v,
                h=h,
                lr=fs["lr_preconditioner"],
                betaL=fs["betaL"],
                damping=fs["damping"],
            )

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        updates = objective.get_updates()

        # cat grads
        g = torch.cat([t.ravel() for t in updates]).unsqueeze(1) # column vec
        pre_grad = self.global_state["precond_grad"](self.global_state["Q"], g)

        # norm clipping
        grad_clip_max_norm = settings[0]["grad_clip_max_norm"]
        if grad_clip_max_norm < float("inf"): # clip preconditioned gradient
            grad_norm = torch.linalg.vector_norm(pre_grad)
            if grad_norm > grad_clip_max_norm:
                pre_grad *= grad_clip_max_norm / grad_norm

        vec_to_tensors_(pre_grad, updates)
        return objective

PSGDKronNewton

Bases: torchzero.core.transform.Transform

Kron hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

Parameters:

  • max_dim (int, default: 10000 ) –

    dimensions with size larger than this use diagonal preconditioner. Defaults to 10_000.

  • max_skew (float, default: 1.0 ) –

    if memory used by full preconditioner (dim^2) is larger than total number of elements in a parameter times max_skew, it uses a diagonal preconditioner. Defaults to 1.0.

  • init_scale (float | None, default: None ) –

    initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.

  • lr_preconditioner (float, default: 0.1 ) –

    learning rate of the preconditioner. Defaults to 0.1.

  • betaL (float, default: 0.9 ) –

    EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.

  • damping (float, default: 1e-09 ) –

    adds small noise to gradient when updating the preconditioner. Defaults to 1e-9.

  • grad_clip_max_amp (float, default: inf ) –

    clips amplitude of the update. Defaults to float("inf").

  • update_probability (float, default: 1.0 ) –

    probability of updating preconditioner on each step. Defaults to 1.0.

  • dQ (str, default: 'Q0.5EQ1.5' ) –

    geometry for preconditioner update. Defaults to "Q0.5EQ1.5".

  • balance_probability (float, default: 0.01 ) –

    probablility of balancing the dynamic ranges of the factors of Q to avoid over/under-flow on each step. Defaults to 0.01.

  • hvp_method (Literal, default: 'autograd' ) –

    how to compute hessian-vector products. Defaults to 'autograd'.

  • h (float, default: 0.001 ) –

    if hvp_method is "fd_central" or "fd_forward", controls finite difference step size. Defaults to 1e-3.

  • distribution (Literal, default: 'normal' ) –

    distribution for random vectors for hessian-vector products. Defaults to 'normal'.

  • inner (Chainable | None, default: None ) –

    preconditioning will be applied to output of this module. Defaults to None.

Examples:

Pure PSGD Kron Newton:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.KronNewton(),
    tz.m.LR(1e-3),
)

Applying preconditioner to momentum:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.KronNewton(inner=tz.m.EMA(0.9)),
    tz.m.LR(1e-3),
)

Source code in torchzero/modules/adaptive/psgd/psgd_kron_newton.py
class PSGDKronNewton(Transform):
    """Kron hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

    Args:
        max_dim (int, optional): dimensions with size larger than this use diagonal preconditioner. Defaults to 10_000.
        max_skew (float, optional):
            if memory used by full preconditioner (dim^2) is larger than total number of elements in a parameter times ``max_skew``, it uses a diagonal preconditioner. Defaults to 1.0.
        init_scale (float | None, optional):
            initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
        lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
        betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
        damping (float, optional): adds small noise to gradient when updating the preconditioner. Defaults to 1e-9.
        grad_clip_max_amp (float, optional): clips amplitude of the update. Defaults to float("inf").
        update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
        dQ (str, optional): geometry for preconditioner update. Defaults to "Q0.5EQ1.5".
        balance_probability (float, optional):
            probablility of balancing the dynamic ranges of the factors of Q to avoid over/under-flow on each step. Defaults to 0.01.
        hvp_method (HVPMethod, optional): how to compute hessian-vector products. Defaults to 'autograd'.
        h (float, optional):
            if ``hvp_method`` is ``"fd_central"`` or ``"fd_forward"``, controls finite difference step size.
            Defaults to 1e-3.
        distribution (Distributions, optional):
            distribution for random vectors for hessian-vector products. Defaults to 'normal'.
        inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.


    ###Examples:

    Pure PSGD Kron Newton:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.KronNewton(),
        tz.m.LR(1e-3),
    )
    ```

    Applying preconditioner to momentum:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.KronNewton(inner=tz.m.EMA(0.9)),
        tz.m.LR(1e-3),
    )
    ```
    """
    def __init__(
        self,
        max_dim: int = 10_000,
        max_skew: float = 1.0,
        init_scale: float | None = None,
        lr_preconditioner: float = 0.1,
        betaL: float = 0.9,
        damping: float = 1e-9,
        grad_clip_max_amp: float = float("inf"),
        update_probability: float= 1.0,
        dQ: Literal["QEP", "EQ", "QEQ", "QUAD",  "Q0.5EQ1.5", "Q0p5EQ1p5", "QUAD4P"] = "Q0.5EQ1.5",
        balance_probability: float = 0.01,

        hvp_method: HVPMethod = 'autograd',
        h: float = 1e-3,
        distribution: Distributions = 'normal',

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults["inner"], defaults["self"]
        super().__init__(defaults, inner=inner)


    def _initialize_state(self, param, state, setting):
        assert "initialized" not in state
        state["initialized"] = True

        # initialize preconditioners
        if setting["init_scale"] is None:
            warnings.warn("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
            state["QLs_exprs"] = None
        else:
            state["QLs_exprs"] = init_kron(
                param.squeeze(),
                Scale=setting["init_scale"],
                max_size=setting["max_dim"],
                max_skew=setting["max_skew"],
                dQ=setting["dQ"],
            )

        dQ = setting["dQ"]
        if dQ == "QUAD4P":
            assert torch.finfo(param.dtype).eps < 1e-6, "Directly fitting P needs at least single precision"
            state["update_precond"] = update_precond_kron_newton_quad4p
            state["precond_grad"] = lambda QL, exprs, G: exprs[0](*QL[0], G) # it's exprA(*Q, G)

        else:
            state["precond_grad"] = precond_grad_kron
            if dQ == "QEP":
                state["update_precond"] = update_precond_kron_newton_quad
            elif dQ == "EQ":
                state["update_precond"] = update_precond_kron_newton_qep
            elif dQ == "QEQ":
                state["update_precond"] = update_precond_kron_newton_eq
            elif dQ == "QUAD":
                state["update_precond"] = update_precond_kron_newton_qeq
            else:
                assert (dQ == "Q0.5EQ1.5") or (dQ == "Q0p5EQ1p5"), f"Invalid choice for dQ: '{dQ}'"
                state["update_precond"] = update_precond_kron_newton_q0p5eq1p5

    @torch.no_grad
    def update_states(self, objective, states, settings):

        # initialize states
        for param, state, setting in zip(objective.params, states, settings):
            if "initialized" not in state:
                self._initialize_state(param, state, setting)

        fs = settings[0]

        uninitialized = any(state["QLs_exprs"] is None for state in states)
        if (torch.rand([]) < fs["update_probability"]) or uninitialized:

            # hessian-vector product
            vs = TensorList(objective.params).sample_like(distribution=fs["distribution"])
            Hvs, _ = objective.hessian_vector_product(z=vs, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])

            # initialize on the fly (why does it use vs?)
            if uninitialized:

                scale = (sum([torch.sum(torch.abs(v)**2) for v in vs])/sum([v.numel() for v in vs])) ** (1/4) # (mean(|v|^2))^(1/4)

                scale = scale * (max([torch.mean((torch.abs(h))**4) for h in Hvs]) + fs["damping"]**4) ** (-1/8) # (mean(|v|^2))^(1/4) * (mean(|h|^4))^(-1/8)

                for h, state, setting in zip(Hvs, states, settings):
                    if state["QLs_exprs"] is None:
                        state["QLs_exprs"] = init_kron(
                            h.squeeze(),
                            Scale=scale,
                            max_size=setting["max_dim"],
                            max_skew=setting["max_skew"],
                            dQ=setting["dQ"],
                        )

            # update preconditioner
            for v, h, state, setting in zip(vs, Hvs, states, settings):
                state["update_precond"](
                    *state["QLs_exprs"],
                    v.squeeze(),
                    h.squeeze(),
                    lr=setting["lr_preconditioner"],
                    betaL=setting["betaL"],
                    damping=setting["damping"],
                    balance_prob=setting["balance_probability"]
                )

    @torch.no_grad
    def apply_states(self, objective, states, settings):

        params = objective.params
        tensors = objective.get_updates()
        pre_tensors = []

        # precondition
        for param, tensor, state in zip(params, tensors, states):
            t = state["precond_grad"](
                *state["QLs_exprs"],
                tensor.squeeze(),
            )
            pre_tensors.append(t.view_as(param))

        # norm clipping
        grad_clip_max_amp = settings[0]["grad_clip_max_amp"]
        if grad_clip_max_amp < math.inf:
            pre_tensors = TensorList(pre_tensors)
            num_params = sum(t.numel() for t in pre_tensors)

            avg_amp = pre_tensors.dot(pre_tensors.conj()).div(num_params).sqrt()

            if avg_amp > grad_clip_max_amp:
                torch._foreach_mul_(pre_tensors, grad_clip_max_amp / avg_amp)

        objective.updates = pre_tensors
        return objective

PSGDKronWhiten

Bases: torchzero.core.transform.TensorTransform

Kron whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

Parameters:

  • max_dim (int, default: 10000 ) –

    dimensions with size larger than this use diagonal preconditioner. Defaults to 10_000.

  • max_skew (float, default: 1.0 ) –

    if memory used by full preconditioner (dim^2) is larger than total number of elements in a parameter times max_skew, it uses a diagonal preconditioner. Defaults to 1.0.

  • init_scale (float | None, default: None ) –

    initial scale of the preconditioner. If None, determined from magnitude of the first gradient. Defaults to None.

  • lr_preconditioner (float, default: 0.1 ) –

    learning rate of the preconditioner. Defaults to 0.1.

  • betaL (float, default: 0.9 ) –

    EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.

  • damping (float, default: 1e-09 ) –

    adds small noise to gradient when updating the preconditioner. Defaults to 1e-9.

  • grad_clip_max_amp (float, default: inf ) –

    clips amplitude of the update. Defaults to float("inf").

  • update_probability (float, default: 1.0 ) –

    probability of updating preconditioner on each step. Defaults to 1.0.

  • dQ (str, default: 'Q0.5EQ1.5' ) –

    geometry for preconditioner update. Defaults to "Q0.5EQ1.5".

  • balance_probability (float, default: 0.01 ) –

    probablility of balancing the dynamic ranges of the factors of Q to avoid over/under-flow on each step. Defaults to 0.01.

  • inner (Chainable | None, default: None ) –

    preconditioning will be applied to output of this module. Defaults to None.

Examples:

Pure PSGD Kron:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.KronWhiten(),
    tz.m.LR(1e-3),
)

Momentum into preconditioner (whitens momentum):

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.EMA(0.9),
    tz.m.KronWhiten(),
    tz.m.LR(1e-3),
)

Updating the preconditioner from gradients and applying it to momentum:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.KronWhiten(inner=tz.m.EMA(0.9)),
    tz.m.LR(1e-3),
)

Source code in torchzero/modules/adaptive/psgd/psgd_kron_whiten.py
class PSGDKronWhiten(TensorTransform):
    """Kron whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

    Args:
        max_dim (int, optional): dimensions with size larger than this use diagonal preconditioner. Defaults to 10_000.
        max_skew (float, optional):
            if memory used by full preconditioner (dim^2) is larger than total number of elements in a parameter times ``max_skew``, it uses a diagonal preconditioner. Defaults to 1.0.
        init_scale (float | None, optional):
            initial scale of the preconditioner. If None, determined from magnitude of the first gradient. Defaults to None.
        lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
        betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
        damping (float, optional): adds small noise to gradient when updating the preconditioner. Defaults to 1e-9.
        grad_clip_max_amp (float, optional): clips amplitude of the update. Defaults to float("inf").
        update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
        dQ (str, optional): geometry for preconditioner update. Defaults to "Q0.5EQ1.5".
        balance_probability (float, optional):
            probablility of balancing the dynamic ranges of the factors of Q to avoid over/under-flow on each step. Defaults to 0.01.

        inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.

    ###Examples:

    Pure PSGD Kron:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.KronWhiten(),
        tz.m.LR(1e-3),
    )
    ```

    Momentum into preconditioner (whitens momentum):
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.EMA(0.9),
        tz.m.KronWhiten(),
        tz.m.LR(1e-3),
    )
    ```

    Updating the preconditioner from gradients and applying it to momentum:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.KronWhiten(inner=tz.m.EMA(0.9)),
        tz.m.LR(1e-3),
    )
    ```

    """
    def __init__(
        self,
        max_dim: int = 10_000,
        max_skew: float = 1.0,
        init_scale: float | None = None,
        lr_preconditioner: float = 0.1,
        betaL: float = 0.9,
        damping: float = 1e-9,
        grad_clip_max_amp: float = float("inf"),
        update_probability: float= 1.0,
        dQ: Literal["QEP", "EQ", "QEQ", "QUAD",  "Q0.5EQ1.5", "Q0p5EQ1p5", "QUAD4P"] = "Q0.5EQ1.5",
        balance_probability: float = 0.01,

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults["inner"], defaults["self"]
        super().__init__(defaults, inner=inner)

    @torch.no_grad
    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        # initialize preconditioners
        if setting["init_scale"] is None:
            # warnings.warn("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
            state["QLs_exprs"] = None
        else:
            state["QLs_exprs"] = init_kron(
                param.squeeze(),
                Scale=setting["init_scale"],
                max_size=setting["max_dim"],
                max_skew=setting["max_skew"],
                dQ=setting["dQ"],
            )

        dQ = setting["dQ"]
        if dQ == "QUAD4P":
            assert torch.finfo(param.dtype).eps < 1e-6, "Directly fitting P needs at least single precision"
            state["update_precond"] = update_precond_kron_whiten_quad4p
            state["precond_grad"] = lambda QL, exprs, G: exprs[0](*QL[0], G) # it's exprA(*Q, G)

        else:
            state["precond_grad"] = precond_grad_kron
            if dQ == "QEP":
                state["update_precond"] = update_precond_kron_whiten_qep
            elif dQ == "EQ":
                state["update_precond"] = update_precond_kron_whiten_eq
            elif dQ == "QEQ":
                state["update_precond"] = update_precond_kron_whiten_qeq
            elif dQ == "QUAD":
                state["update_precond"] = update_precond_kron_whiten_quad
            else:
                assert (dQ == "Q0.5EQ1.5") or (dQ == "Q0p5EQ1p5"), f"Invalid choice for dQ: '{dQ}'"
                state["update_precond"] = update_precond_kron_whiten_q0p5eq1p5

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):

        # initialize on the fly if not initialized
        if any(state["QLs_exprs"] is None for state in states):

            scale = max([torch.mean((torch.abs(g))**4) for g in tensors])
            scale = (scale + settings[0]["damping"]**4)**(-1/8)

            for param, state, setting in zip(params, states, settings):
                if state["QLs_exprs"] is None:
                    state["QLs_exprs"] = init_kron(
                        param.squeeze(),
                        Scale=scale,
                        max_size=setting["max_dim"],
                        max_skew=setting["max_skew"],
                        dQ=setting["dQ"],
                    )


        # update preconditioners
        # (could also try per-parameter probability)
        if torch.rand([]) < settings[0]["update_probability"]: # update Q
            for tensor, state, setting in zip(tensors, states, settings):
                state["update_precond"](
                    *state["QLs_exprs"],
                    tensor.squeeze(),
                    lr=setting["lr_preconditioner"],
                    betaL=setting["betaL"],
                    damping=setting["damping"],
                    balance_prob=setting["balance_probability"]
                )

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):

        pre_tensors = []

        # precondition
        for param, tensor, state in zip(params, tensors, states):
            t = state["precond_grad"](
                *state["QLs_exprs"],
                tensor.squeeze(),
            )
            pre_tensors.append(t.view_as(param))

        # norm clipping
        grad_clip_max_amp = settings[0]["grad_clip_max_amp"]
        if grad_clip_max_amp < math.inf:
            pre_tensors = TensorList(pre_tensors)
            num_params = sum(t.numel() for t in pre_tensors)

            avg_amp = pre_tensors.dot(pre_tensors.conj()).div(num_params).sqrt()

            if avg_amp > grad_clip_max_amp:
                torch._foreach_mul_(pre_tensors, grad_clip_max_amp / avg_amp)

        return pre_tensors

PSGDLRANewton

Bases: torchzero.core.transform.Transform

Low rank hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

Parameters:

  • rank (int, default: 10 ) –

    Preconditioner has a diagonal part and a low rank part, whose rank is decided by this setting. Defaults to 10.

  • init_scale (float | None, default: None ) –

    initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.

  • lr_preconditioner (float, default: 0.1 ) –

    learning rate of the preconditioner. Defaults to 0.1.

  • betaL (float, default: 0.9 ) –

    EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.

  • damping (float, default: 1e-09 ) –

    adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.

  • grad_clip_max_norm (float, default: inf ) –

    clips norm of the update. Defaults to float("inf").

  • update_probability (float, default: 1.0 ) –

    probability of updating preconditioner on each step. Defaults to 1.0.

  • hvp_method (Literal, default: 'autograd' ) –

    how to compute hessian-vector products. Defaults to 'autograd'.

  • h (float, default: 0.001 ) –

    if hvp_method is "fd_central" or "fd_forward", controls finite difference step size. Defaults to 1e-3.

  • distribution (Literal, default: 'normal' ) –

    distribution for random vectors for hessian-vector products. Defaults to 'normal'.

  • inner (Chainable | None, default: None ) –

    preconditioning will be applied to output of this module. Defaults to None.

Examples:

Pure LRA Newton PSGD:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.LRANewton(),
    tz.m.LR(1e-3),
)

Applying preconditioner to momentum:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.LRANewton(inner=tz.m.EMA(0.9)),
    tz.m.LR(1e-3),
)

Source code in torchzero/modules/adaptive/psgd/psgd_lra_newton.py
class PSGDLRANewton(Transform):
    """Low rank hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

    Args:
        rank (int, optional):
            Preconditioner has a diagonal part and a low rank part, whose rank is decided by this setting. Defaults to 10.
        init_scale (float | None, optional):
            initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
        lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
        betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
        damping (float, optional):
            adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.
        grad_clip_max_norm (float, optional): clips norm of the update. Defaults to float("inf").
        update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
        hvp_method (HVPMethod, optional): how to compute hessian-vector products. Defaults to 'autograd'.
        h (float, optional):
            if ``hvp_method`` is ``"fd_central"`` or ``"fd_forward"``, controls finite difference step size.
            Defaults to 1e-3.
        distribution (Distributions, optional):
            distribution for random vectors for hessian-vector products. Defaults to 'normal'.

        inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.

    ###Examples:

    Pure LRA Newton PSGD:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.LRANewton(),
        tz.m.LR(1e-3),
    )
    ```

    Applying preconditioner to momentum:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.LRANewton(inner=tz.m.EMA(0.9)),
        tz.m.LR(1e-3),
    )
    ```
    """
    def __init__(
        self,
        rank: int = 10,
        init_scale: float | None = None,
        lr_preconditioner=0.1,
        betaL=0.9,
        damping=1e-9,
        grad_clip_max_norm=float("inf"),
        update_probability=1.0,

        hvp_method: HVPMethod = 'autograd',
        h: float = 1e-3,
        distribution: Distributions = 'normal',

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults["inner"], defaults["self"]
        super().__init__(defaults, inner=inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]

        # initialize
        if "UVd" not in self.global_state:
            p = torch.cat([t.ravel() for t in objective.params])
            _initialize_lra_state_(p, self.global_state, fs)

        UVd = self.global_state["UVd"]
        if (torch.rand([]) < fs["update_probability"]) or (UVd[2] is None):

            # hessian-vector product
            vs = TensorList(objective.params).sample_like(distribution=fs["distribution"])
            Hvs, _ = objective.hessian_vector_product(z=vs, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])

            v = torch.cat([t.ravel() for t in vs]).unsqueeze(1)
            h = torch.cat([t.ravel() for t in Hvs]).unsqueeze(1)

            if UVd[2] is None:
                UVd[2] = (torch.mean(v*v))**(1/4) * (torch.mean(h**4) + fs["damping"]**4)**(-1/8) * torch.ones_like(v)

            # update preconditioner
            update_precond_lra_newton(UVd=UVd, Luvd=self.global_state["Luvd"], v=v, h=h, lr=fs["lr_preconditioner"], betaL=fs["betaL"], damping=fs["damping"])


    @torch.no_grad
    def apply_states(self, objective, states, settings):
        updates = objective.get_updates()

        g = torch.cat([t.ravel() for t in updates]).unsqueeze(1) # column vec
        pre_grad = precond_grad_lra(UVd=self.global_state["UVd"], g=g)

        # norm clipping
        grad_clip_max_norm = settings[0]["grad_clip_max_norm"]
        if grad_clip_max_norm < float("inf"): # clip preconditioned gradient
            grad_norm = torch.linalg.vector_norm(pre_grad)
            if grad_norm > grad_clip_max_norm:
                pre_grad *= grad_clip_max_norm / grad_norm

        vec_to_tensors_(pre_grad, updates)
        return objective

PSGDLRAWhiten

Bases: torchzero.core.transform.TensorTransform

Low rank whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

Parameters:

  • rank (int, default: 10 ) –

    Preconditioner has a diagonal part and a low rank part, whose rank is decided by this setting. Defaults to 10.

  • init_scale (float | None, default: None ) –

    initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.

  • lr_preconditioner (float, default: 0.1 ) –

    learning rate of the preconditioner. Defaults to 0.1.

  • betaL (float, default: 0.9 ) –

    EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.

  • damping (float, default: 1e-09 ) –

    adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.

  • grad_clip_max_norm (float) –

    clips norm of the update. Defaults to float("inf").

  • update_probability (float, default: 1.0 ) –

    probability of updating preconditioner on each step. Defaults to 1.0.

  • concat_params (bool, default: True ) –

    if True, treats all parameters as concatenated to a single vector. If False, each parameter is preconditioned separately. Defaults to True.

  • inner (Chainable | None, default: None ) –

    preconditioning will be applied to output of this module. Defaults to None.

Examples:

Pure PSGD LRA:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.LRAWhiten(),
    tz.m.LR(1e-3),
)

Momentum into preconditioner (whitens momentum):

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.EMA(0.9),
    tz.m.LRAWhiten(),
    tz.m.LR(1e-3),
)

Updating the preconditioner from gradients and applying it to momentum:

optimizer = tz.Optimizer(
    model.parameters(),
    tz.m.LRAWhiten(inner=tz.m.EMA(0.9)),
    tz.m.LR(1e-3),
)

Source code in torchzero/modules/adaptive/psgd/psgd_lra_whiten.py
class PSGDLRAWhiten(TensorTransform):
    """Low rank whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)

    Args:
        rank (int, optional):
            Preconditioner has a diagonal part and a low rank part, whose rank is decided by this setting. Defaults to 10.
        init_scale (float | None, optional):
            initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
        lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
        betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
        damping (float, optional):
            adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.
        grad_clip_max_norm (float, optional): clips norm of the update. Defaults to float("inf").
        update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
        concat_params (bool, optional):
            if True, treats all parameters as concatenated to a single vector.
            If False, each parameter is preconditioned separately. Defaults to True.
        inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.

    ###Examples:

    Pure PSGD LRA:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.LRAWhiten(),
        tz.m.LR(1e-3),
    )
    ```

    Momentum into preconditioner (whitens momentum):
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.EMA(0.9),
        tz.m.LRAWhiten(),
        tz.m.LR(1e-3),
    )
    ```

    Updating the preconditioner from gradients and applying it to momentum:
    ```py
    optimizer = tz.Optimizer(
        model.parameters(),
        tz.m.LRAWhiten(inner=tz.m.EMA(0.9)),
        tz.m.LR(1e-3),
    )
    ```

    """
    def __init__(
        self,
        rank: int = 10,
        init_scale: float | None = None,
        lr_preconditioner=0.1,
        betaL=0.9,
        damping=1e-9,
        grad_clip_max_amp=float("inf"),
        update_probability=1.0,

        concat_params: bool = True,
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults["inner"], defaults["self"]
        super().__init__(defaults, concat_params=concat_params, inner=inner)

    @torch.no_grad
    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        _initialize_lra_state_(tensor, state, setting)

    @torch.no_grad
    def single_tensor_update(self, tensor, param, grad, loss, state, setting):

        g = tensor.ravel().unsqueeze(1) # column vector

        UVd = state["UVd"]
        if UVd[2] is None: # initialize d on the fly
            UVd[2] = (torch.mean(g**4) + setting["damping"]**4)**(-1/8) * torch.ones_like(g)

        if torch.rand([]) < setting["update_probability"]:  # update preconditioner
            update_precond_lra_whiten(
                UVd=UVd,
                Luvd=state["Luvd"],
                g=g,
                lr=setting["lr_preconditioner"],
                betaL=setting["betaL"],
                damping=setting["damping"],
            )

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):

        g = tensor.ravel().unsqueeze(1)
        pre_grad = precond_grad_lra(UVd=state["UVd"], g=g)

        # norm clipping
        grad_clip_max_amp = setting["grad_clip_max_amp"]
        if grad_clip_max_amp < float("inf"): # clip preconditioned gradient
            amp = torch.sqrt(torch.mean(pre_grad * pre_grad))
            if amp > grad_clip_max_amp:
                pre_grad *= grad_clip_max_amp/amp

        return pre_grad.view_as(tensor)

Params

Bases: torchzero.core.module.Module

Outputs parameters

Source code in torchzero/modules/ops/utility.py
class Params(Module):
    """Outputs parameters"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def apply(self, objective):
        objective.updates = [p.clone() for p in objective.params]
        return objective

Pearson

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Pearson's Quasi-Newton method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class Pearson(_InverseHessianUpdateStrategyDefaults):
    """
    Pearson's Quasi-Newton method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return pearson_H_(H=H, s=s, y=y)

PerturbWeights

Bases: torchzero.core.module.Module

Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.

Can be disabled for a parameter by setting perturb=False in corresponding parameter group.

Parameters:

  • alpha (float, default: 0.1 ) –

    multiplier for perturbation magnitude. Defaults to 0.1.

  • relative (bool, default: True ) –

    whether to multiply perturbation by mean absolute value of the parameter. Defaults to True.

  • distribution (bool, default: 'normal' ) –

    distribution of the random perturbation. Defaults to False.

Source code in torchzero/modules/misc/regularization.py
class PerturbWeights(Module):
    """
    Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.

    Can be disabled for a parameter by setting ``perturb=False`` in corresponding parameter group.

    Args:
        alpha (float, optional): multiplier for perturbation magnitude. Defaults to 0.1.
        relative (bool, optional): whether to multiply perturbation by mean absolute value of the parameter. Defaults to True.
        distribution (bool, optional):
            distribution of the random perturbation. Defaults to False.
    """

    def __init__(
        self,
        alpha: float = 0.1,
        relative: bool = True,
        distribution: Distributions = "normal",
        metric: Metrics = "mad",
    ):
        defaults = dict(alpha=alpha, relative=relative, distribution=distribution, metric=metric, perturb=True)
        super().__init__(defaults)

    @torch.no_grad
    def update(self, objective):
        closure = objective.closure
        if closure is None: raise RuntimeError('WeightDropout requires closure')
        params = TensorList(objective.params)

        # create perturbations
        perts = []
        for p in params:
            settings = self.settings[p]
            if not settings['perturb']:
                perts.append(torch.zeros_like(p))
                continue

            alpha = settings['alpha']
            if settings['relative']:
                alpha *= evaluate_metric(p, settings["metric"])

            distribution = self.settings[p]['distribution'].lower()
            if distribution in ('normal', 'gaussian'):
                perts.append(torch.randn_like(p).mul_(alpha))
            elif distribution == 'uniform':
                perts.append(torch.empty_like(p).uniform_(-alpha,alpha))
            elif distribution == 'sphere':
                r = torch.randn_like(p)
                perts.append((r * alpha) / torch.linalg.vector_norm(r)) # pylint:disable=not-callable
            else:
                raise ValueError(distribution)

        @torch.no_grad
        def perturbed_closure(backward=True):
            params.add_(perts)
            if backward:
                with torch.enable_grad(): loss = closure()
            else:
                loss = closure(False)
            params.sub_(perts)
            return loss

        objective.closure = perturbed_closure

PolakRibiere

Bases: torchzero.modules.conjugate_gradient.cg.ConguateGradientBase

Polak-Ribière-Polyak nonlinear conjugate gradient method.

Note

This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.

Source code in torchzero/modules/conjugate_gradient/cg.py
class PolakRibiere(ConguateGradientBase):
    """Polak-Ribière-Polyak nonlinear conjugate gradient method.

    Note:
        This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
    """
    def __init__(self, clip_beta=True, restart_interval: int | None | Literal['auto'] = 'auto', inner: Chainable | None = None):
        super().__init__({}, clip_beta=clip_beta, restart_interval=restart_interval, inner=inner)

    def get_beta(self, p, g, prev_g, prev_d):
        return polak_ribiere_beta(g, prev_g)

PolyakStepSize

Bases: torchzero.core.transform.TensorTransform

Polyak's subgradient method with known or unknown f*.

Parameters:

  • f_star (float | Mone, default: 0 ) –

    minimal possible value of the objective function. If not known, set to None. Defaults to 0.

  • y (float, default: 1 ) –

    when f_star is set to None, it is calculated as f_best - y.

  • y_decay (float, default: 0.001 ) –

    y is multiplied by (1 - y_decay) after each step. Defaults to 1e-3.

  • max (float | None, default: None ) –

    maximum possible step size. Defaults to None.

  • use_grad (bool, default: True ) –

    if True, uses dot product of update and gradient to compute the step size. Otherwise, dot product of update with itself is used.

  • alpha (float, default: 1 ) –

    multiplier to Polyak step-size. Defaults to 1.

Source code in torchzero/modules/step_size/adaptive.py
class PolyakStepSize(TensorTransform):
    """Polyak's subgradient method with known or unknown f*.

    Args:
        f_star (float | Mone, optional):
            minimal possible value of the objective function. If not known, set to ``None``. Defaults to 0.
        y (float, optional):
            when ``f_star`` is set to None, it is calculated as ``f_best - y``.
        y_decay (float, optional):
            ``y`` is multiplied by ``(1 - y_decay)`` after each step. Defaults to 1e-3.
        max (float | None, optional): maximum possible step size. Defaults to None.
        use_grad (bool, optional):
            if True, uses dot product of update and gradient to compute the step size.
            Otherwise, dot product of update with itself is used.
        alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
    """
    def __init__(self, f_star: float | None = 0, y: float = 1, y_decay: float = 1e-3, max: float | None = None, use_grad=True, alpha: float = 1, inner: Chainable | None = None):

        defaults = dict(alpha=alpha, max=max, f_star=f_star, y=y, y_decay=y_decay)
        super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        assert grads is not None and loss is not None
        tensors = TensorList(tensors)
        grads = TensorList(grads)

        # load variables
        max, f_star, y, y_decay = itemgetter('max', 'f_star', 'y', 'y_decay')(settings[0])
        y_val = self.global_state.get('y_val', y)
        f_best = self.global_state.get('f_best', None)

        # gg
        if self._uses_grad: gg = tensors.dot(grads)
        else: gg = tensors.dot(tensors)

        # store loss
        if f_best is None or loss < f_best: f_best = tofloat(loss)
        if f_star is None: f_star = f_best - y_val

        # calculate the step size
        if gg <= torch.finfo(gg.dtype).tiny * 2: alpha = 0 # converged
        else: alpha = (loss - f_star) / gg

        # clip
        if max is not None:
            if alpha > max: alpha = max

        # store state
        self.global_state['f_best'] = f_best
        self.global_state['y_val'] = y_val * (1 - y_decay)
        self.global_state['alpha'] = alpha

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        alpha = self.global_state.get('alpha', 1)
        if not _acceptable_alpha(alpha, tensors[0]): alpha = epsilon_step_size(TensorList(tensors))

        torch._foreach_mul_(tensors, alpha * unpack_dicts(settings, 'alpha', cls=NumberList))
        return tensors

    def get_H(self, objective):
        return _get_scaled_identity_H(self, objective)

Pow

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Take tensors to the power of exponent. exponent can be a number or a module.

If exponent is a module, this calculates tensors ^ exponent(tensors)

Source code in torchzero/modules/ops/binary.py
class Pow(BinaryOperationBase):
    """Take tensors to the power of ``exponent``. ``exponent`` can be a number or a module.

    If ``exponent`` is a module, this calculates ``tensors ^ exponent(tensors)``
    """
    def __init__(self, exponent: Chainable | float):
        super().__init__({}, exponent=exponent)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
        torch._foreach_pow_(update, exponent)
        return update

PowModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates input ** exponent. input and other can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class PowModules(MultiOperationBase):
    """Calculates ``input ** exponent``. ``input`` and ``other`` can be numbers or modules."""
    def __init__(self, input: Chainable | float, exponent: Chainable | float):
        defaults = {}
        super().__init__(defaults, input=input, exponent=exponent)

    @torch.no_grad
    def transform(self, objective: Objective, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
        if isinstance(input, (int,float)):
            assert isinstance(exponent, list)
            return input ** TensorList(exponent)

        torch._foreach_div_(input, exponent)
        return input

PowellRestart

Bases: torchzero.modules.restarts.restars.RestartStrategyBase

Powell's two restarting criterions for conjugate gradient methods.

The restart clears all states of modules.

Parameters:

  • modules (Chainable | None) –

    modules to reset. If None, resets all modules.

  • cond1 (float | None, default: 0.2 ) –

    criterion that checks for nonconjugacy of the search directions. Restart is performed whenevr g^Tg_{k+1} >= cond1*||g_{k+1}||^2. The default condition value of 0.2 is suggested by Powell. Can be None to disable that criterion.

  • cond2 (float | None, default: 0.2 ) –

    criterion that checks if direction is not effectively downhill. Restart is performed if -1.2||g||^2 < d^Tg < -0.8||g||^2. Defaults to 0.2. Can be None to disable that criterion.

Reference

Powell, Michael James David. "Restart procedures for the conjugate gradient method." Mathematical programming 12.1 (1977): 241-254.

Source code in torchzero/modules/restarts/restars.py
class PowellRestart(RestartStrategyBase):
    """Powell's two restarting criterions for conjugate gradient methods.

    The restart clears all states of ``modules``.

    Args:
        modules (Chainable | None):
            modules to reset. If None, resets all modules.
        cond1 (float | None, optional):
            criterion that checks for nonconjugacy of the search directions.
            Restart is performed whenevr g^Tg_{k+1} >= cond1*||g_{k+1}||^2.
            The default condition value of 0.2 is suggested by Powell. Can be None to disable that criterion.
        cond2 (float | None, optional):
            criterion that checks if direction is not effectively downhill.
            Restart is performed if -1.2||g||^2 < d^Tg < -0.8||g||^2.
            Defaults to 0.2. Can be None to disable that criterion.

    Reference:
        Powell, Michael James David. "Restart procedures for the conjugate gradient method." Mathematical programming 12.1 (1977): 241-254.
    """
    def __init__(self, modules: Chainable | None, cond1:float | None = 0.2, cond2:float | None = 0.2):
        defaults=dict(cond1=cond1, cond2=cond2)
        super().__init__(defaults, modules)

    def should_reset(self, objective):
        g = TensorList(objective.get_grads())
        cond1 = self.defaults['cond1']; cond2 = self.defaults['cond2']

        # -------------------------------- initialize -------------------------------- #
        if 'initialized' not in self.global_state:
            self.global_state['initialized'] = 0
            g_prev = self.get_state(objective.params, 'g_prev', init=g)
            return False

        g_g = g.dot(g)

        reset = False
        # ------------------------------- 1st condition ------------------------------ #
        if cond1 is not None:
            g_prev = self.get_state(objective.params, 'g_prev', must_exist=True, cls=TensorList)
            g_g_prev = g_prev.dot(g)

            if g_g_prev.abs() >= cond1 * g_g:
                reset = True

        # ------------------------------- 2nd condition ------------------------------ #
        if (cond2 is not None) and (not reset):
            d_g = TensorList(objective.get_updates()).dot(g)
            if (-1-cond2) * g_g < d_g < (-1 + cond2) * g_g:
                reset = True

        # ------------------------------ clear on reset ------------------------------ #
        if reset:
            self.global_state.clear()
            self.clear_state_keys('g_prev')
            return True

        return False

Previous

Bases: torchzero.core.transform.TensorTransform

Maintains an update from n steps back, for example if n=1, returns previous update

Source code in torchzero/modules/misc/misc.py
class Previous(TensorTransform):
    """Maintains an update from n steps back, for example if n=1, returns previous update"""
    def __init__(self, n=1):
        defaults = dict(n=n)
        super().__init__(defaults=defaults)

        self.add_projected_keys("grad", "history")

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        n = setting['n']

        if 'history' not in state:
            state['history'] = deque(maxlen=n+1)

        state['history'].append(tensor)

        return state['history'][0]

PrintLoss

Bases: torchzero.core.module.Module

Prints var.get_loss().

Source code in torchzero/modules/misc/debug.py
class PrintLoss(Module):
    """Prints var.get_loss()."""
    def __init__(self, text = 'loss = ', print_fn = print):
        defaults = dict(text=text, print_fn=print_fn)
        super().__init__(defaults)

    def apply(self, objective):
        self.defaults["print_fn"](f'{self.defaults["text"]}{objective.get_loss(False)}')
        return objective

PrintParams

Bases: torchzero.core.module.Module

Prints current update.

Source code in torchzero/modules/misc/debug.py
class PrintParams(Module):
    """Prints current update."""
    def __init__(self, text = 'params = ', print_fn = print):
        defaults = dict(text=text, print_fn=print_fn)
        super().__init__(defaults)

    def apply(self, objective):
        self.defaults["print_fn"](f'{self.defaults["text"]}{objective.params}')
        return objective

PrintShape

Bases: torchzero.core.module.Module

Prints shapes of the update.

Source code in torchzero/modules/misc/debug.py
class PrintShape(Module):
    """Prints shapes of the update."""
    def __init__(self, text = 'shapes = ', print_fn = print):
        defaults = dict(text=text, print_fn=print_fn)
        super().__init__(defaults)

    def apply(self, objective):
        shapes = [u.shape for u in objective.updates] if objective.updates is not None else None
        self.defaults["print_fn"](f'{self.defaults["text"]}{shapes}')
        return objective

PrintUpdate

Bases: torchzero.core.module.Module

Prints current update.

Source code in torchzero/modules/misc/debug.py
class PrintUpdate(Module):
    """Prints current update."""
    def __init__(self, text = 'update = ', print_fn = print):
        defaults = dict(text=text, print_fn=print_fn)
        super().__init__(defaults)

    def apply(self, objective):
        self.defaults["print_fn"](f'{self.defaults["text"]}{objective.updates}')
        return objective

Prod

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs product of inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class Prod(ReduceOperationBase):
    """Outputs product of ``inputs`` that can be modules or numbers."""
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        prod = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_mul_(prod, v)

        return prod

ProjectedGradientMethod

Bases: torchzero.modules.quasi_newton.quasi_newton.HessianUpdateStrategy

Projected gradient method. Directly projects the gradient onto subspace conjugate to past directions.

Notes
  • This method uses N^2 memory.
  • This requires step size to be determined via a line search, so put a line search like tz.m.StrongWolfe(c2=0.1, a_init="first-order") after this.
  • This is not the same as projected gradient descent.
Reference

Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171. (algorithm 5 in section 6)

Source code in torchzero/modules/conjugate_gradient/cg.py
class ProjectedGradientMethod(HessianUpdateStrategy): # this doesn't maintain hessian
    """Projected gradient method. Directly projects the gradient onto subspace conjugate to past directions.

    Notes:
        - This method uses N^2 memory.
        - This requires step size to be determined via a line search, so put a line search like ``tz.m.StrongWolfe(c2=0.1, a_init="first-order")`` after this.
        - This is not the same as projected gradient descent.

    Reference:
        Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.  (algorithm 5 in section 6)

    """

    def __init__(
        self,
        init_scale: float | Literal["auto"] = 1,
        tol: float = 1e-32,
        ptol: float | None = 1e-32,
        ptol_restart: bool = False,
        gtol: float | None = 1e-32,
        restart_interval: int | None | Literal['auto'] = 'auto',
        beta: float | None = None,
        update_freq: int = 1,
        scale_first: bool = False,
        concat_params: bool = True,
        # inverse: bool = True,
        inner: Chainable | None = None,
    ):
        super().__init__(
            defaults=None,
            init_scale=init_scale,
            tol=tol,
            ptol=ptol,
            ptol_restart=ptol_restart,
            gtol=gtol,
            restart_interval=restart_interval,
            beta=beta,
            update_freq=update_freq,
            scale_first=scale_first,
            concat_params=concat_params,
            inverse=True,
            inner=inner,
        )



    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return projected_gradient_(H=H, y=y)

ProjectedNewtonRaphson

Bases: torchzero.modules.quasi_newton.quasi_newton.HessianUpdateStrategy

Projected Newton Raphson method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.

This one is Algorithm 7.

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class ProjectedNewtonRaphson(HessianUpdateStrategy):
    """
    Projected Newton Raphson method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.

        This one is Algorithm 7.
    """
    def __init__(
        self,
        init_scale: float | Literal["auto"] = 'auto',
        tol: float = 1e-32,
        ptol: float | None = 1e-32,
        ptol_restart: bool = False,
        gtol: float | None = 1e-32,
        restart_interval: int | None | Literal['auto'] = 'auto',
        beta: float | None = None,
        update_freq: int = 1,
        scale_first: bool = False,
        concat_params: bool = True,
        inner: Chainable | None = None,
    ):
        super().__init__(
            init_scale=init_scale,
            tol=tol,
            ptol = ptol,
            ptol_restart=ptol_restart,
            gtol=gtol,
            restart_interval=restart_interval,
            beta=beta,
            update_freq=update_freq,
            scale_first=scale_first,
            concat_params=concat_params,
            inverse=True,
            inner=inner,
        )

    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
        H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y)
        state["R"] = R
        return H

    def reset_P(self, P, s, y, inverse, init_scale, state):
        assert inverse
        if 'R' not in state: state['R'] = torch.eye(P.size(-1), device=P.device, dtype=P.dtype)
        P.copy_(state["R"])

ProjectionBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for projections. This is an abstract class, to use it, subclass it and override project and unproject.

Parameters:

  • modules (Chainable) –

    modules that will be applied in the projected domain.

  • project_update (bool, default: True ) –

    whether to project the update. Defaults to True.

  • project_params (bool, default: False ) –

    whether to project the params. This is necessary for modules that use closure. Defaults to False.

  • project_grad (bool, default: False ) –

    whether to project the gradients (separately from update). Defaults to False.

  • defaults (dict[str, Any] | None, default: None ) –

    dictionary with defaults. Defaults to None.

Methods:

  • project

    projects tensors. Note that this can be called multiple times per step with params, grads, and update.

  • unproject

    unprojects tensors. Note that this can be called multiple times per step with params, grads, and update.

Source code in torchzero/modules/projections/projection.py
class ProjectionBase(Module, ABC):
    """
    Base class for projections.
    This is an abstract class, to use it, subclass it and override ``project`` and ``unproject``.

    Args:
        modules (Chainable): modules that will be applied in the projected domain.
        project_update (bool, optional): whether to project the update. Defaults to True.
        project_params (bool, optional):
            whether to project the params. This is necessary for modules that use closure. Defaults to False.
        project_grad (bool, optional): whether to project the gradients (separately from update). Defaults to False.
        defaults (dict[str, Any] | None, optional): dictionary with defaults. Defaults to None.
    """

    def __init__(
        self,
        modules: Chainable,
        project_update=True,
        project_params=False,
        project_grad=False,
        defaults: dict[str, Any] | None = None,
    ):
        super().__init__(defaults)
        self.set_child('modules', modules)
        self.global_state['current_step'] = 0
        self._project_update = project_update
        self._project_params = project_params
        self._project_grad = project_grad
        self._projected_params = None

        self._states: dict[str, list[dict[str, Any]]] = {}
        """per-parameter states for each projection target"""

    @abstractmethod
    def project(
        self,
        tensors: list[torch.Tensor],
        params: list[torch.Tensor],
        grads: list[torch.Tensor] | None,
        loss: torch.Tensor | None,
        states: list[dict[str, Any]],
        settings: list[ChainMap[str, Any]],
        current: str,
    ) -> Iterable[torch.Tensor]:
        """projects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""

    @abstractmethod
    def unproject(
        self,
        projected_tensors: list[torch.Tensor],
        params: list[torch.Tensor],
        grads: list[torch.Tensor] | None,
        loss: torch.Tensor | None,
        states: list[dict[str, Any]],
        settings: list[ChainMap[str, Any]],
        current: str,
    ) -> Iterable[torch.Tensor]:
        """unprojects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`.

        Args:
            projected_tensors (list[torch.Tensor]): projected tensors to unproject.
            params (list[torch.Tensor]): original, unprojected parameters.
            grads (list[torch.Tensor] | None): original, unprojected gradients
            loss (torch.Tensor | None): loss at initial point.
            states (list[dict[str, Any]]): list of state dictionaries per each UNPROJECTED tensor.
            settings (list[ChainMap[str, Any]]): list of setting dictionaries per each UNPROJECTED tensor.
            current (str): string representing what is being unprojected, e.g. "params", "grads" or "update".

        Returns:
            Iterable[torch.Tensor]: unprojected tensors of the same shape as params
        """

    def update(self, objective: Objective): raise RuntimeError("projections don't support update/apply")
    def apply(self, objective: Objective): raise RuntimeError("projections don't support update/apply")

    @torch.no_grad
    def step(self, objective: Objective):
        params = objective.params
        settings = [self.settings[p] for p in params]

        def _project(tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
            states = self._states.setdefault(current, [{} for _ in params])
            return list(self.project(
                tensors=tensors,
                params=params,
                grads=objective.grads,
                loss=objective.loss,
                states=states,
                settings=settings,
                current=current,
            ))

        projected_obj = objective.clone(clone_updates=False, parent=objective)

        closure = objective.closure

        # if this is True, update and grad were projected simultaneously under current="grads"
        # so update will have to be unprojected with current="grads"
        update_is_grad = False

        # if closure is provided and project_params=True, make new closure that evaluates projected params
        # that also means projected modules can evaluate grad/update at will, it shouldn't be computed here
        # but if it has already been computed, it should be projected
        if self._project_params and closure is not None:

            if self._project_update and objective.updates is not None:
                # project update only if it already exists
                projected_obj.updates = _project(objective.updates, current='update')

            else:
                # update will be set to gradients on var.get_grad()
                # therefore projection will happen with current="grads"
                update_is_grad = True

            # project grad only if it already exists
            if self._project_grad and objective.grads is not None:
                projected_obj.grads = _project(objective.grads, current='grads')

        # otherwise update/grad needs to be calculated and projected here
        else:
            if self._project_update:
                if objective.updates is None:
                    # update is None, meaning it will be set to `grad`.
                    # we can project grad and use it for update
                    grad = objective.get_grads()
                    projected_obj.grads = _project(grad, current='grads')
                    projected_obj.updates = [g.clone() for g in projected_obj.grads]
                    del objective.updates
                    update_is_grad = True

                else:
                    # update exists so it needs to be projected
                    update = objective.get_updates()
                    projected_obj.updates = _project(update, current='update')
                    del update, objective.updates

            if self._project_grad and projected_obj.grads is None:
                # projected_vars.grad may have been projected simultaneously with update
                # but if that didn't happen, it is projected here
                grad = objective.get_grads()
                projected_obj.grads = _project(grad, current='grads')


        original_params = None
        if self._project_params:
            original_params = [p.clone() for p in objective.params]
            projected_params = _project(objective.params, current='params')

        else:
            # make fake params for correct shapes and state storage
            # they reuse update or grad storage for memory efficiency
            projected_params = projected_obj.updates if projected_obj.updates is not None else projected_obj.grads
            assert projected_params is not None

        if self._projected_params is None:
            # 1st step - create objects for projected_params. They have to remain the same python objects
            # to support per-parameter states which are stored by ids.
            self._projected_params = [p.view_as(p).requires_grad_() for p in projected_params]
        else:
            # set storage to new fake params while ID remains the same
            for empty_p, new_p in zip(self._projected_params, projected_params):
                empty_p.set_(new_p.view_as(new_p).requires_grad_()) # pyright: ignore[reportArgumentType]

        projected_params = self._projected_params
        # projected_settings = [self.settings[p] for p in projected_params]

        def _unproject(projected_tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
            states = self._states.setdefault(current, [{} for _ in params])
            return list(self.unproject(
                projected_tensors=projected_tensors,
                params=params,
                grads=objective.grads,
                loss=objective.loss,
                states=states,
                settings=settings,
                current=current,
            ))

        # project closure
        if self._project_params:
            projected_obj.closure = _make_projected_closure(closure, project_fn=_project, unproject_fn=_unproject,
                                                            params=params, projected_params=projected_params)

        elif closure is not None:
            projected_obj.closure = _FakeProjectedClosure(closure, project_fn=_project,
                                                          params=params, fake_params=projected_params)

        else:
            projected_obj.closure = None

        # ----------------------------------- step ----------------------------------- #
        projected_obj.params = projected_params
        projected_obj = self.children['modules'].step(projected_obj)

        # empty fake params storage
        # this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
        if not self._project_params:
            for p in self._projected_params:
                set_storage_(p, torch.empty(0, device=p.device, dtype=p.dtype))

        # --------------------------------- unproject -------------------------------- #
        unprojected_obj = projected_obj.clone(clone_updates=False)
        unprojected_obj.closure = objective.closure
        unprojected_obj.params = objective.params
        unprojected_obj.grads = objective.grads # this may also be set by projected_var since it has var as parent

        if self._project_update:
            assert projected_obj.updates is not None
            unprojected_obj.updates = _unproject(projected_obj.updates, current='grads' if update_is_grad else 'update')
            del projected_obj.updates

        del projected_obj

        # original params are stored if params are projected
        if original_params is not None:
            for p, o in zip(unprojected_obj.params, original_params):
                p.set_(o) # pyright: ignore[reportArgumentType]

        return unprojected_obj

project

project(tensors: list[Tensor], params: list[Tensor], grads: list[Tensor] | None, loss: Tensor | None, states: list[dict[str, Any]], settings: list[ChainMap[str, Any]], current: str) -> Iterable[Tensor]

projects tensors. Note that this can be called multiple times per step with params, grads, and update.

Source code in torchzero/modules/projections/projection.py
@abstractmethod
def project(
    self,
    tensors: list[torch.Tensor],
    params: list[torch.Tensor],
    grads: list[torch.Tensor] | None,
    loss: torch.Tensor | None,
    states: list[dict[str, Any]],
    settings: list[ChainMap[str, Any]],
    current: str,
) -> Iterable[torch.Tensor]:
    """projects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""

unproject

unproject(projected_tensors: list[Tensor], params: list[Tensor], grads: list[Tensor] | None, loss: Tensor | None, states: list[dict[str, Any]], settings: list[ChainMap[str, Any]], current: str) -> Iterable[Tensor]

unprojects tensors. Note that this can be called multiple times per step with params, grads, and update.

Parameters:

  • projected_tensors (list[Tensor]) –

    projected tensors to unproject.

  • params (list[Tensor]) –

    original, unprojected parameters.

  • grads (list[Tensor] | None) –

    original, unprojected gradients

  • loss (Tensor | None) –

    loss at initial point.

  • states (list[dict[str, Any]]) –

    list of state dictionaries per each UNPROJECTED tensor.

  • settings (list[ChainMap[str, Any]]) –

    list of setting dictionaries per each UNPROJECTED tensor.

  • current (str) –

    string representing what is being unprojected, e.g. "params", "grads" or "update".

Returns:

  • Iterable[Tensor]

    Iterable[torch.Tensor]: unprojected tensors of the same shape as params

Source code in torchzero/modules/projections/projection.py
@abstractmethod
def unproject(
    self,
    projected_tensors: list[torch.Tensor],
    params: list[torch.Tensor],
    grads: list[torch.Tensor] | None,
    loss: torch.Tensor | None,
    states: list[dict[str, Any]],
    settings: list[ChainMap[str, Any]],
    current: str,
) -> Iterable[torch.Tensor]:
    """unprojects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`.

    Args:
        projected_tensors (list[torch.Tensor]): projected tensors to unproject.
        params (list[torch.Tensor]): original, unprojected parameters.
        grads (list[torch.Tensor] | None): original, unprojected gradients
        loss (torch.Tensor | None): loss at initial point.
        states (list[dict[str, Any]]): list of state dictionaries per each UNPROJECTED tensor.
        settings (list[ChainMap[str, Any]]): list of setting dictionaries per each UNPROJECTED tensor.
        current (str): string representing what is being unprojected, e.g. "params", "grads" or "update".

    Returns:
        Iterable[torch.Tensor]: unprojected tensors of the same shape as params
    """

RCopySign

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Returns other(tensors) with sign copied from tensors.

Source code in torchzero/modules/ops/binary.py
class RCopySign(BinaryOperationBase):
    """Returns ``other(tensors)`` with sign copied from tensors."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
        return [o.copysign_(u) for u, o in zip(update, other)]

RDSA

Bases: torchzero.modules.grad_approximation.rfdm.RandomizedFDM

Gradient approximation via Random-direction stochastic approximation (RDSA) method.

Note

This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

Parameters:

  • h (float, default: 0.001 ) –

    finite difference step size of jvp_method is set to forward or central. Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of random gradient samples. Defaults to 1.

  • formula (Literal, default: 'central2' ) –

    finite difference formula. Defaults to 'central2'.

  • distribution (Literal, default: 'gaussian' ) –

    distribution. Defaults to "gaussian".

  • pre_generate (bool, default: True ) –

    whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.

  • seed (int | None | Generator, default: None ) –

    Seed for random generator. Defaults to None.

  • target (Literal, default: 'closure' ) –

    what to set on var. Defaults to "closure".

References

Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771

Source code in torchzero/modules/grad_approximation/rfdm.py
class RDSA(RandomizedFDM):
    """
    Gradient approximation via Random-direction stochastic approximation (RDSA) method.

    Note:
        This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
        and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

    Args:
        h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
        n_samples (int, optional): number of random gradient samples. Defaults to 1.
        formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
        distribution (Distributions, optional): distribution. Defaults to "gaussian".
        pre_generate (bool, optional):
            whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
        seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
        target (GradTarget, optional): what to set on var. Defaults to "closure".

    References:
        Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771

    """
    def __init__(
        self,
        h: float = 1e-3,
        n_samples: int = 1,
        formula: _FD_Formula = "central2",
        distribution: Distributions = "gaussian",
        pre_generate: bool = True,
        return_approx_loss: bool = False,
        target: GradTarget = "closure",
        seed: int | None | torch.Generator = None,
    ):
        super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed, return_approx_loss=return_approx_loss)

RDiv

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Divide other by tensors. other can be a number or a module.

If other is a module, this calculates other(tensors) / tensors

Source code in torchzero/modules/ops/binary.py
class RDiv(BinaryOperationBase):
    """Divide ``other`` by tensors. ``other`` can be a number or a module.

    If ``other`` is a module, this calculates ``other(tensors) / tensors``
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        return other / TensorList(update)

RMSprop

Bases: torchzero.core.transform.TensorTransform

Divides graient by EMA of gradient squares.

This implementation is identical to :code:torch.optim.RMSprop.

Parameters:

  • smoothing (float, default: 0.99 ) –

    beta for exponential moving average of gradient squares. Defaults to 0.99.

  • eps (float, default: 1e-08 ) –

    epsilon for division. Defaults to 1e-8.

  • centered (bool, default: False ) –

    whether to center EMA of gradient squares using an additional EMA. Defaults to False.

  • debias (bool, default: False ) –

    applies Adam debiasing. Defaults to False.

  • amsgrad (bool, default: False ) –

    Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.

  • pow (float) –

    power used in second momentum power and root. Defaults to 2.

  • init (str, default: 'zeros' ) –

    how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".

  • inner (Chainable | None, default: None ) –

    Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.

Source code in torchzero/modules/adaptive/rmsprop.py
class RMSprop(TensorTransform):
    """Divides graient by EMA of gradient squares.

    This implementation is identical to :code:`torch.optim.RMSprop`.

    Args:
        smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
        eps (float, optional): epsilon for division. Defaults to 1e-8.
        centered (bool, optional): whether to center EMA of gradient squares using an additional EMA. Defaults to False.
        debias (bool, optional): applies Adam debiasing. Defaults to False.
        amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
        pow (float, optional): power used in second momentum power and root. Defaults to 2.
        init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
        inner (Chainable | None, optional):
            Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.
    """
    def __init__(
        self,
        smoothing: float = 0.99,
        eps: float = 1e-8,
        centered: bool = False,
        debias: bool = False,
        amsgrad: bool = False,
        init: Literal["zeros", "update"] = "zeros",

        inner: Chainable | None = None,
        exp_avg_sq_tfm: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults["inner"], defaults["exp_avg_sq_tfm"]
        super().__init__(defaults, inner=inner)

        self.set_child('exp_avg_sq', exp_avg_sq_tfm)
        self.add_projected_keys("grad", "exp_avg")
        self.add_projected_keys("grad_sq", "exp_avg_sq", "exp_avg_sq_max")

    @torch.no_grad
    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        if setting["init"] == "zeros":
            state["exp_avg_sq"] = torch.zeros_like(tensor)
            if setting["centered"]: state["exp_avg"] = torch.zeros_like(tensor)
            if setting["amsgrad"]: state["amsgrad"] = torch.zeros_like(tensor)

        else:
            state["exp_avg_sq"] = tensor ** 2
            if setting["centered"]: state["exp_avg"] = tensor.clone()
            if setting["amsgrad"]: state["amsgrad"] = tensor ** 2

    @torch.no_grad
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        self.increment_counter("step", start = 0)
        fs = settings[0]

        exp_avg_sq = unpack_states(states, tensors, "exp_avg_sq", cls=TensorList)

        # update exponential average
        smoothing = NumberList(s["smoothing"] for s in settings)
        exp_avg_sq.mul_(smoothing).addcmul_(tensors, tensors, value=1-smoothing)

        # update mean estimate if centered
        if fs["centered"]:
            exp_avg = unpack_states(states, tensors, "exp_avg", cls=TensorList)
            exp_avg.lerp_(tensors, 1-smoothing)

        # amsgrad
        if fs["amsgrad"]:
            exp_avg_sq_max = unpack_states(states, tensors, "exp_avg_sq_max", cls=TensorList)
            exp_avg_sq_max.maximum_(exp_avg_sq)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        tensors = TensorList(tensors)
        step = self.global_state["step"] # 0 on 1st step
        eps = NumberList(s["eps"] for s in settings)
        fs = settings[0]

        if fs["amsgrad"]: key = "max_exp_avg_sq"
        else: key = "exp_avg_sq"
        exp_avg_sq = TensorList(s[key] for s in states)

        # load mean estimate if centered
        exp_avg = None
        if fs['centered']:
            exp_avg = TensorList(s["exp_avg"] for s in states)

        # debias exp_avg_sq and exp_avg
        if fs["debias"]:
            smoothing = NumberList(s["smoothing"] for s in settings)
            bias_correction = 1 - (smoothing ** (step + 1))
            exp_avg_sq = exp_avg_sq / bias_correction

            if fs['centered']:
                assert exp_avg is not None
                exp_avg = exp_avg / bias_correction

        # apply transform to potentially debiased exp_avg_sq
        exp_avg_sq = TensorList(self.inner_step_tensors(
            "exp_avg_sq", exp_avg_sq, params=params, grads=grads, loss=loss, clone=True, must_exist=False
        ))

        # center
        if fs["centered"]:
            assert exp_avg is not None
            exp_avg_sq = exp_avg_sq.addcmul(exp_avg, exp_avg, value=-1)

        return tensors.div_(exp_avg_sq.sqrt().add_(eps))

RPow

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Take other to the power of tensors. other can be a number or a module.

If other is a module, this calculates other(tensors) ^ tensors

Source code in torchzero/modules/ops/binary.py
class RPow(BinaryOperationBase):
    """Take ``other`` to the power of tensors. ``other`` can be a number or a module.

    If ``other`` is a module, this calculates ``other(tensors) ^ tensors``
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        if isinstance(other, (int, float)): return torch._foreach_pow(other, update) # no in-place
        torch._foreach_pow_(other, update)
        return other

RSub

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Subtract tensors from other. other can be a number or a module.

If other is a module, this calculates other(tensors) - tensors

Source code in torchzero/modules/ops/binary.py
class RSub(BinaryOperationBase):
    """Subtract tensors from ``other``. ``other`` can be a number or a module.

    If ``other`` is a module, this calculates ``other(tensors) - tensors``
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        return other - TensorList(update)

Randn

Bases: torchzero.core.module.Module

Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1.

Source code in torchzero/modules/ops/utility.py
class Randn(Module):
    """Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1."""
    def __init__(self):
        super().__init__({})

    @torch.no_grad
    def apply(self, objective):
        objective.updates = [torch.randn_like(p) for p in objective.params]
        return objective

RandomHvp

Bases: torchzero.core.module.Module

Returns a hessian-vector product with a random vector, optionally times vector

Source code in torchzero/modules/misc/misc.py
class RandomHvp(Module):
    """Returns a hessian-vector product with a random vector, optionally times vector"""

    def __init__(
        self,
        n_samples: int = 1,
        distribution: Distributions = "normal",
        update_freq: int = 1,
        zHz: bool = False,
        hvp_method: Literal["autograd", "fd_forward", "central"] = "autograd",
        h=1e-3,
        seed: int | None = None
    ):
        defaults = locals().copy()
        del defaults['self']
        super().__init__(defaults)

    @torch.no_grad
    def apply(self, objective):
        params = TensorList(objective.params)

        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        D = None
        update_freq = self.defaults['update_freq']
        if step % update_freq == 0:

            D, _ = objective.hutchinson_hessian(
                rgrad = None,
                at_x0 = True,
                n_samples = self.defaults['n_samples'],
                distribution = self.defaults['distribution'],
                hvp_method = self.defaults['hvp_method'],
                h = self.defaults['h'],
                zHz = self.defaults["zHz"],
                generator = self.get_generator(params[0].device, self.defaults["seed"]),
            )

            if update_freq != 1:
                assert D is not None
                D_buf = self.get_state(params, "D", cls=TensorList)
                D_buf.set_(D)

        if D is None:
            D = self.get_state(params, "D", cls=TensorList)

        objective.updates = list(D)
        return objective

RandomReinitialize

Bases: torchzero.core.module.Module

On each step with probability p_reinit trigger reinitialization, whereby p_weights weights are reset to their initial values.

This modifies the parameters directly. Place it as the first module.

Parameters:

  • p_reinit (float, default: 0.01 ) –

    probability to trigger reinitialization on each step. Defaults to 0.01.

  • p_weights (float, default: 0.1 ) –

    probability for each weight to be set to initial value when reinitialization is triggered. Defaults to 0.1.

  • store_every (int | None, default: None ) –

    if set, stores new initial values every this many steps. Defaults to None.

  • beta (float, default: 0 ) –

    whenever store_every is triggered, uses linear interpolation with this beta. If store_every=1, this can be set to some value close to 1 such as 0.999 to reinitialize to slow parameter EMA. Defaults to 0.

  • reset (bool, default: False ) –

    whether to reset states of other modules on reinitialization. Defaults to False.

  • seed (int | None, default: None ) –

    random seed.

Source code in torchzero/modules/weight_decay/reinit.py
class RandomReinitialize(Module):
    """On each step with probability ``p_reinit`` trigger reinitialization,
    whereby ``p_weights`` weights are reset to their initial values.

    This modifies the parameters directly. Place it as the first module.

    Args:
        p_reinit (float, optional): probability to trigger reinitialization on each step. Defaults to 0.01.
        p_weights (float, optional): probability for each weight to be set to initial value when reinitialization is triggered. Defaults to 0.1.
        store_every (int | None, optional): if set, stores new initial values every this many steps. Defaults to None.
        beta (float, optional):
            whenever ``store_every`` is triggered, uses linear interpolation with this beta.
            If ``store_every=1``, this can be set to some value close to 1 such as 0.999
            to reinitialize to slow parameter EMA. Defaults to 0.
        reset (bool, optional): whether to reset states of other modules on reinitialization. Defaults to False.
        seed (int | None, optional): random seed.
    """

    def __init__(
        self,
        p_reinit: float = 0.01,
        p_weights: float = 0.1,
        store_every: int | None = None,
        beta: float = 0,
        reset: bool = False,
        seed: int | None = None,
    ):
        defaults = dict(p_weights=p_weights, p_reinit=p_reinit, store_every=store_every, beta=beta, reset=reset, seed=seed)
        super().__init__(defaults)

    def update(self, objective):
        # this stores initial values to per-parameter states
        p_init = self.get_state(objective.params, "p_init", init="params", cls=TensorList)

        # store new params every store_every steps
        step = self.global_state.get("step", 0)
        self.global_state["step"] = step + 1

        store_every = self.defaults["store_every"]
        if (store_every is not None and step % store_every == 0):
            beta = self.get_settings(objective.params, "beta", cls=NumberList)
            p_init.lerp_(objective.params, weight=(1 - beta))

    @torch.no_grad
    def apply(self, objective):
        p_reinit = self.defaults["p_reinit"]
        device = objective.params[0].device
        generator = self.get_generator(device, self.defaults["seed"])

        # determine whether to trigger reinitialization
        reinitialize = torch.rand(1, generator=generator, device=device) < p_reinit

        # reinitialize
        if reinitialize:
            params = TensorList(objective.params)
            p_init = self.get_state(params, "p_init", init=params)


            # mask with p_weights entries being True
            p_weights = self.get_settings(params, "p_weights")
            mask = params.bernoulli_like(p_weights, generator=generator).as_bool()

            # set weights at mask to their initialization
            params.masked_set_(mask, p_init)

            # reset
            if self.defaults["reset"]:
                objective.post_step_hooks.append(partial(_reset_except_self, self=self))

        return objective

RandomSample

Bases: torchzero.core.module.Module

Outputs tensors filled with random numbers from distribution depending on value of distribution.

Source code in torchzero/modules/ops/utility.py
class RandomSample(Module):
    """Outputs tensors filled with random numbers from distribution depending on value of ``distribution``."""
    def __init__(self, distribution: Distributions = 'normal', variance:float | None = None):
        defaults = dict(distribution=distribution, variance=variance)
        super().__init__(defaults)

    @torch.no_grad
    def apply(self, objective):
        distribution = self.defaults['distribution']
        variance = self.get_settings(objective.params, 'variance')
        objective.updates = TensorList(objective.params).sample_like(distribution=distribution, variance=variance)
        return objective

RandomStepSize

Bases: torchzero.core.transform.TensorTransform

Uses random global or layer-wise step size from low to high.

Parameters:

  • low (float, default: 0 ) –

    minimum learning rate. Defaults to 0.

  • high (float, default: 1 ) –

    maximum learning rate. Defaults to 1.

  • parameterwise (bool, default: False ) –

    if True, generate random step size for each parameter separately, if False generate one global random step size. Defaults to False.

Source code in torchzero/modules/step_size/lr.py
class RandomStepSize(TensorTransform):
    """Uses random global or layer-wise step size from ``low`` to ``high``.

    Args:
        low (float, optional): minimum learning rate. Defaults to 0.
        high (float, optional): maximum learning rate. Defaults to 1.
        parameterwise (bool, optional):
            if True, generate random step size for each parameter separately,
            if False generate one global random step size. Defaults to False.
    """
    def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
        defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        s = settings[0]
        parameterwise = s['parameterwise']

        seed = s['seed']
        if 'generator' not in self.global_state:
            self.global_state['generator'] = random.Random(seed)
        generator: random.Random = self.global_state['generator']

        if parameterwise:
            low, high = unpack_dicts(settings, 'low', 'high')
            lr = [generator.uniform(l, h) for l, h in zip(low, high)]
        else:
            low = s['low']
            high = s['high']
            lr = generator.uniform(low, high)

        torch._foreach_mul_(tensors, lr)
        return tensors

RandomizedFDM

Bases: torchzero.modules.grad_approximation.grad_approximator.GradApproximator

Gradient approximation via a randomized finite-difference method.

Note

This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

Parameters:

  • h (float, default: 0.001 ) –

    finite difference step size of jvp_method is set to forward or central. Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of random gradient samples. Defaults to 1.

  • formula (Literal, default: 'central' ) –

    finite difference formula. Defaults to 'central2'.

  • distribution (Literal, default: 'rademacher' ) –

    distribution. Defaults to "rademacher". If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.

  • pre_generate (bool, default: True ) –

    whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.

  • seed (int | None | Generator, default: None ) –

    Seed for random generator. Defaults to None.

  • target (Literal, default: 'closure' ) –

    what to set on var. Defaults to "closure".

Examples:

Simultaneous perturbation stochastic approximation (SPSA) method

SPSA is randomized FDM with rademacher distribution and central formula.

spsa = tz.Optimizer(
    model.parameters(),
    tz.m.RandomizedFDM(formula="fd_central", distribution="rademacher"),
    tz.m.LR(1e-2)
)

Random-direction stochastic approximation (RDSA) method

RDSA is randomized FDM with usually gaussian distribution and central formula.

rdsa = tz.Optimizer(
    model.parameters(),
    tz.m.RandomizedFDM(formula="fd_central", distribution="gaussian"),
    tz.m.LR(1e-2)
)

Gaussian smoothing method

GS uses many gaussian samples with possibly a larger finite difference step size.

gs = tz.Optimizer(
    model.parameters(),
    tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
    tz.m.NewtonCG(hvp_method="forward"),
    tz.m.Backtracking()
)

RandomizedFDM with momentum

Momentum might help by reducing the variance of the estimated gradients.

momentum_spsa = tz.Optimizer(
    model.parameters(),
    tz.m.RandomizedFDM(),
    tz.m.HeavyBall(0.9),
    tz.m.LR(1e-3)
)

Source code in torchzero/modules/grad_approximation/rfdm.py
class RandomizedFDM(GradApproximator):
    """Gradient approximation via a randomized finite-difference method.

    Note:
        This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
        and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

    Args:
        h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
        n_samples (int, optional): number of random gradient samples. Defaults to 1.
        formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
        distribution (Distributions, optional): distribution. Defaults to "rademacher".
            If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
        pre_generate (bool, optional):
            whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
        seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
        target (GradTarget, optional): what to set on var. Defaults to "closure".

    Examples:
    #### Simultaneous perturbation stochastic approximation (SPSA) method

    SPSA is randomized FDM with rademacher distribution and central formula.
    ```py
    spsa = tz.Optimizer(
        model.parameters(),
        tz.m.RandomizedFDM(formula="fd_central", distribution="rademacher"),
        tz.m.LR(1e-2)
    )
    ```

    #### Random-direction stochastic approximation (RDSA) method

    RDSA is randomized FDM with usually gaussian distribution and central formula.
    ```
    rdsa = tz.Optimizer(
        model.parameters(),
        tz.m.RandomizedFDM(formula="fd_central", distribution="gaussian"),
        tz.m.LR(1e-2)
    )
    ```

    #### Gaussian smoothing method

    GS uses many gaussian samples with possibly a larger finite difference step size.
    ```
    gs = tz.Optimizer(
        model.parameters(),
        tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
        tz.m.NewtonCG(hvp_method="forward"),
        tz.m.Backtracking()
    )
    ```

    #### RandomizedFDM with momentum

    Momentum might help by reducing the variance of the estimated gradients.
    ```
    momentum_spsa = tz.Optimizer(
        model.parameters(),
        tz.m.RandomizedFDM(),
        tz.m.HeavyBall(0.9),
        tz.m.LR(1e-3)
    )
    ```
    """
    PRE_MULTIPLY_BY_H = True
    def __init__(
        self,
        h: float = 1e-3,
        n_samples: int = 1,
        formula: _FD_Formula = "central",
        distribution: Distributions = "rademacher",
        pre_generate: bool = True,
        return_approx_loss: bool = False,
        seed: int | None | torch.Generator = None,
        target: GradTarget = "closure",
    ):
        defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, pre_generate=pre_generate, seed=seed)
        super().__init__(defaults, return_approx_loss=return_approx_loss, target=target)


    def pre_step(self, objective):
        h = self.get_settings(objective.params, 'h')
        pre_generate = self.defaults['pre_generate']

        if pre_generate:
            n_samples = self.defaults['n_samples']
            distribution = self.defaults['distribution']

            params = TensorList(objective.params)
            generator = self.get_generator(params[0].device, self.defaults['seed'])
            perturbations = [params.sample_like(distribution=distribution, variance=1, generator=generator) for _ in range(n_samples)]

            # this is false for ForwardGradient where h isn't used and it subclasses this
            if self.PRE_MULTIPLY_BY_H:
                torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])

            for param, prt in zip(params, zip(*perturbations)):
                self.state[param]['perturbations'] = prt

    @torch.no_grad
    def approximate(self, closure, params, loss):
        params = TensorList(params)
        loss_approx = None

        h = NumberList(self.settings[p]['h'] for p in params)
        n_samples = self.defaults['n_samples']
        distribution = self.defaults['distribution']
        fd_fn = _RFD_FUNCS[self.defaults['formula']]

        default = [None]*n_samples
        perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
        generator = self.get_generator(params[0].device, self.defaults['seed'])

        grad = None
        for i in range(n_samples):
            prt = perturbations[i]

            if prt[0] is None:
                prt = params.sample_like(distribution=distribution, generator=generator, variance=1).mul_(h)

            else: prt = TensorList(prt)

            loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h, f_0=loss)
            # here `d` is a numberlist of directional derivatives, due to per parameter `h` values.

            # support for per-sample values which gives better estimate
            if d[0].numel() > 1: d = d.map(torch.mean)

            if grad is None: grad = prt * d
            else: grad += prt * d

        assert grad is not None
        if n_samples > 1: grad.div_(n_samples)

        # mean if got per-sample values
        if loss is not None:
            if loss.numel() > 1:
                loss = loss.mean()

        if loss_approx is not None:
            if loss_approx.numel() > 1:
                loss_approx = loss_approx.mean()

        return grad, loss, loss_approx

PRE_MULTIPLY_BY_H class-attribute

PRE_MULTIPLY_BY_H = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

Reciprocal

Bases: torchzero.core.transform.TensorTransform

Returns 1 / input

Source code in torchzero/modules/ops/unary.py
class Reciprocal(TensorTransform):
    """Returns ``1 / input``"""
    def __init__(self, eps = 0):
        defaults = dict(eps = eps)
        super().__init__(defaults)
    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        eps = [s['eps'] for s in settings]
        if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
        torch._foreach_reciprocal_(tensors)
        return tensors

ReduceOperationBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override transform method to use it.

Methods:

  • transform

    applies the operation to operands

Source code in torchzero/modules/ops/reduce.py
class ReduceOperationBase(Module, ABC):
    """Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override `transform` method to use it."""
    def __init__(self, defaults: dict[str, Any] | None, *operands: Chainable | Any):
        super().__init__(defaults=defaults)

        self.operands = []
        for i, v in enumerate(operands):

            if isinstance(v, (Module, Sequence)):
                self.set_child(f'operand_{i}', v)
                self.operands.append(self.children[f'operand_{i}'])
            else:
                self.operands.append(v)

        if not self.children:
            raise ValueError('At least one operand must be a module')

    @abstractmethod
    def transform(self, objective: Objective, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
        """applies the operation to operands"""
        raise NotImplementedError

    def update(self, objective): raise RuntimeError
    def apply(self, objective): raise RuntimeError

    @torch.no_grad
    def step(self, objective: Objective) -> Objective:
        # pass cloned update to all module operands
        processed_operands: list[Any | list[torch.Tensor]] = self.operands.copy()

        for i, v in enumerate(self.operands):
            if f'operand_{i}' in self.children:
                v: Module
                updated_obj = v.step(objective.clone(clone_updates=True))
                processed_operands[i] = updated_obj.get_updates()
                objective.update_attrs_from_clone_(updated_obj) # update loss, grad, etc if this module calculated them

        transformed = self.transform(objective, *processed_operands)
        objective.updates = transformed
        return objective

transform

transform(objective: Objective, *operands: Any | list[Tensor]) -> list[Tensor]

applies the operation to operands

Source code in torchzero/modules/ops/reduce.py
@abstractmethod
def transform(self, objective: Objective, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
    """applies the operation to operands"""
    raise NotImplementedError

Relative

Bases: torchzero.core.transform.TensorTransform

Multiplies update by absolute parameter values to make it relative to their magnitude, min_value is minimum allowed value to avoid getting stuck at 0.

Source code in torchzero/modules/misc/misc.py
class Relative(TensorTransform):
    """Multiplies update by absolute parameter values to make it relative to their magnitude, ``min_value`` is minimum allowed value to avoid getting stuck at 0."""
    def __init__(self, min_value:float = 1e-4):
        defaults = dict(min_value=min_value)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
        torch._foreach_mul_(tensors, mul)
        return tensors

RelativeWeightDecay

Bases: torchzero.core.transform.TensorTransform

Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of norm_input argument.

Parameters:

  • weight_decay (float, default: 0.1 ) –

    relative weight decay scale.

  • ord (int, default: 2 ) –

    order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.

  • norm_input (str, default: 'update' ) –

    determines what should weight decay be relative to. "update", "grad" or "params". Defaults to "update".

  • metric (Ords, default: 'mad' ) –

    metric (norm, etc) that weight decay should be relative to. defaults to 'mad' (mean absolute deviation).

  • target (Target) –

    what to set on var. Defaults to 'update'.

Examples:

Adam with non-decoupled relative weight decay

opt = tz.Optimizer(
    model.parameters(),
    tz.m.RelativeWeightDecay(1e-1),
    tz.m.Adam(),
    tz.m.LR(1e-3)
)

Adam with decoupled relative weight decay

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adam(),
    tz.m.RelativeWeightDecay(1e-1),
    tz.m.LR(1e-3)
)

Source code in torchzero/modules/weight_decay/weight_decay.py
class RelativeWeightDecay(TensorTransform):
    """Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of ``norm_input`` argument.

    Args:
        weight_decay (float): relative weight decay scale.
        ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
        norm_input (str, optional):
            determines what should weight decay be relative to. "update", "grad" or "params".
            Defaults to "update".
        metric (Ords, optional):
            metric (norm, etc) that weight decay should be relative to.
            defaults to 'mad' (mean absolute deviation).
        target (Target, optional): what to set on var. Defaults to 'update'.

    ### Examples:

    Adam with non-decoupled relative weight decay
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.RelativeWeightDecay(1e-1),
        tz.m.Adam(),
        tz.m.LR(1e-3)
    )
    ```

    Adam with decoupled relative weight decay
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Adam(),
        tz.m.RelativeWeightDecay(1e-1),
        tz.m.LR(1e-3)
    )
    ```
    """
    def __init__(
        self,
        weight_decay: float = 0.1,
        ord: int  = 2,
        norm_input: Literal["update", "grad", "params"] = "update",
        metric: Metrics = 'mad',
        cautious: bool = False,
    ):
        defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input, metric=metric, cautious=cautious)
        super().__init__(defaults, uses_grad=norm_input == 'grad')

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        weight_decay = NumberList(s['weight_decay'] for s in settings)

        fs = settings[0]
        norm_input = fs['norm_input']

        if norm_input == 'update': src = TensorList(tensors)
        elif norm_input == 'grad':
            assert grads is not None
            src = TensorList(grads)
        elif norm_input == 'params':
            src = TensorList(params)
        else:
            raise ValueError(norm_input)

        norm = src.global_metric(fs['metric'])

        if fs["cautious"]:
            wd_ = cautious_weight_decay_
        else:
            wd_ = weight_decay_
        return wd_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * norm, fs["ord"])

RestartEvery

Bases: torchzero.modules.restarts.restars.RestartStrategyBase

Resets the state every n steps

Parameters:

  • modules (Chainable | None) –

    modules to reset. If None, resets all modules.

  • steps (int | Literal['ndim']) –

    number of steps between resets. "ndim" to use number of parameters.

Source code in torchzero/modules/restarts/restars.py
class RestartEvery(RestartStrategyBase):
    """Resets the state every n steps

    Args:
        modules (Chainable | None):
            modules to reset. If None, resets all modules.
        steps (int | Literal["ndim"]):
            number of steps between resets. "ndim" to use number of parameters.
    """
    def __init__(self, modules: Chainable | None, steps: int | Literal['ndim']):
        defaults = dict(steps=steps)
        super().__init__(defaults, modules)

    def should_reset(self, objective):
        step = self.global_state.get('step', 0) + 1
        self.global_state['step'] = step

        n = self.defaults['steps']
        if isinstance(n, str): n = sum(p.numel() for p in objective.params if p.requires_grad)

        # reset every n steps
        if step % n == 0:
            self.global_state.clear()
            return True

        return False

RestartOnStuck

Bases: torchzero.modules.restarts.restars.RestartStrategyBase

Resets the state when update (difference in parameters) is zero for multiple steps in a row.

Parameters:

  • modules (Chainable | None) –

    modules to reset. If None, resets all modules.

  • tol (float, default: None ) –

    step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to None (uses twice the smallest respresentable number)

  • n_tol (int, default: 10 ) –

    number of failed consequtive steps required to trigger a reset. Defaults to 10.

Source code in torchzero/modules/restarts/restars.py
class RestartOnStuck(RestartStrategyBase):
    """Resets the state when update (difference in parameters) is zero for multiple steps in a row.

    Args:
        modules (Chainable | None):
            modules to reset. If None, resets all modules.
        tol (float, optional):
            step is considered failed when maximum absolute parameter difference is smaller than this. Defaults to None (uses twice the smallest respresentable number)
        n_tol (int, optional):
            number of failed consequtive steps required to trigger a reset. Defaults to 10.

    """
    def __init__(self, modules: Chainable | None, tol: float | None = None, n_tol: int = 10):
        defaults = dict(tol=tol, n_tol=n_tol)
        super().__init__(defaults, modules)

    @torch.no_grad
    def should_reset(self, objective):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        params = TensorList(objective.params)
        tol = self.defaults['tol']
        if tol is None: tol = torch.finfo(params[0].dtype).tiny * 2
        n_tol = self.defaults['n_tol']
        n_bad = self.global_state.get('n_bad', 0)

        # calculate difference in parameters
        prev_params = self.get_state(params, 'prev_params', cls=TensorList)
        update = params - prev_params
        prev_params.copy_(params)

        # if update is too small, it is considered bad, otherwise n_bad is reset to 0
        if step > 0:
            if update.abs().global_max() <= tol:
                n_bad += 1

            else:
                n_bad = 0

        self.global_state['n_bad'] = n_bad

        # no progress, reset
        if n_bad >= n_tol:
            self.global_state.clear()
            return True

        return False

RestartStrategyBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for restart strategies.

On each update/step this checks reset condition and if it is satisfied, resets the modules before updating or stepping.

Methods:

Source code in torchzero/modules/restarts/restars.py
class RestartStrategyBase(Module, ABC):
    """Base class for restart strategies.

    On each ``update``/``step`` this checks reset condition and if it is satisfied,
    resets the modules before updating or stepping.
    """
    def __init__(self, defaults: dict | None = None, modules: Chainable | None = None):
        if defaults is None: defaults = {}
        super().__init__(defaults)
        if modules is not None:
            self.set_child('modules', modules)

    @abstractmethod
    def should_reset(self, objective: Objective) -> bool:
        """returns whether reset should occur"""

    def _reset_on_condition(self, objective: Objective):
        modules = self.children.get('modules', None)

        if self.should_reset(objective):
            if modules is None:
                objective.post_step_hooks.append(partial(_reset_except_self, self=self))
            else:
                modules.reset()

        return modules

    @final
    def update(self, objective):
        modules = self._reset_on_condition(objective)
        if modules is not None:
            modules.update(objective)

    @final
    def apply(self, objective):
        # don't check here because it was check in `update`
        modules = self.children.get('modules', None)
        if modules is None: return objective
        return modules.apply(objective.clone(clone_updates=False))

    @final
    def step(self, objective):
        modules = self._reset_on_condition(objective)
        if modules is None: return objective
        return modules.step(objective.clone(clone_updates=False))

should_reset

should_reset(objective: Objective) -> bool

returns whether reset should occur

Source code in torchzero/modules/restarts/restars.py
@abstractmethod
def should_reset(self, objective: Objective) -> bool:
    """returns whether reset should occur"""

Rprop

Bases: torchzero.core.transform.TensorTransform

Resilient propagation. The update magnitude gets multiplied by nplus if gradient didn't change the sign, or nminus if it did. Then the update is applied with the sign of the current gradient.

Additionally, if gradient changes sign, the update for that weight is reverted. Next step, magnitude for that weight won't change.

Compared to pytorch this also implements backtracking update when sign changes.

This implementation is identical to torch.optim.Rprop if backtrack is set to False.

Parameters:

  • nplus (float, default: 1.2 ) –

    multiplicative increase factor for when ascent didn't change sign (default: 1.2).

  • nminus (float, default: 0.5 ) –

    multiplicative decrease factor for when ascent changed sign (default: 0.5).

  • lb (float, default: 1e-06 ) –

    minimum step size, can be None (default: 1e-6)

  • ub (float, default: 50 ) –

    maximum step size, can be None (default: 50)

  • backtrack (float, default: True ) –

    if True, when ascent sign changes, undoes last weight update, otherwise sets update to 0. When this is False, this exactly matches pytorch Rprop. (default: True)

  • alpha (float, default: 1 ) –

    initial per-parameter learning rate (default: 1).

reference Riedmiller, M., & Braun, H. (1993, March). A direct adaptive method for faster backpropagation learning: The RPROP algorithm. In IEEE international conference on neural networks (pp. 586-591). IEEE.

Source code in torchzero/modules/adaptive/rprop.py
class Rprop(TensorTransform):
    """
    Resilient propagation. The update magnitude gets multiplied by ``nplus`` if gradient didn't change the sign,
    or ``nminus`` if it did. Then the update is applied with the sign of the current gradient.

    Additionally, if gradient changes sign, the update for that weight is reverted.
    Next step, magnitude for that weight won't change.

    Compared to pytorch this also implements backtracking update when sign changes.

    This implementation is identical to ``torch.optim.Rprop`` if ``backtrack`` is set to False.

    Args:
        nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
        nminus (float): multiplicative decrease factor for when ascent changed sign (default: 0.5).
        lb (float): minimum step size, can be None (default: 1e-6)
        ub (float): maximum step size, can be None (default: 50)
        backtrack (float):
            if True, when ascent sign changes, undoes last weight update, otherwise sets update to 0.
            When this is False, this exactly matches pytorch Rprop. (default: True)
        alpha (float): initial per-parameter learning rate (default: 1).

    reference
        *Riedmiller, M., & Braun, H. (1993, March). A direct adaptive method for faster backpropagation learning:
        The RPROP algorithm. In IEEE international conference on neural networks (pp. 586-591). IEEE.*
    """
    def __init__(
        self,
        nplus: float = 1.2,
        nminus: float = 0.5,
        lb: float = 1e-6,
        ub: float = 50,
        backtrack=True,
        alpha: float = 1,
    ):
        defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub, backtrack=backtrack)
        super().__init__(defaults, uses_grad=False)

        self.add_projected_keys("grad", "prev")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        nplus, nminus, lb, ub, alpha = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', 'alpha', cls=NumberList)
        prev, allowed, magnitudes = unpack_states(
            states, tensors,
            'prev','allowed','magnitudes',
            init=[torch.zeros_like, _bool_ones_like, torch.zeros_like],
            cls = TensorList,
        )

        tensors = rprop_(
            tensors_ = TensorList(tensors),
            prev_ = prev,
            allowed_ = allowed,
            magnitudes_ = magnitudes,
            nplus = nplus,
            nminus = nminus,
            lb = lb,
            ub = ub,
            alpha = alpha,
            backtrack=settings[0]['backtrack'],
            step=step,
        )

        return tensors

SAM

Bases: torchzero.core.transform.Transform

Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412

SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value. It performs two forward and backward passes per step.

This implementation modifies the closure to return loss and calculate gradients of the SAM objective. All modules after this will use the modified objective.

.. note:: This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients at two points on each step.

Parameters:

  • rho (float, default: 0.05 ) –

    Neighborhood size. Defaults to 0.05.

  • p (float, default: 2 ) –

    norm of the SAM objective. Defaults to 2.

  • asam (bool, default: False ) –

    enables ASAM variant which makes perturbation relative to weight magnitudes. ASAM requires a much larger rho, like 0.5 or 1. The tz.m.ASAM class is idential to setting this argument to True, but it has larger rho by default.

Examples:

SAM-SGD:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SAM(),
    tz.m.LR(1e-2)
)

SAM-Adam:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SAM(),
    tz.m.Adam(),
    tz.m.LR(1e-2)
)
References

Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412.

Source code in torchzero/modules/adaptive/sam.py
class SAM(Transform):
    """Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412

    SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
    It performs two forward and backward passes per step.

    This implementation modifies the closure to return loss and calculate gradients
    of the SAM objective. All modules after this will use the modified objective.

    .. note::
        This module requires a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients at two points on each step.

    Args:
        rho (float, optional): Neighborhood size. Defaults to 0.05.
        p (float, optional): norm of the SAM objective. Defaults to 2.
        asam (bool, optional):
            enables ASAM variant which makes perturbation relative to weight magnitudes.
            ASAM requires a much larger ``rho``, like 0.5 or 1.
            The ``tz.m.ASAM`` class is idential to setting this argument to True, but
            it has larger ``rho`` by default.

    ### Examples:

    SAM-SGD:

    ```py
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SAM(),
        tz.m.LR(1e-2)
    )
    ```

    SAM-Adam:

    ```
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SAM(),
        tz.m.Adam(),
        tz.m.LR(1e-2)
    )
    ```

    References:
        [Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412.](https://arxiv.org/abs/2010.01412#page=3.16)
    """
    def __init__(self, rho: float = 0.05, p: float = 2, eps=1e-10, asam=False):
        defaults = dict(rho=rho, p=p, eps=eps, asam=asam)
        super().__init__(defaults)

    @torch.no_grad
    def update_states(self, objective, states, settings):

        params = objective.params
        closure = objective.closure
        zero_grad = objective.zero_grad
        if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
        p, rho = unpack_dicts(settings, 'p', 'rho', cls=NumberList)
        fs = settings[0]
        eps = fs['eps']
        asam = fs['asam']

        # 1/p + 1/q = 1
        # okay, authors of SAM paper, I will manually solve your equation
        # so q = -p/(1-p)
        q = -p / (1-p)
        # as a validation for 2 it is -2 / -1 = 2

        @torch.no_grad
        def sam_closure(backward=True):
            orig_grads = None
            if not backward:
                # if backward is False, make sure this doesn't modify gradients
                # to avoid issues
                orig_grads = [p.grad for p in params]

            # gradient at initial parameters
            zero_grad()
            with torch.enable_grad():
                closure()

            grad = TensorList(p.grad if p.grad is not None else torch.zeros_like(p) for p in params)
            grad_abs = grad.abs()

            # compute e
            term1 = grad.sign().mul_(rho)
            term2 = grad_abs.pow(q-1)

            if asam:
                grad_abs.mul_(torch._foreach_abs(params))

            denom = grad_abs.pow_(q).sum().pow(1/p)

            e = term1.mul_(term2).div_(denom.clip(min=eps))

            if asam:
                e.mul_(torch._foreach_pow(params, 2))

            # calculate loss and gradient approximation of inner problem
            torch._foreach_add_(params, e)
            if backward:
                zero_grad()
                with torch.enable_grad():
                    # this sets .grad attributes
                    sam_loss = closure()

            else:
                sam_loss = closure(False)

            # and restore initial parameters
            torch._foreach_sub_(params, e)

            if orig_grads is not None:
                for param,orig_grad in zip(params, orig_grads):
                    param.grad = orig_grad

            return sam_loss

        objective.closure = sam_closure

SG2

Bases: torchzero.core.transform.Transform

second-order stochastic gradient

2SPSA (second-order SPSA)

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SPSA(),
    tz.m.SG2(),
    tz.m.LR(1e-2),
)

SG2 with line search

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SG2(),
    tz.m.Backtracking()
)

SG2 with trust region

opt = tz.Optimizer(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.SG2(beta=0.75. n_samples=4)),
)

Source code in torchzero/modules/quasi_newton/sg2.py
class SG2(Transform):
    """second-order stochastic gradient

    2SPSA (second-order SPSA)
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SPSA(),
        tz.m.SG2(),
        tz.m.LR(1e-2),
    )
    ```

    SG2 with line search
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SG2(),
        tz.m.Backtracking()
    )
    ```

    SG2 with trust region
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.LevenbergMarquardt(tz.m.SG2(beta=0.75. n_samples=4)),
    )
    ```

    """

    def __init__(
        self,
        n_samples: int = 1,
        n_first_step_samples: int = 10,
        start_step: int = 10,
        beta: float | None = None,
        damping: float = 1e-4,
        h: float = 1e-2,
        seed=None,
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict(n_samples=n_samples, h=h, beta=beta, damping=damping, seed=seed, start_step=start_step, n_first_step_samples=n_first_step_samples)
        super().__init__(defaults, update_freq=update_freq, inner=inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]
        k = self.increment_counter("step", 0)

        params = TensorList(objective.params)
        closure = objective.closure
        if closure is None:
            raise RuntimeError("closure is required for SG2")
        generator = self.get_generator(params[0].device, self.defaults["seed"])

        h = unpack_dicts(settings, "h")
        x_0 = params.clone()
        n_samples = fs["n_samples"]
        if k == 0: n_samples = fs["n_first_step_samples"]
        H_hat = None

        # compute new approximation
        for i in range(n_samples):
            # generate perturbation
            cd = params.rademacher_like(generator=generator).mul_(h)

            # two sided hessian approximation
            params.add_(cd)
            closure()
            g_p = params.grad.fill_none_(params)

            params.copy_(x_0)
            params.sub_(cd)
            closure()
            g_n = params.grad.fill_none_(params)

            delta_g = g_p - g_n

            # restore params
            params.set_(x_0)

            # compute H hat
            H_i = sg2_(
                delta_g = delta_g.to_vec(),
                cd = cd.to_vec(),
            )

            if H_hat is None: H_hat = H_i
            else: H_hat += H_i

        assert H_hat is not None
        if n_samples > 1: H_hat /= n_samples

        # add damping
        if fs["damping"] != 0:
            reg = torch.eye(H_hat.size(0), device=H_hat.device, dtype=H_hat.dtype).mul_(fs["damping"])
            H_hat += reg

        # update H
        H = self.global_state.get("H", None)
        if H is None: H = H_hat
        else:
            beta = fs["beta"]
            if beta is None: beta = (k+1) / (k+2)
            H.lerp_(H_hat, 1-beta)

        self.global_state["H"] = H


    @torch.no_grad
    def apply_states(self, objective, states, settings):
        fs = settings[0]
        updates = objective.get_updates()

        H: torch.Tensor = self.global_state["H"]
        k = self.global_state["step"]
        if k < fs["start_step"]:
            # don't precondition yet
            # I guess we can try using trace to scale the update
            # because it will have horrible scaling otherwise
            torch._foreach_div_(updates, H.trace())
            return objective

        b = torch.cat([t.ravel() for t in updates])
        sol = torch.linalg.lstsq(H, b).solution # pylint:disable=not-callable

        vec_to_tensors_(sol, updates)
        return objective

    def get_H(self, objective=...):
        return Dense(self.global_state["H"])

SOAP

Bases: torchzero.core.transform.TensorTransform

SOAP (ShampoO with Adam in the Preconditioner's eigenbasis from https://arxiv.org/abs/2409.11321).

Parameters:

  • beta1 (float, default: 0.95 ) –

    beta for first momentum. Defaults to 0.95.

  • beta2 (float, default: 0.95 ) –

    beta for second momentum. Defaults to 0.95.

  • shampoo_beta (float | None, default: 0.95 ) –

    beta for covariance matrices accumulators. Can be None, then it just sums them like Adagrad (which works worse). Defaults to 0.95.

  • precond_freq (int, default: 10 ) –

    How often to update the preconditioner. Defaults to 10.

  • merge_small (bool, default: True ) –

    Whether to merge small dims. Defaults to True.

  • max_dim (int, default: 4096 ) –

    Won't precondition dims larger than this. Defaults to 10_000.

  • precondition_1d (bool, default: True ) –

    Whether to precondition 1d params (SOAP paper sets this to False). Defaults to True.

  • eps (float, default: 1e-08 ) –

    epsilon for dividing first momentum by second. Defaults to 1e-8.

  • debias (bool, default: True ) –

    enables adam bias correction. Defaults to True.

  • proj_exp_avg (bool, default: True ) –

    if True, maintains exponential average of gradients (momentum) in projected space. If False - in original space Defaults to True.

  • alpha (float, default: 1 ) –

    learning rate. Defaults to 1.

  • inner (Chainable | None, default: None ) –

    output of this module is projected and Adam will run on it, but preconditioners are updated from original gradients.

Examples:

SOAP:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SOAP(),
    tz.m.LR(1e-3)
)
Stabilized SOAP:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SOAP(),
    tz.m.NormalizeByEMA(max_ema_growth=1.2),
    tz.m.LR(1e-2)
)
Source code in torchzero/modules/adaptive/soap.py
class SOAP(TensorTransform):
    """SOAP (ShampoO with Adam in the Preconditioner's eigenbasis from https://arxiv.org/abs/2409.11321).

    Args:
        beta1 (float, optional): beta for first momentum. Defaults to 0.95.
        beta2 (float, optional): beta for second momentum. Defaults to 0.95.
        shampoo_beta (float | None, optional):
            beta for covariance matrices accumulators. Can be None, then it just sums them like Adagrad (which works worse). Defaults to 0.95.
        precond_freq (int, optional): How often to update the preconditioner. Defaults to 10.
        merge_small (bool, optional): Whether to merge small dims. Defaults to True.
        max_dim (int, optional): Won't precondition dims larger than this. Defaults to 10_000.
        precondition_1d (bool, optional):
            Whether to precondition 1d params (SOAP paper sets this to False). Defaults to True.
        eps (float, optional):
            epsilon for dividing first momentum by second. Defaults to 1e-8.
        debias (bool, optional):
            enables adam bias correction. Defaults to True.
        proj_exp_avg (bool, optional):
            if True, maintains exponential average of gradients (momentum) in projected space.
            If False - in original space Defaults to True.
        alpha (float, optional):
            learning rate. Defaults to 1.
        inner (Chainable | None, optional):
            output of this module is projected and Adam will run on it, but preconditioners are updated
            from original gradients.

    ### Examples:
    SOAP:

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SOAP(),
        tz.m.LR(1e-3)
    )
    ```
    Stabilized SOAP:

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SOAP(),
        tz.m.NormalizeByEMA(max_ema_growth=1.2),
        tz.m.LR(1e-2)
    )
    ```
    """
    def __init__(
        self,
        beta1: float = 0.95,
        beta2: float = 0.95,
        shampoo_beta: float | None = 0.95,
        precond_freq: int = 10,
        merge_small: bool = True,
        max_dim: int = 4096,
        precondition_1d: bool = True,
        eps: float = 1e-8,
        debias: bool = True,
        proj_exp_avg: bool = True,
        alpha: float = 1,

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults["inner"]

        super().__init__(defaults)
        self.set_child("inner", inner)

    @torch.no_grad
    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        if setting["merge_small"]:
            tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])

        state["exp_avg_proj"] = torch.zeros_like(tensor)
        state["exp_avg_sq_proj"] = torch.zeros_like(tensor)

        if tensor.ndim <= 1 and not setting["precondition_1d"]:
            state['GG'] = []

        else:
            max_dim = setting["max_dim"]
            state['GG'] = [
                torch.zeros(s, s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
            ]

        # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
        if len([i is not None for i in state['GG']]) == 0:
            state['GG'] = None

        # first covariance accumulation
        if state['GG'] is not None:
            update_soap_covariances_(tensor, GGs_=state['GG'], beta=setting["shampoo_beta"])

            # get projection matrix with first gradients with eigh
            try: state['Q'] = get_orthogonal_matrix(state['GG'])
            except torch.linalg.LinAlgError as e:
                warnings.warn(f"torch.linalg.eigh raised an error when initializing SOAP Q matrices on 1st step, diagonal preconditioning will be used for this parameter. The error was:\n{e}")
                state["GG"] = None

        state['step'] = 0


    # no update to avoid running merge_dims twice

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        # note
        # do not modify tensors in-place
        # because they are used to update preconditioner at the end

        steps = [s["step"] for s in states]
        if any(s == 0 for s in steps):
            # skip 1st update so to avoid using current gradient in the projection
            # I scale it instead to avoid issues with further modules
            for s in states: s["step"] += 1
            return TensorList(tensors).clamp(-0.1, 0.1)
            # return TensorList(tensors).zero_()

        fs = settings[0]
        merged_updates = [] # for when exp_avg is maintained unprojected
        merged_grads = [] # this doesn't go into preconditioner
        projected = []

        # -------------------------------- inner step -------------------------------- #
        updates = tensors
        has_inner = "inner" in self.children
        if has_inner:
            updates = self.inner_step_tensors("inner", updates, clone=True,
                                              params=params, grads=grads, loss=loss)

        # ---------------------------------- project --------------------------------- #
        for grad, update, state, setting in zip(tensors, updates, states, settings):
            if setting["merge_small"]:
                update, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(update, setting["max_dim"])
                if has_inner: # grad is a different tensor, merge it too
                    grad, _, _ = _merge_small_dims(grad, setting["max_dim"])
                else: # in this case update is still just grad
                    grad = update

            merged_updates.append(update)
            merged_grads.append(grad)

            if state['GG'] is not None:
                update = project(update, state['Q'])

            projected.append(update)


        # ------------------------ run adam in projected space ----------------------- #
        exp_avg_proj, exp_avg_sq_proj = unpack_states(states, projected, "exp_avg_proj", "exp_avg_sq_proj", must_exist=True, cls=TensorList)
        alpha, beta1, beta2, eps = unpack_dicts(settings, "alpha", "beta1", "beta2", "eps", cls=NumberList)

        # lerp exp_avg in projected space
        if fs["proj_exp_avg"]:
            exp_avg_proj.lerp_(projected, weight=1-beta1)

        # or lerp in original space and project
        else:
            exp_avg = exp_avg_proj
            exp_avg.lerp_(merged_updates, weight=1-beta1)
            exp_avg_proj = []
            for t, state, setting in zip(exp_avg, states, settings):
                if state['GG'] is not None:
                    t = project(t, state["Q"])
                exp_avg_proj.append(t)

        # lerp exp_avg_sq
        exp_avg_sq_proj.mul_(beta2).addcmul_(projected, projected, value=1-beta2)

        # adam direction
        denom = exp_avg_sq_proj.sqrt().add_(eps)
        dirs_proj = exp_avg_proj / denom

        # ------------------------------- project back ------------------------------- #
        dirs: list[torch.Tensor] = []
        for dir, state, setting in zip(dirs_proj, states, settings):
            if state['GG'] is not None:
                dir = project_back(dir, state['Q'])

            if setting["merge_small"]:
                dir = _unmerge_small_dims(dir, state['flat_sizes'], state['sort_idxs'])

            dirs.append(dir)

        # -------------------------- update preconditioners -------------------------- #
        # Update is done after the gradient step to avoid using current gradients in the projection.

        for grad, state, setting in zip(merged_grads, states, settings):
            if state['GG'] is not None:

                # lerp covariances
                update_soap_covariances_(grad, state['GG'], beta=setting["shampoo_beta"])

                # (state['step'] - 1) since we start updating on 2nd step
                if (state['step'] - 1) % setting['precond_freq'] == 0:

                    # unproject exp_avg before updating if it is maintained projected
                    exp_avg = None
                    if fs["proj_exp_avg"]:
                        exp_avg = project_back(state["exp_avg_proj"], state["Q"])

                    # update projection matrix and exp_avg_sq_proj
                    try:
                        state['Q'], state['exp_avg_sq_proj'] = get_orthogonal_matrix_QR(
                            state["exp_avg_sq_proj"], state['GG'], state['Q'])

                        # re-project exp_avg if it is maintained projected
                        if fs["proj_exp_avg"]:
                            assert exp_avg is not None
                            state["exp_avg_proj"] = project(exp_avg, state["Q"])

                    except torch.linalg.LinAlgError:
                        pass

            state["step"] += 1


        # ------------------------- bias-corrected step size ------------------------- #
        if fs["debias"]:
            steps1 = [s+1 for s in steps]
            bias_correction1 = 1.0 - beta1 ** steps1
            bias_correction2 = 1.0 - beta2 ** steps1
            alpha = alpha * (bias_correction2 ** .5) / bias_correction1

        torch._foreach_mul_(dirs, alpha)
        return dirs

SOAPBasis

Bases: torchzero.core.transform.TensorTransform

Run another optimizer in Shampoo eigenbases.

Note

the buffers of the basis_opt are re-projected whenever basis changes. The reprojection logic is not implemented on all modules. Some supported modules are:

Adagrad, Adam, Adan, Lion, MARSCorrection, MSAMMomentum, RMSprop, EMA, HeavyBall, NAG, ClipNormByEMA, ClipValueByEMA, NormalizeByEMA, ClipValueGrowth, CoordinateMomentum, CubicAdam.

Additionally most modules with no internal buffers are supported, e.g. Cautious, Sign, ClipNorm, Orthogonalize, etc. However modules that use weight values, such as WeighDecay can't be supported, as weights can't be projected.

Also, if you say use EMA on output of Pow(2), the exponential average will be reprojected as gradient and not as squared gradients. Use modules like EMASquared, SqrtEMASquared to get correct reprojections.

Parameters:

  • basis_opt (Chainable) –

    module or modules to run in Shampoo eigenbases.

  • shampoo_beta (float | None, default: 0.95 ) –

    beta for covariance matrices accumulators. Can be None, then it just sums them like Adagrad (which works worse). Defaults to 0.95.

  • precond_freq (int, default: 10 ) –

    How often to update the preconditioner. Defaults to 10.

  • merge_small (bool, default: True ) –

    Whether to merge small dims. Defaults to True.

  • max_dim (int, default: 4096 ) –

    Won't precondition dims larger than this. Defaults to 10_000.

  • precondition_1d (bool, default: True ) –

    Whether to precondition 1d params (SOAP paper sets this to False). Defaults to True.

  • inner (Chainable | None, default: None ) –

    output of this module is projected and basis_opt will run on it, but preconditioners are updated from original gradients.

Examples: SOAP with MARS and AMSGrad:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SOAPBasis([tz.m.MARSCorrection(0.95), tz.m.Adam(0.95, 0.95, amsgrad=True)]),
    tz.m.LR(1e-3)
)

LaProp in Shampoo eigenbases (SOLP):

# we define LaProp through other modules, moved it out for brevity
laprop = (
    tz.m.RMSprop(0.95),
    tz.m.Debias(beta1=None, beta2=0.95),
    tz.m.EMA(0.95),
    tz.m.Debias(beta1=0.95, beta2=None),
)

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SOAPBasis(laprop),
    tz.m.LR(1e-3)
)

Lion in Shampoo eigenbases (works kinda well):

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SOAPBasis(tz.m.Lion()),
    tz.m.LR(1e-3)
)

Source code in torchzero/modules/basis/soap_basis.py
class SOAPBasis(TensorTransform):
    """
    Run another optimizer in Shampoo eigenbases.

    Note:
        the buffers of the ``basis_opt`` are re-projected whenever basis changes. The reprojection logic is not implemented on all modules. Some supported modules are:

        ``Adagrad``, ``Adam``, ``Adan``, ``Lion``, ``MARSCorrection``, ``MSAMMomentum``, ``RMSprop``, ``EMA``, ``HeavyBall``, ``NAG``, ``ClipNormByEMA``, ``ClipValueByEMA``, ``NormalizeByEMA``, ``ClipValueGrowth``, ``CoordinateMomentum``, ``CubicAdam``.

        Additionally most modules with no internal buffers are supported, e.g. ``Cautious``, ``Sign``, ``ClipNorm``, ``Orthogonalize``, etc. However modules that use weight values, such as ``WeighDecay`` can't be supported, as weights can't be projected.

        Also, if you say use ``EMA`` on output of ``Pow(2)``, the exponential average will be reprojected as gradient and not as squared gradients. Use modules like ``EMASquared``, ``SqrtEMASquared`` to get correct reprojections.

    Args:
        basis_opt (Chainable): module or modules to run in Shampoo eigenbases.
        shampoo_beta (float | None, optional):
            beta for covariance matrices accumulators. Can be None, then it just sums them like Adagrad (which works worse). Defaults to 0.95.
        precond_freq (int, optional): How often to update the preconditioner. Defaults to 10.
        merge_small (bool, optional): Whether to merge small dims. Defaults to True.
        max_dim (int, optional): Won't precondition dims larger than this. Defaults to 10_000.
        precondition_1d (bool, optional):
            Whether to precondition 1d params (SOAP paper sets this to False). Defaults to True.
        inner (Chainable | None, optional):
            output of this module is projected and ``basis_opt`` will run on it, but preconditioners are updated
            from original gradients.

    Examples:
    SOAP with MARS and AMSGrad:
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SOAPBasis([tz.m.MARSCorrection(0.95), tz.m.Adam(0.95, 0.95, amsgrad=True)]),
        tz.m.LR(1e-3)
    )
    ```

    LaProp in Shampoo eigenbases (SOLP):
    ```python

    # we define LaProp through other modules, moved it out for brevity
    laprop = (
        tz.m.RMSprop(0.95),
        tz.m.Debias(beta1=None, beta2=0.95),
        tz.m.EMA(0.95),
        tz.m.Debias(beta1=0.95, beta2=None),
    )

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SOAPBasis(laprop),
        tz.m.LR(1e-3)
    )
    ```

    Lion in Shampoo eigenbases (works kinda well):
    ```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SOAPBasis(tz.m.Lion()),
        tz.m.LR(1e-3)
    )
    ```
    """
    def __init__(
        self,
        basis_opt: Chainable,
        shampoo_beta: float | None = 0.95,
        precond_freq: int = 10,
        merge_small: bool = True,
        max_dim: int = 4096,
        precondition_1d: bool = True,
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults["inner"], defaults["basis_opt"]

        super().__init__(defaults)
        self.set_child("inner", inner)
        self.set_child("basis_opt", basis_opt)

    @torch.no_grad
    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        if setting["merge_small"]:
            tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])

        state["exp_avg_proj"] = torch.zeros_like(tensor)
        state["exp_avg_sq_proj"] = torch.zeros_like(tensor)

        if tensor.ndim <= 1 and not setting["precondition_1d"]:
            state['GG'] = []

        else:
            max_dim = setting["max_dim"]
            state['GG'] = [
                torch.zeros(s, s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
            ]

        # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
        if len([i is not None for i in state['GG']]) == 0:
            state['GG'] = None

        # first covariance accumulation
        if state['GG'] is not None:
            update_soap_covariances_(tensor, GGs_=state['GG'], beta=setting["shampoo_beta"])

            # get projection matrix with first gradients with eigh
            try: state['Q'] = get_orthogonal_matrix(state['GG'])
            except torch.linalg.LinAlgError as e:
                warnings.warn(f"torch.linalg.eigh raised an error when initializing SOAP Q matrices on 1st step, diagonal preconditioning will be used for this parameter. The error was:\n{e}")
                state["GG"] = None

        state['step'] = 0


    # no update to avoid running merge_dims twice

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        # note
        # do not modify tensors in-place
        # because they are used to update preconditioner at the end

        steps = [s["step"] for s in states]
        if any(s == 0 for s in steps):
            # skip 1st update so to avoid using current gradient in the projection
            # I scale it instead to avoid issues with further modules
            for s in states: s["step"] += 1
            return TensorList(tensors).clamp(-0.1, 0.1)
            # return TensorList(tensors).zero_()

        merged_updates = [] # for when exp_avg is maintained unprojected
        merged_grads = [] # this doesn't go into preconditioner
        projected = []

        # -------------------------------- inner step -------------------------------- #
        updates = tensors
        has_inner = "inner" in self.children
        if has_inner:
            updates = self.inner_step_tensors("inner", updates, clone=True,
                                              params=params, grads=grads, loss=loss)

        # ---------------------------------- project --------------------------------- #
        for grad, update, state, setting in zip(tensors, updates, states, settings):
            if setting["merge_small"]:
                update, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(update, setting["max_dim"])
                if has_inner: # grad is a different tensor, merge it too
                    grad, _, _ = _merge_small_dims(grad, setting["max_dim"])
                else: # in this case update is still just grad
                    grad = update

            merged_updates.append(update)
            merged_grads.append(grad)

            if state['GG'] is not None:
                update = project(update, state['Q'])

            projected.append(update)


        # ------------------------ run opt in projected space ----------------------- #
        dirs_proj = self.inner_step_tensors("basis_opt", tensors=projected, clone=True, grads=projected)

        # ------------------------------- project back ------------------------------- #
        dirs: list[torch.Tensor] = []
        for dir, state, setting in zip(dirs_proj, states, settings):
            if state['GG'] is not None:
                dir = project_back(dir, state['Q'])

            if setting["merge_small"]:
                dir = _unmerge_small_dims(dir, state['flat_sizes'], state['sort_idxs'])

            dirs.append(dir)

        # -------------------------- update preconditioners -------------------------- #
        # Update is done after the gradient step to avoid using current gradients in the projection.

        grad_buffs = self.get_child_projected_buffers("basis_opt", "grad")
        grad_sq_buffs = self.get_child_projected_buffers("basis_opt", ["grad_sq", "grad_cu"])

        for i, (grad, state, setting) in enumerate(zip(merged_grads, states, settings)):
            if state['GG'] is not None:

                # lerp covariances
                update_soap_covariances_(grad, state['GG'], beta=setting["shampoo_beta"])

                # (state['step'] - 1) since we start updating on 2nd step
                if (state['step'] - 1) % setting['precond_freq'] == 0:
                    g_buffs = [b[i] for b in grad_buffs]
                    g_sq_buffs = [b[i] for b in grad_sq_buffs]

                    # unproject grad buffers before updating
                    g_buffs_unproj = [project_back(buff, state["Q"]) for buff in g_buffs]

                    # update projection matrix and exp_avg_sq_proj
                    try:
                        state['Q'], g_sq_buffs_new = get_orthogonal_matrix_QR(
                            g_sq_buffs, state['GG'], state['Q'])

                        for b_old, b_new in zip(g_sq_buffs, g_sq_buffs_new):
                            set_storage_(b_old, b_new)

                        # re-project grad buffers
                        for b_proj, b_unproj in zip(g_buffs, g_buffs_unproj):
                            set_storage_(b_proj, project(b_unproj, state["Q"]))

                    except torch.linalg.LinAlgError:
                        pass

            state["step"] += 1

        return dirs

SPSA

Bases: torchzero.modules.grad_approximation.rfdm.RandomizedFDM

Gradient approximation via Simultaneous perturbation stochastic approximation (SPSA) method.

Note

This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients, and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

Parameters:

  • h (float, default: 0.001 ) –

    finite difference step size of jvp_method is set to forward or central. Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of random gradient samples. Defaults to 1.

  • formula (Literal, default: 'central' ) –

    finite difference formula. Defaults to 'central2'.

  • distribution (Literal, default: 'rademacher' ) –

    distribution. Defaults to "rademacher".

  • pre_generate (bool, default: True ) –

    whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.

  • seed (int | None | Generator, default: None ) –

    Seed for random generator. Defaults to None.

  • target (Literal, default: 'closure' ) –

    what to set on var. Defaults to "closure".

References

Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771

Source code in torchzero/modules/grad_approximation/rfdm.py
class SPSA(RandomizedFDM):
    """
    Gradient approximation via Simultaneous perturbation stochastic approximation (SPSA) method.

    Note:
        This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
        and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.

    Args:
        h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
        n_samples (int, optional): number of random gradient samples. Defaults to 1.
        formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
        distribution (Distributions, optional): distribution. Defaults to "rademacher".
        pre_generate (bool, optional):
            whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
        seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
        target (GradTarget, optional): what to set on var. Defaults to "closure".

    References:
        Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771
    """

SPSA1

Bases: torchzero.modules.grad_approximation.grad_approximator.GradApproximator

One-measurement variant of SPSA. Unlike standard two-measurement SPSA, the estimated gradient often won't be a descent direction, however the expectation is biased towards the descent direction. Therefore this variant of SPSA is only recommended for a specific class of problems where the objective function changes on each evaluation, for example feedback control problems.

Parameters:

  • h (float, default: 0.001 ) –

    finite difference step size, recommended to set to same value as learning rate. Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of random samples. Defaults to 1.

  • eps (float, default: 1e-08 ) –

    measurement noise estimate. Defaults to 1e-8.

  • seed (int | None | Generator, default: None ) –

    random seed. Defaults to None.

  • target (Literal, default: 'closure' ) –

    what to set on closure. Defaults to "closure".

Reference

SPALL, JAMES C. "A One-measurement Form of Simultaneous Stochastic Approximation."

Source code in torchzero/modules/grad_approximation/spsa1.py
class SPSA1(GradApproximator):
    """One-measurement variant of SPSA. Unlike standard two-measurement SPSA, the estimated
    gradient often won't be a descent direction, however the expectation is biased towards
    the descent direction. Therefore this variant of SPSA is only recommended for a specific
    class of problems where the objective function changes on each evaluation,
    for example feedback control problems.

    Args:
        h (float, optional):
            finite difference step size, recommended to set to same value as learning rate. Defaults to 1e-3.
        n_samples (int, optional): number of random samples. Defaults to 1.
        eps (float, optional): measurement noise estimate. Defaults to 1e-8.
        seed (int | None | torch.Generator, optional): random seed. Defaults to None.
        target (GradTarget, optional): what to set on closure. Defaults to "closure".

    Reference:
        [SPALL, JAMES C. "A One-measurement Form of Simultaneous Stochastic Approximation](https://www.jhuapl.edu/spsa/PDF-SPSA/automatica97_one_measSPSA.pdf)."
    """

    def __init__(
        self,
        h: float = 1e-3,
        n_samples: int = 1,
        eps: float = 1e-8, # measurement noise
        pre_generate = False,
        seed: int | None | torch.Generator = None,
        target: GradTarget = "closure",
    ):
        defaults = dict(h=h, eps=eps, n_samples=n_samples, pre_generate=pre_generate, seed=seed)
        super().__init__(defaults, target=target)


    def pre_step(self, objective):

        if self.defaults['pre_generate']:

            params = TensorList(objective.params)
            generator = self.get_generator(params[0].device, self.defaults['seed'])

            n_samples = self.defaults['n_samples']
            h = self.get_settings(objective.params, 'h')

            perturbations = [params.rademacher_like(generator=generator) for _ in range(n_samples)]
            torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])

            for param, prt in zip(params, zip(*perturbations)):
                self.state[param]['perturbations'] = prt

    @torch.no_grad
    def approximate(self, closure, params, loss):
        generator = self.get_generator(params[0].device, self.defaults['seed'])

        params = TensorList(params)
        orig_params = params.clone() # store to avoid small changes due to float imprecision
        loss_approx = None

        h, eps = self.get_settings(params, "h", "eps", cls=NumberList)
        n_samples = self.defaults['n_samples']

        default = [None]*n_samples
        # perturbations are pre-multiplied by h
        perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))

        grad = None
        for i in range(n_samples):
            prt = perturbations[i]

            if prt[0] is None:
                prt = params.rademacher_like(generator=generator).mul_(h)

            else: prt = TensorList(prt)

            params += prt
            L = closure(False)
            params.copy_(orig_params)

            sample = prt * ((L + eps) / h)
            if grad is None: grad = sample
            else: grad += sample

        assert grad is not None
        if n_samples > 1: grad.div_(n_samples)

        # mean if got per-sample values
        return grad, loss, loss_approx

SR1

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Symmetric Rank 1. This works best with a trust region:

tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False))

Parameters:

  • init_scale (float | Literal['auto'], default: 'auto' ) –

    initial hessian matrix is set to identity times this.

    "auto" corresponds to a heuristic from [1] p.142-143.

    Defaults to "auto".

  • tol (float, default: 1e-32 ) –

    tolerance for denominator in SR1 update rule as in [1] p.146. Defaults to 1e-32.

  • ptol (float | None, default: 1e-32 ) –

    skips update if maximum difference between current and previous gradients is less than this, to avoid instability. Defaults to 1e-32.

  • ptol_restart (bool, default: False ) –

    whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.

  • restart_interval (int | None | Literal['auto'], default: None ) –

    interval between resetting the hessian approximation.

    "auto" corresponds to number of decision variables + 1.

    None - no resets.

    Defaults to None.

  • beta (float | None, default: None ) –

    momentum on H or B. Defaults to None.

  • update_freq (int, default: 1 ) –

    frequency of updating H or B. Defaults to 1.

  • scale_first (bool, default: False ) –

    whether to downscale first step before hessian approximation becomes available. Defaults to True.

  • scale_second (bool) –

    whether to downscale second step. Defaults to False.

  • concat_params (bool, default: True ) –

    If true, all parameters are treated as a single vector. If False, the update rule is applied to each parameter separately. Defaults to True.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to the output of this module. Defaults to None.

Examples:

SR1 with trust region

opt = tz.Optimizer(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
)

References:
[1]. Nocedal. Stephen J. Wright. Numerical Optimization
Source code in torchzero/modules/quasi_newton/quasi_newton.py
class SR1(_InverseHessianUpdateStrategyDefaults):
    """Symmetric Rank 1. This works best with a trust region:
    ```python
    tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False))
    ```

    Args:
        init_scale (float | Literal["auto"], optional):
            initial hessian matrix is set to identity times this.

            "auto" corresponds to a heuristic from [1] p.142-143.

            Defaults to "auto".
        tol (float, optional):
            tolerance for denominator in SR1 update rule as in [1] p.146. Defaults to 1e-32.
        ptol (float | None, optional):
            skips update if maximum difference between current and previous gradients is less than this, to avoid instability.
            Defaults to 1e-32.
        ptol_restart (bool, optional): whether to reset the hessian approximation when ptol tolerance is not met. Defaults to False.
        restart_interval (int | None | Literal["auto"], optional):
            interval between resetting the hessian approximation.

            "auto" corresponds to number of decision variables + 1.

            None - no resets.

            Defaults to None.
        beta (float | None, optional): momentum on H or B. Defaults to None.
        update_freq (int, optional): frequency of updating H or B. Defaults to 1.
        scale_first (bool, optional):
            whether to downscale first step before hessian approximation becomes available. Defaults to True.
        scale_second (bool, optional): whether to downscale second step. Defaults to False.
        concat_params (bool, optional):
            If true, all parameters are treated as a single vector.
            If False, the update rule is applied to each parameter separately. Defaults to True.
        inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.

    ### Examples:

    SR1 with trust region
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
    )
    ```

    ###  References:
        [1]. Nocedal. Stephen J. Wright. Numerical Optimization
    """

    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return sr1_(H=H, s=s, y=y, tol=setting['tol'])
    def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
        return sr1_(H=B, s=y, y=s, tol=setting['tol'])

SSVM

Bases: torchzero.modules.quasi_newton.quasi_newton.HessianUpdateStrategy

Self-scaling variable metric Quasi-Newton method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class SSVM(HessianUpdateStrategy):
    """
    Self-scaling variable metric Quasi-Newton method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
    """
    def __init__(
        self,
        switch: tuple[float,float] | Literal[1,2,3,4] = 3,
        init_scale: float | Literal["auto"] = 'auto',
        tol: float = 1e-32,
        ptol: float | None = 1e-32,
        ptol_restart: bool = False,
        gtol: float | None = 1e-32,
        restart_interval: int | None = None,
        beta: float | None = None,
        update_freq: int = 1,
        scale_first: bool = False,
        concat_params: bool = True,
        inner: Chainable | None = None,
    ):
        defaults = dict(switch=switch)
        super().__init__(
            defaults=defaults,
            init_scale=init_scale,
            tol=tol,
            ptol=ptol,
            ptol_restart=ptol_restart,
            gtol=gtol,
            restart_interval=restart_interval,
            beta=beta,
            update_freq=update_freq,
            scale_first=scale_first,
            concat_params=concat_params,
            inverse=True,
            inner=inner,
        )

    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return ssvm_H_(H=H, s=s, y=y, g=g, switch=setting['switch'], tol=setting['tol'])

SVRG

Bases: torchzero.core.module.Module

Stochastic variance reduced gradient method (SVRG).

To use, put SVRG as the first module, it can be used with any other modules. To reduce variance of a gradient estimator, put the gradient estimator before SVRG.

First it uses first accum_steps batches to compute full gradient at initial parameters using gradient accumulation, the model will not be updated during this.

Then it performs svrg_steps SVRG steps, each requires two forward and backward passes.

After svrg_steps, it goes back to full gradient computation step step.

As an alternative to gradient accumulation you can pass "full_closure" argument to the step method, which should compute full gradients, set them to .grad attributes of the parameters, and return full loss.

Parameters:

  • svrg_steps (int) –

    number of steps before calculating full gradient. This can be set to length of the dataloader.

  • accum_steps (int | None, default: None ) –

    number of steps to accumulate the gradient for. Not used if "full_closure" is passed to the step method. If None, uses value of svrg_steps. Defaults to None.

  • reset_before_accum (bool, default: True ) –

    whether to reset all other modules when re-calculating full gradient. Defaults to True.

  • svrg_loss (bool, default: True ) –

    whether to replace loss with SVRG loss (calculated by same formula as SVRG gradient). Defaults to True.

  • alpha (float, default: 1 ) –

    multiplier to g_full(x_0) - g_batch(x_0) term, can be annealed linearly from 1 to 0 as suggested in https://arxiv.org/pdf/2311.05589#page=6

Examples:

SVRG-LBFGS

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SVRG(len(dataloader)),
    tz.m.LBFGS(),
    tz.m.Backtracking(),
)

For extra variance reduction one can use Online versions of algorithms, although it won't always help.

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SVRG(len(dataloader)),
    tz.m.Online(tz.m.LBFGS()),
    tz.m.Backtracking(),
)

Variance reduction can also be applied to gradient estimators.
```python
opt = tz.Optimizer(
    model.parameters(),
    tz.m.SPSA(),
    tz.m.SVRG(100),
    tz.m.LR(1e-2),
)

Notes

The SVRG gradient is computed as g_b(x) - alpha * (g_b(x_0) - g_f(x_0)), where: - x is current parameters - x_0 is initial parameters, where full gradient was computed - g_b refers to mini-batch gradient at x or x_0 - g_f refers to full gradient at x_0.

The SVRG loss is computed using the same formula.

Source code in torchzero/modules/variance_reduction/svrg.py
class SVRG(Module):
    """Stochastic variance reduced gradient method (SVRG).

    To use, put SVRG as the first module, it can be used with any other modules.
    To reduce variance of a gradient estimator, put the gradient estimator before SVRG.

    First it uses first ``accum_steps`` batches to compute full gradient at initial
    parameters using gradient accumulation, the model will not be updated during this.

    Then it performs ``svrg_steps`` SVRG steps, each requires two forward and backward passes.

    After ``svrg_steps``, it goes back to full gradient computation step step.

    As an alternative to gradient accumulation you can pass "full_closure" argument to the ``step`` method,
    which should compute full gradients, set them to ``.grad`` attributes of the parameters,
    and return full loss.

    Args:
        svrg_steps (int): number of steps before calculating full gradient. This can be set to length of the dataloader.
        accum_steps (int | None, optional):
            number of steps to accumulate the gradient for. Not used if "full_closure" is passed to the ``step`` method. If None, uses value of ``svrg_steps``. Defaults to None.
        reset_before_accum (bool, optional):
            whether to reset all other modules when re-calculating full gradient. Defaults to True.
        svrg_loss (bool, optional):
            whether to replace loss with SVRG loss (calculated by same formula as SVRG gradient). Defaults to True.
        alpha (float, optional):
            multiplier to ``g_full(x_0) - g_batch(x_0)`` term, can be annealed linearly from 1 to 0 as suggested in https://arxiv.org/pdf/2311.05589#page=6

    ## Examples:
    SVRG-LBFGS
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SVRG(len(dataloader)),
        tz.m.LBFGS(),
        tz.m.Backtracking(),
    )
    ```

    For extra variance reduction one can use Online versions of algorithms, although it won't always help.
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SVRG(len(dataloader)),
        tz.m.Online(tz.m.LBFGS()),
        tz.m.Backtracking(),
    )

    Variance reduction can also be applied to gradient estimators.
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SPSA(),
        tz.m.SVRG(100),
        tz.m.LR(1e-2),
    )
    ```
    ## Notes

    The SVRG gradient is computed as ``g_b(x) - alpha * (g_b(x_0) - g_f(x_0))``, where:
    - ``x`` is current parameters
    - ``x_0`` is initial parameters, where full gradient was computed
    - ``g_b`` refers to mini-batch gradient at ``x`` or ``x_0``
    - ``g_f`` refers to full gradient at ``x_0``.

    The SVRG loss is computed using the same formula.
    """
    def __init__(self, svrg_steps: int, accum_steps: int | None = None, reset_before_accum:bool=True, svrg_loss:bool=True, alpha:float=1):
        defaults = dict(svrg_steps = svrg_steps, accum_steps=accum_steps, reset_before_accum=reset_before_accum, svrg_loss=svrg_loss, alpha=alpha)
        super().__init__(defaults)


    @torch.no_grad
    def update(self, objective):
        params = objective.params
        closure = objective.closure
        assert closure is not None

        if "full_grad" not in self.global_state:

            # -------------------------- calculate full gradient ------------------------- #
            if "full_closure" in objective.storage:
                full_closure = objective.storage['full_closure']
                with torch.enable_grad():
                    full_loss = full_closure()
                    if all(p.grad is None for p in params):
                        warnings.warn("all gradients are None after evaluating full_closure.")

                    full_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
                    self.global_state["full_loss"] = full_loss
                    self.global_state["full_grad"] = full_grad
                    self.global_state['x_0'] = [p.clone() for p in params]

                # current batch will be used for svrg update

            else:
                # accumulate gradients over n steps
                accum_steps = self.defaults['accum_steps']
                if accum_steps is None: accum_steps = self.defaults['svrg_steps']

                current_accum_step = self.global_state.get('current_accum_step', 0) + 1
                self.global_state['current_accum_step'] = current_accum_step

                # accumulate grads
                accumulator = self.get_state(params, 'accumulator')
                grad = objective.get_grads()
                torch._foreach_add_(accumulator, grad)

                # accumulate loss
                loss_accumulator = self.global_state.get('loss_accumulator', 0)
                loss_accumulator += tofloat(objective.loss)
                self.global_state['loss_accumulator'] = loss_accumulator

                # on nth step, use the accumulated gradient
                if current_accum_step >= accum_steps:
                    torch._foreach_div_(accumulator, accum_steps)
                    self.global_state["full_grad"] = accumulator
                    self.global_state["full_loss"] = loss_accumulator / accum_steps

                    self.global_state['x_0'] = [p.clone() for p in params]
                    self.clear_state_keys('accumulator')
                    del self.global_state['current_accum_step']

                # otherwise skip update until enough grads are accumulated
                else:
                    objective.updates = None
                    objective.stop = True
                    objective.skip_update = True
                    return


        svrg_steps = self.defaults['svrg_steps']
        current_svrg_step = self.global_state.get('current_svrg_step', 0) + 1
        self.global_state['current_svrg_step'] = current_svrg_step

        # --------------------------- SVRG gradient closure -------------------------- #
        x0 = self.global_state['x_0']
        gf_x0 = self.global_state["full_grad"]
        ff_x0 = self.global_state['full_loss']
        use_svrg_loss = self.defaults['svrg_loss']
        alpha = self.get_settings(params, 'alpha')
        alpha_0 = alpha[0]
        if all(a == 1 for a in alpha): alpha = None

        def svrg_closure(backward=True):
            # g_b(x) - α * (g_f(x_0) - g_b(x_0)) and same for loss
            with torch.no_grad():
                x = [p.clone() for p in params]

                if backward:
                    # f and g at x
                    with torch.enable_grad(): fb_x = closure()
                    gb_x = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]

                    # f and g at x_0
                    torch._foreach_copy_(params, x0)
                    with torch.enable_grad(): fb_x0 = closure()
                    gb_x0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
                    torch._foreach_copy_(params, x)

                    # g_svrg = gb_x - alpha * (gf_x0 - gb_x0)
                    correction = torch._foreach_sub(gb_x0, gf_x0)
                    if alpha is not None: torch._foreach_mul_(correction, alpha)
                    g_svrg = torch._foreach_sub(gb_x, correction)

                    f_svrg = fb_x - alpha_0 * (fb_x0 - ff_x0)
                    for p, g in zip(params, g_svrg):
                        p.grad = g

                    if use_svrg_loss: return f_svrg
                    return fb_x

            # no backward
            if use_svrg_loss:
                fb_x = closure(False)
                torch._foreach_copy_(params, x0)
                fb_x0 = closure(False)
                torch._foreach_copy_(params, x)
                f_svrg = fb_x - alpha_0 * (fb_x0 - ff_x0)
                return f_svrg

            return closure(False)

        objective.closure = svrg_closure

        # --- after svrg_steps steps reset so that new full gradient is calculated on next step --- #
        if current_svrg_step >= svrg_steps:
            del self.global_state['current_svrg_step']
            del self.global_state['full_grad']
            del self.global_state['full_loss']
            del self.global_state['x_0']
            if self.defaults['reset_before_accum']:
                objective.post_step_hooks.append(partial(_reset_except_self, self=self))

    def apply(self, objective): return objective

SaveBest

Bases: torchzero.core.module.Module

Saves best parameters found so far, ones that have lowest loss. Put this as the last module.

Adds the following attrs:

  • best_params - a list of tensors with best parameters.
  • best_loss - loss value with best_params.
  • load_best_parameters - a function that sets parameters to the best parameters./

Examples

```python def rosenbrock(x, y): return (1 - x)2 + (100 * (y - x2))**2

xy = torch.tensor((-1.1, 2.5), requires_grad=True) opt = tz.Optimizer( [xy], tz.m.NAG(0.999), tz.m.LR(1e-6), tz.m.SaveBest() )

optimize for 1000 steps

for i in range(1000): loss = rosenbrock(*xy) opt.zero_grad() loss.backward() opt.step(loss=loss) # SaveBest needs closure or loss

NAG overshot, but we saved the best params

print(f'{rosenbrock(*xy) = }') # >> 3.6583 print(f"{opt.attrs['best_loss'] = }") # >> 0.000627

load best parameters

opt.attrs'load_best_params' print(f'{rosenbrock(*xy) = }') # >> 0.000627

Source code in torchzero/modules/misc/misc.py
class SaveBest(Module):
    """Saves best parameters found so far, ones that have lowest loss. Put this as the last module.

    Adds the following attrs:

    - ``best_params`` - a list of tensors with best parameters.
    - ``best_loss`` - loss value with ``best_params``.
    - ``load_best_parameters`` - a function that sets parameters to the best parameters./

    ## Examples
    ```python
    def rosenbrock(x, y):
        return (1 - x)**2 + (100 * (y - x**2))**2

    xy = torch.tensor((-1.1, 2.5), requires_grad=True)
    opt = tz.Optimizer(
        [xy],
        tz.m.NAG(0.999),
        tz.m.LR(1e-6),
        tz.m.SaveBest()
    )

    # optimize for 1000 steps
    for i in range(1000):
        loss = rosenbrock(*xy)
        opt.zero_grad()
        loss.backward()
        opt.step(loss=loss) # SaveBest needs closure or loss

    # NAG overshot, but we saved the best params
    print(f'{rosenbrock(*xy) = }') # >> 3.6583
    print(f"{opt.attrs['best_loss'] = }") # >> 0.000627

    # load best parameters
    opt.attrs['load_best_params']()
    print(f'{rosenbrock(*xy) = }') # >> 0.000627
    """
    def __init__(self):
        super().__init__()

    @torch.no_grad
    def apply(self, objective):
        loss = tofloat(objective.get_loss(False))
        lowest_loss = self.global_state.get('lowest_loss', float("inf"))

        if loss < lowest_loss:
            self.global_state['lowest_loss'] = loss
            best_params = objective.attrs['best_params'] = [p.clone() for p in objective.params]
            objective.attrs['best_loss'] = loss
            objective.attrs['load_best_params'] = partial(_load_best_parameters, params=objective.params, best_params=best_params)

        return objective

ScalarProjection

Bases: torchzero.modules.projections.projection.ProjectionBase

projetion that splits all parameters into individual scalars

Source code in torchzero/modules/projections/projection.py
class ScalarProjection(ProjectionBase):
    """projetion that splits all parameters into individual scalars"""
    def __init__(
        self,
        modules: Chainable,
        project_update=True,
        project_params=True,
        project_grad=True,
    ):
        super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)

    @torch.no_grad
    def project(self, tensors, params, grads, loss, states, settings, current):
        return [s for t in tensors for s in t.ravel().unbind(0)]

    @torch.no_grad
    def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
        return vec_to_tensors(vec=torch.stack(projected_tensors), reference=params)

ScaleByGradCosineSimilarity

Bases: torchzero.core.transform.TensorTransform

Multiplies the update by cosine similarity with gradient. If cosine similarity is negative, naturally the update will be negated as well.

Parameters:

  • eps (float, default: 1e-06 ) –

    epsilon for division. Defaults to 1e-6.

Examples:

Scaled Adam

opt = tz.Optimizer(
    bench.parameters(),
    tz.m.Adam(),
    tz.m.ScaleByGradCosineSimilarity(),
    tz.m.LR(1e-2)
)

Source code in torchzero/modules/momentum/cautious.py
class ScaleByGradCosineSimilarity(TensorTransform):
    """Multiplies the update by cosine similarity with gradient.
    If cosine similarity is negative, naturally the update will be negated as well.

    Args:
        eps (float, optional): epsilon for division. Defaults to 1e-6.

    ## Examples:

    Scaled Adam
    ```python
    opt = tz.Optimizer(
        bench.parameters(),
        tz.m.Adam(),
        tz.m.ScaleByGradCosineSimilarity(),
        tz.m.LR(1e-2)
    )
    ```
    """
    def __init__(
        self,
        eps: float = 1e-6,
    ):
        defaults = dict(eps=eps)
        super().__init__(defaults, uses_grad=True)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        eps = settings[0]['eps']
        tensors = TensorList(tensors)
        grads = TensorList(grads)
        cos_sim = tensors.dot(grads) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)

        return tensors.mul_(cos_sim)

ScaleLRBySignChange

Bases: torchzero.core.transform.TensorTransform

learning rate gets multiplied by nplus if ascent/gradient didn't change the sign, or nminus if it did.

This is part of RProp update rule.

Parameters:

  • nplus (float, default: 1.2 ) –

    learning rate gets multiplied by nplus if ascent/gradient didn't change the sign

  • nminus (float, default: 0.5 ) –

    learning rate gets multiplied by nminus if ascent/gradient changed the sign

  • lb (float, default: 1e-06 ) –

    lower bound for lr.

  • ub (float, default: 50.0 ) –

    upper bound for lr.

  • alpha (float, default: 1.0 ) –

    initial learning rate.

Source code in torchzero/modules/adaptive/rprop.py
class ScaleLRBySignChange(TensorTransform):
    """
    learning rate gets multiplied by ``nplus`` if ascent/gradient didn't change the sign,
    or ``nminus`` if it did.

    This is part of RProp update rule.

    Args:
        nplus (float): learning rate gets multiplied by ``nplus`` if ascent/gradient didn't change the sign
        nminus (float): learning rate gets multiplied by ``nminus`` if ascent/gradient changed the sign
        lb (float): lower bound for lr.
        ub (float): upper bound for lr.
        alpha (float): initial learning rate.

    """

    def __init__(
        self,
        nplus: float = 1.2,
        nminus: float = 0.5,
        lb=1e-6,
        ub=50.0,
        alpha=1.0,
        use_grad=False,
    ):
        defaults = dict(nplus=nplus, nminus=nminus, alpha=alpha, lb=lb, ub=ub, use_grad=use_grad)
        super().__init__(defaults, uses_grad=use_grad)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        tensors = TensorList(tensors)
        if self._uses_grad:
            assert grads is not None
            cur = TensorList(grads)
        else: cur = tensors

        nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
        prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)

        if step == 0:
            lrs.set_(tensors.full_like([s['alpha'] for s in settings]))

        tensors = scale_by_sign_change_(
            tensors_ = tensors,
            cur = cur,
            prev_ = prev,
            lrs_ = lrs,
            nplus = nplus,
            nminus = nminus,
            lb = lb,
            ub = ub,
            step = step,
        )
        return tensors

ScaleModulesByCosineSimilarity

Bases: torchzero.core.module.Module

Scales the output of main module by it's cosine similarity to the output of compare module.

Parameters:

  • main (Chainable) –

    main module or sequence of modules whose update will be scaled.

  • compare (Chainable) –

    module or sequence of modules to compare to

  • eps (float, default: 1e-06 ) –

    epsilon for division. Defaults to 1e-6.

Examples:

Adam scaled by similarity to RMSprop

opt = tz.Optimizer(
    bench.parameters(),
    tz.m.ScaleModulesByCosineSimilarity(
        main = tz.m.Adam(),
        compare = tz.m.RMSprop(0.999, debiased=True),
    ),
    tz.m.LR(1e-2)
)

Source code in torchzero/modules/momentum/cautious.py
class ScaleModulesByCosineSimilarity(Module):
    """Scales the output of ``main`` module by it's cosine similarity to the output
    of ``compare`` module.

    Args:
        main (Chainable): main module or sequence of modules whose update will be scaled.
        compare (Chainable): module or sequence of modules to compare to
        eps (float, optional): epsilon for division. Defaults to 1e-6.

    ## Examples:

    Adam scaled by similarity to RMSprop
    ```python
    opt = tz.Optimizer(
        bench.parameters(),
        tz.m.ScaleModulesByCosineSimilarity(
            main = tz.m.Adam(),
            compare = tz.m.RMSprop(0.999, debiased=True),
        ),
        tz.m.LR(1e-2)
    )
    ```
    """
    def __init__(
        self,
        main: Chainable,
        compare: Chainable,
        eps=1e-6,
    ):
        defaults = dict(eps=eps)
        super().__init__(defaults)

        self.set_child('main', main)
        self.set_child('compare', compare)

    def update(self, objective): raise RuntimeError
    def apply(self, objective): raise RuntimeError

    @torch.no_grad
    def step(self, objective):
        main = self.children['main']
        compare = self.children['compare']

        main_var = main.step(objective.clone(clone_updates=True))
        objective.update_attrs_from_clone_(main_var)

        compare_var = compare.step(objective.clone(clone_updates=True))
        objective.update_attrs_from_clone_(compare_var)

        m = TensorList(main_var.get_updates())
        c = TensorList(compare_var.get_updates())
        eps = self.defaults['eps']

        cos_sim = m.dot(c) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)

        objective.updates = m.mul_(cos_sim)
        return objective

ScipyMinimizeScalar

Bases: torchzero.modules.line_search.line_search.LineSearchBase

Line search via :code:scipy.optimize.minimize_scalar which implements brent, golden search and bounded brent methods.

Parameters:

  • method (str | None, default: None ) –

    "brent", "golden" or "bounded". Defaults to None.

  • maxiter (int | None, default: None ) –

    maximum number of function evaluations the line search is allowed to perform. Defaults to None.

  • bracket (Sequence | None, default: None ) –

    Either a triple (xa, xb, xc) satisfying xa < xb < xc and func(xb) < func(xa) and func(xb) < func(xc), or a pair (xa, xb) to be used as initial points for a downhill bracket search. Defaults to None.

  • bounds (Sequence | None, default: None ) –

    For method ‘bounded’, bounds is mandatory and must have two finite items corresponding to the optimization bounds. Defaults to None.

  • tol (float | None, default: None ) –

    Tolerance for termination. Defaults to None.

  • prev_init (bool, default: False ) –

    uses previous step size as initial guess for the line search.

  • options (dict | None, default: None ) –

    A dictionary of solver options. Defaults to None.

For more details on methods and arguments refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html

Source code in torchzero/modules/line_search/scipy.py
class ScipyMinimizeScalar(LineSearchBase):
    """Line search via :code:`scipy.optimize.minimize_scalar` which implements brent, golden search and bounded brent methods.

    Args:
        method (str | None, optional): "brent", "golden" or "bounded". Defaults to None.
        maxiter (int | None, optional): maximum number of function evaluations the line search is allowed to perform. Defaults to None.
        bracket (Sequence | None, optional):
            Either a triple (xa, xb, xc) satisfying xa < xb < xc and func(xb) < func(xa) and  func(xb) < func(xc), or a pair (xa, xb) to be used as initial points for a downhill bracket search. Defaults to None.
        bounds (Sequence | None, optional):
            For method ‘bounded’, bounds is mandatory and must have two finite items corresponding to the optimization bounds. Defaults to None.
        tol (float | None, optional): Tolerance for termination. Defaults to None.
        prev_init (bool, optional): uses previous step size as initial guess for the line search.
        options (dict | None, optional): A dictionary of solver options. Defaults to None.

    For more details on methods and arguments refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html

    """
    def __init__(
        self,
        method: str | None = None,
        maxiter: int | None = None,
        bracket=None,
        bounds=None,
        tol: float | None = None,
        prev_init: bool = False,
        options=None,
    ):
        defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter, prev_init=prev_init)
        super().__init__(defaults)

        import scipy.optimize
        self.scopt = scipy.optimize


    @torch.no_grad
    def search(self, update, var):
        objective = self.make_objective(var=var)
        method, bracket, bounds, tol, options, maxiter = itemgetter(
            'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.defaults)

        if maxiter is not None:
            options = dict(options) if isinstance(options, Mapping) else {}
            options['maxiter'] = maxiter

        if self.defaults["prev_init"] and "x_prev" in self.global_state:
            if bracket is None: bracket = (0, 1)
            bracket = (*bracket[:-1], self.global_state["x_prev"])

        x = self.scopt.minimize_scalar(objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options).x # pyright:ignore[reportAttributeAccessIssue]

        max = torch.finfo(var.params[0].dtype).max / 2
        if (not math.isfinite(x)) or abs(x) >= max: x = 0

        self.global_state['x_prev'] = x
        return x

Sequential

Bases: torchzero.core.module.Module

On each step, this sequentially steps with modules steps times.

The update is taken to be the parameter difference between parameters before and after the inner loop.

Source code in torchzero/modules/misc/multistep.py
class Sequential(Module):
    """On each step, this sequentially steps with ``modules`` ``steps`` times.

    The update is taken to be the parameter difference between parameters before and after the inner loop."""
    def __init__(self, modules: Iterable[Chainable], steps: int=1):
        defaults = dict(steps=steps)
        super().__init__(defaults)
        self.set_children_sequence(modules)

    @torch.no_grad
    def apply(self, objective):
        return _sequential_step(self, objective, sequential=True)

Shampoo

Bases: torchzero.core.transform.TensorTransform

Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).

Notes

Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.

Shampoo is a very computationally expensive optimizer, increase update_freq if it is too slow.

SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as tz.m.SOAP.

Parameters:

  • update_freq (int) –

    preconditioner update frequency. Defaults to 10.

  • matrix_power (float | None, default: None ) –

    overrides matrix exponent. By default uses -1/grad.ndim. Defaults to None.

  • merge_small (bool, default: True ) –

    whether to merge small dims on tensors. Defaults to True.

  • max_dim (int, default: 10000 ) –

    maximum dimension size for preconditioning. Defaults to 10_000.

  • precondition_1d (bool, default: True ) –

    whether to precondition 1d tensors. Defaults to True.

  • adagrad_eps (float, default: 1e-08 ) –

    epsilon for adagrad division for tensors where shampoo can't be applied. Defaults to 1e-8.

  • matrix_power_method (Literal, default: 'eigh_abs' ) –

    how to compute matrix power.

  • beta (float | None, default: None ) –

    if None calculates sum as in standard Shampoo, otherwise uses EMA of preconditioners. Defaults to None.

  • inner (Chainable | None, default: None ) –

    module applied after updating preconditioners and before applying preconditioning. For example if beta≈0.999 and inner=tz.m.EMA(0.9), this becomes Adam with shampoo preconditioner (ignoring debiasing). Defaults to None.

Examples: Shampoo grafted to Adam

opt = tz.Optimizer(
    model.parameters(),
    tz.m.GraftModules(
        direction = tz.m.Shampoo(),
        magnitude = tz.m.Adam(),
    ),
    tz.m.LR(1e-3)
)

Adam with Shampoo preconditioner

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
    tz.m.Debias(0.9, 0.999),
    tz.m.LR(1e-3)
)
Source code in torchzero/modules/adaptive/shampoo.py
class Shampoo(TensorTransform):
    """Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).

    Notes:
        Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.

        Shampoo is a very computationally expensive optimizer, increase ``update_freq`` if it is too slow.

        SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as ``tz.m.SOAP``.

    Args:
        update_freq (int, optional): preconditioner update frequency. Defaults to 10.
        matrix_power (float | None, optional): overrides matrix exponent. By default uses ``-1/grad.ndim``. Defaults to None.
        merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
        max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 10_000.
        precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
        adagrad_eps (float, optional): epsilon for adagrad division for tensors where shampoo can't be applied. Defaults to 1e-8.
        matrix_power_method (MatrixPowerMethod, optional): how to compute matrix power.
        beta (float | None, optional):
            if None calculates sum as in standard Shampoo, otherwise uses EMA of preconditioners. Defaults to None.
        inner (Chainable | None, optional):
            module applied after updating preconditioners and before applying preconditioning.
            For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
            Defaults to None.

    Examples:
    Shampoo grafted to Adam

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.GraftModules(
            direction = tz.m.Shampoo(),
            magnitude = tz.m.Adam(),
        ),
        tz.m.LR(1e-3)
    )
    ```

    Adam with Shampoo preconditioner

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
        tz.m.Debias(0.9, 0.999),
        tz.m.LR(1e-3)
    )
    ```
    """
    def __init__(
        self,
        reg: float = 1e-12,
        precond_freq: int = 10,
        matrix_power: float | None = None,
        merge_small: bool = True,
        max_dim: int = 10_000,
        precondition_1d: bool = True,
        adagrad_eps: float = 1e-8,
        matrix_power_method: MatrixPowerMethod = "eigh_abs",
        beta: float | None = None,
        beta_debias: bool = True,

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults["inner"]

        super().__init__(defaults, inner=inner)

    @torch.no_grad
    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        if setting["merge_small"]:
            tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])

        if tensor.ndim <= 1 and not setting["precondition_1d"]:
            state["accumulators"] = []

        else:
            max_dim = setting["max_dim"]
            state['accumulators'] = [
                torch.eye(s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
            ]
            state['preconditioners'] = [
                torch.eye(s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
            ]

        # either scalar parameter, 1d with precondition_1d=False, or too big, then diagonal preconditioner is used.
        if len([i is not None for i in state['accumulators']]) == 0:
            state['diagonal_accumulator'] = torch.zeros_like(tensor)

        state['step'] = 0
        state["num_GTG"] = 0

    @torch.no_grad
    def single_tensor_update(self, tensor, param, grad, loss, state, setting):
        if setting["merge_small"]:
            tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])

            if "inner" not in self.children:
                state["merged"] = tensor

        if 'diagonal_accumulator' in state:
            update_diagonal_(tensor, state['diagonal_accumulator'], beta=setting["beta"])
        else:
            update_shampoo_preconditioner_(
                tensor,
                accumulators_=state['accumulators'],
                preconditioners_=state['preconditioners'],
                step=state['step'],
                precond_freq=setting["precond_freq"],
                matrix_power=setting["matrix_power"],
                beta=setting["beta"],
                reg=setting["reg"],
                matrix_power_method=setting["matrix_power_method"],
            )

        if state["step"] % setting["precond_freq"] == 0:
            state["num_GTG"] += 1

        state["step"] += 1


    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):

        if setting["merge_small"]:
            if "inner" not in self.children:
                tensor = state.pop("merged")
            else:
                tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])

        if 'diagonal_accumulator' in state:
            dir = apply_diagonal_(tensor, state['diagonal_accumulator'], eps=setting["adagrad_eps"])
        else:
            dir = apply_shampoo_preconditioner(tensor, preconditioners_=state['preconditioners'])

        if setting["merge_small"]:
            dir = _unmerge_small_dims(dir, state['flat_sizes'], state['sort_idxs'])

        if setting['beta_debias'] and setting["beta"] is not None:
            bias_correction = 1 - (setting["beta"] ** state["num_GTG"])
            dir *= bias_correction ** 0.5

        return dir

ShorR

Bases: torchzero.modules.quasi_newton.quasi_newton.HessianUpdateStrategy

Shor’s r-algorithm.

Note
  • A line search such as [tz.m.StrongWolfe(a_init="quadratic", fallback=True), tz.m.Mul(1.2)] is required. Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling, so setting a_init in the line search is recommended.

  • The line search should try to overstep by a little, therefore it can help to multiply direction given by a line search by some value slightly larger than 1 such as 1.2.

References

Those are the original references, but neither seem to be available online: - Shor, N. Z., Utilization of the Operation of Space Dilatation in the Minimization of Convex Functions, Kibernetika, No. 1, pp. 6-12, 1970.

- Skokov, V. A., Note on Minimization Methods Employing Space Stretching, Kibernetika, No. 4, pp. 115-117, 1974.

An overview is available in Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720.

Reference by Skokov, V. A. describes a more efficient formula which can be found here Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class ShorR(HessianUpdateStrategy):
    """Shor’s r-algorithm.

    Note:
        - A line search such as ``[tz.m.StrongWolfe(a_init="quadratic", fallback=True), tz.m.Mul(1.2)]`` is required. Similarly to conjugate gradient, ShorR doesn't have an automatic step size scaling, so setting ``a_init`` in the line search is recommended.

        - The line search should try to overstep by a little, therefore it can help to multiply direction given by a line search by some value slightly larger than 1 such as 1.2.

    References:
        Those are the original references, but neither seem to be available online:
            - Shor, N. Z., Utilization of the Operation of Space Dilatation in the Minimization of Convex Functions, Kibernetika, No. 1, pp. 6-12, 1970.

            - Skokov, V. A., Note on Minimization Methods Employing Space Stretching, Kibernetika, No. 4, pp. 115-117, 1974.

        An overview is available in [Burke, James V., Adrian S. Lewis, and Michael L. Overton. "The Speed of Shor's R-algorithm." IMA Journal of numerical analysis 28.4 (2008): 711-720](https://sites.math.washington.edu/~burke/papers/reprints/60-speed-Shor-R.pdf).

        Reference by Skokov, V. A. describes a more efficient formula which can be found here [Ansari, Zafar A. Limited Memory Space Dilation and Reduction Algorithms. Diss. Virginia Tech, 1998.](https://camo.ici.ro/books/thesis/th.pdf)
    """

    def __init__(
        self,
        alpha=0.5,
        init_scale: float | Literal["auto"] = 1,
        tol: float = 1e-32,
        ptol: float | None = 1e-32,
        ptol_restart: bool = False,
        gtol: float | None = 1e-32,
        restart_interval: int | None | Literal['auto'] = None,
        beta: float | None = None,
        update_freq: int = 1,
        scale_first: bool = False,
        concat_params: bool = True,
        # inverse: bool = True,
        inner: Chainable | None = None,
    ):
        defaults = dict(alpha=alpha)
        super().__init__(
            defaults=defaults,
            init_scale=init_scale,
            tol=tol,
            ptol=ptol,
            ptol_restart=ptol_restart,
            gtol=gtol,
            restart_interval=restart_interval,
            beta=beta,
            update_freq=update_freq,
            scale_first=scale_first,
            concat_params=concat_params,
            inverse=True,
            inner=inner,
        )

    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        return shor_r_(H=H, y=y, alpha=setting['alpha'])

Sign

Bases: torchzero.core.transform.TensorTransform

Returns sign(input)

Source code in torchzero/modules/ops/unary.py
class Sign(TensorTransform):
    """Returns ``sign(input)``"""
    def __init__(self): super().__init__()
    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        torch._foreach_sign_(tensors)
        return tensors

SignConsistencyLRs

Bases: torchzero.core.transform.TensorTransform

Outputs per-weight learning rates based on consecutive sign consistency.

The learning rate for a weight is multiplied by nplus when two consecutive update signs are the same, otherwise it is multiplied by nplus. The learning rates are bounded to be in (lb, ub) range.

Examples:

GD scaled by consecutive gradient sign consistency

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Mul(tz.m.SignConsistencyLRs()),
    tz.m.LR(1e-2)
)
Source code in torchzero/modules/adaptive/rprop.py
class SignConsistencyLRs(TensorTransform):
    """Outputs per-weight learning rates based on consecutive sign consistency.

    The learning rate for a weight is multiplied by ``nplus`` when two consecutive update signs are the same, otherwise it is multiplied by ``nplus``. The learning rates are bounded to be in ``(lb, ub)`` range.

    ### Examples:

    GD scaled by consecutive gradient sign consistency

    ```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Mul(tz.m.SignConsistencyLRs()),
        tz.m.LR(1e-2)
    )
    ```

"""
    def __init__(
        self,
        nplus: float = 1.2,
        nminus: float = 0.5,
        lb: float | None = 1e-6,
        ub: float | None = 50,
        alpha: float = 1,
    ):
        defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        target = TensorList(tensors)
        nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
        prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)

        if step == 0:
            lrs.set_(target.full_like([s['alpha'] for s in settings]))

        target = sign_consistency_lrs_(
            tensors = target,
            prev_ = prev,
            lrs_ = lrs,
            nplus = nplus,
            nminus = nminus,
            lb = lb,
            ub = ub,
            step = step,
        )
        return target.clone()

SignConsistencyMask

Bases: torchzero.core.transform.TensorTransform

Outputs a mask of sign consistency of current and previous inputs.

The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.

Examples:

GD that skips update for weights where gradient sign changed compared to previous gradient.

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Mul(tz.m.SignConsistencyMask()),
    tz.m.LR(1e-2)
)
Source code in torchzero/modules/adaptive/rprop.py
class SignConsistencyMask(TensorTransform):
    """
    Outputs a mask of sign consistency of current and previous inputs.

    The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.

    ### Examples:

    GD that skips update for weights where gradient sign changed compared to previous gradient.

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Mul(tz.m.SignConsistencyMask()),
        tz.m.LR(1e-2)
    )
    ```

    """
    def __init__(self):
        super().__init__()

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        prev = unpack_states(states, tensors, 'prev', cls=TensorList)
        mask = prev.mul_(tensors).gt_(0)
        prev.copy_(tensors)
        return mask

SixthOrder3P

Bases: torchzero.modules.second_order.multipoint.HigherOrderMethodBase

Sixth-order iterative method.

Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.

Source code in torchzero/modules/second_order/multipoint.py
class SixthOrder3P(HigherOrderMethodBase):
    """Sixth-order iterative method.

    Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
    """
    def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
        defaults=dict(lstsq=lstsq)
        super().__init__(defaults=defaults, derivatives_method=derivatives_method)

    @torch.no_grad
    def one_iteration(self, x, evaluate, objective, setting):
        def f(x): return evaluate(x, 1)[1]
        def f_j(x): return evaluate(x, 2)[1:]
        x_star = sixth_order_3p(x, f, f_j, setting['lstsq'])
        return x - x_star

SixthOrder3PM2

Bases: torchzero.modules.second_order.multipoint.HigherOrderMethodBase

Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45.

Source code in torchzero/modules/second_order/multipoint.py
class SixthOrder3PM2(HigherOrderMethodBase):
    """Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45."""
    def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
        defaults=dict(lstsq=lstsq)
        super().__init__(defaults=defaults, derivatives_method=derivatives_method)

    @torch.no_grad
    def one_iteration(self, x, evaluate, objective, setting):
        def f_j(x): return evaluate(x, 2)[1:]
        def f(x): return evaluate(x, 1)[1]
        x_star = sixth_order_3pm2(x, f, f_j, setting['lstsq'])
        return x - x_star

SixthOrder5P

Bases: torchzero.modules.second_order.multipoint.HigherOrderMethodBase

Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139.

Source code in torchzero/modules/second_order/multipoint.py
class SixthOrder5P(HigherOrderMethodBase):
    """Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139."""
    def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
        defaults=dict(lstsq=lstsq)
        super().__init__(defaults=defaults, derivatives_method=derivatives_method)

    @torch.no_grad
    def one_iteration(self, x, evaluate, objective, setting):
        def f_j(x): return evaluate(x, 2)[1:]
        x_star = sixth_order_5p(x, f_j, setting['lstsq'])
        return x - x_star

SophiaH

Bases: torchzero.core.transform.Transform

SophiaH optimizer from https://arxiv.org/abs/2305.14342

This is similar to Adam, but the second momentum is replaced by an exponential moving average of randomized hessian diagonal estimates, and the update is agressively clipped.

Notes
  • In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply SophiaH preconditioning to another module's output.

  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

Parameters:

  • beta1 (float, default: 0.96 ) –

    first momentum. Defaults to 0.96.

  • beta2 (float, default: 0.99 ) –

    momentum for hessian diagonal estimate. Defaults to 0.99.

  • update_freq (int, default: 10 ) –

    frequency of updating hessian diagonal estimate via a hessian-vector product. Defaults to 10.

  • precond_scale (float, default: 1 ) –

    scale of the preconditioner. Defaults to 1.

  • clip (float, default: 1 ) –

    clips update to (-clip, clip). Defaults to 1.

  • eps (float, default: 1e-12 ) –

    clips hessian diagonal esimate to be no less than this value. Defaults to 1e-12.

  • hvp_method (str, default: 'autograd' ) –

    Determines how Hessian-vector products are computed.

    • "batched_autograd" - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to "autograd". Faster than "autograd" but uses more memory.
    • "autograd" - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than "batched_autograd" but uses less memory.
    • "fd_forward" - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
    • "fd_central" - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

    Defaults to "autograd".

  • h (float, default: 0.001 ) –

    The step size for finite difference if hvp_method is "fd_forward" or "fd_central". Defaults to 1e-3.

  • n_samples (int, default: 1 ) –

    number of hessian-vector products with random vectors to evaluate each time when updating the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.

  • seed (int | None, default: None ) –

    seed for random vectors. Defaults to None.

  • inner (Chainable | None) –

    preconditioning is applied to the output of this module. Defaults to None.

Examples:

Using SophiaH:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SophiaH(),
    tz.m.LR(0.1)
)

SophiaH preconditioner can be applied to any other module by passing it to the inner argument. Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying SophiaH preconditioning to nesterov momentum (tz.m.NAG):

opt = tz.Optimizer(
    model.parameters(),
    tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
    tz.m.LR(0.1)
)
Source code in torchzero/modules/adaptive/sophia_h.py
class SophiaH(Transform):
    """SophiaH optimizer from https://arxiv.org/abs/2305.14342

    This is similar to Adam, but the second momentum is replaced by an exponential moving average of randomized hessian diagonal estimates, and the update is agressively clipped.

    Notes:
        - In most cases SophiaH should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply SophiaH preconditioning to another module's output.

        - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        beta1 (float, optional): first momentum. Defaults to 0.96.
        beta2 (float, optional): momentum for hessian diagonal estimate. Defaults to 0.99.
        update_freq (int, optional):
            frequency of updating hessian diagonal estimate via a hessian-vector product. Defaults to 10.
        precond_scale (float, optional):
            scale of the preconditioner. Defaults to 1.
        clip (float, optional):
            clips update to (-clip, clip). Defaults to 1.
        eps (float, optional):
            clips hessian diagonal esimate to be no less than this value. Defaults to 1e-12.
        hvp_method (str, optional):
            Determines how Hessian-vector products are computed.

            - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
            - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
            - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

            Defaults to ``"autograd"``.
        h (float, optional):
            The step size for finite difference if ``hvp_method`` is
            ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
        n_samples (int, optional):
            number of hessian-vector products with random vectors to evaluate each time when updating
            the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
        seed (int | None, optional): seed for random vectors. Defaults to None.
        inner (Chainable | None, optional): preconditioning is applied to the output of this module. Defaults to None.

    ### Examples:

    Using SophiaH:

    ```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SophiaH(),
        tz.m.LR(0.1)
    )
    ```

    SophiaH preconditioner can be applied to any other module by passing it to the ``inner`` argument.
    Turn off SophiaH's first momentum to get just the preconditioning. Here is an example of applying
    SophiaH preconditioning to nesterov momentum (``tz.m.NAG``):

    ```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.SophiaH(beta1=0, inner=tz.m.NAG(0.96)),
        tz.m.LR(0.1)
    )
    ```
    """
    def __init__(
        self,
        beta1: float = 0.96,
        beta2: float = 0.99,
        update_freq: int = 10,
        precond_scale: float = 1,
        clip: float = 1,
        eps: float = 1e-12,
        hvp_method: HVPMethod = 'autograd',
        distribution: Distributions = 'gaussian',
        h: float = 1e-3,
        n_samples = 1,
        zHz: bool = True,
        debias: bool = False,
        seed: int | None = None,

        exp_avg_tfm: Chainable | None = None,
        D_exp_avg_tfm: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['exp_avg_tfm'], defaults["D_exp_avg_tfm"]
        super().__init__(defaults)

        self.set_child('exp_avg', exp_avg_tfm)
        self.set_child('D_exp_avg', D_exp_avg_tfm)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        params = objective.params

        beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)

        exp_avg, D_exp_avg = unpack_states(states, params, 'exp_avg', 'D_exp_avg', cls=TensorList)

        step = self.increment_counter("step", start=0) # 0 on 1st update

        # ---------------------------- hutchinson hessian ---------------------------- #
        fs = settings[0]
        update_freq = fs['update_freq']

        if step % update_freq == 0:
            self.increment_counter("num_Ds", start=1)

            D, _ = objective.hutchinson_hessian(
                rgrad = None,
                at_x0 = True,
                n_samples = fs['n_samples'],
                distribution = fs['distribution'],
                hvp_method = fs['hvp_method'],
                h = fs['h'],
                zHz = fs["zHz"],
                generator = self.get_generator(params[0].device, fs["seed"]),
            )

            D_exp_avg.lerp_(D, weight=1-beta2)

        # --------------------------------- momentum --------------------------------- #
        tensors = objective.get_updates() # do this after hutchinson to not disturb autograd
        exp_avg.lerp_(tensors, 1-beta1)


    @torch.no_grad
    def apply_states(self, objective, states, settings):
        params = objective.params

        beta1, beta2, eps, precond_scale, clip = unpack_dicts(
            settings, 'beta1', 'beta2', 'eps', 'precond_scale', 'clip', cls=NumberList)

        exp_avg, D_exp_avg = unpack_states(states, params, 'exp_avg', 'D_exp_avg')

        # ---------------------------------- debias ---------------------------------- #
        if settings[0]["debias"]:
            bias_correction1 = 1.0 - (beta1 ** (self.global_state["step"] + 1))
            bias_correction2 = 1.0 - (beta2 ** self.global_state["num_Ds"])

            exp_avg = exp_avg / bias_correction1
            D_exp_avg = D_exp_avg / bias_correction2

        # -------------------------------- transforms -------------------------------- #
        exp_avg = TensorList(self.inner_step_tensors(
            "exp_avg", tensors=exp_avg, clone=True, objective=objective, must_exist=False))

        D_exp_avg = TensorList(self.inner_step_tensors(
            "D_exp_avg", tensors=D_exp_avg, clone=True, objective=objective, must_exist=False))

        # ------------------------------ compute update ------------------------------ #
        denom = D_exp_avg.lazy_mul(precond_scale).clip(min=eps)
        objective.updates = (exp_avg / denom).clip_(-clip, clip)
        return objective

Split

Bases: torchzero.core.module.Module

Apply true modules to all parameters filtered by filter, apply false modules to all other parameters.

Parameters:

  • filter (Filter, bool]) –

    a filter that selects tensors to be optimized by true. - tensor or iterable of tensors (e.g. encoder.parameters()). - function that takes in tensor and outputs a bool (e.g. lambda x: x.ndim >= 2). - a sequence of above (acts as "or", so returns true if any of them is true).

  • true (Chainable | None) –

    modules that are applied to tensors where filter is True.

  • false (Chainable | None) –

    modules that are applied to tensors where filter is False.

Examples:

Muon with Adam fallback using same hyperparams as https://github.com/KellerJordan/Muon

opt = tz.Optimizer(
    model.parameters(),
    tz.m.NAG(0.95),
    tz.m.Split(
        lambda p: p.ndim >= 2,
        true = tz.m.Orthogonalize(),
        false = [tz.m.Adam(0.9, 0.95), tz.m.Mul(1/66)],
    ),
    tz.m.LR(1e-2),
)
Source code in torchzero/modules/misc/split.py
class Split(Module):
    """Apply ``true`` modules to all parameters filtered by ``filter``, apply ``false`` modules to all other parameters.

    Args:
        filter (Filter, bool]):
            a filter that selects tensors to be optimized by ``true``.
            - tensor or iterable of tensors (e.g. ``encoder.parameters()``).
            - function that takes in tensor and outputs a bool (e.g. ``lambda x: x.ndim >= 2``).
            - a sequence of above (acts as "or", so returns true if any of them is true).

        true (Chainable | None): modules that are applied to tensors where ``filter`` is ``True``.
        false (Chainable | None): modules that are applied to tensors where ``filter`` is ``False``.

    ### Examples:

    Muon with Adam fallback using same hyperparams as https://github.com/KellerJordan/Muon

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NAG(0.95),
        tz.m.Split(
            lambda p: p.ndim >= 2,
            true = tz.m.Orthogonalize(),
            false = [tz.m.Adam(0.9, 0.95), tz.m.Mul(1/66)],
        ),
        tz.m.LR(1e-2),
    )
    ```
    """
    def __init__(self, filter: Filter, true: Chainable | None, false: Chainable | None):
        defaults = dict(filter=filter)
        super().__init__(defaults)

        if true is not None: self.set_child('true', true)
        if false is not None: self.set_child('false', false)

    def update(self, objective): raise RuntimeError
    def apply(self, objective): raise RuntimeError

    def step(self, objective):

        params = objective.params
        filter = _make_filter(self.settings[params[0]]['filter'])

        true_idxs = []
        false_idxs = []
        for i,p in enumerate(params):
            if filter(p): true_idxs.append(i)
            else: false_idxs.append(i)

        if 'true' in self.children and len(true_idxs) > 0:
            true = self.children['true']
            objective = _split(true, idxs=true_idxs, params=params, objective=objective)

        if 'false' in self.children and len(false_idxs) > 0:
            false = self.children['false']
            objective = _split(false, idxs=false_idxs, params=params, objective=objective)

        return objective

Sqrt

Bases: torchzero.core.transform.TensorTransform

Returns sqrt(input)

Source code in torchzero/modules/ops/unary.py
class Sqrt(TensorTransform):
    """Returns ``sqrt(input)``"""
    def __init__(self): super().__init__()
    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        torch._foreach_sqrt_(tensors)
        return tensors

SqrtEMASquared

Bases: torchzero.core.transform.TensorTransform

Maintains an exponential moving average of squared updates, outputs optionally debiased square root.

Parameters:

  • beta (float, default: 0.999 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • debiased (bool, default: False ) –

    whether to multiply the output by a debiasing term from the Adam method. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Methods:

  • SQRT_EMA_SQ_FN

    Updates exp_avg_sq_ with EMA of squared tensors and calculates it's square root,

Source code in torchzero/modules/ops/higher_level.py
class SqrtEMASquared(TensorTransform):
    """Maintains an exponential moving average of squared updates, outputs optionally debiased square root.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    SQRT_EMA_SQ_FN: staticmethod = staticmethod(sqrt_ema_sq_)
    def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2,):
        defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
        super().__init__(defaults)
        self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return self.SQRT_EMA_SQ_FN(
            TensorList(tensors),
            exp_avg_sq_=exp_avg_sq,
            beta=beta,
            max_exp_avg_sq_=max_exp_avg_sq,
            debiased=debiased,
            step=step,
            pow=pow,
        )

SQRT_EMA_SQ_FN

SQRT_EMA_SQ_FN(tensors: TensorList, exp_avg_sq_: TensorList, beta: float | NumberList, max_exp_avg_sq_: TensorList | None, debiased: bool, step: int, pow: float = 2, ema_sq_fn: Callable = ema_sq_)

Updates exp_avg_sq_ with EMA of squared tensors and calculates it's square root, with optional AMSGrad and debiasing.

Returns new tensors.

Source code in torchzero/modules/opt_utils.py
def sqrt_ema_sq_(
    tensors: TensorList,
    exp_avg_sq_: TensorList,
    beta: float | NumberList,
    max_exp_avg_sq_: TensorList | None,
    debiased: bool,
    step: int,
    pow: float = 2,
    ema_sq_fn: Callable = ema_sq_,
):
    """
    Updates `exp_avg_sq_` with EMA of squared `tensors` and calculates it's square root,
    with optional AMSGrad and debiasing.

    Returns new tensors.
    """
    exp_avg_sq_=ema_sq_fn(
        tensors=tensors,
        exp_avg_sq_=exp_avg_sq_,
        beta=beta,
        max_exp_avg_sq_=max_exp_avg_sq_,
        pow=pow,
    )

    sqrt_exp_avg_sq = root(exp_avg_sq_, pow, inplace=False)

    if debiased: sqrt_exp_avg_sq = debias_second_momentum(sqrt_exp_avg_sq, step=step, beta=beta, pow=pow, inplace=True)
    return sqrt_exp_avg_sq

SqrtHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class SqrtHomotopy(HomotopyBase):
    def __init__(self): super().__init__()
    def loss_transform(self, loss): return (loss+1e-12).sqrt()

SquareHomotopy

Bases: torchzero.modules.misc.homotopy.HomotopyBase

Source code in torchzero/modules/misc/homotopy.py
class SquareHomotopy(HomotopyBase):
    def __init__(self): super().__init__()
    def loss_transform(self, loss): return loss.square().copysign(loss)

StepSize

Bases: torchzero.core.transform.TensorTransform

this is exactly the same as LR, except the lr parameter can be renamed to any other name to avoid clashes

Source code in torchzero/modules/step_size/lr.py
class StepSize(TensorTransform):
    """this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
    def __init__(self, step_size: float, key = 'step_size'):
        defaults={"key": key, key: step_size}
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)

StrongWolfe

Bases: torchzero.modules.line_search.line_search.LineSearchBase

Interpolation line search satisfying Strong Wolfe condition.

Parameters:

  • c1 (float, default: 0.0001 ) –

    sufficient descent condition. Defaults to 1e-4.

  • c2 (float, default: 0.9 ) –

    strong curvature condition. For CG set to 0.1. Defaults to 0.9.

  • a_init (str, default: 'fixed' ) –

    strategy for initializing the initial step size guess. - "fixed" - uses a fixed value specified in init_value argument. - "first-order" - assumes first-order change in the function at iterate will be the same as that obtained at the previous step. - "quadratic" - interpolates quadratic to f(x_{-1}) and f_x. - "quadratic-clip" - same as quad, but uses min(1, 1.01*alpha) as described in Numerical Optimization. - "previous" - uses final step size found on previous iteration.

    For 2nd order methods it is usually best to leave at "fixed". For methods that do not produce well scaled search directions, e.g. conjugate gradient, "first-order" or "quadratic-clip" are recommended. Defaults to 'init'.

  • a_max (float, default: 1000000000000.0 ) –

    upper bound for the proposed step sizes. Defaults to 1e12.

  • init_value (float, default: 1 ) –

    initial step size. Used when a_init="fixed", and with other strategies as fallback value. Defaults to 1.

  • maxiter (int, default: 25 ) –

    maximum number of line search iterations. Defaults to 25.

  • maxzoom (int, default: 10 ) –

    maximum number of zoom iterations. Defaults to 10.

  • maxeval (int | None, default: None ) –

    maximum number of function evaluations. Defaults to None.

  • tol_change (float, default: 1e-09 ) –

    tolerance, terminates on small brackets. Defaults to 1e-9.

  • interpolation (str, default: 'cubic' ) –

    What type of interpolation to use. - "bisection" - uses the middle point. This is robust, especially if the objective function is non-smooth, however it may need more function evaluations. - "quadratic" - minimizes a quadratic model, generally outperformed by "cubic". - "cubic" - minimizes a cubic model - this is the most widely used interpolation strategy. - "polynomial" - fits a a polynomial to all points obtained during line search. - "polynomial2" - alternative polynomial fit, where if a point is outside of bounds, a lower degree polynomial is tried. This may have faster convergence than "cubic" and "polynomial".

    Defaults to 'cubic'.

  • adaptive (bool, default: True ) –

    if True, the initial step size will be halved when line search failed to find a good direction. When a good direction is found, initial step size is reset to the original value. Defaults to True.

  • fallback (bool, default: False ) –

    if True, when no point satisfied strong wolfe criteria, returns a point with value lower than initial value that doesn't satisfy the criteria. Defaults to False.

  • plus_minus (bool, default: False ) –

    if True, enables the plus-minus variant, where if curvature is negative, line search is performed in the opposite direction. Defaults to False.

Examples:

Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG. Since CG doesn't produce well scaled directions, initial alpha can be determined from function values by a_init="first-order".

opt = tz.Optimizer(
    model.parameters(),
    tz.m.PolakRibiere(),
    tz.m.StrongWolfe(c2=0.1, a_init="first-order")
)

LBFGS strong wolfe line search:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.LBFGS(),
    tz.m.StrongWolfe()
)

Source code in torchzero/modules/line_search/strong_wolfe.py
class StrongWolfe(LineSearchBase):
    """Interpolation line search satisfying Strong Wolfe condition.

    Args:
        c1 (float, optional): sufficient descent condition. Defaults to 1e-4.
        c2 (float, optional): strong curvature condition. For CG set to 0.1. Defaults to 0.9.
        a_init (str, optional):
            strategy for initializing the initial step size guess.
            - "fixed" - uses a fixed value specified in `init_value` argument.
            - "first-order" - assumes first-order change in the function at iterate will be the same as that obtained at the previous step.
            - "quadratic" - interpolates quadratic to f(x_{-1}) and f_x.
            - "quadratic-clip" - same as quad, but uses min(1, 1.01*alpha) as described in Numerical Optimization.
            - "previous" - uses final step size found on previous iteration.

            For 2nd order methods it is usually best to leave at "fixed".
            For methods that do not produce well scaled search directions, e.g. conjugate gradient,
            "first-order" or "quadratic-clip" are recommended. Defaults to 'init'.
        a_max (float, optional): upper bound for the proposed step sizes. Defaults to 1e12.
        init_value (float, optional):
            initial step size. Used when ``a_init``="fixed", and with other strategies as fallback value. Defaults to 1.
        maxiter (int, optional): maximum number of line search iterations. Defaults to 25.
        maxzoom (int, optional): maximum number of zoom iterations. Defaults to 10.
        maxeval (int | None, optional): maximum number of function evaluations. Defaults to None.
        tol_change (float, optional): tolerance, terminates on small brackets. Defaults to 1e-9.
        interpolation (str, optional):
            What type of interpolation to use.
            - "bisection" - uses the middle point. This is robust, especially if the objective function is non-smooth, however it may need more function evaluations.
            - "quadratic" - minimizes a quadratic model, generally outperformed by "cubic".
            - "cubic" - minimizes a cubic model - this is the most widely used interpolation strategy.
            - "polynomial" - fits a a polynomial to all points obtained during line search.
            - "polynomial2" - alternative polynomial fit, where if a point is outside of bounds, a lower degree polynomial is tried.
            This may have faster convergence than "cubic" and "polynomial".

            Defaults to 'cubic'.
        adaptive (bool, optional):
            if True, the initial step size will be halved when line search failed to find a good direction.
            When a good direction is found, initial step size is reset to the original value. Defaults to True.
        fallback (bool, optional):
            if True, when no point satisfied strong wolfe criteria,
            returns a point with value lower than initial value that doesn't satisfy the criteria. Defaults to False.
        plus_minus (bool, optional):
            if True, enables the plus-minus variant, where if curvature is negative, line search is performed
            in the opposite direction. Defaults to False.


    ## Examples:

    Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG. Since CG doesn't produce well scaled directions, initial alpha can be determined from function values by ``a_init="first-order"``.

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.PolakRibiere(),
        tz.m.StrongWolfe(c2=0.1, a_init="first-order")
    )
    ```

    LBFGS strong wolfe line search:
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.LBFGS(),
        tz.m.StrongWolfe()
    )
    ```

    """
    def __init__(
        self,
        c1: float = 1e-4,
        c2: float = 0.9,
        a_init: Literal['first-order', 'quadratic', 'quadratic-clip', 'previous', 'fixed'] = 'fixed',
        a_max: float = 1e12,
        init_value: float = 1,
        maxiter: int = 25,
        maxzoom: int = 10,
        maxeval: int | None = None,
        tol_change: float = 1e-9,
        interpolation: Literal["quadratic", "cubic", "bisection", "polynomial", 'polynomial2'] = 'cubic',
        adaptive = True,
        fallback:bool = False,
        plus_minus = False,
    ):
        defaults=dict(init_value=init_value,init=a_init,a_max=a_max,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom, fallback=fallback,
                      maxeval=maxeval, adaptive=adaptive, interpolation=interpolation, plus_minus=plus_minus, tol_change=tol_change)
        super().__init__(defaults=defaults)

        self.global_state['initial_scale'] = 1.0

    @torch.no_grad
    def search(self, update, var):
        self._g_prev = self._f_prev = None
        objective = self.make_objective_with_derivative(var=var)

        init_value, init, c1, c2, a_max, maxiter, maxzoom, maxeval, interpolation, adaptive, plus_minus, fallback, tol_change = itemgetter(
            'init_value', 'init', 'c1', 'c2', 'a_max', 'maxiter', 'maxzoom',
            'maxeval', 'interpolation', 'adaptive', 'plus_minus', 'fallback', 'tol_change')(self.defaults)

        dir = as_tensorlist(var.get_updates())
        grad_list = var.get_grads()

        g_0 = -sum(t.sum() for t in torch._foreach_mul(grad_list, dir))
        f_0 = var.get_loss(False)
        dir_norm = dir.global_vector_norm()

        inverted = False
        if plus_minus and g_0 > 0:
            original_objective = objective
            def inverted_objective(a):
                l, g_a = original_objective(-a)
                return l, -g_a
            objective = inverted_objective
            inverted = True

        # --------------------- determine initial step size guess -------------------- #
        init = init.lower().strip()

        a_init = init_value
        if init == 'fixed':
            pass # use init_value

        elif init == 'previous':
            if 'a_prev' in self.global_state:
                a_init = self.global_state['a_prev']

        elif init == 'first-order':
            if 'g_prev' in self.global_state and g_0 < -torch.finfo(dir[0].dtype).tiny * 2:
                a_prev = self.global_state['a_prev']
                g_prev = self.global_state['g_prev']
                if g_prev < 0:
                    a_init = a_prev * g_prev / g_0

        elif init in ('quadratic', 'quadratic-clip'):
            if 'f_prev' in self.global_state and g_0 < -torch.finfo(dir[0].dtype).tiny * 2:
                f_prev = self.global_state['f_prev']
                if f_0 < f_prev:
                    a_init = 2 * (f_0 - f_prev) / g_0
                    if init == 'quadratic-clip': a_init = min(1, 1.01*a_init)
        else:
            raise ValueError(init)

        if adaptive:
            a_init *= self.global_state.get('initial_scale', 1)

        strong_wolfe = _StrongWolfe(
            f=objective,
            f_0=f_0,
            g_0=g_0,
            d_norm=dir_norm,
            a_init=a_init,
            a_max=a_max,
            c1=c1,
            c2=c2,
            maxiter=maxiter,
            maxzoom=maxzoom,
            maxeval=maxeval,
            tol_change=tol_change,
            interpolation=interpolation,
        )

        a, f_a, g_a = strong_wolfe.search()
        if inverted and a is not None: a = -a
        if f_a is not None and (f_a > f_0 or not math.isfinite(f_a)): a = None

        if fallback:
            if a is None or a==0 or not math.isfinite(a):
                lowest = min(strong_wolfe.history.items(), key=lambda x: x[1][0])
                if lowest[1][0] < f_0:
                    a = lowest[0]
                    f_a, g_a = lowest[1]
                    if inverted: a = -a

        if a is not None and a != 0 and math.isfinite(a):
            self.global_state['initial_scale'] = 1
            self.global_state['a_prev'] = a
            self.global_state['f_prev'] = f_0
            self.global_state['g_prev'] = g_0
            return a

        # fail
        if adaptive:
            self.global_state['initial_scale'] = self.global_state.get('initial_scale', 1) * 0.5
            finfo = torch.finfo(dir[0].dtype)
            if self.global_state['initial_scale'] < finfo.tiny * 2:
                self.global_state['initial_scale'] = init_value * 2

        return 0

Sub

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Subtract other from tensors. other can be a number or a module.

If other is a module, this calculates :code:tensors - other(tensors)

Source code in torchzero/modules/ops/binary.py
class Sub(BinaryOperationBase):
    """Subtract ``other`` from tensors. ``other`` can be a number or a module.

    If ``other`` is a module, this calculates :code:`tensors - other(tensors)`
    """
    def __init__(self, other: Chainable | float, alpha: float = 1):
        defaults = dict(alpha=alpha)
        super().__init__(defaults, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.defaults['alpha'])
        else: torch._foreach_sub_(update, other, alpha=self.defaults['alpha'])
        return update

SubModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates input - other. input and other can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class SubModules(MultiOperationBase):
    """Calculates ``input - other``. ``input`` and ``other`` can be numbers or modules."""
    def __init__(self, input: Chainable | float, other: Chainable | float, alpha: float = 1):
        defaults = dict(alpha=alpha)
        super().__init__(defaults, input=input, other=other)

    @torch.no_grad
    def transform(self, objective: Objective, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
        alpha = self.defaults['alpha']

        if isinstance(input, (int,float)):
            assert isinstance(other, list)
            return input - TensorList(other).mul_(alpha)

        if isinstance(other, (int, float)): torch._foreach_sub_(input, other * alpha)
        else: torch._foreach_sub_(input, other, alpha=alpha)
        return input

SubspaceNewton

Bases: torchzero.core.transform.Transform

Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).

Parameters:

  • sketch_size (int, default: 100 ) –

    size of the random sketch. This many hessian-vector products will need to be evaluated each step.

  • sketch_type (str, default: 'common_directions' ) –
    • "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt (default).
    • "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
    • "rows" - samples random rows.
    • "topk" - samples top-rank rows with largest gradient magnitude.
    • "rademacher" - approximately orthonormal (if dimension is large) scaled random rademacher basis.
    • "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction.
  • damping (float, default: 0 ) –

    hessian damping (scale of identity matrix added to hessian). Defaults to 0.

  • hvp_method (str, default: 'batched_autograd' ) –

    How to compute hessian-matrix product: - "batched_autograd" - uses batched autograd - "autograd" - uses unbatched autograd - "forward" - uses finite difference with forward formula, performing 1 backward pass per Hvp. - "central" - uses finite difference with a more accurate central formula, performing 2 backward passes per Hvp.

    . Defaults to "batched_autograd".

  • h (float, default: 0.01 ) –

    finite difference step size. Defaults to 1e-2.

  • use_lstsq (bool, default: False ) –

    whether to use least squares to solve Hx=g. Defaults to False.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • H_tfm (Callable | None) –

    optional hessian transforms, takes in two arguments - (hessian, gradient).

    must return either a tuple: (hessian, is_inverted) with transformed hessian and a boolean value which must be True if transform inverted the hessian and False otherwise.

    Or it returns a single tensor which is used as the update.

    Defaults to None.

  • eigval_fn (Callable | None, default: None ) –

    optional eigenvalues transform, for example torch.abs or lambda L: torch.clip(L, min=1e-8). If this is specified, eigendecomposition will be used to invert the hessian.

  • seed (int | None, default: None ) –

    seed for random generator. Defaults to None.

  • inner (Chainable | None, default: None ) –

    preconditions output of this module. Defaults to None.

Examples

RSN with line search

opt = tz.Optimizer(
    model.parameters(),
    tz.m.RSN(),
    tz.m.Backtracking()
)

RSN with trust region

opt = tz.Optimizer(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.RSN()),
)

References
  1. Gower, Robert, et al. "RSN: randomized subspace Newton." Advances in Neural Information Processing Systems 32 (2019).
  2. Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
Source code in torchzero/modules/second_order/rsn.py
class SubspaceNewton(Transform):
    """Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).

    Args:
        sketch_size (int):
            size of the random sketch. This many hessian-vector products will need to be evaluated each step.
        sketch_type (str, optional):
            - "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt (default).
            - "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
            - "rows" - samples random rows.
            - "topk" - samples top-rank rows with largest gradient magnitude.
            - "rademacher" - approximately orthonormal (if dimension is large) scaled random rademacher basis.
            - "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction.
        damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
        hvp_method (str, optional):
            How to compute hessian-matrix product:
            - "batched_autograd" - uses batched autograd
            - "autograd" - uses unbatched autograd
            - "forward" - uses finite difference with forward formula, performing 1 backward pass per Hvp.
            - "central" - uses finite difference with a more accurate central formula, performing 2 backward passes per Hvp.

            . Defaults to "batched_autograd".
        h (float, optional): finite difference step size. Defaults to 1e-2.
        use_lstsq (bool, optional): whether to use least squares to solve ``Hx=g``. Defaults to False.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        H_tfm (Callable | None, optional):
            optional hessian transforms, takes in two arguments - `(hessian, gradient)`.

            must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
            which must be True if transform inverted the hessian and False otherwise.

            Or it returns a single tensor which is used as the update.

            Defaults to None.
        eigval_fn (Callable | None, optional):
            optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
            If this is specified, eigendecomposition will be used to invert the hessian.
        seed (int | None, optional): seed for random generator. Defaults to None.
        inner (Chainable | None, optional): preconditions output of this module. Defaults to None.

    ### Examples

    RSN with line search
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.RSN(),
        tz.m.Backtracking()
    )
    ```

    RSN with trust region
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.LevenbergMarquardt(tz.m.RSN()),
    )
    ```


    References:
        1. [Gower, Robert, et al. "RSN: randomized subspace Newton." Advances in Neural Information Processing Systems 32 (2019).](https://arxiv.org/abs/1905.10874)
        2. Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
    """

    def __init__(
        self,
        sketch_size: int = 100,
        sketch_type: Literal["orthonormal", "common_directions", "mixed", "rademacher", "rows", "topk"] = "common_directions",
        damping:float=0,
        eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
        eigv_tol: float | None = None,
        truncate: int | None = None,
        update_freq: int = 1,
        precompute_inverse: bool = False,
        use_lstsq: bool = False,
        hvp_method: HVPMethod = "batched_autograd",
        h: float = 1e-2,
        seed: int | None = None,
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults["update_freq"]
        super().__init__(defaults, update_freq=update_freq, inner=inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]
        params = objective.params
        generator = self.get_generator(params[0].device, fs["seed"])

        ndim = sum(p.numel() for p in params)

        device=params[0].device
        dtype=params[0].dtype

        # sample sketch matrix S: (ndim, sketch_size)
        sketch_size = min(fs["sketch_size"], ndim)
        sketch_type = fs["sketch_type"]
        hvp_method = fs["hvp_method"]

        if sketch_type == "rademacher":
            S = _rademacher_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)

        elif sketch_type == 'orthonormal':
            S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)

        elif sketch_type == "rows":
            S = _row_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)

        elif sketch_type == "topk":
            g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
            g = torch.cat([t.ravel() for t in g_list])
            S = _topk_rows(g, ndim, sketch_size, device=device, dtype=dtype, generator=generator)

        elif sketch_type == 'common_directions':
            # Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
            g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
            g = torch.cat([t.ravel() for t in g_list])

            # initialize directions deque
            if "directions" not in self.global_state:

                g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
                if g_norm < torch.finfo(g.dtype).tiny * 2:
                    g = torch.randn_like(g)
                    g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable

                self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
                S = self.global_state["directions"][0].unsqueeze(1)

            # add new steepest descent direction orthonormal to existing columns
            else:
                S = torch.stack(tuple(self.global_state["directions"]), dim=1)
                p = g - S @ (S.T @ g)
                p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
                if p_norm > torch.finfo(p.dtype).tiny * 2:
                    p = p / p_norm
                    self.global_state["directions"].append(p)
                    S = torch.cat([S, p.unsqueeze(1)], dim=1)

        elif sketch_type == "mixed":
            g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
            g = torch.cat([t.ravel() for t in g_list])

            # initialize state
            if "slow_ema" not in self.global_state:
                self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
                self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
                self.global_state["p_prev"] = torch.randn_like(g)

            # previous update direction
            p_cur = torch.cat([t.ravel() for t in params])
            prev_dir = p_cur - self.global_state["p_prev"]
            self.global_state["p_prev"] = p_cur

            # EMAs
            slow_ema = self.global_state["slow_ema"]
            fast_ema = self.global_state["fast_ema"]
            slow_ema.lerp_(g, 0.001)
            fast_ema.lerp_(g, 0.1)

            # form and orthogonalize sketching matrix
            S = torch.stack([g, slow_ema, fast_ema, prev_dir], dim=1)
            if sketch_size > 4:
                S_random = torch.randn(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator) / math.sqrt(ndim)
                S = torch.cat([S, S_random], dim=1)

            S = _qr_orthonormalize(S)

        else:
            raise ValueError(f'Unknown sketch_type {sketch_type}')

        # print(f'{S.shape = }')
        # I = torch.eye(S.size(1), device=S.device, dtype=S.dtype)
        # print(f'{torch.nn.functional.mse_loss(S.T @ S, I) = }')

        # form sketched hessian
        HS, _ = objective.hessian_matrix_product(S, rgrad=None, at_x0=True,
                                                 hvp_method=fs["hvp_method"], h=fs["h"])
        H_sketched = S.T @ HS

        # update state
        _newton_update_state_(
            state = self.global_state,
            H = H_sketched,
            damping = fs["damping"],
            eigval_fn = fs["eigval_fn"],
            eigv_tol = fs["eigv_tol"],
            truncate = fs["truncate"],
            precompute_inverse = fs["precompute_inverse"],
            use_lstsq = fs["use_lstsq"]
        )

        self.global_state["S"] = S

    def apply_states(self, objective, states, settings):
        updates = objective.get_updates()
        fs = settings[0]

        S = self.global_state["S"]
        b = torch.cat([t.ravel() for t in updates])
        b_proj = S.T @ b

        d_proj = _newton_solve(b=b_proj, state=self.global_state, use_lstsq=fs["use_lstsq"])

        d = S @ d_proj
        vec_to_tensors_(d, updates)
        return objective

    def get_H(self, objective=...):
        if "H" in self.global_state:
            H_sketched = self.global_state["H"]

        else:
            L = self.global_state["L"]
            Q = self.global_state["Q"]
            H_sketched = Q @ L.diag_embed() @ Q.mH

        S: torch.Tensor = self.global_state["S"]
        return Sketched(S, H_sketched)

Sum

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs sum of inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class Sum(ReduceOperationBase):
    """Outputs sum of ``inputs`` that can be modules or numbers."""
    USE_MEAN = False
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        sum = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_add_(sum, v)

        if self.USE_MEAN and len(sorted_inputs) > 1: torch._foreach_div_(sum, len(sorted_inputs))
        return sum

USE_MEAN class-attribute

USE_MEAN = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

SumOfSquares

Bases: torchzero.core.transform.Transform

Sets loss to be the sum of squares of values returned by the closure.

This is meant to be used to test least squares methods against ordinary minimization methods.

To use this, the closure should return a vector of values to minimize sum of squares of. Please add the backward argument, it will always be False but it is required.

Source code in torchzero/modules/least_squares/gn.py
class SumOfSquares(Transform):
    """Sets loss to be the sum of squares of values returned by the closure.

    This is meant to be used to test least squares methods against ordinary minimization methods.

    To use this, the closure should return a vector of values to minimize sum of squares of.
    Please add the ``backward`` argument, it will always be False but it is required.
    """
    def __init__(self):
        super().__init__()

    @torch.no_grad
    def update_states(self, objective, states, settings):
        closure = objective.closure

        if closure is not None:

            def sos_closure(backward=True):
                if backward:
                    objective.zero_grad()
                    with torch.enable_grad():
                        loss = closure(False)
                        loss = loss.pow(2).sum()
                        loss.backward()
                    return loss

                loss = closure(False)
                return loss.pow(2).sum()

            objective.closure = sos_closure

        if objective.loss is not None:
            objective.loss = objective.loss.pow(2).sum()

        if objective.loss_approx is not None:
            objective.loss_approx = objective.loss_approx.pow(2).sum()

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        return objective

Switch

Bases: torchzero.modules.misc.switch.Alternate

After steps steps switches to the next module.

Parameters:

  • steps (int | Iterable[int]) –

    Number of steps to perform with each module.

Examples:

Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Switch(
        [tz.m.Adam(), tz.m.LR(1e-3)],
        [tz.m.LBFGS(), tz.m.Backtracking()],
        [tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
        steps = (1000, 2000)
    )
)
Source code in torchzero/modules/misc/switch.py
class Switch(Alternate):
    """After ``steps`` steps switches to the next module.

    Args:
        steps (int | Iterable[int]): Number of steps to perform with each module.

    ### Examples:

    Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Switch(
            [tz.m.Adam(), tz.m.LR(1e-3)],
            [tz.m.LBFGS(), tz.m.Backtracking()],
            [tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
            steps = (1000, 2000)
        )
    )
    ```
    """

    LOOP = False
    def __init__(self, *modules: Chainable, steps: int | Iterable[int]):

        if isinstance(steps, Iterable):
            steps = list(steps)
            if len(steps) != len(modules) - 1:
                raise ValueError(f"steps must be the same length as modules minus 1, got {len(modules) = }, {len(steps) = }")

            steps.append(1)

        super().__init__(*modules, steps=steps)

LOOP class-attribute

LOOP = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

TerminateAfterNEvaluations

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateAfterNEvaluations(TerminationCriteriaBase):
    def __init__(self, maxevals:int):
        defaults = dict(maxevals=maxevals)
        super().__init__(defaults)

    def termination_criteria(self, objective):
        maxevals = self.defaults['maxevals']
        assert objective.modular is not None
        return objective.modular.num_evaluations >= maxevals

TerminateAfterNSeconds

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateAfterNSeconds(TerminationCriteriaBase):
    def __init__(self, seconds:float, sec_fn = time.time):
        defaults = dict(seconds=seconds, sec_fn=sec_fn)
        super().__init__(defaults)

    def termination_criteria(self, objective):
        max_seconds = self.defaults['seconds']
        sec_fn = self.defaults['sec_fn']

        if 'start' not in self.global_state:
            self.global_state['start'] = sec_fn()
            return False

        seconds_passed = sec_fn() - self.global_state['start']
        return seconds_passed >= max_seconds

TerminateAfterNSteps

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateAfterNSteps(TerminationCriteriaBase):
    def __init__(self, steps:int):
        defaults = dict(steps=steps)
        super().__init__(defaults)

    def termination_criteria(self, objective):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        max_steps = self.defaults['steps']
        return step >= max_steps

TerminateAll

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateAll(TerminationCriteriaBase):
    def __init__(self, *criteria: TerminationCriteriaBase):
        super().__init__()

        self.set_children_sequence(criteria)

    def termination_criteria(self, objective: Objective) -> bool:
        for c in self.get_children_sequence():
            if not cast(TerminationCriteriaBase, c).termination_criteria(objective): return False

        return True

TerminateAny

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateAny(TerminationCriteriaBase):
    def __init__(self, *criteria: TerminationCriteriaBase):
        super().__init__()

        self.set_children_sequence(criteria)

    def termination_criteria(self, objective: Objective) -> bool:
        for c in self.get_children_sequence():
            if cast(TerminationCriteriaBase, c).termination_criteria(objective): return True

        return False

TerminateByGradientNorm

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateByGradientNorm(TerminationCriteriaBase):
    def __init__(self, tol:float = 1e-8, n: int = 3, ord: Metrics = 2):
        defaults = dict(tol=tol, ord=ord)
        super().__init__(defaults, n=n)

    def termination_criteria(self, objective):
        tol = self.defaults['tol']
        ord = self.defaults['ord']
        return TensorList(objective.get_grads()).global_metric(ord) <= tol

TerminateByUpdateNorm

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

update is calculated as parameter difference

Source code in torchzero/modules/termination/termination.py
class TerminateByUpdateNorm(TerminationCriteriaBase):
    """update is calculated as parameter difference"""
    def __init__(self, tol:float = 1e-8, n: int = 3, ord: Metrics = 2):
        defaults = dict(tol=tol, ord=ord)
        super().__init__(defaults, n=n)

    def termination_criteria(self, objective):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        tol = self.defaults['tol']
        ord = self.defaults['ord']

        p_prev = self.get_state(objective.params, 'p_prev', cls=TensorList)
        if step == 0:
            p_prev.copy_(objective.params)
            return False

        should_terminate = (p_prev - objective.params).global_metric(ord) <= tol
        p_prev.copy_(objective.params)
        return should_terminate

TerminateNever

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateNever(TerminationCriteriaBase):
    def __init__(self):
        super().__init__()

    def termination_criteria(self, objective): return False

TerminateOnLossReached

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateOnLossReached(TerminationCriteriaBase):
    def __init__(self, value: float):
        defaults = dict(value=value)
        super().__init__(defaults)

    def termination_criteria(self, objective):
        value = self.defaults['value']
        return objective.get_loss(False) <= value

TerminateOnNoImprovement

Bases: torchzero.modules.termination.termination.TerminationCriteriaBase

Source code in torchzero/modules/termination/termination.py
class TerminateOnNoImprovement(TerminationCriteriaBase):
    def __init__(self, tol:float = 1e-8, n: int = 10):
        defaults = dict(tol=tol)
        super().__init__(defaults, n=n)

    def termination_criteria(self, objective):
        tol = self.defaults['tol']

        f = tofloat(objective.get_loss(False))
        if 'f_min' not in self.global_state:
            self.global_state['f_min'] = f
            return False

        f_min = self.global_state['f_min']
        d = f_min - f
        should_terminate = d <= tol
        self.global_state['f_min'] = min(f, f_min)
        return should_terminate

TerminationCriteriaBase

Bases: torchzero.core.module.Module

Source code in torchzero/modules/termination/termination.py
class TerminationCriteriaBase(Module):
    def __init__(self, defaults:dict | None = None, n: int = 1):
        if defaults is None: defaults = {}
        safe_dict_update_(defaults, {"_n": n})
        super().__init__(defaults)

    @abstractmethod
    def termination_criteria(self, objective: Objective) -> bool:
        ...

    @final
    def should_terminate(self, objective: Objective) -> bool:
        n_bad = self.global_state.get('_n_bad', 0)
        n = self.defaults['_n']

        if self.termination_criteria(objective):
            n_bad += 1
            if n_bad >= n:
                self.global_state['_n_bad'] = 0
                return True

        else:
            n_bad = 0

        self.global_state['_n_bad'] = n_bad
        return False


    def update(self, objective):
        objective.should_terminate = self.should_terminate(objective)
        if objective.should_terminate: self.global_state['_n_bad'] = 0

    def apply(self, objective):
        return objective

ThomasOptimalMethod

Bases: torchzero.modules.quasi_newton.quasi_newton._InverseHessianUpdateStrategyDefaults

Thomas's "optimal" Quasi-Newton method.

Note

a line search is recommended.

Warning

this uses at least O(N^2) memory.

Reference

Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975.

Source code in torchzero/modules/quasi_newton/quasi_newton.py
class ThomasOptimalMethod(_InverseHessianUpdateStrategyDefaults):
    """
    Thomas's "optimal" Quasi-Newton method.

    Note:
        a line search is recommended.

    Warning:
        this uses at least O(N^2) memory.

    Reference:
        Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975.
    """
    def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
        if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
        H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y)
        return H

    def reset_P(self, P, s, y, inverse, init_scale, state):
        super().reset_P(P, s, y, inverse, init_scale, state)
        for st in self.state.values():
            st.pop("R", None)

Threshold

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs tensors thresholded such that values above threshold are set to value.

Source code in torchzero/modules/ops/binary.py
class Threshold(BinaryOperationBase):
    """Outputs tensors thresholded such that values above ``threshold`` are set to ``value``."""
    def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
        defaults = dict(update_above=update_above)
        super().__init__(defaults, threshold=threshold, value=value)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
        update_above = self.defaults['update_above']
        update = TensorList(update)
        if update_above:
            if isinstance(value, list): return update.where(update>threshold, value)
            return update.masked_fill_(update<=threshold, value)

        if isinstance(value, list): return update.where(update<threshold, value)
        return update.masked_fill_(update>=threshold, value)

To

Bases: torchzero.modules.projections.projection.ProjectionBase

Cast modules to specified device and dtype

Source code in torchzero/modules/projections/cast.py
class To(ProjectionBase):
    """Cast modules to specified device and dtype"""
    def __init__(self, modules: Chainable, dtype: torch.dtype | None, device:torch.types.Device | None = None):
        defaults = dict(dtype=dtype, device=device)
        super().__init__(modules, project_update=True, project_params=True, project_grad=True, defaults=defaults)

    @torch.no_grad
    def project(self, tensors, params, grads, loss, states, settings, current):
        casted = []
        for tensor, state, setting in zip(tensors,states, settings):
            state['dtype'] = tensor.dtype
            state['device'] = tensor.device
            tensor = tensor.to(dtype=setting['dtype'], device=setting['device'])
            casted.append(tensor)
        return casted

    @torch.no_grad
    def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
        uncasted = []
        for tensor, state in zip(projected_tensors, states):
            tensor = tensor.to(dtype=state['dtype'], device=state['device'])
            uncasted.append(tensor)
        return uncasted

TrustCG

Bases: torchzero.modules.trust_region.trust_region.TrustRegionBase

Trust region via Steihaug-Toint Conjugate Gradient method.

.. note::

If you wish to use exact hessian, use the matrix-free :code:`tz.m.NewtonCGSteihaug`
which only uses hessian-vector products. While passing ``tz.m.Newton`` to this
is possible, it is usually less efficient.

Parameters:

  • hess_module (Module | None) –

    A module that maintains a hessian approximation (not hessian inverse!). This includes all full-matrix quasi-newton methods, tz.m.Newton and tz.m.GaussNewton. When using quasi-newton methods, set inverse=False when constructing them.

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. When :code:hess_module is GaussNewton, this can be set to 0. Defaults to 0.15.

  • nplus (float, default: 3.5 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.99 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.0001 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • reg (int, default: 0 ) –

    regularization parameter for conjugate gradient. Defaults to 0.

  • max_attempts (max_attempts, default: 10 ) –

    maximum number of trust region size size reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • boundary_tol (float | None, default: 1e-06 ) –

    The trust region only increases when suggested step's norm is at least (1-boundary_tol)*trust_region. This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.

  • prefer_exact (bool, default: True ) –

    when exact solution can be easily calculated without CG (e.g. hessian is stored as scaled identity), uses the exact solution. If False, always uses CG. Defaults to True.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to output of thise module. Defaults to None.

Examples:

Trust-SR1

.. code-block:: python

opt = tz.Optimizer(
    model.parameters(),
    tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
)
Source code in torchzero/modules/trust_region/trust_cg.py
class TrustCG(TrustRegionBase):
    """Trust region via Steihaug-Toint Conjugate Gradient method.

    .. note::

        If you wish to use exact hessian, use the matrix-free :code:`tz.m.NewtonCGSteihaug`
        which only uses hessian-vector products. While passing ``tz.m.Newton`` to this
        is possible, it is usually less efficient.

    Args:
        hess_module (Module | None, optional):
            A module that maintains a hessian approximation (not hessian inverse!).
            This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
            When using quasi-newton methods, set `inverse=False` when constructing them.
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted.
            When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        reg (int, optional): regularization parameter for conjugate gradient. Defaults to 0.
        max_attempts (max_attempts, optional):
            maximum number of trust region size size reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        boundary_tol (float | None, optional):
            The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
            This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
        prefer_exact (bool, optional):
            when exact solution can be easily calculated without CG (e.g. hessian is stored as scaled identity),
            uses the exact solution. If False, always uses CG. Defaults to True.
        inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.

    Examples:
        Trust-SR1

        .. code-block:: python

            opt = tz.Optimizer(
                model.parameters(),
                tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
            )
    """
    def __init__(
        self,
        hess_module: Chainable,
        eta: float= 0.0,
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        boundary_tol: float | None = 1e-6, # tuned
        init: float = 1,
        max_attempts: int = 10,
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
        reg: float = 0,
        maxiter: int | None = None,
        miniter: int = 1,
        cg_tol: float = 1e-8,
        prefer_exact: bool = True,
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol, maxiter=maxiter, miniter=miniter)
        super().__init__(
            defaults=defaults,
            hess_module=hess_module,
            eta=eta,
            nplus=nplus,
            nminus=nminus,
            rho_good=rho_good,
            rho_bad=rho_bad,
            boundary_tol=boundary_tol,
            init=init,
            max_attempts=max_attempts,
            radius_strategy=radius_strategy,
            update_freq=update_freq,
            inner=inner,

            radius_fn=torch.linalg.vector_norm,
        )

    def trust_solve(self, f, g, H, radius, params, closure, settings):
        if settings['prefer_exact'] and isinstance(H, linear_operator.ScaledIdentity):
            return H.solve_bounded(g, radius)

        x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], maxiter=settings["maxiter"], miniter=settings["miniter"], tol=settings["cg_tol"])
        return x

TrustRegionBase

Bases: torchzero.core.module.Module, abc.ABC

Methods:

  • trust_region_apply

    Solves the trust region subproblem and outputs Objective with the solution direction.

  • trust_region_update

    updates the state of this module after H or B have been updated, if necessary

  • trust_solve

    Solve Hx=g with a trust region penalty/bound defined by radius

Source code in torchzero/modules/trust_region/trust_region.py
class TrustRegionBase(Module, ABC):
    def __init__(
        self,
        defaults: dict | None,
        hess_module: Chainable,
        # suggested default values:
        # Gould, Nicholas IM, et al. "Sensitivity of trust-region algorithms to their parameters." 4OR 3.3 (2005): 227-241.
        # which I found from https://github.com/patrick-kidger/optimistix/blob/c1dad7e75fc35bd5a4977ac3a872991e51e83d2c/optimistix/_solver/trust_region.py#L113-200
        eta: float, # 0.0
        nplus: float, # 3.5
        nminus: float, # 0.25
        rho_good: float, # 0.99
        rho_bad: float, # 1e-4
        boundary_tol: float | None, # None or 1e-1
        init: float, # 1
        max_attempts: int, # 10
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS, # "default"
        radius_fn: Callable | None, # torch.linalg.vector_norm
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        if isinstance(radius_strategy, str): radius_strategy = _RADIUS_STRATEGIES[radius_strategy]
        if defaults is None: defaults = {}

        safe_dict_update_(
            defaults,
            dict(eta=eta, nplus=nplus, nminus=nminus, rho_good=rho_good, rho_bad=rho_bad, init=init,
                 update_freq=update_freq, max_attempts=max_attempts, radius_strategy=radius_strategy,
                 boundary_tol=boundary_tol)
        )

        super().__init__(defaults)

        self._radius_fn = radius_fn
        self.set_child('hess_module', hess_module)

        if inner is not None:
            self.set_child('inner', inner)

    @abstractmethod
    def trust_solve(
        self,
        f: float,
        g: torch.Tensor,
        H: LinearOperator,
        radius: float,
        params: list[torch.Tensor],
        closure: Callable,
        settings: Mapping[str, Any],
    ) -> torch.Tensor:
        """Solve Hx=g with a trust region penalty/bound defined by `radius`"""
        ... # pylint:disable=unnecessary-ellipsis

    def trust_region_update(self, objective: Objective, H: LinearOperator | None) -> None:
        """updates the state of this module after H or B have been updated, if necessary"""

    def trust_region_apply(self, objective: Objective, tensors:list[torch.Tensor], H: LinearOperator | None) -> Objective:
        """Solves the trust region subproblem and outputs ``Objective`` with the solution direction."""
        assert H is not None

        params = TensorList(objective.params)
        settings = self.settings[params[0]]
        g = _flatten_tensors(tensors)

        max_attempts = settings['max_attempts']

        # loss at x_0
        loss = objective.loss
        closure = objective.closure
        if closure is None: raise RuntimeError("Trust region requires closure")
        if loss is None: loss = objective.get_loss(False)
        loss = tofloat(loss)

        # trust region step and update
        success = False
        d = None
        while not success:
            max_attempts -= 1
            if max_attempts < 0: break

            trust_radius = self.global_state.get('trust_radius', settings['init'])

            # solve Hx=g
            d = self.trust_solve(f=loss, g=g, H=H, radius=trust_radius, params=params, closure=closure, settings=settings)

            # update trust radius
            radius_strategy: _RadiusStrategy = settings['radius_strategy']
            self.global_state["trust_radius"], success = radius_strategy(
                params=params,
                closure=closure,
                d=d,
                f=loss,
                g=g,
                H=H,
                trust_radius=trust_radius,

                eta=settings["eta"],
                nplus=settings["nplus"],
                nminus=settings["nminus"],
                rho_good=settings["rho_good"],
                rho_bad=settings["rho_bad"],
                boundary_tol=settings["boundary_tol"],
                init=settings["init"],

                state=self.global_state,
                settings=settings,
                radius_fn=self._radius_fn,
            )

        assert d is not None
        if success: objective.updates = vec_to_tensors(d, params)
        else: objective.updates = params.zeros_like()

        return objective


    @final
    @torch.no_grad
    def update(self, objective):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        if step % self.defaults["update_freq"] == 0:

            hessian_module = self.children['hess_module']
            hessian_module.update(objective)
            H = hessian_module.get_H(objective)
            self.global_state["H"] = H

            self.trust_region_update(objective, H=H)


    @final
    @torch.no_grad
    def apply(self, objective):
        H = self.global_state.get('H', None)

        # -------------------------------- inner step -------------------------------- #
        objective = self.inner_step("inner", objective, must_exist=False)

        # ----------------------------------- apply ---------------------------------- #
        return self.trust_region_apply(objective=objective, tensors=objective.get_updates(), H=H)

trust_region_apply

trust_region_apply(objective: Objective, tensors: list[Tensor], H: LinearOperator | None) -> Objective

Solves the trust region subproblem and outputs Objective with the solution direction.

Source code in torchzero/modules/trust_region/trust_region.py
def trust_region_apply(self, objective: Objective, tensors:list[torch.Tensor], H: LinearOperator | None) -> Objective:
    """Solves the trust region subproblem and outputs ``Objective`` with the solution direction."""
    assert H is not None

    params = TensorList(objective.params)
    settings = self.settings[params[0]]
    g = _flatten_tensors(tensors)

    max_attempts = settings['max_attempts']

    # loss at x_0
    loss = objective.loss
    closure = objective.closure
    if closure is None: raise RuntimeError("Trust region requires closure")
    if loss is None: loss = objective.get_loss(False)
    loss = tofloat(loss)

    # trust region step and update
    success = False
    d = None
    while not success:
        max_attempts -= 1
        if max_attempts < 0: break

        trust_radius = self.global_state.get('trust_radius', settings['init'])

        # solve Hx=g
        d = self.trust_solve(f=loss, g=g, H=H, radius=trust_radius, params=params, closure=closure, settings=settings)

        # update trust radius
        radius_strategy: _RadiusStrategy = settings['radius_strategy']
        self.global_state["trust_radius"], success = radius_strategy(
            params=params,
            closure=closure,
            d=d,
            f=loss,
            g=g,
            H=H,
            trust_radius=trust_radius,

            eta=settings["eta"],
            nplus=settings["nplus"],
            nminus=settings["nminus"],
            rho_good=settings["rho_good"],
            rho_bad=settings["rho_bad"],
            boundary_tol=settings["boundary_tol"],
            init=settings["init"],

            state=self.global_state,
            settings=settings,
            radius_fn=self._radius_fn,
        )

    assert d is not None
    if success: objective.updates = vec_to_tensors(d, params)
    else: objective.updates = params.zeros_like()

    return objective

trust_region_update

trust_region_update(objective: Objective, H: LinearOperator | None) -> None

updates the state of this module after H or B have been updated, if necessary

Source code in torchzero/modules/trust_region/trust_region.py
def trust_region_update(self, objective: Objective, H: LinearOperator | None) -> None:
    """updates the state of this module after H or B have been updated, if necessary"""

trust_solve

trust_solve(f: float, g: Tensor, H: LinearOperator, radius: float, params: list[Tensor], closure: Callable, settings: Mapping[str, Any]) -> Tensor

Solve Hx=g with a trust region penalty/bound defined by radius

Source code in torchzero/modules/trust_region/trust_region.py
@abstractmethod
def trust_solve(
    self,
    f: float,
    g: torch.Tensor,
    H: LinearOperator,
    radius: float,
    params: list[torch.Tensor],
    closure: Callable,
    settings: Mapping[str, Any],
) -> torch.Tensor:
    """Solve Hx=g with a trust region penalty/bound defined by `radius`"""
    ... # pylint:disable=unnecessary-ellipsis

TwoPointNewton

Bases: torchzero.modules.second_order.multipoint.HigherOrderMethodBase

two-point Newton method with frozen derivative with third order convergence.

Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73.

Source code in torchzero/modules/second_order/multipoint.py
class TwoPointNewton(HigherOrderMethodBase):
    """two-point Newton method with frozen derivative with third order convergence.

    Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73."""
    def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
        defaults=dict(lstsq=lstsq)
        super().__init__(defaults=defaults, derivatives_method=derivatives_method)

    @torch.no_grad
    def one_iteration(self, x, evaluate, objective, setting):
        def f(x): return evaluate(x, 1)[1]
        def f_j(x): return evaluate(x, 2)[1:]
        x_star = two_point_newton(x, f, f_j, setting['lstsq'])
        return x - x_star

UnaryLambda

Bases: torchzero.core.transform.TensorTransform

Applies fn to input tensors.

fn must accept and return a list of tensors.

Source code in torchzero/modules/ops/unary.py
class UnaryLambda(TensorTransform):
    """Applies ``fn`` to input tensors.

    ``fn`` must accept and return a list of tensors.
    """
    def __init__(self, fn):
        defaults = dict(fn=fn)
        super().__init__(defaults=defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        return settings[0]['fn'](tensors)

UnaryParameterwiseLambda

Bases: torchzero.core.transform.TensorTransform

Applies fn to each input tensor.

fn must accept and return a tensor.

Source code in torchzero/modules/ops/unary.py
class UnaryParameterwiseLambda(TensorTransform):
    """Applies ``fn`` to each input tensor.

    ``fn`` must accept and return a tensor.
    """
    def __init__(self, fn):
        defaults = dict(fn=fn)
        super().__init__(defaults=defaults)

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        return setting['fn'](tensor)

Uniform

Bases: torchzero.core.module.Module

Outputs tensors filled with random numbers from uniform distribution between low and high.

Source code in torchzero/modules/ops/utility.py
class Uniform(Module):
    """Outputs tensors filled with random numbers from uniform distribution between ``low`` and ``high``."""
    def __init__(self, low: float, high: float):
        defaults = dict(low=low, high=high)
        super().__init__(defaults)

    @torch.no_grad
    def apply(self, objective):
        low,high = self.get_settings(objective.params, 'low','high')
        objective.updates = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(objective.params, low, high)]
        return objective

UpdateGradientSignConsistency

Bases: torchzero.core.transform.TensorTransform

Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.

Parameters:

  • normalize (bool, default: False ) –

    renormalize update after masking. Defaults to False.

  • eps (float, default: 1e-06 ) –

    epsilon for normalization. Defaults to 1e-6.

Source code in torchzero/modules/momentum/cautious.py
class UpdateGradientSignConsistency(TensorTransform):
    """Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.

    Args:
        normalize (bool, optional):
            renormalize update after masking. Defaults to False.
        eps (float, optional): epsilon for normalization. Defaults to 1e-6.
    """
    def __init__(self, normalize = False, eps=1e-6):

        defaults = dict(normalize=normalize, eps=eps)
        super().__init__(defaults, uses_grad=True)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        normalize, eps = itemgetter('normalize', 'eps')(settings[0])

        mask = (TensorList(tensors).mul_(grads)).gt_(0)
        if normalize: mask = mask / mask.global_mean().clip(min = eps) # pyright: ignore[reportOperatorIssue]

        return mask

UpdateSign

Bases: torchzero.core.transform.TensorTransform

Outputs gradient with sign copied from the update.

Source code in torchzero/modules/misc/misc.py
class UpdateSign(TensorTransform):
    """Outputs gradient with sign copied from the update."""
    def __init__(self):
        super().__init__(uses_grad=True)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        assert grads is not None
        return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place

UpdateToNone

Bases: torchzero.core.module.Module

Sets update attribute to None on var.

Source code in torchzero/modules/ops/utility.py
class UpdateToNone(Module):
    """Sets ``update`` attribute to None on ``var``."""
    def __init__(self): super().__init__()
    def apply(self, objective):
        objective.updates = None
        return objective

VectorProjection

Bases: torchzero.modules.projections.projection.ProjectionBase

projection that concatenates all parameters into a vector

Source code in torchzero/modules/projections/projection.py
class VectorProjection(ProjectionBase):
    """projection that concatenates all parameters into a vector"""
    def __init__(
        self,
        modules: Chainable,
        project_update=True,
        project_params=True,
        project_grad=True,
    ):
        super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)

    @torch.no_grad
    def project(self, tensors, params, grads, loss, states, settings, current):
        return [torch.cat([t.ravel() for t in tensors])]

    @torch.no_grad
    def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
        return vec_to_tensors(vec=projected_tensors[0], reference=params)

ViewAsReal

Bases: torchzero.modules.projections.projection.ProjectionBase

View complex tensors as real tensors. Doesn't affect tensors that are already.

Source code in torchzero/modules/projections/cast.py
class ViewAsReal(ProjectionBase):
    """View complex tensors as real tensors. Doesn't affect tensors that are already."""
    def __init__(self, modules: Chainable):
        super().__init__(modules, project_update=True, project_params=True, project_grad=True, defaults=None)

    @torch.no_grad
    def project(self, tensors, params, grads, loss, states, settings, current):
        views = []
        for tensor, state in zip(tensors,states):
            is_complex = torch.is_complex(tensor)
            state['is_complex'] = is_complex
            if is_complex: tensor = torch.view_as_real(tensor)
            views.append(tensor)
        return views

    @torch.no_grad
    def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
        un_views = []
        for tensor, state in zip(projected_tensors, states):
            if state['is_complex']: tensor = torch.view_as_complex(tensor)
            un_views.append(tensor)
        return un_views

Warmup

Bases: torchzero.core.transform.TensorTransform

Learning rate warmup, linearly increases learning rate multiplier from start_lr to end_lr over steps steps.

Parameters:

  • steps (int, default: 100 ) –

    number of steps to perform warmup for. Defaults to 100.

  • start_lr (_type_, default: 1e-05 ) –

    initial learning rate multiplier on first step. Defaults to 1e-5.

  • end_lr (float, default: 1 ) –

    learning rate multiplier at the end and after warmup. Defaults to 1.

Example

Adam with 1000 steps warmup

.. code-block:: python

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adam(),
    tz.m.LR(1e-2),
    tz.m.Warmup(steps=1000)
)
Source code in torchzero/modules/step_size/lr.py
class Warmup(TensorTransform):
    """Learning rate warmup, linearly increases learning rate multiplier from ``start_lr`` to ``end_lr`` over ``steps`` steps.

    Args:
        steps (int, optional): number of steps to perform warmup for. Defaults to 100.
        start_lr (_type_, optional): initial learning rate multiplier on first step. Defaults to 1e-5.
        end_lr (float, optional): learning rate multiplier at the end and after warmup. Defaults to 1.

    Example:
        Adam with 1000 steps warmup

        .. code-block:: python

            opt = tz.Optimizer(
                model.parameters(),
                tz.m.Adam(),
                tz.m.LR(1e-2),
                tz.m.Warmup(steps=1000)
            )

    """
    def __init__(self, steps = 100, start_lr = 1e-5, end_lr:float = 1):
        defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
        num_steps = settings[0]['steps']
        step = self.global_state.get('step', 0)

        tensors = lazy_lr(
            TensorList(tensors),
            lr=_warmup_lr(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
            inplace=True
        )
        self.global_state['step'] = step + 1
        return tensors

WarmupNormClip

Bases: torchzero.core.transform.TensorTransform

Warmup via clipping of the update norm.

Parameters:

  • start_norm (_type_, default: 1e-05 ) –

    maximal norm on the first step. Defaults to 1e-5.

  • end_norm (float, default: 1 ) –

    maximal norm on the last step. After that, norm clipping is disabled. Defaults to 1.

  • steps (int, default: 100 ) –

    number of steps to perform warmup for. Defaults to 100.

Example

Adam with 1000 steps norm clip warmup

.. code-block:: python

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adam(),
    tz.m.WarmupNormClip(steps=1000)
    tz.m.LR(1e-2),
)
Source code in torchzero/modules/step_size/lr.py
class WarmupNormClip(TensorTransform):
    """Warmup via clipping of the update norm.

    Args:
        start_norm (_type_, optional): maximal norm on the first step. Defaults to 1e-5.
        end_norm (float, optional): maximal norm on the last step. After that, norm clipping is disabled. Defaults to 1.
        steps (int, optional): number of steps to perform warmup for. Defaults to 100.

    Example:
        Adam with 1000 steps norm clip warmup

        .. code-block:: python

            opt = tz.Optimizer(
                model.parameters(),
                tz.m.Adam(),
                tz.m.WarmupNormClip(steps=1000)
                tz.m.LR(1e-2),
            )
    """
    def __init__(self, steps = 100, start_norm = 1e-5, end_norm:float = 1):
        defaults = dict(start_norm=start_norm,end_norm=end_norm, steps=steps)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        start_norm, end_norm = unpack_dicts(settings, 'start_norm', 'end_norm', cls = NumberList)
        num_steps = settings[0]['steps']
        step = self.global_state.get('step', 0)
        if step > num_steps: return tensors

        tensors = TensorList(tensors)
        norm = tensors.global_vector_norm()
        current_max_norm = _warmup_lr(step, start_norm[0], end_norm[0], num_steps)
        if norm > current_max_norm:
            tensors.mul_(current_max_norm / norm)

        self.global_state['step'] = step + 1
        return tensors

WeightDecay

Bases: torchzero.core.transform.TensorTransform

Weight decay.

Parameters:

  • weight_decay (float) –

    weight decay scale.

  • ord (int, default: 2 ) –

    order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.

  • target (Target) –

    what to set on var. Defaults to 'update'.

Examples:

Adam with non-decoupled weight decay

opt = tz.Optimizer(
    model.parameters(),
    tz.m.WeightDecay(1e-3),
    tz.m.Adam(),
    tz.m.LR(1e-3)
)

Adam with decoupled weight decay that still scales with learning rate

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adam(),
    tz.m.WeightDecay(1e-3),
    tz.m.LR(1e-3)
)

Adam with fully decoupled weight decay that doesn't scale with learning rate

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Adam(),
    tz.m.LR(1e-3),
    tz.m.WeightDecay(1e-6)
)

Source code in torchzero/modules/weight_decay/weight_decay.py
class WeightDecay(TensorTransform):
    """Weight decay.

    Args:
        weight_decay (float): weight decay scale.
        ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
        target (Target, optional): what to set on var. Defaults to 'update'.

    ### Examples:

    Adam with non-decoupled weight decay
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.WeightDecay(1e-3),
        tz.m.Adam(),
        tz.m.LR(1e-3)
    )
    ```

    Adam with decoupled weight decay that still scales with learning rate
    ```python

    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Adam(),
        tz.m.WeightDecay(1e-3),
        tz.m.LR(1e-3)
    )
    ```

    Adam with fully decoupled weight decay that doesn't scale with learning rate
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Adam(),
        tz.m.LR(1e-3),
        tz.m.WeightDecay(1e-6)
    )
    ```

    """
    def __init__(self, weight_decay: float, ord: int = 2):

        defaults = dict(weight_decay=weight_decay, ord=ord)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        weight_decay = NumberList(s['weight_decay'] for s in settings)
        ord = settings[0]['ord']

        return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)

WeightDropout

Bases: torchzero.core.module.Module

Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.

Dropout can be disabled for a parameter by setting use_dropout=False in corresponding parameter group.

Parameters:

  • p (float, default: 0.5 ) –

    probability that any weight is replaced with 0. Defaults to 0.5.

Source code in torchzero/modules/misc/regularization.py
class WeightDropout(Module):
    """
    Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.

    Dropout can be disabled for a parameter by setting ``use_dropout=False`` in corresponding parameter group.

    Args:
        p (float, optional): probability that any weight is replaced with 0. Defaults to 0.5.
    """
    def __init__(self, p: float = 0.5):
        defaults = dict(p=p, use_dropout=True)
        super().__init__(defaults)

    @torch.no_grad
    def update(self, objective):
        closure = objective.closure
        if closure is None: raise RuntimeError('WeightDropout requires closure')
        params = TensorList(objective.params)
        p = NumberList(self.settings[p]['p'] for p in params)

        # create masks
        mask = []
        for p in params:
            prob = self.settings[p]['p']
            use_dropout = self.settings[p]['use_dropout']
            if use_dropout: mask.append(_bernoulli_like(p, prob))
            else: mask.append(torch.ones_like(p))

        # create a closure that evaluates masked parameters
        @torch.no_grad
        def dropout_closure(backward=True):
            orig_params = params.clone()
            params.mul_(mask)
            if backward:
                with torch.enable_grad(): loss = closure()
            else:
                loss = closure(False)
            params.copy_(orig_params)
            return loss

        objective.closure = dropout_closure

WeightedAveraging

Bases: torchzero.core.transform.TensorTransform

Weighted average of past len(weights) updates.

Parameters:

  • weights (Sequence[float]) –

    a sequence of weights from oldest to newest.

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/momentum/averaging.py
class WeightedAveraging(TensorTransform):
    """Weighted average of past ``len(weights)`` updates.

    Args:
        weights (Sequence[float]): a sequence of weights from oldest to newest.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, weights: Sequence[float] | torch.Tensor | Any):
        defaults = dict(weights = tolist(weights))
        super().__init__(defaults=defaults)

        self.add_projected_keys("grad", "history")

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        weights = setting['weights']

        if 'history' not in state:
            state['history'] = deque(maxlen=len(weights))

        history = state['history']
        history.append(tensor)
        if len(history) != len(weights):
            weights = weights[-len(history):]

        average = None
        for i, (h, w) in enumerate(zip(history, weights)):
            if average is None: average = h * (w / len(history))
            else:
                if w == 0: continue
                average += h * (w / len(history))

        assert average is not None
        return average

WeightedMean

Bases: torchzero.modules.ops.reduce.WeightedSum

Outputs weighted mean of inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class WeightedMean(WeightedSum):
    """Outputs weighted mean of ``inputs`` that can be modules or numbers."""
    USE_MEAN = True

USE_MEAN class-attribute

USE_MEAN = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

WeightedSum

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs a weighted sum of inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class WeightedSum(ReduceOperationBase):
    """Outputs a weighted sum of ``inputs`` that can be modules or numbers."""
    USE_MEAN = False
    def __init__(self, *inputs: Chainable | float, weights: Iterable[float]):
        weights = list(weights)
        if len(inputs) != len(weights):
            raise ValueError(f'Number of inputs {len(inputs)} must match number of weights {len(weights)}')
        defaults = dict(weights=weights)
        super().__init__(defaults=defaults, *inputs)

    @torch.no_grad
    def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        weights = self.defaults['weights']
        sum = cast(list, sorted_inputs[0])
        torch._foreach_mul_(sum, weights[0])
        if len(sorted_inputs) > 1:
            for v, w in zip(sorted_inputs[1:], weights[1:]):
                if isinstance(v, (int, float)): torch._foreach_add_(sum, v*w)
                else: torch._foreach_add_(sum, v, alpha=w)

        if self.USE_MEAN and len(sorted_inputs) > 1: torch._foreach_div_(sum, len(sorted_inputs))
        return sum

USE_MEAN class-attribute

USE_MEAN = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

Wrap

Bases: torchzero.core.module.Module

Wraps a pytorch optimizer to use it as a module.

Note

Custom param groups are supported only by set_param_groups, settings passed to Optimizer will be applied to all parameters.

Parameters:

  • opt_fn (Callable[..., Optimizer] | Optimizer) –

    function that takes in parameters and returns the optimizer, for example torch.optim.Adam or lambda parameters: torch.optim.Adam(parameters, lr=1e-3)

  • *args
  • **kwargs

    Extra args to be passed to opt_fn. The function is called as opt_fn(parameters, *args, **kwargs).

  • use_param_groups (bool, default: True ) –

    Whether to pass settings passed to Optimizer to the wrapped optimizer.

    Note that settings to the first parameter are used for all parameters, so if you specified per-parameter settings, they will be ignored.

Example:

wrapping pytorch_optimizer.StableAdamW

from pytorch_optimizer import StableAdamW
opt = tz.Optimizer(
    model.parameters(),
    tz.m.Wrap(StableAdamW, lr=1),
    tz.m.Cautious(),
    tz.m.LR(1e-2)
)
Source code in torchzero/modules/wrappers/optim_wrapper.py
class Wrap(Module):
    """
    Wraps a pytorch optimizer to use it as a module.

    Note:
        Custom param groups are supported only by ``set_param_groups``, settings passed to Optimizer will be applied to all parameters.

    Args:
        opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
            function that takes in parameters and returns the optimizer, for example ``torch.optim.Adam``
            or ``lambda parameters: torch.optim.Adam(parameters, lr=1e-3)``
        *args:
        **kwargs:
            Extra args to be passed to opt_fn. The function is called as ``opt_fn(parameters, *args, **kwargs)``.
        use_param_groups:
            Whether to pass settings passed to Optimizer to the wrapped optimizer.

            Note that settings to the first parameter are used for all parameters,
            so if you specified per-parameter settings, they will be ignored.

    ### Example:
    wrapping pytorch_optimizer.StableAdamW

    ```python

    from pytorch_optimizer import StableAdamW
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Wrap(StableAdamW, lr=1),
        tz.m.Cautious(),
        tz.m.LR(1e-2)
    )
    ```

    """

    def __init__(
        self,
        opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer,
        *args,
        use_param_groups: bool = True,
        **kwargs,
    ):
        defaults = dict(use_param_groups=use_param_groups)
        super().__init__(defaults=defaults)

        self._opt_fn = opt_fn
        self._opt_args = args
        self._opt_kwargs = kwargs
        self._custom_param_groups = None

        self.optimizer: torch.optim.Optimizer | None = None
        if isinstance(self._opt_fn, torch.optim.Optimizer) or not callable(self._opt_fn):
            self.optimizer = self._opt_fn

    def set_param_groups(self, param_groups):
        self._custom_param_groups = _make_param_groups(param_groups, differentiable=False)
        return super().set_param_groups(param_groups)

    @torch.no_grad
    def apply(self, objective):
        params = objective.params

        # initialize opt on 1st step
        if self.optimizer is None:
            assert callable(self._opt_fn)
            param_groups = params if self._custom_param_groups is None else self._custom_param_groups
            self.optimizer = self._opt_fn(param_groups, *self._opt_args, **self._opt_kwargs)

        # set optimizer per-parameter settings
        if self.defaults["use_param_groups"] and objective.modular is not None:
            for group in self.optimizer.param_groups:
                first_param = group['params'][0]
                setting = self.settings[first_param]

                # settings passed in `set_param_groups` are the highest priority
                # schedulers will override defaults but not settings passed in `set_param_groups`
                # this is consistent with how Optimizer does it.
                if self._custom_param_groups is not None:
                    setting = {k:v for k,v in setting if k not in self._custom_param_groups[0]}

                group.update(setting)

        # set grad to update
        orig_grad = [p.grad for p in params]
        for p, u in zip(params, objective.get_updates()):
            p.grad = u

        # if this is last module, simply use optimizer to update parameters
        if objective.modular is not None and self is objective.modular.modules[-1]:
            self.optimizer.step()

            # restore grad
            for p, g in zip(params, orig_grad):
                p.grad = g

            objective.stop = True; objective.skip_update = True
            return objective

        # this is not the last module, meaning update is difference in parameters
        # and passed to next module
        params_before_step = [p.clone() for p in params]
        self.optimizer.step() # step and update params
        for p, g in zip(params, orig_grad):
            p.grad = g
        objective.updates = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
        for p, o in zip(params, params_before_step):
            p.set_(o) # pyright: ignore[reportArgumentType]

        return objective

    def reset(self):
        super().reset()
        assert self.optimizer is not None
        for g in self.optimizer.param_groups:
            for p in g['params']:
                state = self.optimizer.state[p]
                state.clear()

Zeros

Bases: torchzero.core.module.Module

Outputs zeros

Source code in torchzero/modules/ops/utility.py
class Zeros(Module):
    """Outputs zeros"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def apply(self, objective):
        objective.updates = [torch.zeros_like(p) for p in objective.params]
        return objective

clip_grad_norm_

clip_grad_norm_(params: Iterable[Tensor], max_norm: float | None, ord: Union[Literal['mad', 'std', 'var', 'sum', 'l0', 'l1', 'l2', 'l3', 'l4', 'linf'], float, Tensor] = 2, dim: Union[int, Sequence[int], Literal['global'], NoneType] = None, inverse_dims: bool = False, min_size: int = 2, min_norm: float | None = None)

Clips gradient of an iterable of parameters to specified norm value. Gradients are modified in-place.

Parameters:

  • params (Iterable[Tensor]) –

    parameters with gradients to clip.

  • max_norm (float) –

    value to clip norm to.

  • ord (float, default: 2 ) –

    norm order. Defaults to 2.

  • dim (int | Sequence[int] | str | None, default: None ) –

    calculates norm along those dimensions. If list/tuple, tensors are normalized along all dimensios in dim that they have. Can be set to "global" to normalize by global norm of all gradients concatenated to a vector. Defaults to None.

  • min_size (int, default: 2 ) –

    minimal size of a dimension to normalize along it. Defaults to 1.

Source code in torchzero/modules/clipping/clipping.py
def clip_grad_norm_(
    params: Iterable[torch.Tensor],
    max_norm: float | None,
    ord: Metrics = 2,
    dim: int | Sequence[int] | Literal["global"] | None = None,
    inverse_dims: bool = False,
    min_size: int = 2,
    min_norm: float | None = None,
):
    """Clips gradient of an iterable of parameters to specified norm value.
    Gradients are modified in-place.

    Args:
        params (Iterable[torch.Tensor]): parameters with gradients to clip.
        max_norm (float): value to clip norm to.
        ord (float, optional): norm order. Defaults to 2.
        dim (int | Sequence[int] | str | None, optional):
            calculates norm along those dimensions.
            If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
            Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
            Defaults to None.
        min_size (int, optional):
            minimal size of a dimension to normalize along it. Defaults to 1.
    """
    grads = TensorList(p.grad for p in params if p.grad is not None)
    _clip_norm_(grads, min=min_norm, max=max_norm, norm_value=None, ord=ord, dim=dim, inverse_dims=inverse_dims, min_size=min_size)

clip_grad_value_

clip_grad_value_(params: Iterable[Tensor], value: float)

Clips gradient of an iterable of parameters at specified value. Gradients are modified in-place. Args: params (Iterable[Tensor]): iterable of tensors with gradients to clip. value (float or int): maximum allowed value of gradient

Source code in torchzero/modules/clipping/clipping.py
def clip_grad_value_(params: Iterable[torch.Tensor], value: float):
    """Clips gradient of an iterable of parameters at specified value.
    Gradients are modified in-place.
    Args:
        params (Iterable[Tensor]): iterable of tensors with gradients to clip.
        value (float or int): maximum allowed value of gradient
    """
    grads = [p.grad for p in params if p.grad is not None]
    torch._foreach_clamp_min_(grads, -value)
    torch._foreach_clamp_max_(grads, value)

decay_weights_

decay_weights_(params: Iterable[Tensor], weight_decay: float | NumberList, ord: int = 2)

directly decays weights in-place

Source code in torchzero/modules/weight_decay/weight_decay.py
@torch.no_grad
def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberList, ord:int=2):
    """directly decays weights in-place"""
    params = TensorList(params)
    weight_decay_(params, params, -weight_decay, ord)

normalize_grads_

normalize_grads_(params: Iterable[Tensor], norm_value: float, ord: Union[Literal['mad', 'std', 'var', 'sum', 'l0', 'l1', 'l2', 'l3', 'l4', 'linf'], float, Tensor] = 2, dim: Union[int, Sequence[int], Literal['global'], NoneType] = None, inverse_dims: bool = False, min_size: int = 1)

Normalizes gradient of an iterable of parameters to specified norm value. Gradients are modified in-place.

Parameters:

  • params (Iterable[Tensor]) –

    parameters with gradients to clip.

  • norm_value (float) –

    value to clip norm to.

  • ord (float, default: 2 ) –

    norm order. Defaults to 2.

  • dim (int | Sequence[int] | str | None, default: None ) –

    calculates norm along those dimensions. If list/tuple, tensors are normalized along all dimensios in dim that they have. Can be set to "global" to normalize by global norm of all gradients concatenated to a vector. Defaults to None.

  • inverse_dims (bool, default: False ) –

    if True, the dims argument is inverted, and all other dimensions are normalized.

  • min_size (int, default: 1 ) –

    minimal size of a dimension to normalize along it. Defaults to 1.

Source code in torchzero/modules/clipping/clipping.py
def normalize_grads_(
    params: Iterable[torch.Tensor],
    norm_value: float,
    ord: Metrics = 2,
    dim: int | Sequence[int] | Literal["global"] | None = None,
    inverse_dims: bool = False,
    min_size: int = 1,
):
    """Normalizes gradient of an iterable of parameters to specified norm value.
    Gradients are modified in-place.

    Args:
        params (Iterable[torch.Tensor]): parameters with gradients to clip.
        norm_value (float): value to clip norm to.
        ord (float, optional): norm order. Defaults to 2.
        dim (int | Sequence[int] | str | None, optional):
            calculates norm along those dimensions.
            If list/tuple, tensors are normalized along all dimensios in `dim` that they have.
            Can be set to "global" to normalize by global norm of all gradients concatenated to a vector.
            Defaults to None.
        inverse_dims (bool, optional):
            if True, the `dims` argument is inverted, and all other dimensions are normalized.
        min_size (int, optional):
            minimal size of a dimension to normalize along it. Defaults to 1.
    """
    grads = TensorList(p.grad for p in params if p.grad is not None)
    _clip_norm_(grads, min=None, max=None, norm_value=norm_value, ord=ord, dim=dim, inverse_dims=inverse_dims, min_size=min_size)

orthogonalize_grads_

orthogonalize_grads_(params: Iterable[Tensor], dual_norm_correction=False, method: Literal['newtonschulz', 'ns5', 'polar_express', 'svd', 'qr', 'eigh'] = 'newtonschulz', channel_first: bool = True)

Computes the zeroth power / orthogonalization of gradients of an iterable of parameters.

This sets gradients in-place. Applies along first 2 dims (expected to be out_channels, in_channels).

Note that the Muon page says that embeddings and classifier heads should not be orthogonalized. Args: params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize. dual_norm_correction (bool, optional): enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False. method (str, optional): Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise. channel_first (bool, optional): if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions are considered batch dimensions.

Source code in torchzero/modules/adaptive/muon.py
def orthogonalize_grads_(
    params: Iterable[torch.Tensor],
    dual_norm_correction=False,
    method: OrthogonalizeMethod = "newtonschulz",
    channel_first:bool=True,
):
    """Computes the zeroth power / orthogonalization of gradients of an iterable of parameters.

    This sets gradients in-place. Applies along first 2 dims (expected to be `out_channels, in_channels`).

    Note that the Muon page says that embeddings and classifier heads should not be orthogonalized.
    Args:
        params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
        dual_norm_correction (bool, optional):
            enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
        method (str, optional):
            Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
        channel_first (bool, optional):
            if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions
            are considered batch dimensions.
    """
    for p in params:
        if (p.grad is not None) and _is_at_least_2d(p.grad, channel_first=channel_first):
            X = _orthogonalize_format(p.grad, method=method, channel_first=channel_first)
            if dual_norm_correction: X = _dual_norm_correction(X, p.grad, channel_first=False)
            p.grad.set_(X.view_as(p)) # pyright:ignore[reportArgumentType]

orthograd_

orthograd_(params: Iterable[Tensor], eps: float = 1e-30)

Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

Parameters:

  • params (Iterable[Tensor]) –

    parameters that hold gradients to apply ⟂Grad to.

  • eps (float, default: 1e-30 ) –

    epsilon added to the denominator for numerical stability (default: 1e-30)

reference https://arxiv.org/abs/2501.04697

Source code in torchzero/modules/adaptive/orthograd.py
def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
    """Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.

    Args:
        params (abc.Iterable[torch.Tensor]): parameters that hold gradients to apply ⟂Grad to.
        eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)

    reference
        https://arxiv.org/abs/2501.04697
    """
    params = TensorList(params).with_grad()
    grad = params.grad
    grad -= (params.dot(grad)/(params.dot(params) + eps)) * params