Skip to content

Experimental

This subpackage contains various horrible atrocities that are generally less tested.

Those are various ideas of mine plus some other modules that I decided not to move to other sub-packages for whatever reason. This is generally less tested.

Classes:

  • BlockPartition

    splits parameters into blocks (for now flatttens them and chunks)

  • CoordinateMomentum

    Maintains a momentum buffer, on each step each value in the buffer has p chance to be updated with the new value.

  • CubicAdam

    Adam which has 3rd momentum and minimizes a cubic polynomial.

  • CurveBall

    CurveBall method from https://arxiv.org/pdf/1805.08095#page=4.09.

  • FFTProjection

    Project update into Fourier space of real-valued inputs.

  • GradMin

    Reformulates the objective to minimize sum of gradient magnitudes via autograd. This is not expected to be practical.

  • HigherOrderNewton

    A basic arbitrary order newton's method with optional trust region and proximal penalty.

  • InfinityNormTrustRegion

    Trust region with L-infinity norm via scipy.optimize.lsq_linear.

  • NewtonNewton

    Applies Newton-like preconditioning to Newton step.

  • NewtonSolver

    Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG).

  • ReduceOutwardLR

    When update sign matches weight sign, the learning rate for that weight is multiplied by mul.

  • ScipyNewtonCG

    NewtonCG with scipy solvers (any from scipy.sparse.linalg)

  • SubspaceCubicAdam

    Runs cubic Adam in low rank eigenbasis.

  • TensorizeProjection

    flattens and concatenates all parameters into a vector and then reshapes it into a tensor

BlockPartition

Bases: torchzero.modules.projections.projection.ProjectionBase

splits parameters into blocks (for now flatttens them and chunks)

Source code in torchzero/modules/experimental/structural_projections.py
class BlockPartition(ProjectionBase):
    """splits parameters into blocks (for now flatttens them and chunks)"""
    def __init__(self, modules: Chainable, max_size: int, batched: bool = False, project_update=True, project_params=False, project_grad=False):
        defaults = dict(max_size=max_size, batched=batched)
        super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)

    @torch.no_grad
    def project(self, tensors, params, grads, loss, states, settings, current):
        partitioned = []
        for p,t in zip(params, tensors):
            settings = self.settings[p]
            max_size = settings['max_size']
            n = t.numel()
            if n <= max_size:
                partitioned.append(t)
                continue

            t_flat = t.view(-1)

            batched = settings['batched']
            num_chunks = math.ceil(n / max_size)

            if batched:
                chunks_size = num_chunks * max_size
                if num_chunks * max_size > n:
                    t_flat = torch.cat([t_flat, torch.zeros(n-chunks_size, dtype=t_flat.dtype, device=t_flat.device)])
                partitioned.append(t_flat.view(num_chunks, -1))

            else:
                partitioned.extend(t_flat.chunk(num_chunks))

        return partitioned

    @torch.no_grad
    def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
        ti = iter(projected_tensors)
        unprojected = []
        for p in params:
            settings = self.settings[p]
            n = p.numel()

            if settings['batched']:
                unprojected.append(next(ti).view(-1)[:n].view_as(p))

            else:
                chunks = []
                t_n = 0
                while t_n < n:
                    t = next(ti)
                    chunks.append(t)
                    t_n += t.numel()

                assert t_n == n
                unprojected.append(torch.cat(chunks).view_as(p))

        return unprojected

CoordinateMomentum

Bases: torchzero.core.transform.TensorTransform

Maintains a momentum buffer, on each step each value in the buffer has p chance to be updated with the new value.

Parameters:

  • p (float, default: 0.1 ) –

    description. Defaults to 0.1.

Source code in torchzero/modules/experimental/coordinate_momentum.py
class CoordinateMomentum(TensorTransform):
    """Maintains a momentum buffer, on each step each value in the buffer has ``p`` chance to be updated with the new value.

    Args:
        p (float, optional): _description_. Defaults to 0.1.
    """
    def __init__(self, p: float = 0.1):
        defaults = dict(p=p)
        super().__init__(defaults)

        self.add_projected_keys("grad", "velocity")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        p = NumberList(s['p'] for s in settings)
        velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
        return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()

CubicAdam

Bases: torchzero.core.transform.TensorTransform

Adam which has 3rd momentum and minimizes a cubic polynomial.

Source code in torchzero/modules/experimental/cubic_adam.py
class CubicAdam(TensorTransform):
    """Adam which has 3rd momentum and minimizes a cubic polynomial."""
    def __init__(
        self,
        beta1: float = 0.9,
        beta2: float = 0.99,
        beta3: float = 0.99,
        eps: float = 1e-8,
        debiased:bool=True,
        alpha: float = 1.,

        mode: _cubic_adam_mode = 'signed_cbrt'
    ):
        defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,debiased=debiased,alpha=alpha,mode=mode)
        super().__init__(defaults)

        self.add_projected_keys("grad", "exp_avg")
        self.add_projected_keys("grad_sq", "exp_avg_sq")
        self.add_projected_keys("grad_cu", "exp_avg_cu")

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        beta1,beta2,beta3,eps,alpha=unpack_dicts(settings, 'beta1','beta2','beta3','eps','alpha', cls=NumberList)
        exp_avg, exp_avg_sq, exp_avg_cu = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'exp_avg_cu', cls=TensorList)

        return cubic_adam_(
            tensors=TensorList(tensors),
            exp_avg_=exp_avg,
            exp_avg_sq_=exp_avg_sq,
            exp_avg_cu_=exp_avg_cu,
            alpha=alpha,
            beta1=beta1,
            beta2=beta2,
            beta3=beta3,
            eps=eps,
            debiased=settings[0]['debiased'],
            step=step,

            mode=settings[0]["mode"]
        )

