Skip to content

Trust region

This subpackage contains trust region methods.

See also

  • Step size - step size selection methods like Barzilai-Borwein and Polyak's step size.
  • Line search - line search methods.

Classes:

CubicRegularization

Bases: torchzero.modules.trust_region.trust_region.TrustRegionBase

Cubic regularization.

Parameters:

  • hess_module (Module | None) –

    A module that maintains a hessian approximation (not hessian inverse!). This includes all full-matrix quasi-newton methods, tz.m.Newton and tz.m.GaussNewton. When using quasi-newton methods, set inverse=False when constructing them.

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. When :code:hess_module is GaussNewton, this can be set to 0. Defaults to 0.15.

  • nplus (float, default: 3.5 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.99 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.0001 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • maxiter (float, default: 100 ) –

    maximum iterations when solving cubic subproblem, defaults to 1e-7.

  • eps (float, default: 1e-08 ) –

    epsilon for the solver, defaults to 1e-8.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • max_attempts (max_attempts, default: 10 ) –

    maximum number of trust region size size reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • fallback (bool) –

    if True, when hess_module maintains hessian inverse which can't be inverted efficiently, it will be inverted anyway. When False (default), a RuntimeError will be raised instead.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to output of thise module. Defaults to None.

Examples:

Cubic regularized newton

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.CubicRegularization(tz.m.Newton()),
)
Source code in torchzero/modules/trust_region/cubic_regularization.py
class CubicRegularization(TrustRegionBase):
    """Cubic regularization.

    Args:
        hess_module (Module | None, optional):
            A module that maintains a hessian approximation (not hessian inverse!).
            This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
            When using quasi-newton methods, set `inverse=False` when constructing them.
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted.
            When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        maxiter (float, optional): maximum iterations when solving cubic subproblem, defaults to 1e-7.
        eps (float, optional): epsilon for the solver, defaults to 1e-8.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        max_attempts (max_attempts, optional):
            maximum number of trust region size size reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        fallback (bool, optional):
            if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
            be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
        inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.


    Examples:
        Cubic regularized newton

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.CubicRegularization(tz.m.Newton()),
            )

    """
    def __init__(
        self,
        hess_module: Chainable,
        eta: float= 0.0,
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        init: float = 1,
        max_attempts: int = 10,
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
        maxiter: int = 100,
        eps: float = 1e-8,
        check_decrease:bool=False,
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict(maxiter=maxiter, eps=eps, check_decrease=check_decrease)
        super().__init__(
            defaults=defaults,
            hess_module=hess_module,
            eta=eta,
            nplus=nplus,
            nminus=nminus,
            rho_good=rho_good,
            rho_bad=rho_bad,
            init=init,
            max_attempts=max_attempts,
            radius_strategy=radius_strategy,
            update_freq=update_freq,
            inner=inner,

            boundary_tol=None,
            radius_fn=None,
        )

    def trust_solve(self, f, g, H, radius, params, closure, settings):
        params = TensorList(params)

        loss_at_params_plus_x_fn = None
        if settings['check_decrease']:
            def closure_plus_x(x):
                x_unflat = vec_to_tensors(x, params)
                params.add_(x_unflat)
                loss_x = closure(False)
                params.sub_(x_unflat)
                return loss_x
            loss_at_params_plus_x_fn = closure_plus_x


        d, _ = ls_cubic_solver(f=f, g=g, H=H, M=1/radius, loss_at_params_plus_x_fn=loss_at_params_plus_x_fn,
                               it_max=settings['maxiter'], epsilon=settings['eps'])
        return d.neg_()

Dogleg

Bases: torchzero.modules.trust_region.trust_region.TrustRegionBase

Dogleg trust region algorithm.

Parameters:

  • hess_module (Module | None) –

    A module that maintains a hessian approximation (not hessian inverse!). This includes all full-matrix quasi-newton methods, tz.m.Newton and tz.m.GaussNewton. When using quasi-newton methods, set inverse=False when constructing them.

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. When :code:hess_module is GaussNewton, this can be set to 0. Defaults to 0.15.

  • nplus (float, default: 2 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.75 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.25 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • max_attempts (max_attempts, default: 10 ) –

    maximum number of trust region size size reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to output of thise module. Defaults to None.

Source code in torchzero/modules/trust_region/dogleg.py
class Dogleg(TrustRegionBase):
    """Dogleg trust region algorithm.


    Args:
        hess_module (Module | None, optional):
            A module that maintains a hessian approximation (not hessian inverse!).
            This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
            When using quasi-newton methods, set `inverse=False` when constructing them.
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted.
            When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        max_attempts (max_attempts, optional):
            maximum number of trust region size size reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.

    """
    def __init__(
        self,
        hess_module: Chainable,
        eta: float= 0.0,
        nplus: float = 2,
        nminus: float = 0.25,
        rho_good: float = 0.75,
        rho_bad: float = 0.25,
        boundary_tol: float | None = None,
        init: float = 1,
        max_attempts: int = 10,
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict()
        super().__init__(
            defaults=defaults,
            hess_module=hess_module,
            eta=eta,
            nplus=nplus,
            nminus=nminus,
            rho_good=rho_good,
            rho_bad=rho_bad,
            boundary_tol=boundary_tol,
            init=init,
            max_attempts=max_attempts,
            radius_strategy=radius_strategy,
            update_freq=update_freq,
            inner=inner,

            radius_fn=torch.linalg.vector_norm,
        )

    def trust_solve(self, f, g, H, radius, params, closure, settings):
        if radius > 2: radius = self.global_state['radius'] = 2
        eps = torch.finfo(g.dtype).tiny * 2

        gHg = g.dot(H.matvec(g))
        if gHg <= eps:
            return (radius / torch.linalg.vector_norm(g)) * g # pylint:disable=not-callable

        p_cauchy = (g.dot(g) / gHg) * g
        p_newton = H.solve(g)

        a = p_newton - p_cauchy
        b = p_cauchy

        aa = a.dot(a)
        if aa < eps:
            return (radius / torch.linalg.vector_norm(g)) * g # pylint:disable=not-callable

        ab = a.dot(b)
        bb = b.dot(b)
        c = bb - radius**2
        discriminant = (2*ab)**2 - 4*aa*c
        beta = (-2*ab + torch.sqrt(discriminant.clip(min=0))) / (2 * aa)
        return p_cauchy + beta * (p_newton - p_cauchy)

LevenbergMarquardt

Bases: torchzero.modules.trust_region.trust_region.TrustRegionBase

Levenberg-Marquardt trust region algorithm.

Parameters:

  • hess_module (Module | None) –

    A module that maintains a hessian approximation (not hessian inverse!). This includes all full-matrix quasi-newton methods, tz.m.Newton and tz.m.GaussNewton. When using quasi-newton methods, set inverse=False when constructing them.

  • y (float, default: 0 ) –

    when y=0, identity matrix is added to hessian, when y=1, diagonal of the hessian approximation is added. Values between interpolate. This should only be used with Gauss-Newton. Defaults to 0.

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. When hess_module is Newton or GaussNewton, this can be set to 0. Defaults to 0.15.

  • nplus (float, default: 3.5 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.99 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.0001 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • max_attempts (max_attempts, default: 10 ) –

    maximum number of trust region size size reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • fallback (bool, default: False ) –

    if True, when hess_module maintains hessian inverse which can't be inverted efficiently, it will be inverted anyway. When False (default), a RuntimeError will be raised instead.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to output of thise module. Defaults to None.

Examples:

Gauss-Newton with Levenberg-Marquardt trust-region

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
)

LM-SR1

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
)

First order trust region (hessian is assumed to be identity)

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.LevenbergMarquardt(tz.m.Identity()),
)
Source code in torchzero/modules/trust_region/levenberg_marquardt.py
class LevenbergMarquardt(TrustRegionBase):
    """Levenberg-Marquardt trust region algorithm.


    Args:
        hess_module (Module | None, optional):
            A module that maintains a hessian approximation (not hessian inverse!).
            This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
            When using quasi-newton methods, set ``inverse=False`` when constructing them.
        y (float, optional):
            when ``y=0``, identity matrix is added to hessian, when ``y=1``, diagonal of the hessian approximation
            is added. Values between interpolate. This should only be used with Gauss-Newton. Defaults to 0.
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted.
            When ``hess_module`` is ``Newton`` or ``GaussNewton``, this can be set to 0. Defaults to 0.15.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        max_attempts (max_attempts, optional):
            maximum number of trust region size size reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        fallback (bool, optional):
            if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
            be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
        inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.

    Examples:
        Gauss-Newton with Levenberg-Marquardt trust-region

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
            )

        LM-SR1

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
            )

        First order trust region (hessian is assumed to be identity)

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.LevenbergMarquardt(tz.m.Identity()),
            )

    """
    def __init__(
        self,
        hess_module: Chainable,
        eta: float= 0.0,
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        init: float = 1,
        max_attempts: int = 10,
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
        y: float = 0,
        fallback: bool = False,
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict(y=y, fallback=fallback)
        super().__init__(
            defaults=defaults,
            hess_module=hess_module,
            eta=eta,
            nplus=nplus,
            nminus=nminus,
            rho_good=rho_good,
            rho_bad=rho_bad,
            init=init,
            max_attempts=max_attempts,
            radius_strategy=radius_strategy,
            update_freq=update_freq,
            inner=inner,

            boundary_tol=None,
            radius_fn=None,
        )

    def trust_solve(self, f, g, H, radius, params, closure, settings):
        y = settings['y']

        if isinstance(H, linear_operator.DenseInverse):
            if settings['fallback']:
                H = H.to_dense()
            else:
                raise RuntimeError(
                    f"{self.children['hess_module']} maintains a hessian inverse. "
                    "LevenbergMarquardt requires the hessian, not the inverse. "
                    "If that module is a quasi-newton module, pass `inverse=False` on initialization. "
                    "Or pass `fallback=True` to LevenbergMarquardt to allow inverting the hessian inverse, "
                    "however that can be inefficient and unstable."
                )

        reg = 1/radius
        if y == 0:
            return H.add_diagonal(reg).solve(g)

        diag = H.diagonal()
        diag = torch.where(diag < torch.finfo(diag.dtype).tiny * 2, 1, diag)
        if y != 1: diag = (diag*y) + (1-y)
        return H.add_diagonal(diag*reg).solve(g)

TrustCG

Bases: torchzero.modules.trust_region.trust_region.TrustRegionBase

Trust region via Steihaug-Toint Conjugate Gradient method.

.. note::

If you wish to use exact hessian, use the matrix-free :code:`tz.m.NewtonCGSteihaug`
which only uses hessian-vector products. While passing ``tz.m.Newton`` to this
is possible, it is usually less efficient.

Parameters:

  • hess_module (Module | None) –

    A module that maintains a hessian approximation (not hessian inverse!). This includes all full-matrix quasi-newton methods, tz.m.Newton and tz.m.GaussNewton. When using quasi-newton methods, set inverse=False when constructing them.

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. When :code:hess_module is GaussNewton, this can be set to 0. Defaults to 0.15.

  • nplus (float, default: 3.5 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.99 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.0001 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • reg (int, default: 0 ) –

    regularization parameter for conjugate gradient. Defaults to 0.

  • max_attempts (max_attempts, default: 10 ) –

    maximum number of trust region size size reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • boundary_tol (float | None, default: 1e-06 ) –

    The trust region only increases when suggested step's norm is at least (1-boundary_tol)*trust_region. This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.

  • prefer_exact (bool, default: True ) –

    when exact solution can be easily calculated without CG (e.g. hessian is stored as scaled identity), uses the exact solution. If False, always uses CG. Defaults to True.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to output of thise module. Defaults to None.

Examples:

Trust-SR1

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
)
Source code in torchzero/modules/trust_region/trust_cg.py
class TrustCG(TrustRegionBase):
    """Trust region via Steihaug-Toint Conjugate Gradient method.

    .. note::

        If you wish to use exact hessian, use the matrix-free :code:`tz.m.NewtonCGSteihaug`
        which only uses hessian-vector products. While passing ``tz.m.Newton`` to this
        is possible, it is usually less efficient.

    Args:
        hess_module (Module | None, optional):
            A module that maintains a hessian approximation (not hessian inverse!).
            This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
            When using quasi-newton methods, set `inverse=False` when constructing them.
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted.
            When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        reg (int, optional): regularization parameter for conjugate gradient. Defaults to 0.
        max_attempts (max_attempts, optional):
            maximum number of trust region size size reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        boundary_tol (float | None, optional):
            The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
            This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
        prefer_exact (bool, optional):
            when exact solution can be easily calculated without CG (e.g. hessian is stored as scaled identity),
            uses the exact solution. If False, always uses CG. Defaults to True.
        inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.

    Examples:
        Trust-SR1

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
            )
    """
    def __init__(
        self,
        hess_module: Chainable,
        eta: float= 0.0,
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        boundary_tol: float | None = 1e-6, # tuned
        init: float = 1,
        max_attempts: int = 10,
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
        reg: float = 0,
        maxiter: int | None = None,
        miniter: int = 1,
        cg_tol: float = 1e-8,
        prefer_exact: bool = True,
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol, maxiter=maxiter, miniter=miniter)
        super().__init__(
            defaults=defaults,
            hess_module=hess_module,
            eta=eta,
            nplus=nplus,
            nminus=nminus,
            rho_good=rho_good,
            rho_bad=rho_bad,
            boundary_tol=boundary_tol,
            init=init,
            max_attempts=max_attempts,
            radius_strategy=radius_strategy,
            update_freq=update_freq,
            inner=inner,

            radius_fn=torch.linalg.vector_norm,
        )

    def trust_solve(self, f, g, H, radius, params, closure, settings):
        if settings['prefer_exact'] and isinstance(H, linear_operator.ScaledIdentity):
            return H.solve_bounded(g, radius)

        x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], maxiter=settings["maxiter"], miniter=settings["miniter"], tol=settings["cg_tol"])
        return x

