Skip to content

Second order

This subpackage contains "True" second order methods that use exact second order information.

See also

  • Quasi-newton - quasi-newton methods that estimate the hessian using gradient information.

Classes:

  • ImprovedNewton

    Improved Newton's Method (INM).

  • InverseFreeNewton

    Inverse-free newton's method

  • Newton

    Exact Newton's method via autograd.

  • NewtonCG

    Newton's method with a matrix-free conjugate gradient or minimial-residual solver.

  • NewtonCGSteihaug

    Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.

  • NystromPCG

    Newton's method with a Nyström-preconditioned conjugate gradient solver.

  • NystromSketchAndSolve

    Newton's method with a Nyström sketch-and-solve solver.

  • SixthOrder3P

    Sixth-order iterative method.

  • SixthOrder3PM2

    Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45.

  • SixthOrder5P

    Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139.

  • SubspaceNewton

    Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).

  • TwoPointNewton

    two-point Newton method with frozen derivative with third order convergence.

ImprovedNewton

Bases: torchzero.core.transform.Transform

Improved Newton's Method (INM).

Reference

Saheya, B., et al. "A new Newton-like method for solving nonlinear equations." SpringerPlus 5.1 (2016): 1269.

Source code in torchzero/modules/second_order/inm.py
class ImprovedNewton(Transform):
    """Improved Newton's Method (INM).

    Reference:
        [Saheya, B., et al. "A new Newton-like method for solving nonlinear equations." SpringerPlus 5.1 (2016): 1269.](https://d-nb.info/1112813721/34)
    """

    def __init__(
        self,
        damping: float = 0,
        eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
        update_freq: int = 1,
        precompute_inverse: bool | None = None,
        use_lstsq: bool = False,
        hessian_method: HessianMethod = "batched_autograd",
        h: float = 1e-3,
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults["update_freq"]
        super().__init__(defaults, update_freq=update_freq, inner=inner, )

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]

        _, f_list, J = objective.hessian(
            hessian_method=fs['hessian_method'],
            h=fs['h'],
            at_x0=True
        )
        if f_list is None: f_list = objective.get_grads()

        f = torch.cat([t.ravel() for t in f_list])
        J = _eigval_fn(J, fs["eigval_fn"])

        x_list = TensorList(objective.params)
        f_list = TensorList(objective.get_grads())
        x_prev, f_prev = unpack_states(states, objective.params, "x_prev", "f_prev", cls=TensorList)

        # initialize on 1st step, do Newton step
        if "H" not in self.global_state:
            x_prev.copy_(x_list)
            f_prev.copy_(f_list)
            P = J

        # INM update
        else:
            s_list = x_list - x_prev
            y_list = f_list - f_prev
            x_prev.copy_(x_list)
            f_prev.copy_(f_list)

            P = inm(f, J, s=s_list.to_vec(), y=y_list.to_vec())

        # update state
        precompute_inverse = fs["precompute_inverse"]
        if precompute_inverse is None:
            precompute_inverse = fs["__update_freq"] >= 10

        _newton_update_state_(
            H=P,
            state = self.global_state,
            damping = fs["damping"],
            eigval_fn = fs["eigval_fn"],
            precompute_inverse = precompute_inverse,
            use_lstsq = fs["use_lstsq"]
        )

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        updates = objective.get_updates()
        fs = settings[0]

        b = torch.cat([t.ravel() for t in updates])
        sol = _newton_solve(b=b, state=self.global_state, use_lstsq=fs["use_lstsq"])

        vec_to_tensors_(sol, updates)
        return objective


    def get_H(self,objective=...):
        return _newton_get_H(self.global_state)

InverseFreeNewton

Bases: torchzero.core.transform.Transform

Inverse-free newton's method

Reference Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.

Source code in torchzero/modules/second_order/ifn.py
class InverseFreeNewton(Transform):
    """Inverse-free newton's method

    Reference
        [Massalski, Marcin, and Magdalena Nockowska-Rosiak. "INVERSE-FREE NEWTON'S METHOD." Journal of Applied Analysis & Computation 15.4 (2025): 2238-2257.](https://www.jaac-online.com/article/doi/10.11948/20240428)
    """
    def __init__(
        self,
        update_freq: int = 1,
        hessian_method: HessianMethod = "batched_autograd",
        h: float = 1e-3,
        inner: Chainable | None = None,
    ):
        defaults = dict(hessian_method=hessian_method, h=h)
        super().__init__(defaults, update_freq=update_freq, inner=inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]

        _, _, H = objective.hessian(
            hessian_method=fs['hessian_method'],
            h=fs['h'],
            at_x0=True
        )

        self.global_state["H"] = H

        # inverse free part
        if 'Y' not in self.global_state:
            num = H.T
            denom = (torch.linalg.norm(H, 1) * torch.linalg.norm(H, float('inf'))) # pylint:disable=not-callable

            finfo = torch.finfo(H.dtype)
            self.global_state['Y'] = num.div_(denom.clip(min=finfo.tiny * 2, max=finfo.max / 2))

        else:
            Y = self.global_state['Y']
            I2 = torch.eye(Y.size(0), device=Y.device, dtype=Y.dtype).mul_(2)
            I2 -= H @ Y
            self.global_state['Y'] = Y @ I2


    def apply_states(self, objective, states, settings):
        Y = self.global_state["Y"]
        g = torch.cat([t.ravel() for t in objective.get_updates()])
        objective.updates = vec_to_tensors(Y@g, objective.params)
        return objective

    def get_H(self,objective=...):
        return DenseWithInverse(A = self.global_state["H"], A_inv=self.global_state["Y"])

Newton

Bases: torchzero.core.transform.Transform

Exact Newton's method via autograd.

Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function. The update rule is given by (H + yI)⁻¹g, where H is the hessian and g is the gradient, y is the damping parameter.

g can be output of another module, if it is specifed in inner argument.

Note

In most cases Newton should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply Newton preconditioning to another module's output.

Note

This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating the hessian. The closure must accept a backward argument (refer to documentation).