CurveBall

Bases: torchzero.core.transform.Transform

CurveBall method from https://arxiv.org/pdf/1805.08095#page=4.09.

For now this implementation does not include automatic ρ, α and β hyper-parameters in closed form, therefore it is expected to underperform compared to official implementation (https://github.com/jotaf98/pytorch-curveball/tree/master) so I moved this to experimental.

Parameters:

  • precond_lr (float, default: 0.001 ) –

    learning rate for updating preconditioned gradients. Defaults to 1e-3.

  • momentum (float, default: 0.9 ) –

    decay rate for preconditioned gradients. Defaults to 0.9.

  • hvp_method (str, default: 'autograd' ) –

    how to calculate hessian vector products. Defaults to "autograd".

  • h (float, default: 0.001 ) –

    finite difference step size for when hvp_method is set to finite difference. Defaults to 1e-3.

  • reg (float, default: 1 ) –

    hessian regularization. Defaults to 1.

  • inner (Chainable | None, default: None ) –

    Inner modules. Defaults to None.

Source code in torchzero/modules/experimental/curveball.py
class CurveBall(Transform):
    """CurveBall method from https://arxiv.org/pdf/1805.08095#page=4.09.

    For now this implementation does not include automatic ρ, α and β hyper-parameters in closed form, therefore it is expected to underperform compared to official implementation (https://github.com/jotaf98/pytorch-curveball/tree/master) so I moved this to experimental.

    Args:
        precond_lr (float, optional): learning rate for updating preconditioned gradients. Defaults to 1e-3.
        momentum (float, optional): decay rate for preconditioned gradients. Defaults to 0.9.
        hvp_method (str, optional): how to calculate hessian vector products. Defaults to "autograd".
        h (float, optional): finite difference step size for when hvp_method is set to finite difference. Defaults to 1e-3.
        reg (float, optional): hessian regularization. Defaults to 1.
        inner (Chainable | None, optional): Inner modules. Defaults to None.
    """
    def __init__(
        self,
        precond_lr: float=1e-3,
        momentum: float=0.9,
        hvp_method: HVPMethod = "autograd",
        h: float = 1e-3,
        reg: float = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict(precond_lr=precond_lr, momentum=momentum, hvp_method=hvp_method, h=h, reg=reg)
        super().__init__(defaults)

        self.set_child('inner', inner)

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        params = objective.params
        fs = settings[0]
        hvp_method = fs['hvp_method']
        h = fs['h']

        precond_lr, momentum, reg = unpack_dicts(settings, 'precond_lr', 'momentum', 'reg', cls=NumberList)

        closure = objective.closure
        assert closure is not None

        z, Hz = unpack_states(states, params, 'z', 'Hz', cls=TensorList)
        Hz, _ = objective.hessian_vector_product(z, rgrad=None, at_x0=True, hvp_method=hvp_method, h=h)

        Hz = TensorList(Hz)
        Hzz = Hz.add_(z * reg)

        objective = self.inner_step("inner", objective, must_exist=False)
        updates = objective.get_updates()

        z = curveball(TensorList(updates), z, Hzz, momentum=momentum, precond_lr=precond_lr)
        objective.updates = z.neg()

        return objective

FFTProjection

Bases: torchzero.modules.projections.projection.ProjectionBase

Project update into Fourier space of real-valued inputs.

Parameters:

  • modules (Chainable) –

    modules that will optimize the projected update.

  • one_d (bool, default: False ) –
    • If True, uses 1d fft on parameters concatenated into a vector.
    • If False, uses n-dimensional fft on each parameter (default).
  • norm (str, default: None ) –

    Normalization mode.

    • "forward" - normalize by 1/n
    • "backward" - no normalization
    • "ortho" - normalize by 1/sqrt(n) (making the FFT orthonormal)

    Calling the backward transform (:func:~torch.fft.irfft) with the same normalization mode will apply an overall normalization of 1/n between the two transforms. This is required to make :func:~torch.fft.irfft the exact inverse.

    Default is "backward" (no normalization).

    The actual torch.fft.rfft default is None, so I set it to None too. I guess None and "backward" are the same.

Source code in torchzero/modules/experimental/fft.py
class FFTProjection(ProjectionBase):
    # norm description copied from pytorch docstring
    """Project update into Fourier space of real-valued inputs.

    Args:
        modules (Chainable): modules that will optimize the projected update.
        one_d (bool, optional):
            * If True, uses 1d fft on parameters concatenated into a vector.
            * If False, uses n-dimensional fft on each parameter (default).
        norm (str, optional):
            Normalization mode.

            * "forward" - normalize by 1/n
            * "backward" - no normalization
            * "ortho" - normalize by 1/sqrt(n) (making the FFT orthonormal)

            Calling the backward transform (:func:`~torch.fft.irfft`) with the same
            normalization mode will apply an overall normalization of ``1/n`` between
            the two transforms. This is required to make :func:`~torch.fft.irfft`
            the exact inverse.

            Default is "backward" (no normalization).

            The actual torch.fft.rfft default is None, so I set it to None too. I guess None and "backward"
            are the same.
    """

    def __init__(
        self,
        modules: Chainable,
        one_d: bool = False,
        norm=None,
        project_update=True,
        project_params=False,
        project_grad=False,
    ):
        defaults = dict(one_d=one_d, norm=norm)
        super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)

    @torch.no_grad
    def project(self, tensors, params, grads, loss, states, settings, current):
        settings = settings[0]
        one_d = settings['one_d']
        norm = settings['norm']

        # 1d fft, concatenate all parameters into a vector and calculate fft
        if one_d:
            vec = torch.cat([t.view(-1) for t in tensors])
            self.global_state['length'] = len(vec)
            return [torch.view_as_real(torch.fft.rfft(vec, norm=norm))] # pylint:disable=not-callable

        # multidimensional fft for each parameter
        return [torch.view_as_real(torch.fft.rfftn(t, norm=norm)) if t.numel() > 1 else t for t in tensors] # pylint:disable=not-callable

    @torch.no_grad
    def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
        settings = settings[0]
        one_d = settings['one_d']
        norm = settings['norm']

        if one_d:
            vec = torch.view_as_complex(projected_tensors[0])
            unprojected_vec = torch.fft.irfft(vec, n=self.global_state['length'], norm=norm) # pylint:disable=not-callable
            return vec_to_tensors(unprojected_vec, reference=params)

        return [torch.fft.irfftn(torch.view_as_complex(t.contiguous()), s=p.shape, norm=norm) if t.numel() > 1 else t for t, p in zip(projected_tensors, params)] # pylint:disable=not-callable

GradMin

Bases: torchzero.core.reformulation.Reformulation

Reformulates the objective to minimize sum of gradient magnitudes via autograd. This is not expected to be practical.

Parameters:

  • loss_term (float, default: 0 ) –

    adds loss value times this to sum of gradient magnitudes. Defaults to 1.

  • relative (bool, default: None ) –

    whether to make loss_term relative to gradient magnitude. Defaults to False.

  • graft (bool, default: None ) –

    whether to make loss term same as gradient magnitude. Defaults to False.

  • square (bool, default: False ) –

    whether to use sum of squared gradient magnitudes, if False uses absolute values. Defaults to False.

  • mean (bool, default: True ) –

    whether to use mean, if False uses sum. Defaults to True.

  • maximize_grad (bool, default: False ) –

    whether to maximize gradient magnitudes instead of minimizing. Defaults to False.

  • create_graph (bool, default: False ) –

    whether to create graph. Defaults to False.

  • modify_loss (bool, default: True ) –

    whether to modify the loss value to make line searches minimize new objective. Defaults to True.

Source code in torchzero/modules/experimental/gradmin.py
class GradMin(Reformulation):
    """Reformulates the objective to minimize sum of gradient magnitudes via autograd. This is not expected to be practical.

    Args:
        loss_term (float, optional): adds loss value times this to sum of gradient magnitudes. Defaults to 1.
        relative (bool, optional): whether to make loss_term relative to gradient magnitude. Defaults to False.
        graft (bool, optional): whether to make loss term same as gradient magnitude. Defaults to False.
        square (bool, optional): whether to use sum of squared gradient magnitudes, if False uses absolute values. Defaults to False.
        mean (bool, optional): whether to use mean, if False uses sum. Defaults to True.
        maximize_grad (bool, optional): whether to maximize gradient magnitudes instead of minimizing. Defaults to False.
        create_graph (bool, optional): whether to create graph. Defaults to False.
        modify_loss (bool, optional): whether to modify the loss value to make line searches minimize new objective. Defaults to True.
    """
    def __init__(
        self,
        modules: Chainable,
        loss_term: float | None = 0,
        relative: Literal['loss_to_grad', 'grad_to_loss'] | None = None,
        graft: Literal['loss_to_grad', 'grad_to_loss'] | None = None,
        square=False,
        mean=True,
        maximize_grad=False,
        create_graph=False,
        modify_loss: bool = True,
    ):
        if (relative is not None) and (graft is not None): warnings.warn('both relative and graft loss are True, they will clash with each other')
        defaults = dict(loss_term=loss_term, relative=relative, graft=graft, square=square, mean=mean, maximize_grad=maximize_grad, create_graph=create_graph, modify_loss=modify_loss)
        super().__init__(defaults, modules=modules)

    @torch.no_grad
    def closure(self, backward, closure, params, objective):
        settings = self.settings[params[0]]
        loss_term = settings['loss_term']
        relative = settings['relative']
        graft = settings['graft']
        square = settings['square']
        maximize_grad = settings['maximize_grad']
        create_graph = settings['create_graph']
        modify_loss = settings['modify_loss']
        mean = settings['mean']

        with torch.enable_grad():
            for p in params: p.grad = None
            loss = closure(False)
            grads = TensorList(torch.autograd.grad(loss, params, create_graph=True))

            if square: grads = grads ** 2
            else: grads = grads.abs()

            if mean: f = grads.global_mean()
            else: f = grads.global_sum()


            if graft == 'grad_to_loss': f = f * (loss.detach()/f.detach()).detach()
            if relative == 'grad_to_loss': f = f * loss

            if loss_term is not None and loss_term != 0:
                if relative == 'loss_to_grad': loss_term = loss_term * f
                l = loss
                if graft == 'loss_to_grad': l = loss * (f.detach()/loss.detach()).detach()
                f = f + l*loss_term

            if maximize_grad: f = -f
            if modify_loss: loss = f

            grad = None
            if backward:
                for p in params: p.grad = None
                grad = TensorList(torch.autograd.grad(f, params, create_graph=create_graph))

        return loss, grad

HigherOrderNewton

Bases: torchzero.core.module.Module

A basic arbitrary order newton's method with optional trust region and proximal penalty.

This constructs an nth order taylor approximation via autograd and minimizes it with scipy.optimize.minimize trust region newton solvers with optional proximal penalty.

The hessian of taylor approximation is easier to evaluate, plus it can be evaluated in a batched mode, so it can be more efficient in very specific instances.

Notes
  • In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the inner argument if you wish to apply Newton preconditioning to another module's output.
  • This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating higher order derivatives. The closure must accept a backward argument (refer to documentation).
  • this uses roughly O(N^order) memory and solving the subproblem is very expensive.
  • "none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.

Args:

order (int, optional):
    Order of the method, number of taylor series terms (orders of derivatives) used to approximate the function. Defaults to 4.
trust_method (str | None, optional):
    Method used for trust region.
    - "bounds" - the model is minimized within bounds defined by trust region.
    - "proximal" - the model is minimized with penalty for going too far from current point.
    - "none" - disables trust region.

    Defaults to 'bounds'.
increase (float, optional): trust region multiplier on good steps. Defaults to 1.5.
decrease (float, optional): trust region multiplier on bad steps. Defaults to 0.75.
trust_init (float | None, optional):
    initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on ``"proximal"``. Defaults to None.
trust_tol (float, optional):
    Maximum ratio of expected loss reduction to actual reduction for trust region increase.
    Should 1 or higer. Defaults to 2.
de_iters (int | None, optional):
    If this is specified, the model is minimized via differential evolution first to possibly escape local minima,
    then it is passed to scipy.optimize.minimize. Defaults to None.
vectorize (bool, optional): whether to enable vectorized jacobians (usually faster). Defaults to True.
Source code in torchzero/modules/experimental/higher_order_newton.py
class HigherOrderNewton(Module):
    """A basic arbitrary order newton's method with optional trust region and proximal penalty.

    This constructs an nth order taylor approximation via autograd and minimizes it with
    ``scipy.optimize.minimize`` trust region newton solvers with optional proximal penalty.

    The hessian of taylor approximation is easier to evaluate, plus it can be evaluated in a batched mode,
    so it can be more efficient in very specific instances.

    Notes:
        - In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
        - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating higher order derivatives. The closure must accept a ``backward`` argument (refer to documentation).
        - this uses roughly O(N^order) memory and solving the subproblem is very expensive.
        - "none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.

    Args:

        order (int, optional):
            Order of the method, number of taylor series terms (orders of derivatives) used to approximate the function. Defaults to 4.
        trust_method (str | None, optional):
            Method used for trust region.
            - "bounds" - the model is minimized within bounds defined by trust region.
            - "proximal" - the model is minimized with penalty for going too far from current point.
            - "none" - disables trust region.

            Defaults to 'bounds'.
        increase (float, optional): trust region multiplier on good steps. Defaults to 1.5.
        decrease (float, optional): trust region multiplier on bad steps. Defaults to 0.75.
        trust_init (float | None, optional):
            initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on ``"proximal"``. Defaults to None.
        trust_tol (float, optional):
            Maximum ratio of expected loss reduction to actual reduction for trust region increase.
            Should 1 or higer. Defaults to 2.
        de_iters (int | None, optional):
            If this is specified, the model is minimized via differential evolution first to possibly escape local minima,
            then it is passed to scipy.optimize.minimize. Defaults to None.
        vectorize (bool, optional): whether to enable vectorized jacobians (usually faster). Defaults to True.
    """
    def __init__(
        self,
        order: int = 4,
        trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        init: float | None = None,
        eta: float = 1e-6,
        max_attempts = 10,
        boundary_tol: float = 1e-2,
        de_iters: int | None = None,
        derivatives_method: DerivativesMethod = "batched_autograd",
    ):
        if init is None:
            if trust_method == 'bounds': init = 1
            else: init = 0.1

        defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, de_iters=de_iters, max_attempts=max_attempts, boundary_tol=boundary_tol, rho_good=rho_good, rho_bad=rho_bad, derivatives_method=derivatives_method)
        super().__init__(defaults)

    @torch.no_grad
    def apply(self, objective):
        params = TensorList(objective.params)
        closure = objective.closure
        if closure is None: raise RuntimeError('HigherOrderNewton requires closure')

        settings = self.defaults
        order = settings['order']
        nplus = settings['nplus']
        nminus = settings['nminus']
        eta = settings['eta']
        init = settings['init']
        trust_method = settings['trust_method']
        de_iters = settings['de_iters']
        max_attempts = settings['max_attempts']
        boundary_tol = settings['boundary_tol']
        rho_good = settings['rho_good']
        rho_bad = settings['rho_bad']

        # ------------------------ calculate grad and hessian ------------------------ #
        loss, *derivatives = objective.derivatives(order=order, at_x0=True, method=self.defaults["derivatives_method"])

        x0 = torch.cat([p.ravel() for p in params])

        success = False
        x_star = None
        while not success:
            max_attempts -= 1
            if max_attempts < 0: break

            # load trust region value
            trust_value = self.global_state.get('trust_region', init)

            # make sure its not too small or too large
            finfo = torch.finfo(x0.dtype)
            if trust_value < finfo.tiny*2 or trust_value > finfo.max / (2*nplus):
                trust_value = self.global_state['trust_region'] = settings['init']

            # determine tr and prox values
            if trust_method is None: trust_method = 'none'
            else: trust_method = trust_method.lower()

            if trust_method == 'none':
                trust_region = None
                prox = 0

            elif trust_method == 'bounds':
                trust_region = trust_value
                prox = 0

            elif trust_method == 'proximal':
                trust_region = None
                prox = 1 / trust_value

            else:
                raise ValueError(trust_method)

            # minimize the model
            x_star, expected_loss = _poly_minimize(
                trust_region=trust_region,
                prox=prox,
                de_iters=de_iters,
                c=loss.item(),
                x=x0,
                derivatives=derivatives,
            )

            # update trust region
            if trust_method == 'none':
                success = True
            else:
                pred_reduction = loss - expected_loss

                vec_to_tensors_(x_star, params)
                loss_star = closure(False)
                vec_to_tensors_(x0, params)
                reduction = loss - loss_star

                rho = reduction / (max(pred_reduction, finfo.tiny * 2)) # pyright:ignore[reportArgumentType]

                # failed step
                if rho < rho_bad:
                    self.global_state['trust_region'] = trust_value * nminus

                # very good step
                elif rho > rho_good:
                    step = (x_star - x0)
                    magn = torch.linalg.vector_norm(step) # pylint:disable=not-callable
                    if trust_method == 'proximal' or (trust_value - magn) / trust_value <= boundary_tol:
                        # close to boundary
                        self.global_state['trust_region'] = trust_value * nplus

                # if the ratio is high enough then accept the proposed step
                success = rho > eta

        assert x_star is not None
        if success:
            difference = vec_to_tensors(x0 - x_star, params)
            objective.updates = list(difference)
        else:
            objective.updates = params.zeros_like()

        return objective