TrustRegionBase

Bases: torchzero.core.module.Module, abc.ABC

Methods:

  • trust_region_apply

    Solves the trust region subproblem and outputs Var with the solution direction.

  • trust_region_update

    updates the state of this module after H or B have been updated, if necessary

  • trust_solve

    Solve Hx=g with a trust region penalty/bound defined by radius

Source code in torchzero/modules/trust_region/trust_region.py
class TrustRegionBase(Module, ABC):
    def __init__(
        self,
        defaults: dict | None,
        hess_module: Chainable,
        # suggested default values:
        # Gould, Nicholas IM, et al. "Sensitivity of trust-region algorithms to their parameters." 4OR 3.3 (2005): 227-241.
        # which I found from https://github.com/patrick-kidger/optimistix/blob/c1dad7e75fc35bd5a4977ac3a872991e51e83d2c/optimistix/_solver/trust_region.py#L113-200
        eta: float, # 0.0
        nplus: float, # 3.5
        nminus: float, # 0.25
        rho_good: float, # 0.99
        rho_bad: float, # 1e-4
        boundary_tol: float | None, # None or 1e-1
        init: float, # 1
        max_attempts: int, # 10
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS, # "default"
        radius_fn: Callable | None, # torch.linalg.vector_norm
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        if isinstance(radius_strategy, str): radius_strategy = _RADIUS_STRATEGIES[radius_strategy]
        if defaults is None: defaults = {}

        safe_dict_update_(
            defaults,
            dict(eta=eta, nplus=nplus, nminus=nminus, rho_good=rho_good, rho_bad=rho_bad, init=init,
                 update_freq=update_freq, max_attempts=max_attempts, radius_strategy=radius_strategy,
                 boundary_tol=boundary_tol)
        )

        super().__init__(defaults)

        self._radius_fn = radius_fn
        self.set_child('hess_module', hess_module)

        if inner is not None:
            self.set_child('inner', inner)

    @abstractmethod
    def trust_solve(
        self,
        f: float,
        g: torch.Tensor,
        H: LinearOperator,
        radius: float,
        params: list[torch.Tensor],
        closure: Callable,
        settings: Mapping[str, Any],
    ) -> torch.Tensor:
        """Solve Hx=g with a trust region penalty/bound defined by `radius`"""
        ... # pylint:disable=unnecessary-ellipsis

    def trust_region_update(self, var: Var, H: LinearOperator | None) -> None:
        """updates the state of this module after H or B have been updated, if necessary"""

    def trust_region_apply(self, var: Var, tensors:list[torch.Tensor], H: LinearOperator | None) -> Var:
        """Solves the trust region subproblem and outputs ``Var`` with the solution direction."""
        assert H is not None

        params = TensorList(var.params)
        settings = self.settings[params[0]]
        g = _flatten_tensors(tensors)

        max_attempts = settings['max_attempts']

        # loss at x_0
        loss = var.loss
        closure = var.closure
        if closure is None: raise RuntimeError("Trust region requires closure")
        if loss is None: loss = var.get_loss(False)
        loss = tofloat(loss)

        # trust region step and update
        success = False
        d = None
        while not success:
            max_attempts -= 1
            if max_attempts < 0: break

            trust_radius = self.global_state.get('trust_radius', settings['init'])

            # solve Hx=g
            d = self.trust_solve(f=loss, g=g, H=H, radius=trust_radius, params=params, closure=closure, settings=settings)

            # update trust radius
            radius_strategy: _RadiusStrategy = settings['radius_strategy']
            self.global_state["trust_radius"], success = radius_strategy(
                params=params,
                closure=closure,
                d=d,
                f=loss,
                g=g,
                H=H,
                trust_radius=trust_radius,

                eta=settings["eta"],
                nplus=settings["nplus"],
                nminus=settings["nminus"],
                rho_good=settings["rho_good"],
                rho_bad=settings["rho_bad"],
                boundary_tol=settings["boundary_tol"],
                init=settings["init"],

                state=self.global_state,
                settings=settings,
                radius_fn=self._radius_fn,
            )

        assert d is not None
        if success: var.update = vec_to_tensors(d, params)
        else: var.update = params.zeros_like()

        return var


    @final
    @torch.no_grad
    def update(self, var):
        step = self.global_state.get('step', 0)
        self.global_state['step'] = step + 1

        if step % self.defaults["update_freq"] == 0:

            hessian_module = self.children['hess_module']
            hessian_module.update(var)
            H = hessian_module.get_H(var)
            self.global_state["H"] = H

            self.trust_region_update(var, H=H)


    @final
    @torch.no_grad
    def apply(self, var):
        H = self.global_state.get('H', None)

        # -------------------------------- inner step -------------------------------- #
        update = var.get_update()
        if 'inner' in self.children:
            update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)

        # ----------------------------------- apply ---------------------------------- #
        return self.trust_region_apply(var=var, tensors=update, H=H)