Parameters:

  • damping (float, default: 0 ) –

    tikhonov regularizer value. Defaults to 0.

  • eigval_fn (Callable | None, default: None ) –

    function to apply to eigenvalues, for example torch.abs or lambda L: torch.clip(L, min=1e-8). If this is specified, eigendecomposition will be used to invert the hessian.

  • update_freq (int, default: 1 ) –

    updates hessian every update_freq steps.

  • precompute_inverse (bool, default: None ) –

    if True, whenever hessian is computed, also computes the inverse. This is more efficient when update_freq is large. If None, this is True if update_freq >= 10.

  • use_lstsq ((bool, Optional), default: False ) –

    if True, least squares will be used to solve the linear system, this can prevent it from exploding when hessian is indefinite. If False, tries cholesky, if it fails tries LU, and then least squares. If eigval_fn is specified, eigendecomposition is always used and this argument is ignored.

  • hessian_method (str, default: 'batched_autograd' ) –

    Determines how hessian is computed.

    • "batched_autograd" - uses autograd to compute ndim batched hessian-vector products. Faster than "autograd" but uses more memory.
    • "autograd" - uses autograd to compute ndim hessian-vector products using for loop. Slower than "batched_autograd" but uses less memory.
    • "functional_revrev" - uses torch.autograd.functional with "reverse-over-reverse" strategy and a for-loop. This is generally equivalent to "autograd".
    • "functional_fwdrev" - uses torch.autograd.functional with vectorized "forward-over-reverse" strategy. Faster than "functional_fwdrev" but uses more memory ("batched_autograd" seems to be faster)
    • "func" - uses torch.func.hessian which uses "forward-over-reverse" strategy. This method is the fastest and is recommended, however it is more restrictive and fails with some operators which is why it isn't the default.
    • "gfd_forward" - computes ndim hessian-vector products via gradient finite difference using a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
    • "gfd_central" - computes ndim hessian-vector products via gradient finite difference using a more accurate central formula which requires two gradient evaluations per hessian-vector product.
    • "fd" - uses function values to estimate gradient and hessian via finite difference. This uses less evaluations than chaining "gfd_*" after tz.m.FDM.
    • "thoad" - uses thoad library, can be significantly faster than pytorch but limited operator coverage.

    Defaults to "batched_autograd".

  • h (float, default: 0.001 ) –

    finite difference step size if hessian is compute via finite-difference.

  • inner (Chainable | None, default: None ) –

    modules to apply hessian preconditioner to. Defaults to None.

See also

  • tz.m.NewtonCG: uses a matrix-free conjugate gradient solver and hessian-vector products. useful for large scale problems as it doesn't form the full hessian.
  • tz.m.NewtonCGSteihaug: trust region version of tz.m.NewtonCG.
  • tz.m.ImprovedNewton: Newton with additional rank one correction to the hessian, can be faster than Newton.
  • tz.m.InverseFreeNewton: an inverse-free variant of Newton's method.
  • tz.m.quasi_newton: large collection of quasi-newton methods that estimate the hessian.

Notes

Implementation details

(H + yI)⁻¹g is calculated by solving the linear system (H + yI)x = g. The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares. Least squares can be forced by setting use_lstsq=True.

Additionally, if eigval_fn is specified, eigendecomposition of the hessian is computed, eigval_fn is applied to the eigenvalues, and (H + yI)⁻¹ is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive but not by much.

Handling non-convexity

Standard Newton's method does not handle non-convexity well without some modifications. This is because it jumps to the stationary point, which may be the maxima of the quadratic approximation.

A modification to handle non-convexity is to modify the eignevalues to be positive, for example by setting eigval_fn = lambda L: L.abs().clip(min=1e-4).

Examples:

Newton's method with backtracking line search

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Newton(),
    tz.m.Backtracking()
)

Newton's method for non-convex optimization.

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Newton(eigval_fn = lambda L: L.abs().clip(min=1e-4)),
    tz.m.Backtracking()
)

Newton preconditioning applied to momentum

opt = tz.Optimizer(
    model.parameters(),
    tz.m.Newton(inner=tz.m.EMA(0.9)),
    tz.m.LR(0.1)
)
Source code in torchzero/modules/second_order/newton.py
class Newton(Transform):
    """Exact Newton's method via autograd.

    Newton's method produces a direction jumping to the stationary point of quadratic approximation of the target function.
    The update rule is given by ``(H + yI)⁻¹g``, where ``H`` is the hessian and ``g`` is the gradient, ``y`` is the ``damping`` parameter.

    ``g`` can be output of another module, if it is specifed in ``inner`` argument.

    Note:
        In most cases Newton should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.

    Note:
        This module requires the a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients for calculating the hessian.
        The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        damping (float, optional): tikhonov regularizer value. Defaults to 0.
        eigval_fn (Callable | None, optional):
            function to apply to eigenvalues, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
            If this is specified, eigendecomposition will be used to invert the hessian.
        update_freq (int, optional):
            updates hessian every ``update_freq`` steps.
        precompute_inverse (bool, optional):
            if ``True``, whenever hessian is computed, also computes the inverse. This is more efficient
            when ``update_freq`` is large. If ``None``, this is ``True`` if ``update_freq >= 10``.
        use_lstsq (bool, Optional):
            if True, least squares will be used to solve the linear system, this can prevent it from exploding
            when hessian is indefinite. If False, tries cholesky, if it fails tries LU, and then least squares.
            If ``eigval_fn`` is specified, eigendecomposition is always used and this argument is ignored.
        hessian_method (str):
            Determines how hessian is computed.

            - ``"batched_autograd"`` - uses autograd to compute ``ndim`` batched hessian-vector products. Faster than ``"autograd"`` but uses more memory.
            - ``"autograd"`` - uses autograd to compute ``ndim`` hessian-vector products using for loop. Slower than ``"batched_autograd"`` but uses less memory.
            - ``"functional_revrev"`` - uses ``torch.autograd.functional`` with "reverse-over-reverse" strategy and a for-loop. This is generally equivalent to ``"autograd"``.
            - ``"functional_fwdrev"`` - uses ``torch.autograd.functional`` with vectorized "forward-over-reverse" strategy. Faster than ``"functional_fwdrev"`` but uses more memory (``"batched_autograd"`` seems to be faster)
            - ``"func"`` - uses ``torch.func.hessian`` which uses "forward-over-reverse" strategy. This method is the fastest and is recommended, however it is more restrictive and fails with some operators which is why it isn't the default.
            - ``"gfd_forward"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"gfd_central"`` - computes ``ndim`` hessian-vector products via gradient finite difference using a more accurate central formula which requires two gradient evaluations per hessian-vector product.
            - ``"fd"`` - uses function values to estimate gradient and hessian via finite difference. This uses less evaluations than chaining ``"gfd_*"`` after ``tz.m.FDM``.
            - ``"thoad"`` - uses ``thoad`` library, can be significantly faster than pytorch but limited operator coverage.

            Defaults to ``"batched_autograd"``.
        h (float, optional):
            finite difference step size if hessian is compute via finite-difference.
        inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.

    # See also

    * ``tz.m.NewtonCG``: uses a matrix-free conjugate gradient solver and hessian-vector products.
    useful for large scale problems as it doesn't form the full hessian.
    * ``tz.m.NewtonCGSteihaug``: trust region version of ``tz.m.NewtonCG``.
    * ``tz.m.ImprovedNewton``: Newton with additional rank one correction to the hessian, can be faster than Newton.
    * ``tz.m.InverseFreeNewton``: an inverse-free variant of Newton's method.
    * ``tz.m.quasi_newton``: large collection of quasi-newton methods that estimate the hessian.

    # Notes

    ## Implementation details

    ``(H + yI)⁻¹g`` is calculated by solving the linear system ``(H + yI)x = g``.
    The linear system is solved via cholesky decomposition, if that fails, LU decomposition, and if that fails, least squares. Least squares can be forced by setting ``use_lstsq=True``.

    Additionally, if ``eigval_fn`` is specified, eigendecomposition of the hessian is computed,
    ``eigval_fn`` is applied to the eigenvalues, and ``(H + yI)⁻¹`` is computed using the computed eigenvectors and transformed eigenvalues. This is more generally more computationally expensive but not by much.

    ## Handling non-convexity

    Standard Newton's method does not handle non-convexity well without some modifications.
    This is because it jumps to the stationary point, which may be the maxima of the quadratic approximation.

    A modification to handle non-convexity is to modify the eignevalues to be positive,
    for example by setting ``eigval_fn = lambda L: L.abs().clip(min=1e-4)``.

    # Examples:

    Newton's method with backtracking line search

    ```py
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Newton(),
        tz.m.Backtracking()
    )
    ```

    Newton's method for non-convex optimization.

    ```py
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Newton(eigval_fn = lambda L: L.abs().clip(min=1e-4)),
        tz.m.Backtracking()
    )
    ```

    Newton preconditioning applied to momentum

    ```py
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Newton(inner=tz.m.EMA(0.9)),
        tz.m.LR(0.1)
    )
    ```

    """
    def __init__(
        self,
        damping: float = 0,
        eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
        update_freq: int = 1,
        precompute_inverse: bool | None = None,
        use_lstsq: bool = False,
        hessian_method: HessianMethod = "batched_autograd",
        h: float = 1e-3,
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['update_freq'], defaults["inner"]
        super().__init__(defaults, update_freq=update_freq, inner=inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]

        precompute_inverse = fs["precompute_inverse"]
        if precompute_inverse is None:
            precompute_inverse = fs["__update_freq"] >= 10

        __, _, H = objective.hessian(hessian_method=fs["hessian_method"], h=fs["h"], at_x0=True)

        _newton_update_state_(
            state = self.global_state,
            H=H,
            damping = fs["damping"],
            eigval_fn = fs["eigval_fn"],
            precompute_inverse = precompute_inverse,
            use_lstsq = fs["use_lstsq"]
        )

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        updates = objective.get_updates()
        fs = settings[0]

        b = torch.cat([t.ravel() for t in updates])
        sol = _newton_solve(b=b, state=self.global_state, use_lstsq=fs["use_lstsq"])

        vec_to_tensors_(sol, updates)
        return objective

    def get_H(self,objective=...):
        return _newton_get_H(self.global_state)

NewtonCG

Bases: torchzero.core.transform.Transform

Newton's method with a matrix-free conjugate gradient or minimial-residual solver.

Notes
  • In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply Newton preconditioning to another module's output.

  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

Warning

CG may fail if hessian is not positive-definite.

Parameters:

  • maxiter (int | None, default: None ) –

    Maximum number of iterations for the conjugate gradient solver. By default, this is set to the number of dimensions in the objective function, which is the theoretical upper bound for CG convergence. Setting this to a smaller value (truncated Newton) can still generate good search directions. Defaults to None.

  • tol (float, default: 1e-08 ) –

    Relative tolerance for the conjugate gradient solver to determine convergence. Defaults to 1e-4.

  • reg (float, default: 1e-08 ) –

    Regularization parameter (damping) added to the Hessian diagonal. This helps ensure the system is positive-definite. Defaults to 1e-8.

  • hvp_method (str, default: 'autograd' ) –

    Determines how Hessian-vector products are evaluated.

    • "autograd" - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop.
    • "fd_forward" - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
    • "fd_central" - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

    For NewtonCG "batched_autograd" is equivalent to "autograd". Defaults to "autograd".

  • h (float, default: 0.001 ) –

    The step size for finite difference if hvp_method is "fd_forward" or "fd_central". Defaults to 1e-3.

  • warm_start (bool, default: False ) –

    If True, the conjugate gradient solver is initialized with the solution from the previous optimization step. This can accelerate convergence, especially in truncated Newton methods. Defaults to False.

  • inner (Chainable | None, default: None ) –

    NewtonCG will attempt to apply preconditioning to the output of this module.

Examples: Newton-CG with a backtracking line search:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.NewtonCG(),
    tz.m.Backtracking()
)