InfinityNormTrustRegion

Bases: torchzero.modules.trust_region.trust_region.TrustRegionBase

Trust region with L-infinity norm via scipy.optimize.lsq_linear.

Parameters:

  • hess_module (Module | None) –

    A module that maintains a hessian approximation (not hessian inverse!). This includes all full-matrix quasi-newton methods, tz.m.Newton and tz.m.GaussNewton. When using quasi-newton methods, set inverse=False when constructing them.

  • eta (float, default: 0.0 ) –

    if ratio of actual to predicted rediction is larger than this, step is accepted. When :code:hess_module is GaussNewton, this can be set to 0. Defaults to 0.15.

  • nplus (float, default: 3.5 ) –

    increase factor on successful steps. Defaults to 1.5.

  • nminus (float, default: 0.25 ) –

    decrease factor on unsuccessful steps. Defaults to 0.75.

  • rho_good (float, default: 0.99 ) –

    if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by nplus.

  • rho_bad (float, default: 0.0001 ) –

    if ratio of actual to predicted rediction is less than this, trust region size is multiplied by nminus.

  • init (float, default: 1 ) –

    Initial trust region value. Defaults to 1.

  • update_freq (int, default: 1 ) –

    frequency of updating the hessian. Defaults to 1.

  • max_attempts (max_attempts, default: 10 ) –

    maximum number of trust region size size reductions per step. A zero update vector is returned when this limit is exceeded. Defaults to 10.

  • boundary_tol (float | None, default: None ) –

    The trust region only increases when suggested step's norm is at least (1-boundary_tol)*trust_region. This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.

  • tol (float | None, default: 1e-10 ) –

    tolerance for least squares solver.

  • fallback (bool) –

    if True, when hess_module maintains hessian inverse which can't be inverted efficiently, it will be inverted anyway. When False (default), a RuntimeError will be raised instead.

  • inner (Chainable | None, default: None ) –

    preconditioning is applied to output of thise module. Defaults to None.