trust_region_apply

trust_region_apply(var: Var, tensors: list[Tensor], H: LinearOperator | None) -> Var

Solves the trust region subproblem and outputs Var with the solution direction.

Source code in torchzero/modules/trust_region/trust_region.py
def trust_region_apply(self, var: Var, tensors:list[torch.Tensor], H: LinearOperator | None) -> Var:
    """Solves the trust region subproblem and outputs ``Var`` with the solution direction."""
    assert H is not None

    params = TensorList(var.params)
    settings = self.settings[params[0]]
    g = _flatten_tensors(tensors)

    max_attempts = settings['max_attempts']

    # loss at x_0
    loss = var.loss
    closure = var.closure
    if closure is None: raise RuntimeError("Trust region requires closure")
    if loss is None: loss = var.get_loss(False)
    loss = tofloat(loss)

    # trust region step and update
    success = False
    d = None
    while not success:
        max_attempts -= 1
        if max_attempts < 0: break

        trust_radius = self.global_state.get('trust_radius', settings['init'])

        # solve Hx=g
        d = self.trust_solve(f=loss, g=g, H=H, radius=trust_radius, params=params, closure=closure, settings=settings)

        # update trust radius
        radius_strategy: _RadiusStrategy = settings['radius_strategy']
        self.global_state["trust_radius"], success = radius_strategy(
            params=params,
            closure=closure,
            d=d,
            f=loss,
            g=g,
            H=H,
            trust_radius=trust_radius,

            eta=settings["eta"],
            nplus=settings["nplus"],
            nminus=settings["nminus"],
            rho_good=settings["rho_good"],
            rho_bad=settings["rho_bad"],
            boundary_tol=settings["boundary_tol"],
            init=settings["init"],

            state=self.global_state,
            settings=settings,
            radius_fn=self._radius_fn,
        )

    assert d is not None
    if success: var.update = vec_to_tensors(d, params)
    else: var.update = params.zeros_like()

    return var

trust_region_update

trust_region_update(var: Var, H: LinearOperator | None) -> None

updates the state of this module after H or B have been updated, if necessary

Source code in torchzero/modules/trust_region/trust_region.py
def trust_region_update(self, var: Var, H: LinearOperator | None) -> None:
    """updates the state of this module after H or B have been updated, if necessary"""

trust_solve

trust_solve(f: float, g: Tensor, H: LinearOperator, radius: float, params: list[Tensor], closure: Callable, settings: Mapping[str, Any]) -> Tensor

Solve Hx=g with a trust region penalty/bound defined by radius

Source code in torchzero/modules/trust_region/trust_region.py
@abstractmethod
def trust_solve(
    self,
    f: float,
    g: torch.Tensor,
    H: LinearOperator,
    radius: float,
    params: list[torch.Tensor],
    closure: Callable,
    settings: Mapping[str, Any],
) -> torch.Tensor:
    """Solve Hx=g with a trust region penalty/bound defined by `radius`"""
    ... # pylint:disable=unnecessary-ellipsis