Truncated Newton method (useful for large-scale problems):

opt = tz.Optimizer(
    model.parameters(),
    tz.m.NewtonCG(maxiter=10),
    tz.m.Backtracking()
)

Source code in torchzero/modules/second_order/newton_cg.py
class NewtonCG(Transform):
    """Newton's method with a matrix-free conjugate gradient or minimial-residual solver.

    Notes:
        * In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.

        * This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

    Warning:
        CG may fail if hessian is not positive-definite.

    Args:
        maxiter (int | None, optional):
            Maximum number of iterations for the conjugate gradient solver.
            By default, this is set to the number of dimensions in the
            objective function, which is the theoretical upper bound for CG
            convergence. Setting this to a smaller value (truncated Newton)
            can still generate good search directions. Defaults to None.
        tol (float, optional):
            Relative tolerance for the conjugate gradient solver to determine
            convergence. Defaults to 1e-4.
        reg (float, optional):
            Regularization parameter (damping) added to the Hessian diagonal.
            This helps ensure the system is positive-definite. Defaults to 1e-8.
        hvp_method (str, optional):
            Determines how Hessian-vector products are evaluated.

            - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop.
            - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

            For NewtonCG ``"batched_autograd"`` is equivalent to ``"autograd"``. Defaults to ``"autograd"``.
        h (float, optional):
            The step size for finite difference if ``hvp_method`` is
            ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
        warm_start (bool, optional):
            If ``True``, the conjugate gradient solver is initialized with the
            solution from the previous optimization step. This can accelerate
            convergence, especially in truncated Newton methods.
            Defaults to False.
        inner (Chainable | None, optional):
            NewtonCG will attempt to apply preconditioning to the output of this module.

    Examples:
    Newton-CG with a backtracking line search:

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NewtonCG(),
        tz.m.Backtracking()
    )
    ```

    Truncated Newton method (useful for large-scale problems):
    ```
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NewtonCG(maxiter=10),
        tz.m.Backtracking()
    )
    ```

    """
    def __init__(
        self,
        maxiter: int | None = None,
        tol: float = 1e-8,
        reg: float = 1e-8,
        hvp_method: HVPMethod = "autograd",
        solver: Literal['cg', 'minres'] = 'cg',
        npc_terminate: bool = False,
        h: float = 1e-3, # tuned 1e-4 or 1e-3
        miniter:int = 1,
        warm_start=False,
        warm_beta:float=0,
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner']
        super().__init__(defaults, inner=inner)

        self._num_hvps = 0
        self._num_hvps_last_step = 0

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]
        hvp_method = fs['hvp_method']
        h = fs['h']

        # ---------------------- Hessian vector product function --------------------- #
        _, H_mv = objective.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
        objective.temp = H_mv

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        self._num_hvps_last_step = 0
        H_mv = objective.poptemp()

        fs = settings[0]
        tol = fs['tol']
        reg = fs['reg']
        maxiter = fs['maxiter']
        solver = fs['solver'].lower().strip()
        warm_start = fs['warm_start']
        npc_terminate = fs["npc_terminate"]

        # ---------------------------------- run cg ---------------------------------- #
        x0 = None
        if warm_start:
            x0 = unpack_states(states, objective.params, 'prev_x', cls=TensorList)

        b = TensorList(objective.get_updates())

        if solver == 'cg':
            d, _ = cg(A_mv=H_mv, b=b, x0=x0, tol=tol, maxiter=maxiter,
                      miniter=fs["miniter"], reg=reg, npc_terminate=npc_terminate)

        elif solver == 'minres':
            d = minres(A_mv=H_mv, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)

        else:
            raise ValueError(f"Unknown solver {solver}")

        if warm_start:
            assert x0 is not None
            x0.lerp_(d, weight = 1-fs["warm_beta"])

        objective.updates = d
        self._num_hvps += self._num_hvps_last_step
        return objective