Examples:

BFGS with infinity-norm trust region

.. code-block:: python

opt = tz.Optimizer(
    model.parameters(),
    tz.m.InfinityNormTrustRegion(hess_module=tz.m.BFGS(inverse=False)),
)
Source code in torchzero/modules/experimental/l_infinity.py
class InfinityNormTrustRegion(TrustRegionBase):
    """Trust region with L-infinity norm via ``scipy.optimize.lsq_linear``.

    Args:
        hess_module (Module | None, optional):
            A module that maintains a hessian approximation (not hessian inverse!).
            This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
            When using quasi-newton methods, set `inverse=False` when constructing them.
        eta (float, optional):
            if ratio of actual to predicted rediction is larger than this, step is accepted.
            When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
        nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
        nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
        rho_good (float, optional):
            if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
        rho_bad (float, optional):
            if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
        init (float, optional): Initial trust region value. Defaults to 1.
        update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
        max_attempts (max_attempts, optional):
            maximum number of trust region size size reductions per step. A zero update vector is returned when
            this limit is exceeded. Defaults to 10.
        boundary_tol (float | None, optional):
            The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
            This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
        tol (float | None, optional): tolerance for least squares solver.
        fallback (bool, optional):
            if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
            be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
        inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.

    Examples:
        BFGS with infinity-norm trust region

        .. code-block:: python

            opt = tz.Optimizer(
                model.parameters(),
                tz.m.InfinityNormTrustRegion(hess_module=tz.m.BFGS(inverse=False)),
            )
    """
    def __init__(
        self,
        hess_module: Module,
        prefer_dense:bool=True,
        tol: float = 1e-10,
        eta: float= 0.0,
        nplus: float = 3.5,
        nminus: float = 0.25,
        rho_good: float = 0.99,
        rho_bad: float = 1e-4,
        boundary_tol: float | None = None,
        init: float = 1,
        max_attempts: int = 10,
        radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict(tol=tol, prefer_dense=prefer_dense)
        super().__init__(
            defaults=defaults,
            hess_module=hess_module,
            eta=eta,
            nplus=nplus,
            nminus=nminus,
            rho_good=rho_good,
            rho_bad=rho_bad,
            boundary_tol=boundary_tol,
            init=init,
            max_attempts=max_attempts,
            radius_strategy=radius_strategy,
            update_freq=update_freq,
            inner=inner,

            radius_fn=torch.amax,
        )

    def trust_solve(self, f, g, H, radius, params, closure, settings):
        if settings['prefer_dense'] and H.is_dense():
            # convert to array if possible to avoid many conversions
            # between torch and numpy, plus it seems that it uses
            # a better solver
            A = H.to_tensor().numpy(force=True).astype(np.float64)
        else:
            # memory efficient linear operator (is this still faster on CUDA?)
            A = H.scipy_linop()

        try:
            d_np = lsq_linear(
                A,
                g.numpy(force=True).astype(np.float64),
                tol=settings['bounds'],
                bounds=(-radius, radius),
            ).x
            return torch.as_tensor(d_np, device=g.device, dtype=g.dtype)

        except np.linalg.LinAlgError:
            self.children['hess_module'].reset()
            g_max = g.amax()
            if g_max > radius:
                g = g * (radius / g_max)
            return g

NewtonNewton

Bases: torchzero.core.transform.Transform

Applies Newton-like preconditioning to Newton step.

This is a method that I thought of and then it worked. Here is how it works:

  1. Calculate newton step by solving Hx=g

  2. Calculate jacobian of x wrt parameters and call it H2

  3. Solve H2 x2 = x for x2.

  4. Optionally, repeat (if order is higher than 3.)

Source code in torchzero/modules/experimental/newtonnewton.py
class NewtonNewton(Transform):
    """Applies Newton-like preconditioning to Newton step.

    This is a method that I thought of and then it worked. Here is how it works:

    1. Calculate newton step by solving Hx=g

    2. Calculate jacobian of x wrt parameters and call it H2

    3. Solve H2 x2 = x for x2.

    4. Optionally, repeat (if order is higher than 3.)
    """
    def __init__(
        self,
        reg: float = 1e-6,
        order: int = 3,
        vectorize: bool = True,
        update_freq: int = 1,
        inner: Chainable | None = None,
    ):
        defaults = dict(order=order, reg=reg, vectorize=vectorize)
        super().__init__(defaults, update_freq=update_freq, inner=inner)

    @torch.no_grad
    def update_states(self, objective, states, settings):
        fs = settings[0]

        params = TensorList(objective.params)
        closure = objective.closure
        if closure is None: raise RuntimeError('NewtonNewton requires closure')

        reg = fs['reg']
        vectorize = fs['vectorize']
        order = fs['order']

        # ------------------------ calculate grad and hessian ------------------------ #
        P = None
        with torch.enable_grad():
            loss = objective.loss = objective.loss_approx = closure(False)
            g_list = torch.autograd.grad(loss, params, create_graph=True)
            objective.grads = list(g_list)

            xp = torch.cat([t.ravel() for t in g_list])
            I = torch.eye(xp.numel(), dtype=xp.dtype, device=xp.device)

            for o in range(2, order + 1):
                is_last = o == order
                H_list = jacobian_wrt([xp], params, create_graph=not is_last, batched=vectorize)
                with torch.no_grad() if is_last else nullcontext():
                    H = flatten_jacobian(H_list)
                    if reg != 0: H = H + I * reg
                    if P is None: P = H
                    else: P = P @ H

                    if not is_last:
                        x = _try_cholesky_solve(H, xp)
                        if x is None: x = _try_lu_solve(H, xp)
                        if x is None: x = _least_squares_solve(H, xp)
                        xp = x.squeeze()

        self.global_state["P"] = P

    @torch.no_grad
    def apply_states(self, objective, states, settings):
        updates = objective.get_updates()
        P = self.global_state['P']
        b = torch.cat([t.ravel() for t in updates])

        sol = _try_cholesky_solve(P, b)
        if sol is None: sol = _try_lu_solve(P, b)
        if sol is None: sol = _least_squares_solve(P, b)

        vec_to_tensors_(sol, updates)
        return objective

    @torch.no_grad
    def get_H(self, objective=...):
        return Dense(self.global_state["P"])

NewtonSolver

Bases: torchzero.core.module.Module

Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG).

Source code in torchzero/modules/experimental/newton_solver.py
class NewtonSolver(Module):
    """Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)."""
    def __init__(
        self,
        solver: Callable[[list[torch.Tensor]], Any] = lambda p: Optimizer(p, LBFGS()),
        maxiter=None,
        maxiter1=None,
        tol:float | None=1e-3,
        reg: float = 0,
        warm_start=True,
        hvp_method: HVPMethod = "autograd",
        reset_solver: bool = False,
        h: float= 1e-3,

        inner: Chainable | None = None,
    ):
        defaults = locals().copy()
        del defaults['self'], defaults['inner']
        super().__init__(defaults)

        self.set_child("inner", inner)

        self._num_hvps = 0
        self._num_hvps_last_step = 0

    @torch.no_grad
    def apply(self, objective):

        params = TensorList(objective.params)
        closure = objective.closure
        if closure is None: raise RuntimeError('NewtonCG requires closure')

        settings = self.settings[params[0]]
        solver_cls = settings['solver']
        maxiter = settings['maxiter']
        maxiter1 = settings['maxiter1']
        tol = settings['tol']
        hvp_method = settings['hvp_method']
        warm_start = settings['warm_start']
        h = settings['h']
        reset_solver = settings['reset_solver']

        self._num_hvps_last_step = 0

        # ---------------------- Hessian vector product function --------------------- #
        _, H_mv = objective.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)

        # -------------------------------- inner step -------------------------------- #
        objective = self.inner_step("inner", objective, must_exist=False)
        b = TensorList(objective.get_updates())

        # ---------------------------------- run cg ---------------------------------- #
        x0 = None
        if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
        if x0 is None: x = b.zeros_like().requires_grad_(True)
        else: x = x0.clone().requires_grad_(True)


        if 'solver' not in self.global_state:
            if maxiter1 is not None: maxiter = maxiter1
            solver = self.global_state['solver'] = solver_cls(x)
            self.global_state['x'] = x

        else:
            if reset_solver:
                solver = self.global_state['solver'] = solver_cls(x)
            else:
                solver_params = self.global_state['x']
                solver_params.set_(x)
                x = solver_params
                solver = self.global_state['solver']

        def lstsq_closure(backward=True):
            Hx = H_mv(x).detach()
            # loss = (Hx-b).pow(2).global_mean()
            # if backward:
            #     solver.zero_grad()
            #     loss.backward(inputs=x)

            residual = Hx - b
            loss = residual.pow(2).global_mean()
            if backward:
                with torch.no_grad():
                    H_residual = H_mv(residual)
                    n = residual.global_numel()
                    x.set_grad_((2.0 / n) * H_residual)

            return loss

        if maxiter is None: maxiter = b.global_numel()
        loss = None
        initial_loss = lstsq_closure(False) if tol is not None else None # skip unnecessary closure if tol is None
        if initial_loss is None or initial_loss > torch.finfo(b[0].dtype).eps:
            for i in range(maxiter):
                loss = solver.step(lstsq_closure)
                assert loss is not None
                if initial_loss is not None and loss/initial_loss < tol: break

        # print(f'{loss = }')

        if warm_start:
            assert x0 is not None
            x0.copy_(x)

        objective.updates = x.detach()
        self._num_hvps += self._num_hvps_last_step
        return objective