NewtonCGSteihaug

Bases: torchzero.core.transform.Transform

Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.

Notes
  • In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply Newton preconditioning to another module's output.

  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

Parameters:

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. Defaults to 0.0.

  • nplus (float, default: 3.5 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.99 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.0001 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • max_attempts (max_attempts, default: 100 ) –

    maximum number of trust radius reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • max_history (int, default: 100 ) –

    CG will store this many intermediate solutions, reusing them when trust radius is reduced instead of re-running CG. Each solution storage requires 2N memory. Defaults to 100.

  • boundary_tol (float | None, default: 1e-06 ) –

    The trust region only increases when suggested step's norm is at least (1-boundary_tol)*trust_region. This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.

  • maxiter (int | None, default: None ) –

    maximum number of CG iterations per step. Each iteration requies one backward pass if hvp_method="forward", two otherwise. Defaults to None.

  • miniter (int, default: 1 ) –

    minimal number of CG iterations. This prevents making no progress

  • tol (float, default: 1e-08 ) –

    terminates CG when norm of the residual is less than this value. Defaults to 1e-8. when initial guess is below tolerance. Defaults to 1.

  • reg (float, default: 1e-08 ) –

    hessian regularization. Defaults to 1e-8.

  • solver (str, default: 'cg' ) –

    solver, "cg" or "minres". "cg" is recommended. Defaults to 'cg'.

  • adapt_tol (bool, default: True ) –

    if True, whenever trust radius collapses to smallest representable number, the tolerance is multiplied by 0.1. Defaults to True.

  • npc_terminate (bool, default: False ) –

    whether to terminate CG/MINRES whenever negative curvature is detected. Defaults to False.

  • hvp_method (str, default: 'fd_central' ) –

    either "fd_forward" to use forward formula which requires one backward pass per hessian-vector product, or "fd_central" to use a more accurate central formula which requires two backward passes. "fd_forward" is usually accurate enough. Defaults to "fd_forward".

  • h (float, default: 0.001 ) –

    finite difference step size. Defaults to 1e-3.

  • inner (Chainable | None, default: None ) –

    applies preconditioning to output of this module. Defaults to None.

Examples:

Trust-region Newton-CG:

opt = tz.Optimizer(
    model.parameters(),
    tz.m.NewtonCGSteihaug(),
)
Reference:
Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
Source code in torchzero/modules/second_order/newton_cg.py
class NewtonCGSteihaug(Transform):
    """Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.

    Notes:
        * In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.

        * This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

    Args:
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted. Defaults to 0.0.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        max_attempts (max_attempts, optional):
            maximum number of trust radius reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        max_history (int, optional):
            CG will store this many intermediate solutions, reusing them when trust radius is reduced
            instead of re-running CG. Each solution storage requires 2N memory. Defaults to 100.
        boundary_tol (float | None, optional):
            The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
            This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.

        maxiter (int | None, optional):
            maximum number of CG iterations per step. Each iteration requies one backward pass if `hvp_method="forward"`, two otherwise. Defaults to None.
        miniter (int, optional):
            minimal number of CG iterations. This prevents making no progress
        tol (float, optional):
            terminates CG when norm of the residual is less than this value. Defaults to 1e-8.
            when initial guess is below tolerance. Defaults to 1.
        reg (float, optional): hessian regularization. Defaults to 1e-8.
        solver (str, optional): solver, "cg" or "minres". "cg" is recommended. Defaults to 'cg'.
        adapt_tol (bool, optional):
            if True, whenever trust radius collapses to smallest representable number,
            the tolerance is multiplied by 0.1. Defaults to True.
        npc_terminate (bool, optional):
            whether to terminate CG/MINRES whenever negative curvature is detected. Defaults to False.

        hvp_method (str, optional):
            either ``"fd_forward"`` to use forward formula which requires one backward pass per hessian-vector product, or ``"fd_central"`` to use a more accurate central formula which requires two backward passes. ``"fd_forward"`` is usually accurate enough. Defaults to ``"fd_forward"``.
        h (float, optional): finite difference step size. Defaults to 1e-3.

        inner (Chainable | None, optional):
            applies preconditioning to output of this module. Defaults to None.

    ### Examples:
    Trust-region Newton-CG:

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NewtonCGSteihaug(),
    )
    ```

    ### Reference:
        Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
    """
    def __init__(
        self,
        # trust region settings
        eta: float= 0.0,
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        init: float = 1,
        max_attempts: int = 100,
        max_history: int = 100,
        boundary_tol: float = 1e-6, # tuned

        # cg settings
        maxiter: int | None = None,
        miniter: int = 1,
        tol: float = 1e-8,
        reg: float = 1e-8,
        solver: Literal['cg', "minres"] = 'cg',
        adapt_tol: bool = True,
        npc_terminate: bool = False,

        # hvp settings
        hvp_method: Literal["fd_forward", "fd_central"] = "fd_central",
        h: float = 1e-3, # tuned 1e-4 or 1e-3

        # inner
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner']
        super().__init__(defaults, inner=inner)

        self._num_hvps = 0
        self._num_hvps_last_step = 0


    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]
        hvp_method = fs['hvp_method']
        h = fs['h']

        # ---------------------- Hessian vector product function --------------------- #
        _, H_mv = objective.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
        objective.temp = H_mv

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        self._num_hvps_last_step = 0

        H_mv = objective.poptemp()
        params = TensorList(objective.params)
        fs = settings[0]

        tol = fs['tol'] * self.global_state.get('tol_mul', 1)
        solver = fs['solver'].lower().strip()

        reg=fs["reg"]
        maxiter=fs["maxiter"]
        max_attempts=fs["max_attempts"]
        init=fs["init"]
        npc_terminate=fs["npc_terminate"]
        miniter=fs["miniter"]
        max_history=fs["max_history"]
        adapt_tol=fs["adapt_tol"]


        # ------------------------------- trust region ------------------------------- #
        success = False
        d = None
        orig_params = [p.clone() for p in params]
        b = TensorList(objective.get_updates())
        solution = None
        closure = objective.closure
        assert closure is not None

        while not success:
            max_attempts -= 1
            if max_attempts < 0: break

            trust_radius = self.global_state.get('trust_radius', init)

            # -------------- make sure trust radius isn't too small or large ------------- #
            finfo = torch.finfo(orig_params[0].dtype)
            if trust_radius < finfo.tiny * 2:
                trust_radius = self.global_state['trust_radius'] = init
                if adapt_tol:
                    self.global_state["tol_mul"] = self.global_state.get("tol_mul", 1) * 0.1

            elif trust_radius > finfo.max / 2:
                trust_radius = self.global_state['trust_radius'] = init

            # ----------------------------------- solve ---------------------------------- #
            d = None
            if solution is not None and solution.history is not None:
                d = find_within_trust_radius(solution.history, trust_radius)

            if d is None:
                if solver == 'cg':
                    d, solution = cg(
                        A_mv=H_mv,
                        b=b,
                        tol=tol,
                        maxiter=maxiter,
                        reg=reg,
                        trust_radius=trust_radius,
                        miniter=miniter,
                        npc_terminate=npc_terminate,
                        history_size=max_history,
                    )

                elif solver == 'minres':
                    d = minres(A_mv=H_mv, b=b, trust_radius=trust_radius, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)

                else:
                    raise ValueError(f"unknown solver {solver}")

            # ---------------------------- update trust radius --------------------------- #
            self.global_state["trust_radius"], success = default_radius(
                params = params,
                closure = closure,
                f = tofloat(objective.get_loss(False)),
                g = b,
                H = H_mv,
                d = d,
                trust_radius = trust_radius,
                eta = fs["eta"],
                nplus = fs["nplus"],
                nminus = fs["nminus"],
                rho_good = fs["rho_good"],
                rho_bad = fs["rho_bad"],
                boundary_tol = fs["boundary_tol"],

                init = cast(int, None), # init isn't used because check_overflow=False
                state = cast(dict, None), # not used
                settings = cast(dict, None), # not used
                check_overflow = False, # this is checked manually to adapt tolerance
            )

        # --------------------------- assign new direction --------------------------- #
        assert d is not None
        if success:
            objective.updates = d

        else:
            objective.updates = params.zeros_like()

        self._num_hvps += self._num_hvps_last_step
        return objective

NystromPCG

Bases: torchzero.core.transform.Transform

Newton's method with a Nyström-preconditioned conjugate gradient solver.

Notes
  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

  • In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply Newton preconditioning to another module's output.

Parameters:

  • rank (int) –

    size of the sketch for preconditioning, this many hessian-vector products will be evaluated before running the conjugate gradient solver. Larger value improves the preconditioning and speeds up conjugate gradient.

  • maxiter (int | None, default: None ) –

    maximum number of iterations. By default this is set to the number of dimensions in the objective function, which is supposed to be enough for conjugate gradient to have guaranteed convergence. Setting this to a small value can still generate good enough directions. Defaults to None.

  • tol (float, default: 1e-08 ) –

    relative tolerance for conjugate gradient solver. Defaults to 1e-4.

  • reg (float, default: 1e-06 ) –

    regularization parameter. Defaults to 1e-8.

  • hvp_method (str, default: 'batched_autograd' ) –

    Determines how Hessian-vector products are computed.

    • "batched_autograd" - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than "autograd" but uses more memory.
    • "autograd" - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than "batched_autograd" but uses less memory.
    • "fd_forward" - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
    • "fd_central" - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

    Defaults to "autograd".

  • h (float, default: 0.001 ) –

    The step size for finite difference if hvp_method is "fd_forward" or "fd_central". Defaults to 1e-3.

  • inner (Chainable | None, default: None ) –

    modules to apply hessian preconditioner to. Defaults to None.

  • seed (int | None, default: None ) –

    seed for random generator. Defaults to None.

Examples:

NystromPCG with backtracking line search

opt = tz.Optimizer(
    model.parameters(),
    tz.m.NystromPCG(10),
    tz.m.Backtracking()
)
Reference

Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820

Source code in torchzero/modules/second_order/nystrom.py
class NystromPCG(Transform):
    """Newton's method with a Nyström-preconditioned conjugate gradient solver.

    Notes:
        - This module requires the a closure passed to the optimizer step,
        as it needs to re-evaluate the loss and gradients for calculating HVPs.
        The closure must accept a ``backward`` argument (refer to documentation).

        - In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.

    Args:
        rank (int):
            size of the sketch for preconditioning, this many hessian-vector products will be evaluated before
            running the conjugate gradient solver. Larger value improves the preconditioning and speeds up
            conjugate gradient.
        maxiter (int | None, optional):
            maximum number of iterations. By default this is set to the number of dimensions
            in the objective function, which is supposed to be enough for conjugate gradient
            to have guaranteed convergence. Setting this to a small value can still generate good enough directions.
            Defaults to None.
        tol (float, optional): relative tolerance for conjugate gradient solver. Defaults to 1e-4.
        reg (float, optional): regularization parameter. Defaults to 1e-8.
        hvp_method (str, optional):
            Determines how Hessian-vector products are computed.

            - ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
            - ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
            - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

            Defaults to ``"autograd"``.
        h (float, optional):
            The step size for finite difference if ``hvp_method`` is
            ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
        inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
        seed (int | None, optional): seed for random generator. Defaults to None.

    Examples:

    NystromPCG with backtracking line search

    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NystromPCG(10),
        tz.m.Backtracking()
    )
    ```

    Reference:
        Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820

    """
    def __init__(
        self,
        rank: int,
        maxiter=None,
        tol=1e-8,
        reg: float = 1e-6,
        update_freq: int = 1, # here update_freq is within update_states
        eigv_tol: float = 0,
        orthogonalize_method: OrthogonalizeMethod = 'qr',
        hvp_method: HVPMethod = "batched_autograd",
        h=1e-3,
        inner: Chainable | None = None,
        seed: int | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner']
        super().__init__(defaults, inner=inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]

        # ---------------------- Hessian vector product function --------------------- #
        # this should run on every update_states
        _, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=fs['hvp_method'], h=fs['h'], at_x0=True)
        objective.temp = H_mv

        # --------------------------- update preconditioner -------------------------- #
        step = self.increment_counter("step", 0)
        if step % fs["update_freq"] == 0:

            ndim = sum(t.numel() for t in objective.params)
            device = objective.params[0].device
            dtype = objective.params[0].dtype
            generator = self.get_generator(device, seed=fs['seed'])

            try:
                Omega = torch.randn(ndim, min(fs["rank"], ndim), device=device, dtype=dtype, generator=generator)
                HOmega = H_mm(orthogonalize(Omega, fs["orthogonalize_method"]))
                # compute the approximation
                L, Q = nystrom_approximation(
                    Omega=Omega,
                    AOmega=HOmega,
                    eigv_tol=fs["eigv_tol"],
                )

                self.global_state["L"] = L
                self.global_state["Q"] = Q

            except torch.linalg.LinAlgError as e:
                warnings.warn(f"Nystrom approximation failed with: {e}")

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        b = objective.get_updates()
        H_mv = objective.poptemp()
        fs = self.settings[objective.params[0]]

        # ----------------------------------- solve ---------------------------------- #
        if "L" not in self.global_state:
            # fallback on cg
            sol = cg(A_mv=H_mv, b=TensorList(b), tol=fs["tol"], reg=fs["reg"], maxiter=fs["maxiter"])
            objective.updates = sol.x
            return objective

        L = self.global_state["L"]
        Q = self.global_state["Q"]

        x = nystrom_pcg(L=L, Q=Q, A_mv=H_mv, b=torch.cat([t.ravel() for t in b]),
                        reg=fs['reg'], tol=fs["tol"], maxiter=fs["maxiter"])

        # -------------------------------- set update -------------------------------- #
        objective.updates = vec_to_tensors(x, reference=objective.params)
        return objective

NystromSketchAndSolve

Bases: torchzero.core.transform.Transform

Newton's method with a Nyström sketch-and-solve solver.

Notes
  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a backward argument (refer to documentation).

  • In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the inner argument if you wish to apply Newton preconditioning to another module's output.

  • If this is unstable, increase the reg parameter and tune the rank.

Parameters:

  • rank (int) –

    size of the sketch, this many hessian-vector products will be evaluated per step.

  • reg (float | None, default: 0.01 ) –

    scale of identity matrix added to hessian. Note that if this is specified, nystrom sketch-and-solve is used to compute (Q diag(L) Q.T + reg*I)x = b. It is very unstable when reg is small, i.e. smaller than 1e-4. If this is None,(Q diag(L) Q.T)x = b is computed by simply taking reciprocal of eigenvalues. Defaults to 1e-3.

  • eigv_tol (float, default: 0 ) –

    all eigenvalues smaller than largest eigenvalue times eigv_tol are removed. Defaults to None.

  • truncate (int | None, default: None ) –

    keeps top truncate eigenvalues. Defaults to None.

  • damping (float, default: 0 ) –

    scalar added to eigenvalues. Defaults to 0.

  • rdamping (float, default: 0 ) –

    scalar multiplied by largest eigenvalue and added to eigenvalues. Defaults to 0.

  • update_freq (int, default: 1 ) –

    frequency of updating preconditioner. Defaults to 1.

  • hvp_method (str, default: 'batched_autograd' ) –

    Determines how Hessian-vector products are computed.

    • "batched_autograd" - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than "autograd" but uses more memory.
    • "autograd" - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than "batched_autograd" but uses less memory.
    • "fd_forward" - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
    • "fd_central" - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

    Defaults to "autograd".

  • h (float, default: 0.001 ) –

    The step size for finite difference if hvp_method is "fd_forward" or "fd_central". Defaults to 1e-3.

  • inner (Chainable | None, default: None ) –

    modules to apply hessian preconditioner to. Defaults to None.

  • seed (int | None, default: None ) –

    seed for random generator. Defaults to None.

Examples: NystromSketchAndSolve with backtracking line search

opt = tz.Optimizer(
    model.parameters(),
    tz.m.NystromSketchAndSolve(100),
    tz.m.Backtracking()
)

Trust region NystromSketchAndSolve

opt = tz.Optimizer(
    model.parameters(),
    tz.m.LevenbergMarquadt(tz.m.NystromSketchAndSolve(100)),
)

References: - Frangella, Z., Rathore, P., Zhao, S., & Udell, M. (2024). SketchySGD: Reliable Stochastic Optimization via Randomized Curvature Estimates. SIAM Journal on Mathematics of Data Science, 6(4), 1173-1204. - Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752

Source code in torchzero/modules/second_order/nystrom.py
class NystromSketchAndSolve(Transform):
    """Newton's method with a Nyström sketch-and-solve solver.

    Notes:
        - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).

        - In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.

        - If this is unstable, increase the ``reg`` parameter and tune the rank.

    Args:
        rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
        reg (float | None, optional):
            scale of identity matrix added to hessian. Note that if this is specified, nystrom sketch-and-solve
            is used to compute ``(Q diag(L) Q.T + reg*I)x = b``. It is very unstable when ``reg`` is small,
            i.e. smaller than 1e-4. If this is None,``(Q diag(L) Q.T)x = b`` is computed by simply taking
            reciprocal of eigenvalues. Defaults to 1e-3.
        eigv_tol (float, optional):
            all eigenvalues smaller than largest eigenvalue times ``eigv_tol`` are removed. Defaults to None.
        truncate (int | None, optional):
            keeps top ``truncate`` eigenvalues. Defaults to None.
        damping (float, optional): scalar added to eigenvalues. Defaults to 0.
        rdamping (float, optional): scalar multiplied by largest eigenvalue and added to eigenvalues. Defaults to 0.
        update_freq (int, optional): frequency of updating preconditioner. Defaults to 1.
        hvp_method (str, optional):
            Determines how Hessian-vector products are computed.

            - ``"batched_autograd"`` - uses autograd with batched hessian-vector products to compute the preconditioner. Faster than ``"autograd"`` but uses more memory.
            - ``"autograd"`` - uses autograd hessian-vector products, uses a for loop to compute the preconditioner. Slower than ``"batched_autograd"`` but uses less memory.
            - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
            - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.

            Defaults to ``"autograd"``.
        h (float, optional):
            The step size for finite difference if ``hvp_method`` is
            ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
        inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
        seed (int | None, optional): seed for random generator. Defaults to None.


    Examples:
    NystromSketchAndSolve with backtracking line search

    ```py
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.NystromSketchAndSolve(100),
        tz.m.Backtracking()
    )
    ```

    Trust region NystromSketchAndSolve

    ```py
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.LevenbergMarquadt(tz.m.NystromSketchAndSolve(100)),
    )
    ```

    References:
    - [Frangella, Z., Rathore, P., Zhao, S., & Udell, M. (2024). SketchySGD: Reliable Stochastic Optimization via Randomized Curvature Estimates. SIAM Journal on Mathematics of Data Science, 6(4), 1173-1204.](https://arxiv.org/pdf/2211.08597)
    - [Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752](https://arxiv.org/abs/2110.02820)

    """
    def __init__(
        self,
        rank: int,
        reg: float | None = 1e-2,
        eigv_tol: float = 0,
        truncate: int | None = None,
        damping: float = 0,
        rdamping: float = 0,
        update_freq: int = 1,
        orthogonalize_method: OrthogonalizeMethod = 'qr',
        hvp_method: HVPMethod = "batched_autograd",
        h: float = 1e-3,
        inner: Chainable | None = None,
        seed: int | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults["update_freq"]
        super().__init__(defaults, update_freq=update_freq, inner=inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        params = TensorList(objective.params)
        fs = settings[0]

        # ---------------------- Hessian vector product function --------------------- #
        hvp_method = fs['hvp_method']
        h = fs['h']
        _, H_mv, H_mm = objective.tensor_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)

        # ---------------------------------- sketch ---------------------------------- #
        ndim = sum(t.numel() for t in objective.params)
        device = params[0].device
        dtype = params[0].dtype

        generator = self.get_generator(params[0].device, seed=fs['seed'])
        try:
            Omega = torch.randn([ndim, min(fs["rank"], ndim)], device=device, dtype=dtype, generator=generator)
            Omega = orthogonalize(Omega, fs["orthogonalize_method"])
            HOmega = H_mm(Omega)

            # compute the approximation
            L, Q = nystrom_approximation(
                Omega=Omega,
                AOmega=HOmega,
                eigv_tol=fs["eigv_tol"],
            )

            # regularize
            L, Q = regularize_eigh(
                L=L,
                Q=Q,
                truncate=fs["truncate"],
                tol=fs["eigv_tol"],
                damping=fs["damping"],
                rdamping=fs["rdamping"],
            )

            # store
            if L is not None:
                self.global_state["L"] = L
                self.global_state["Q"] = Q

        except torch.linalg.LinAlgError as e:
            warnings.warn(f"Nystrom approximation failed with: {e}")

    def apply_states(self, objective, states, settings):
        if "L" not in self.global_state:
            return objective

        fs = settings[0]
        updates = objective.get_updates()
        b=torch.cat([t.ravel() for t in updates])

        # ----------------------------------- solve ---------------------------------- #
        L = self.global_state["L"]
        Q = self.global_state["Q"]

        if fs["reg"] is None:
            x = Q @ ((Q.mH @ b) / L)
        else:
            x = nystrom_sketch_and_solve(L=L, Q=Q, b=b, reg=fs["reg"])

        # -------------------------------- set update -------------------------------- #
        objective.updates = vec_to_tensors(x, reference=objective.params)
        return objective

    def get_H(self, objective=...):
        if "L" not in self.global_state:
            return ScaledIdentity()

        L = self.global_state["L"]
        Q = self.global_state["Q"]
        return Eigendecomposition(L, Q)

SixthOrder3P

Bases: torchzero.modules.second_order.multipoint.HigherOrderMethodBase

Sixth-order iterative method.

Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.

Source code in torchzero/modules/second_order/multipoint.py
class SixthOrder3P(HigherOrderMethodBase):
    """Sixth-order iterative method.

    Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
    """
    def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
        defaults=dict(lstsq=lstsq)
        super().__init__(defaults=defaults, derivatives_method=derivatives_method)

    @torch.no_grad
    def one_iteration(self, x, evaluate, objective, setting):
        def f(x): return evaluate(x, 1)[1]
        def f_j(x): return evaluate(x, 2)[1:]
        x_star = sixth_order_3p(x, f, f_j, setting['lstsq'])
        return x - x_star

SixthOrder3PM2

Bases: torchzero.modules.second_order.multipoint.HigherOrderMethodBase

Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45.

Source code in torchzero/modules/second_order/multipoint.py
class SixthOrder3PM2(HigherOrderMethodBase):
    """Wang, Xiaofeng, and Yang Li. "An efficient sixth-order Newton-type method for solving nonlinear systems." Algorithms 10.2 (2017): 45."""
    def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
        defaults=dict(lstsq=lstsq)
        super().__init__(defaults=defaults, derivatives_method=derivatives_method)

    @torch.no_grad
    def one_iteration(self, x, evaluate, objective, setting):
        def f_j(x): return evaluate(x, 2)[1:]
        def f(x): return evaluate(x, 1)[1]
        x_star = sixth_order_3pm2(x, f, f_j, setting['lstsq'])
        return x - x_star

SixthOrder5P

Bases: torchzero.modules.second_order.multipoint.HigherOrderMethodBase

Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139.

Source code in torchzero/modules/second_order/multipoint.py
class SixthOrder5P(HigherOrderMethodBase):
    """Argyros, Ioannis K., et al. "Extended convergence for two sixth order methods under the same weak conditions." Foundations 3.1 (2023): 127-139."""
    def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
        defaults=dict(lstsq=lstsq)
        super().__init__(defaults=defaults, derivatives_method=derivatives_method)

    @torch.no_grad
    def one_iteration(self, x, evaluate, objective, setting):
        def f_j(x): return evaluate(x, 2)[1:]
        x_star = sixth_order_5p(x, f_j, setting['lstsq'])
        return x - x_star

SubspaceNewton

Bases: torchzero.core.transform.Transform

Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).

Parameters:

  • sketch_size (int) –

    size of the random sketch. This many hessian-vector products will need to be evaluated each step.

  • sketch_type (str, default: 'common_directions' ) –
    • "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt (default).
    • "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
    • "rows" - samples random rows.
    • "topk" - samples top-rank rows with largest gradient magnitude.
    • "rademacher" - approximately orthonormal (if dimension is large) scaled random rademacher basis.
    • "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction.
  • damping (float, default: 0 ) –

    hessian damping (scale of identity matrix added to hessian). Defaults to 0.

  • hvp_method (str, default: 'batched_autograd' ) –

    How to compute hessian-matrix product: - "batched_autograd" - uses batched autograd - "autograd" - uses unbatched autograd - "forward" - uses finite difference with forward formula, performing 1 backward pass per Hvp. - "central" - uses finite difference with a more accurate central formula, performing 2 backward passes per Hvp.

    . Defaults to "batched_autograd".

  • h (float, default: 0.01 ) –

    finite difference step size. Defaults to 1e-2.

  • use_lstsq (bool, default: True ) –

    whether to use least squares to solve Hx=g. Defaults to False.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • H_tfm (Callable | None) –

    optional hessian transforms, takes in two arguments - (hessian, gradient).

    must return either a tuple: (hessian, is_inverted) with transformed hessian and a boolean value which must be True if transform inverted the hessian and False otherwise.

    Or it returns a single tensor which is used as the update.

    Defaults to None.

  • eigval_fn (Callable | None, default: None ) –

    optional eigenvalues transform, for example torch.abs or lambda L: torch.clip(L, min=1e-8). If this is specified, eigendecomposition will be used to invert the hessian.

  • seed (int | None, default: None ) –

    seed for random generator. Defaults to None.

  • inner (Chainable | None, default: None ) –

    preconditions output of this module. Defaults to None.

Examples

RSN with line search

opt = tz.Optimizer(
    model.parameters(),
    tz.m.RSN(),
    tz.m.Backtracking()
)

RSN with trust region

opt = tz.Optimizer(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.RSN()),
)