ReduceOutwardLR

Bases: torchzero.core.transform.TensorTransform

When update sign matches weight sign, the learning rate for that weight is multiplied by mul.

This means updates that move weights towards zero have higher learning rates.

Warning

This sounded good but after testing turns out it sucks.

Source code in torchzero/modules/experimental/reduce_outward_lr.py
class ReduceOutwardLR(TensorTransform):
    """When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.

    This means updates that move weights towards zero have higher learning rates.

    Warning:
        This sounded good but after testing turns out it sucks.
    """
    def __init__(self, mul = 0.5, use_grad=False, invert=False):
        defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
        super().__init__(defaults, uses_grad=use_grad)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        params = TensorList(params)
        tensors = TensorList(tensors)

        mul = [s['mul'] for s in settings]
        s = settings[0]
        use_grad = self._uses_grad
        invert = s['invert']

        if use_grad: cur = grads
        else: cur = tensors
        assert cur is not None

        # mask of weights where sign matches with update sign (minus ascent sign), multiplied by `mul`.
        if invert: mask = (params * cur) > 0
        else: mask = (params * cur) < 0

        tensors.masked_set_(mask, tensors*mul)

        return tensors

ScipyNewtonCG

Bases: torchzero.core.module.Module

NewtonCG with scipy solvers (any from scipy.sparse.linalg)

Source code in torchzero/modules/experimental/scipy_newton_cg.py
class ScipyNewtonCG(Module):
    """NewtonCG with scipy solvers (any from scipy.sparse.linalg)"""
    def __init__(
        self,
        solver = gcrotmk,
        hvp_method: Literal["fd_forward", "fd_central", "autograd"] = "autograd",
        h: float = 1e-3,
        warm_start=False,
        inner: Chainable | None = None,
        kwargs: dict | None = None,
    ):
        defaults = dict(hvp_method=hvp_method, solver=solver, h=h, warm_start=warm_start)
        super().__init__(defaults,)

        if inner is not None:
            self.set_child('inner', inner)

        self._num_hvps = 0
        self._num_hvps_last_step = 0

        if kwargs is None: kwargs = {}
        self._kwargs = kwargs

    @torch.no_grad
    def apply(self, objective):
        params = TensorList(objective.params)
        closure = objective.closure
        if closure is None: raise RuntimeError('NewtonCG requires closure')

        fs = self.settings[params[0]]
        hvp_method = fs['hvp_method']
        solver = fs['solver']
        h = fs['h']
        warm_start = fs['warm_start']

        self._num_hvps_last_step = 0
        # ---------------------- Hessian vector product function --------------------- #
        device = params[0].device; dtype=params[0].dtype
        if hvp_method == 'autograd':
            grad = objective.get_grads(create_graph=True)

            def H_mm(x_np):
                self._num_hvps_last_step += 1
                x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
                with torch.enable_grad():
                    Hvp = TensorList(torch.autograd.grad(grad, params, x, retain_graph=True))
                return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)

        else:

            with torch.enable_grad():
                grad = objective.get_grads()

            if hvp_method == 'forward':
                def H_mm(x_np):
                    self._num_hvps_last_step += 1
                    x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
                    Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad)[1])
                    return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)

            elif hvp_method == 'central':
                def H_mm(x_np):
                    self._num_hvps_last_step += 1
                    x = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), grad)
                    Hvp = TensorList(hvp_fd_central(closure, params, x, h=h)[1])
                    return torch.cat([t.ravel() for t in Hvp]).numpy(force=True)

            else:
                raise ValueError(hvp_method)

        ndim = sum(p.numel() for p in params)
        H = LinearOperator(shape=(ndim,ndim), matvec=H_mm, rmatvec=H_mm) # type:ignore

        # -------------------------------- inner step -------------------------------- #
        objective = self.inner_step("inner", objective, must_exist=False)
        b = TensorList(objective.get_updates())

        # ---------------------------------- run cg ---------------------------------- #
        x0 = None
        if warm_start: x0 = self.global_state.get('x_prev', None) # initialized to 0 which is default anyway

        x_np = solver(H, b.to_vec().nan_to_num().numpy(force=True), x0=x0, **self._kwargs)
        if isinstance(x_np, tuple): x_np = x_np[0]

        if warm_start:
            self.global_state['x_prev'] = x_np

        objective.updates = vec_to_tensors(torch.as_tensor(x_np, device=device, dtype=dtype), params)

        self._num_hvps += self._num_hvps_last_step
        return objective