References
  1. Gower, Robert, et al. "RSN: randomized subspace Newton." Advances in Neural Information Processing Systems 32 (2019).
  2. Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
Source code in torchzero/modules/second_order/rsn.py
class SubspaceNewton(Transform):
    """Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).

    Args:
        sketch_size (int):
            size of the random sketch. This many hessian-vector products will need to be evaluated each step.
        sketch_type (str, optional):
            - "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt (default).
            - "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
            - "rows" - samples random rows.
            - "topk" - samples top-rank rows with largest gradient magnitude.
            - "rademacher" - approximately orthonormal (if dimension is large) scaled random rademacher basis.
            - "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction.
        damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
        hvp_method (str, optional):
            How to compute hessian-matrix product:
            - "batched_autograd" - uses batched autograd
            - "autograd" - uses unbatched autograd
            - "forward" - uses finite difference with forward formula, performing 1 backward pass per Hvp.
            - "central" - uses finite difference with a more accurate central formula, performing 2 backward passes per Hvp.

            . Defaults to "batched_autograd".
        h (float, optional): finite difference step size. Defaults to 1e-2.
        use_lstsq (bool, optional): whether to use least squares to solve ``Hx=g``. Defaults to False.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        H_tfm (Callable | None, optional):
            optional hessian transforms, takes in two arguments - `(hessian, gradient)`.

            must return either a tuple: `(hessian, is_inverted)` with transformed hessian and a boolean value
            which must be True if transform inverted the hessian and False otherwise.

            Or it returns a single tensor which is used as the update.

            Defaults to None.
        eigval_fn (Callable | None, optional):
            optional eigenvalues transform, for example ``torch.abs`` or ``lambda L: torch.clip(L, min=1e-8)``.
            If this is specified, eigendecomposition will be used to invert the hessian.
        seed (int | None, optional): seed for random generator. Defaults to None.
        inner (Chainable | None, optional): preconditions output of this module. Defaults to None.

    ### Examples

    RSN with line search
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.RSN(),
        tz.m.Backtracking()
    )
    ```

    RSN with trust region
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.LevenbergMarquardt(tz.m.RSN()),
    )
    ```


    References:
        1. [Gower, Robert, et al. "RSN: randomized subspace Newton." Advances in Neural Information Processing Systems 32 (2019).](https://arxiv.org/abs/1905.10874)
        2. Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
    """

    def __init__(
        self,
        sketch_size: int,
        sketch_type: Literal["orthonormal", "common_directions", "mixed", "rademacher", "rows", "topk"] = "common_directions",
        damping:float=0,
        eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
        update_freq: int = 1,
        precompute_inverse: bool = False,
        use_lstsq: bool = True,
        hvp_method: HVPMethod = "batched_autograd",
        h: float = 1e-2,
        seed: int | None = None,
        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner'], defaults["update_freq"]
        super().__init__(defaults, update_freq=update_freq, inner=inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]
        params = objective.params
        generator = self.get_generator(params[0].device, fs["seed"])

        ndim = sum(p.numel() for p in params)

        device=params[0].device
        dtype=params[0].dtype

        # sample sketch matrix S: (ndim, sketch_size)
        sketch_size = min(fs["sketch_size"], ndim)
        sketch_type = fs["sketch_type"]
        hvp_method = fs["hvp_method"]

        if sketch_type == "rademacher":
            S = _rademacher_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)

        elif sketch_type == 'orthonormal':
            S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)

        elif sketch_type == "rows":
            S = _row_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)

        elif sketch_type == "topk":
            g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
            g = torch.cat([t.ravel() for t in g_list])
            S = _topk_rows(g, ndim, sketch_size, device=device, dtype=dtype, generator=generator)

        elif sketch_type == 'common_directions':
            # Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
            g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
            g = torch.cat([t.ravel() for t in g_list])

            # initialize directions deque
            if "directions" not in self.global_state:

                g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
                if g_norm < torch.finfo(g.dtype).tiny * 2:
                    g = torch.randn_like(g)
                    g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable

                self.global_state["directions"] = deque([g / g_norm], maxlen=sketch_size)
                S = self.global_state["directions"][0].unsqueeze(1)

            # add new steepest descent direction orthonormal to existing columns
            else:
                S = torch.stack(tuple(self.global_state["directions"]), dim=1)
                p = g - S @ (S.T @ g)
                p_norm = torch.linalg.vector_norm(p) # pylint:disable=not-callable
                if p_norm > torch.finfo(p.dtype).tiny * 2:
                    p = p / p_norm
                    self.global_state["directions"].append(p)
                    S = torch.cat([S, p.unsqueeze(1)], dim=1)

        elif sketch_type == "mixed":
            g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
            g = torch.cat([t.ravel() for t in g_list])

            # initialize state
            if "slow_ema" not in self.global_state:
                self.global_state["slow_ema"] = torch.randn_like(g) * 1e-2
                self.global_state["fast_ema"] = torch.randn_like(g) * 1e-2
                self.global_state["p_prev"] = torch.randn_like(g)

            # previous update direction
            p_cur = torch.cat([t.ravel() for t in params])
            prev_dir = p_cur - self.global_state["p_prev"]
            self.global_state["p_prev"] = p_cur

            # EMAs
            slow_ema = self.global_state["slow_ema"]
            fast_ema = self.global_state["fast_ema"]
            slow_ema.lerp_(g, 0.001)
            fast_ema.lerp_(g, 0.1)

            # form and orthogonalize sketching matrix
            S = torch.stack([g, slow_ema, fast_ema, prev_dir], dim=1)
            if sketch_size > 4:
                S_random = torch.randn(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator) / math.sqrt(ndim)
                S = torch.cat([S, S_random], dim=1)

            S = _qr_orthonormalize(S)

        else:
            raise ValueError(f'Unknown sketch_type {sketch_type}')

        # print(f'{S.shape = }')
        # I = torch.eye(S.size(1), device=S.device, dtype=S.dtype)
        # print(f'{torch.nn.functional.mse_loss(S.T @ S, I) = }')

        # form sketched hessian
        HS, _ = objective.hessian_matrix_product(S, rgrad=None, at_x0=True,
                                                 hvp_method=fs["hvp_method"], h=fs["h"])
        H_sketched = S.T @ HS

        # update state
        _newton_update_state_(
            state = self.global_state,
            H = H_sketched,
            damping = fs["damping"],
            eigval_fn = fs["eigval_fn"],
            precompute_inverse = fs["precompute_inverse"],
            use_lstsq = fs["use_lstsq"]

        )

        self.global_state["S"] = S

    def apply_states(self, objective, states, settings):
        updates = objective.get_updates()
        fs = settings[0]

        S = self.global_state["S"]
        b = torch.cat([t.ravel() for t in updates])
        b_proj = S.T @ b

        d_proj = _newton_solve(b=b_proj, state=self.global_state, use_lstsq=fs["use_lstsq"])

        d = S @ d_proj
        vec_to_tensors_(d, updates)
        return objective

    def get_H(self, objective=...):
        if "H" in self.global_state:
            H_sketched = self.global_state["H"]

        else:
            L = self.global_state["L"]
            Q = self.global_state["Q"]
            H_sketched = Q @ L.diag_embed() @ Q.mH

        S: torch.Tensor = self.global_state["S"]
        return Sketched(S, H_sketched)

TwoPointNewton

Bases: torchzero.modules.second_order.multipoint.HigherOrderMethodBase

two-point Newton method with frozen derivative with third order convergence.

Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73.

Source code in torchzero/modules/second_order/multipoint.py
class TwoPointNewton(HigherOrderMethodBase):
    """two-point Newton method with frozen derivative with third order convergence.

    Sharma, Janak Raj, and Deepak Kumar. "A fast and efficient composite Newton–Chebyshev method for systems of nonlinear equations." Journal of Complexity 49 (2018): 56-73."""
    def __init__(self, lstsq: bool=False, derivatives_method: DerivativesMethod = 'batched_autograd'):
        defaults=dict(lstsq=lstsq)
        super().__init__(defaults=defaults, derivatives_method=derivatives_method)

    @torch.no_grad
    def one_iteration(self, x, evaluate, objective, setting):
        def f(x): return evaluate(x, 1)[1]
        def f_j(x): return evaluate(x, 2)[1:]
        x_star = two_point_newton(x, f, f_j, setting['lstsq'])
        return x - x_star