SubspaceCubicAdam

Bases: torchzero.modules.adaptive.lre_optimizers.LREOptimizerBase

Runs cubic Adam in low rank eigenbasis.

Source code in torchzero/modules/experimental/cubic_adam.py
class SubspaceCubicAdam(LREOptimizerBase):
    """Runs cubic Adam in low rank eigenbasis."""
    def __init__(self, beta1=0.9, beta2=0.95, beta3=0.95, eps=1e-8, mode: _cubic_adam_mode = 'signed_cbrt', cautious:bool=False, exact_reproject:bool=True):
        self.beta1 = beta1
        self.beta2 = beta2
        self.beta3 = beta3
        self.eps = eps
        self.cautious = cautious
        self.mode: _cubic_adam_mode = mode
        self.exact_reproject = exact_reproject

    def step(self, g, L, Q, state):
        g = Q.T @ g

        if "exp_avg" not in state:
            state["exp_avg"] = torch.zeros_like(g)
            state["exp_avg_sq"] = torch.zeros_like(g)
            state["exp_avg_cu"] = torch.zeros_like(g)
            state["current_step"] = 1

        dir = cubic_adam_(
            tensors = TensorList([g]),
            exp_avg_ = TensorList([state["exp_avg"]]),
            exp_avg_sq_ = TensorList([state["exp_avg_sq"]]),
            exp_avg_cu_ = TensorList([state["exp_avg_cu"]]),
            alpha = 1,
            beta1 = self.beta1,
            beta2 = self.beta2,
            beta3 = self.beta3,
            eps = self.eps,
            debiased = True,
            step = state["current_step"],

            mode=self.mode,
        )[0]

        state["current_step"] += 1
        return Q @ dir

    def reproject(self, L_old, Q_old, L_new, Q_new, state):
        if  "exp_avg" not in state: return

        C = Q_new.T @ Q_old

        state["exp_avg"] = C @ state["exp_avg"]
        state["exp_avg_sq"] = _squared_reproject(C, state["exp_avg_sq"], exact=self.exact_reproject)
        state["exp_avg_cu"] = C.pow(3) @ state["exp_avg_cu"] # exact reproject with 1_000_000 is feasible

TensorizeProjection

Bases: torchzero.modules.projections.projection.ProjectionBase

flattens and concatenates all parameters into a vector and then reshapes it into a tensor

Source code in torchzero/modules/experimental/structural_projections.py
class TensorizeProjection(ProjectionBase):
    """flattens and concatenates all parameters into a vector and then reshapes it into a tensor"""
    def __init__(self, modules: Chainable, max_side: int, project_update=True, project_params=False, project_grad=False):
        defaults = dict(max_side=max_side)
        super().__init__(modules, defaults=defaults, project_update=project_update, project_params=project_params, project_grad=project_grad)

    @torch.no_grad
    def project(self, tensors, params, grads, loss, states, settings, current):
        max_side = self.settings[params[0]]['max_side']
        num_elems = sum(t.numel() for t in tensors)

        if num_elems < max_side:
            self.global_state['remainder'] = 0
            # return 1d
            return [torch.cat([t.view(-1) for t in tensors])]


        # determine appropriate shape to reshape into
        ndims = math.ceil(math.log(num_elems, max_side)) # determine number of dims
        dim_size = math.ceil(num_elems ** (1/ndims)) # average size of a dim with ndims
        dims = [dim_size for _ in range(ndims)]
        required_elems = math.prod(dims)

        # add few extra zeros to vec to match a reshapable size
        remainder = required_elems-num_elems
        if remainder > 0: tensors = tensors + [torch.zeros(remainder, dtype=tensors[0].dtype, device=tensors[0].device)]
        self.global_state['remainder'] = remainder

        # flatten and reshape
        vec = torch.cat([t.view(-1) for t in tensors])
        return [vec.view(dims)]

    @torch.no_grad
    def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
        remainder = self.global_state['remainder']
        # warnings.warn(f'{tensors[0].shape = }')
        vec = projected_tensors[0].view(-1)
        if remainder > 0: vec = vec[:-remainder]
        return vec_to_tensors(vec, params)