Skip to content

Operations

This subpackage contains operations like adding modules, subtracting, grafting, tracking the maximum, etc.

Classes:

  • Abs

    Returns abs(input)

  • AccumulateMaximum

    Accumulates maximum of all past updates.

  • AccumulateMean

    Accumulates mean of all past updates.

  • AccumulateMinimum

    Accumulates minimum of all past updates.

  • AccumulateProduct

    Accumulates product of all past updates.

  • AccumulateSum

    Accumulates sum of all past updates.

  • Add

    Add other to tensors. other can be a number or a module.

  • BinaryOperationBase

    Base class for operations that use update as the first operand. This is an abstract class, subclass it and override transform method to use it.

  • CenteredEMASquared

    Maintains a centered exponential moving average of squared updates. This also maintains an additional

  • CenteredSqrtEMASquared

    Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.

  • Clip

    clip tensors to be in (min, max) range. min and `max: can be None, numbers or modules.

  • ClipModules

    Calculates input(tensors).clip(min, max). min and max can be numbers or modules.

  • Clone

    Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations

  • CopyMagnitude

    Returns other(tensors) with sign copied from tensors.

  • CopySign

    Returns tensors with sign copied from other(tensors).

  • CustomUnaryOperation

    Applies getattr(tensor, name) to each tensor

  • Debias

    Multiplies the update by an Adam debiasing term based first and/or second momentum.

  • Debias2

    Multiplies the update by an Adam debiasing term based on the second momentum.

  • Div

    Divide tensors by other. other can be a number or a module.

  • DivModules

    Calculates input / other. input and other can be numbers or modules.

  • EMASquared

    Maintains an exponential moving average of squared updates.

  • Exp

    Returns exp(input)

  • Fill

    Outputs tensors filled with value

  • Grad

    Outputs the gradient

  • GradToNone

    Sets grad attribute to None on objective.

  • Graft

    Outputs direction output rescaled to have the same norm as magnitude output.

  • GraftInputToOutput

    Outputs tensors rescaled to have the same norm as magnitude(tensors).

  • GraftOutputToInput

    Outputs magnitude(tensors) rescaled to have the same norm as tensors

  • GramSchimdt

    outputs tensors made orthogonal to other(tensors) via Gram-Schmidt.

  • Identity

    Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

  • LerpModules

    Does a linear interpolation of input(tensors) and end(tensors) based on a scalar weight.

  • Maximum

    Outputs maximum(tensors, other(tensors))

  • MaximumModules

    Outputs elementwise maximum of inputs that can be modules or numbers.

  • Mean

    Outputs a mean of inputs that can be modules or numbers.

  • Minimum

    Outputs minimum(tensors, other(tensors))

  • MinimumModules

    Outputs elementwise minimum of inputs that can be modules or numbers.

  • Mul

    Multiply tensors by other. other can be a number or a module.

  • MultiOperationBase

    Base class for operations that use operands. This is an abstract class, subclass it and override transform method to use it.

  • NanToNum

    Convert nan, inf and -inf`` to numbers.

  • Negate

    Returns - input

  • Noop

    Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

  • Ones

    Outputs ones

  • Params

    Outputs parameters

  • Pow

    Take tensors to the power of exponent. exponent can be a number or a module.

  • PowModules

    Calculates input ** exponent. input and other can be numbers or modules.

  • Prod

    Outputs product of inputs that can be modules or numbers.

  • RCopySign

    Returns other(tensors) with sign copied from tensors.

  • RDiv

    Divide other by tensors. other can be a number or a module.

  • RPow

    Take other to the power of tensors. other can be a number or a module.

  • RSub

    Subtract tensors from other. other can be a number or a module.

  • Randn

    Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1.

  • RandomSample

    Outputs tensors filled with random numbers from distribution depending on value of distribution.

  • Reciprocal

    Returns 1 / input

  • ReduceOperationBase

    Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override transform method to use it.

  • Sign

    Returns sign(input)

  • Sqrt

    Returns sqrt(input)

  • SqrtEMASquared

    Maintains an exponential moving average of squared updates, outputs optionally debiased square root.

  • Sub

    Subtract other from tensors. other can be a number or a module.

  • SubModules

    Calculates input - other. input and other can be numbers or modules.

  • Sum

    Outputs sum of inputs that can be modules or numbers.

  • Threshold

    Outputs tensors thresholded such that values above threshold are set to value.

  • UnaryLambda

    Applies fn to input tensors.

  • UnaryParameterwiseLambda

    Applies fn to each input tensor.

  • Uniform

    Outputs tensors filled with random numbers from uniform distribution between low and high.

  • UpdateToNone

    Sets update attribute to None on var.

  • WeightedMean

    Outputs weighted mean of inputs that can be modules or numbers.

  • WeightedSum

    Outputs a weighted sum of inputs that can be modules or numbers.

  • Zeros

    Outputs zeros

Abs

Bases: torchzero.core.transform.TensorTransform

Returns abs(input)

Source code in torchzero/modules/ops/unary.py
class Abs(TensorTransform):
    """Returns ``abs(input)``"""
    def __init__(self): super().__init__()
    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        torch._foreach_abs_(tensors)
        return tensors

AccumulateMaximum

Bases: torchzero.core.transform.TensorTransform

Accumulates maximum of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateMaximum(TensorTransform):
    """Accumulates maximum of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0):
        defaults = dict(decay=decay)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return maximum.maximum_(tensors).lazy_mul(decay, clone=True)

AccumulateMean

Bases: torchzero.core.transform.TensorTransform

Accumulates mean of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateMean(TensorTransform):
    """Accumulates mean of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0):
        defaults = dict(decay=decay)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1
        mean = unpack_states(states, tensors, 'mean', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return mean.add_(tensors).lazy_mul(decay, clone=True).div_(step)

AccumulateMinimum

Bases: torchzero.core.transform.TensorTransform

Accumulates minimum of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateMinimum(TensorTransform):
    """Accumulates minimum of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0):
        defaults = dict(decay=decay)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return minimum.minimum_(tensors).lazy_mul(decay, clone=True)

AccumulateProduct

Bases: torchzero.core.transform.TensorTransform

Accumulates product of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Target, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateProduct(TensorTransform):
    """Accumulates product of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0, target = 'update',):
        defaults = dict(decay=decay)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        prod = unpack_states(states, tensors, 'prod', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return prod.mul_(tensors).lazy_mul(decay, clone=True)

AccumulateSum

Bases: torchzero.core.transform.TensorTransform

Accumulates sum of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateSum(TensorTransform):
    """Accumulates sum of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0):
        defaults = dict(decay=decay)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        sum = unpack_states(states, tensors, 'sum', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return sum.add_(tensors).lazy_mul(decay, clone=True)

Add

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Add other to tensors. other can be a number or a module.

If other is a module, this calculates tensors + other(tensors)

Source code in torchzero/modules/ops/binary.py
class Add(BinaryOperationBase):
    """Add ``other`` to tensors. ``other`` can be a number or a module.

    If ``other`` is a module, this calculates ``tensors + other(tensors)``
    """
    def __init__(self, other: Chainable | float, alpha: float = 1):
        defaults = dict(alpha=alpha)
        super().__init__(defaults, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.defaults['alpha'])
        else: torch._foreach_add_(update, other, alpha=self.defaults['alpha'])
        return update

BinaryOperationBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for operations that use update as the first operand. This is an abstract class, subclass it and override transform method to use it.

Methods:

  • transform

    applies the operation to operands

Source code in torchzero/modules/ops/binary.py
class BinaryOperationBase(Module, ABC):
    """Base class for operations that use update as the first operand. This is an abstract class, subclass it and override `transform` method to use it."""
    def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
        super().__init__(defaults=defaults)

        self.operands = {}
        for k,v in operands.items():

            if isinstance(v, (Module, Sequence)):
                self.set_child(k, v)
                self.operands[k] = self.children[k]
            else:
                self.operands[k] = v

    @abstractmethod
    def transform(self, objective: Objective, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
        """applies the operation to operands"""
        raise NotImplementedError

    def update(self, objective): raise RuntimeError
    def apply(self, objective): raise RuntimeError

    @torch.no_grad
    def step(self, objective: Objective) -> Objective:
        # pass cloned update to all module operands
        processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()

        for k,v in self.operands.items():
            if k in self.children:
                v: Module
                updated_obj = v.step(objective.clone(clone_updates=True))
                processed_operands[k] = updated_obj.get_updates()
                objective.update_attrs_from_clone_(updated_obj) # update loss, grad, etc if this module calculated them

        transformed = self.transform(objective, update=objective.get_updates(), **processed_operands)
        objective.updates = list(transformed)
        return objective

transform

transform(objective: Objective, update: list[Tensor], **operands: Any | list[Tensor]) -> Iterable[Tensor]

applies the operation to operands

Source code in torchzero/modules/ops/binary.py
@abstractmethod
def transform(self, objective: Objective, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
    """applies the operation to operands"""
    raise NotImplementedError

CenteredEMASquared

Bases: torchzero.core.transform.TensorTransform

Maintains a centered exponential moving average of squared updates. This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.

Parameters:

  • beta (float, default: 0.99 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Source code in torchzero/modules/ops/higher_level.py
class CenteredEMASquared(TensorTransform):
    """Maintains a centered exponential moving average of squared updates. This also maintains an additional
    exponential moving average of un-squared updates, square of which is subtracted from the EMA.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2):
        defaults = dict(beta=beta, amsgrad=amsgrad, pow=pow)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        amsgrad, pow = itemgetter('amsgrad', 'pow')(settings[0])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return centered_ema_sq_(
            TensorList(tensors),
            exp_avg_=exp_avg,
            exp_avg_sq_=exp_avg_sq,
            beta=beta,
            max_exp_avg_sq_=max_exp_avg_sq,
            pow=pow,
        ).clone()

CenteredSqrtEMASquared

Bases: torchzero.core.transform.TensorTransform

Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root. This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.

Parameters:

  • beta (float, default: 0.99 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • debiased (bool, default: False ) –

    whether to multiply the output by a debiasing term from the Adam method. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Source code in torchzero/modules/ops/higher_level.py
class CenteredSqrtEMASquared(TensorTransform):
    """Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.
    This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2):
        defaults = dict(beta=beta, amsgrad=amsgrad, debiased=debiased, pow=pow)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return sqrt_centered_ema_sq_(
            TensorList(tensors),
            exp_avg_=exp_avg,
            exp_avg_sq_=exp_avg_sq,
            beta=beta,
            debiased=debiased,
            step=step,
            max_exp_avg_sq_=max_exp_avg_sq,
            pow=pow,
        )

Clip

Bases: torchzero.modules.ops.binary.BinaryOperationBase

clip tensors to be in (min, max) range. min and `max: can be None, numbers or modules.

If min and max are modules, this calculates tensors.clip(min(tensors), max(tensors)).

Source code in torchzero/modules/ops/binary.py
class Clip(BinaryOperationBase):
    """clip tensors to be in  ``(min, max)`` range. ``min`` and ``max`: can be None, numbers or modules.

    If ``min`` and ``max``  are modules, this calculates ``tensors.clip(min(tensors), max(tensors))``.
    """
    def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
        super().__init__({}, min=min, max=max)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
        return TensorList(update).clamp_(min=min,  max=max)

ClipModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates input(tensors).clip(min, max). min and max can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class ClipModules(MultiOperationBase):
    """Calculates ``input(tensors).clip(min, max)``. ``min`` and ``max`` can be numbers or modules."""
    def __init__(self, input: Chainable, min: float | Chainable | None = None, max: float | Chainable | None = None):
        defaults = {}
        super().__init__(defaults, input=input, min=min, max=max)

    @torch.no_grad
    def transform(self, objective: Objective, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
        return TensorList(input).clamp_(min=min, max=max)

Clone

Bases: torchzero.core.module.Module

Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations

Source code in torchzero/modules/ops/utility.py
class Clone(Module):
    """Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def apply(self, objective):
        objective.updates = [u.clone() for u in objective.get_updates()]
        return objective

CopyMagnitude

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Returns other(tensors) with sign copied from tensors.

Source code in torchzero/modules/ops/binary.py
class RCopySign(BinaryOperationBase):
    """Returns ``other(tensors)`` with sign copied from tensors."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
        return [o.copysign_(u) for u, o in zip(update, other)]

CopySign

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Returns tensors with sign copied from other(tensors).

Source code in torchzero/modules/ops/binary.py
class CopySign(BinaryOperationBase):
    """Returns tensors with sign copied from ``other(tensors)``."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
        return [u.copysign_(o) for u, o in zip(update, other)]

CustomUnaryOperation

Bases: torchzero.core.transform.TensorTransform

Applies getattr(tensor, name) to each tensor

Source code in torchzero/modules/ops/unary.py
class CustomUnaryOperation(TensorTransform):
    """Applies ``getattr(tensor, name)`` to each tensor
    """
    def __init__(self, name: str):
        defaults = dict(name=name)
        super().__init__(defaults=defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        return getattr(tensors, settings[0]['name'])()

Debias

Bases: torchzero.core.transform.TensorTransform

Multiplies the update by an Adam debiasing term based first and/or second momentum.

Parameters:

  • beta1 (float | None, default: None ) –

    first momentum, should be the same as first momentum used in modules before. Defaults to None.

  • beta2 (float | None, default: None ) –

    second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.

  • alpha (float, default: 1 ) –

    learning rate. Defaults to 1.

  • pow (float, default: 2 ) –

    power, assumes absolute value is used. Defaults to 2.

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/higher_level.py
class Debias(TensorTransform):
    """Multiplies the update by an Adam debiasing term based first and/or second momentum.

    Args:
        beta1 (float | None, optional):
            first momentum, should be the same as first momentum used in modules before. Defaults to None.
        beta2 (float | None, optional):
            second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
        alpha (float, optional): learning rate. Defaults to 1.
        pow (float, optional): power, assumes absolute value is used. Defaults to 2.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, beta1: float | None = None, beta2: float | None = None, alpha: float = 1, pow:float=2):
        defaults = dict(beta1=beta1, beta2=beta2, alpha=alpha, pow=pow)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        pow = settings[0]['pow']
        alpha, beta1, beta2 = unpack_dicts(settings, 'alpha', 'beta1', 'beta2', cls=NumberList)

        return debias(TensorList(tensors), step=step, beta1=beta1, beta2=beta2, alpha=alpha, pow=pow, inplace=True)

Debias2

Bases: torchzero.core.transform.TensorTransform

Multiplies the update by an Adam debiasing term based on the second momentum.

Parameters:

  • beta (float | None, default: 0.999 ) –

    second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.

  • pow (float, default: 2 ) –

    power, assumes absolute value is used. Defaults to 2.

  • target (Target) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/higher_level.py
class Debias2(TensorTransform):
    """Multiplies the update by an Adam debiasing term based on the second momentum.

    Args:
        beta (float | None, optional):
            second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
        pow (float, optional): power, assumes absolute value is used. Defaults to 2.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, beta: float = 0.999, pow: float = 2,):
        defaults = dict(beta=beta, pow=pow)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        pow = settings[0]['pow']
        beta = NumberList(s['beta'] for s in settings)
        return debias_second_momentum(TensorList(tensors), step=step, beta=beta, pow=pow, inplace=True)

Div

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Divide tensors by other. other can be a number or a module.

If other is a module, this calculates tensors / other(tensors)

Source code in torchzero/modules/ops/binary.py
class Div(BinaryOperationBase):
    """Divide tensors by ``other``. ``other`` can be a number or a module.

    If ``other`` is a module, this calculates ``tensors / other(tensors)``
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        torch._foreach_div_(update, other)
        return update

DivModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates input / other. input and other can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class DivModules(MultiOperationBase):
    """Calculates ``input / other``. ``input`` and ``other`` can be numbers or modules."""
    def __init__(self, input: Chainable | float, other: Chainable | float, other_first:bool=False):
        defaults = {}
        if other_first: super().__init__(defaults, other=other, input=input)
        else: super().__init__(defaults, input=input, other=other)

    @torch.no_grad
    def transform(self, objective: Objective, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
        if isinstance(input, (int,float)):
            assert isinstance(other, list)
            return input / TensorList(other)

        torch._foreach_div_(input, other)
        return input

EMASquared

Bases: torchzero.core.transform.TensorTransform

Maintains an exponential moving average of squared updates.

Parameters:

  • beta (float, default: 0.999 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Methods:

  • EMA_SQ_FN

    Updates exp_avg_sq_ with EMA of squared tensors, if max_exp_avg_sq_ is not None, updates it with maximum of EMA.

Source code in torchzero/modules/ops/higher_level.py
class EMASquared(TensorTransform):
    """Maintains an exponential moving average of squared updates.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    EMA_SQ_FN: staticmethod = staticmethod(ema_sq_)

    def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2):
        defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return self.EMA_SQ_FN(TensorList(tensors), exp_avg_sq_=exp_avg_sq, beta=beta, max_exp_avg_sq_=max_exp_avg_sq, pow=pow).clone()

EMA_SQ_FN

EMA_SQ_FN(tensors: TensorList, exp_avg_sq_: TensorList, beta: float | NumberList, max_exp_avg_sq_: TensorList | None, pow: float = 2)

Updates exp_avg_sq_ with EMA of squared tensors, if max_exp_avg_sq_ is not None, updates it with maximum of EMA.

Returns exp_avg_sq_ or max_exp_avg_sq_.

Source code in torchzero/modules/opt_utils.py
def ema_sq_(
    tensors: TensorList,
    exp_avg_sq_: TensorList,
    beta: float | NumberList,
    max_exp_avg_sq_: TensorList | None,
    pow: float = 2,
):
    """
    Updates `exp_avg_sq_` with EMA of squared `tensors`, if `max_exp_avg_sq_` is not None, updates it with maximum of EMA.

    Returns `exp_avg_sq_` or `max_exp_avg_sq_`.
    """
    lerp_power_(tensors=tensors, exp_avg_pow_=exp_avg_sq_,beta=beta,pow=pow)

    # AMSGrad
    if max_exp_avg_sq_ is not None:
        max_exp_avg_sq_.maximum_(exp_avg_sq_)
        exp_avg_sq_ = max_exp_avg_sq_

    return exp_avg_sq_

Exp

Bases: torchzero.core.transform.TensorTransform

Returns exp(input)

Source code in torchzero/modules/ops/unary.py
class Exp(TensorTransform):
    """Returns ``exp(input)``"""
    def __init__(self): super().__init__()
    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        torch._foreach_exp_(tensors)
        return tensors

Fill

Bases: torchzero.core.module.Module

Outputs tensors filled with value

Source code in torchzero/modules/ops/utility.py
class Fill(Module):
    """Outputs tensors filled with ``value``"""
    def __init__(self, value: float):
        defaults = dict(value=value)
        super().__init__(defaults)

    @torch.no_grad
    def apply(self, objective):
        objective.updates = [torch.full_like(p, self.settings[p]['value']) for p in objective.params]
        return objective

Grad

Bases: torchzero.core.module.Module

Outputs the gradient

Source code in torchzero/modules/ops/utility.py
class Grad(Module):
    """Outputs the gradient"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def apply(self, objective):
        objective.updates = [g.clone() for g in objective.get_grads()]
        return objective

GradToNone

Bases: torchzero.core.module.Module

Sets grad attribute to None on objective.

Source code in torchzero/modules/ops/utility.py
class GradToNone(Module):
    """Sets ``grad`` attribute to None on ``objective``."""
    def __init__(self): super().__init__()
    def apply(self, objective):
        objective.grads = None
        return objective

Graft

Bases: torchzero.modules.ops.multi.MultiOperationBase

Outputs direction output rescaled to have the same norm as magnitude output.

Parameters:

  • direction (Chainable) –

    module to use the direction from

  • magnitude (Chainable) –

    module to use the magnitude from

  • tensorwise (bool, default: True ) –

    whether to calculate norm per-tensor or globally. Defaults to True.

  • ord (float, default: 2 ) –

    norm order. Defaults to 2.

  • eps (float, default: 1e-06 ) –

    clips denominator to be no less than this value. Defaults to 1e-6.

  • strength (float, default: 1 ) –

    strength of grafting. Defaults to 1.

Example:

Shampoo grafted to Adam

opt = tz.Optimizer(
    model.parameters(),
    tz.m.GraftModules(
        direction = tz.m.Shampoo(),
        magnitude = tz.m.Adam(),
    ),
    tz.m.LR(1e-3)
)

Reference

Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803.

Source code in torchzero/modules/ops/multi.py
class Graft(MultiOperationBase):
    """Outputs ``direction`` output rescaled to have the same norm as ``magnitude`` output.

    Args:
        direction (Chainable): module to use the direction from
        magnitude (Chainable): module to use the magnitude from
        tensorwise (bool, optional): whether to calculate norm per-tensor or globally. Defaults to True.
        ord (float, optional): norm order. Defaults to 2.
        eps (float, optional): clips denominator to be no less than this value. Defaults to 1e-6.
        strength (float, optional): strength of grafting. Defaults to 1.

    ### Example:

    Shampoo grafted to Adam
    ```python
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.GraftModules(
            direction = tz.m.Shampoo(),
            magnitude = tz.m.Adam(),
        ),
        tz.m.LR(1e-3)
    )
    ```

    Reference:
        [Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803.](https://arxiv.org/pdf/2002.11803)
    """
    def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:Metrics=2, eps:float = 1e-6, strength:float=1):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
        super().__init__(defaults, direction=direction, magnitude=magnitude)

    @torch.no_grad
    def transform(self, objective, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
        tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.defaults)
        return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)

GraftInputToOutput

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs tensors rescaled to have the same norm as magnitude(tensors).

Source code in torchzero/modules/ops/binary.py
class GraftInputToOutput(BinaryOperationBase):
    """Outputs ``tensors`` rescaled to have the same norm as ``magnitude(tensors)``."""
    def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, magnitude=magnitude)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
        return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)

GraftOutputToInput

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs magnitude(tensors) rescaled to have the same norm as tensors

Source code in torchzero/modules/ops/binary.py
class GraftOutputToInput(BinaryOperationBase):
    """Outputs ``magnitude(tensors)`` rescaled to have the same norm as ``tensors``"""

    def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, direction=direction)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], direction: list[torch.Tensor]):
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
        return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)

GramSchimdt

Bases: torchzero.modules.ops.binary.BinaryOperationBase

outputs tensors made orthogonal to other(tensors) via Gram-Schmidt.

Source code in torchzero/modules/ops/binary.py
class GramSchimdt(BinaryOperationBase):
    """outputs tensors made orthogonal to ``other(tensors)`` via Gram-Schmidt."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
        update = TensorList(update); other = TensorList(other)
        min = torch.finfo(update[0].dtype).tiny * 2
        return update - (other*update) / (other*other).clip(min=min)

Identity

Bases: torchzero.core.module.Module

Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

Source code in torchzero/modules/ops/utility.py
class Identity(Module):
    """Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
    def __init__(self, *args, **kwargs): super().__init__()
    def update(self, objective): pass
    def apply(self, objective): return objective
    def get_H(self, objective):
        n = sum(p.numel() for p in objective.params)
        p = objective.params[0]
        return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)

LerpModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Does a linear interpolation of input(tensors) and end(tensors) based on a scalar weight.

The output is given by output = input(tensors) + weight * (end(tensors) - input(tensors))

Source code in torchzero/modules/ops/multi.py
class LerpModules(MultiOperationBase):
    """Does a linear interpolation of ``input(tensors)`` and ``end(tensors)`` based on a scalar ``weight``.

    The output is given by ``output = input(tensors) + weight * (end(tensors) - input(tensors))``
    """
    def __init__(self, input: Chainable, end: Chainable, weight: float):
        defaults = dict(weight=weight)
        super().__init__(defaults, input=input, end=end)

    @torch.no_grad
    def transform(self, objective: Objective, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
        torch._foreach_lerp_(input, end, weight=self.defaults['weight'])
        return input

Maximum

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs maximum(tensors, other(tensors))

Source code in torchzero/modules/ops/binary.py
class Maximum(BinaryOperationBase):
    """Outputs ``maximum(tensors, other(tensors))``"""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
        torch._foreach_maximum_(update, other)
        return update

MaximumModules

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs elementwise maximum of inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class MaximumModules(ReduceOperationBase):
    """Outputs elementwise maximum of ``inputs`` that can be modules or numbers."""
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        maximum = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_maximum_(maximum, v)

        return maximum

Mean

Bases: torchzero.modules.ops.reduce.Sum

Outputs a mean of inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class Mean(Sum):
    """Outputs a mean of ``inputs`` that can be modules or numbers."""
    USE_MEAN = True

USE_MEAN class-attribute

USE_MEAN = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

Minimum

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs minimum(tensors, other(tensors))

Source code in torchzero/modules/ops/binary.py
class Minimum(BinaryOperationBase):
    """Outputs ``minimum(tensors, other(tensors))``"""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
        torch._foreach_minimum_(update, other)
        return update

MinimumModules

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs elementwise minimum of inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class MinimumModules(ReduceOperationBase):
    """Outputs elementwise minimum of ``inputs`` that can be modules or numbers."""
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        minimum = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_minimum_(minimum, v)

        return minimum

Mul

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Multiply tensors by other. other can be a number or a module.

If other is a module, this calculates tensors * other(tensors)

Source code in torchzero/modules/ops/binary.py
class Mul(BinaryOperationBase):
    """Multiply tensors by ``other``. ``other`` can be a number or a module.

    If ``other`` is a module, this calculates ``tensors * other(tensors)``
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        torch._foreach_mul_(update, other)
        return update

MultiOperationBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for operations that use operands. This is an abstract class, subclass it and override transform method to use it.

Methods:

  • transform

    applies the operation to operands

Source code in torchzero/modules/ops/multi.py
class MultiOperationBase(Module, ABC):
    """Base class for operations that use operands. This is an abstract class, subclass it and override `transform` method to use it."""
    def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
        super().__init__(defaults=defaults)

        self.operands = {}
        for k,v in operands.items():

            if isinstance(v, (Module, Sequence)):
                self.set_child(k, v)
                self.operands[k] = self.children[k]
            else:
                self.operands[k] = v

        if not self.children:
            raise ValueError('At least one operand must be a module')

    @abstractmethod
    def transform(self, objective: Objective, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
        """applies the operation to operands"""
        raise NotImplementedError

    def update(self, objective): raise RuntimeError
    def apply(self, objective): raise RuntimeError

    @torch.no_grad
    def step(self, objective: Objective) -> Objective:
        # pass cloned update to all module operands
        processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()

        for k,v in self.operands.items():
            if k in self.children:
                v: Module
                updated_obj = v.step(objective.clone(clone_updates=True))
                processed_operands[k] = updated_obj.get_updates()
                objective.update_attrs_from_clone_(updated_obj) # update loss, grad, etc if this module calculated them

        transformed = self.transform(objective, **processed_operands)
        objective.updates = transformed
        return objective

transform

transform(objective: Objective, **operands: Any | list[Tensor]) -> list[Tensor]

applies the operation to operands

Source code in torchzero/modules/ops/multi.py
@abstractmethod
def transform(self, objective: Objective, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
    """applies the operation to operands"""
    raise NotImplementedError

NanToNum

Bases: torchzero.core.transform.TensorTransform

Convert nan, inf and -inf`` to numbers.

Parameters:

  • nan (optional, default: None ) –

    the value to replace NaNs with. Default is zero.

  • posinf (optional, default: None ) –

    if a Number, the value to replace positive infinity values with. If None, positive infinity values are replaced with the greatest finite value representable by input's dtype. Default is None.

  • neginf (optional, default: None ) –

    if a Number, the value to replace negative infinity values with. If None, negative infinity values are replaced with the lowest finite value representable by input's dtype. Default is None.

Source code in torchzero/modules/ops/unary.py
class NanToNum(TensorTransform):
    """Convert ``nan``, ``inf`` and `-`inf`` to numbers.

    Args:
        nan (optional): the value to replace NaNs with. Default is zero.
        posinf (optional): if a Number, the value to replace positive infinity values with.
            If None, positive infinity values are replaced with the greatest finite value
            representable by input's dtype. Default is None.
        neginf (optional): if a Number, the value to replace negative infinity values with.
            If None, negative infinity values are replaced with the lowest finite value
            representable by input's dtype. Default is None.
    """
    def __init__(self, nan=None, posinf=None, neginf=None):
        defaults = dict(nan=nan, posinf=posinf, neginf=neginf)
        super().__init__(defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        nan, posinf, neginf = unpack_dicts(settings, 'nan', 'posinf', 'neginf')
        return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]

Negate

Bases: torchzero.core.transform.TensorTransform

Returns - input

Source code in torchzero/modules/ops/unary.py
class Negate(TensorTransform):
    """Returns ``- input``"""
    def __init__(self): super().__init__()
    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        torch._foreach_neg_(tensors)
        return tensors

Noop

Bases: torchzero.core.module.Module

Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

Source code in torchzero/modules/ops/utility.py
class Identity(Module):
    """Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
    def __init__(self, *args, **kwargs): super().__init__()
    def update(self, objective): pass
    def apply(self, objective): return objective
    def get_H(self, objective):
        n = sum(p.numel() for p in objective.params)
        p = objective.params[0]
        return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)

Ones

Bases: torchzero.core.module.Module

Outputs ones

Source code in torchzero/modules/ops/utility.py
class Ones(Module):
    """Outputs ones"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def apply(self, objective):
        objective.updates = [torch.ones_like(p) for p in objective.params]
        return objective

Params

Bases: torchzero.core.module.Module

Outputs parameters

Source code in torchzero/modules/ops/utility.py
class Params(Module):
    """Outputs parameters"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def apply(self, objective):
        objective.updates = [p.clone() for p in objective.params]
        return objective

Pow

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Take tensors to the power of exponent. exponent can be a number or a module.

If exponent is a module, this calculates tensors ^ exponent(tensors)

Source code in torchzero/modules/ops/binary.py
class Pow(BinaryOperationBase):
    """Take tensors to the power of ``exponent``. ``exponent`` can be a number or a module.

    If ``exponent`` is a module, this calculates ``tensors ^ exponent(tensors)``
    """
    def __init__(self, exponent: Chainable | float):
        super().__init__({}, exponent=exponent)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
        torch._foreach_pow_(update, exponent)
        return update

PowModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates input ** exponent. input and other can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class PowModules(MultiOperationBase):
    """Calculates ``input ** exponent``. ``input`` and ``other`` can be numbers or modules."""
    def __init__(self, input: Chainable | float, exponent: Chainable | float):
        defaults = {}
        super().__init__(defaults, input=input, exponent=exponent)

    @torch.no_grad
    def transform(self, objective: Objective, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
        if isinstance(input, (int,float)):
            assert isinstance(exponent, list)
            return input ** TensorList(exponent)

        torch._foreach_div_(input, exponent)
        return input

Prod

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs product of inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class Prod(ReduceOperationBase):
    """Outputs product of ``inputs`` that can be modules or numbers."""
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        prod = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_mul_(prod, v)

        return prod

RCopySign

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Returns other(tensors) with sign copied from tensors.

Source code in torchzero/modules/ops/binary.py
class RCopySign(BinaryOperationBase):
    """Returns ``other(tensors)`` with sign copied from tensors."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: list[torch.Tensor]):
        return [o.copysign_(u) for u, o in zip(update, other)]

RDiv

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Divide other by tensors. other can be a number or a module.

If other is a module, this calculates other(tensors) / tensors

Source code in torchzero/modules/ops/binary.py
class RDiv(BinaryOperationBase):
    """Divide ``other`` by tensors. ``other`` can be a number or a module.

    If ``other`` is a module, this calculates ``other(tensors) / tensors``
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        return other / TensorList(update)

RPow

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Take other to the power of tensors. other can be a number or a module.

If other is a module, this calculates other(tensors) ^ tensors

Source code in torchzero/modules/ops/binary.py
class RPow(BinaryOperationBase):
    """Take ``other`` to the power of tensors. ``other`` can be a number or a module.

    If ``other`` is a module, this calculates ``other(tensors) ^ tensors``
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        if isinstance(other, (int, float)): return torch._foreach_pow(other, update) # no in-place
        torch._foreach_pow_(other, update)
        return other

RSub

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Subtract tensors from other. other can be a number or a module.

If other is a module, this calculates other(tensors) - tensors

Source code in torchzero/modules/ops/binary.py
class RSub(BinaryOperationBase):
    """Subtract tensors from ``other``. ``other`` can be a number or a module.

    If ``other`` is a module, this calculates ``other(tensors) - tensors``
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        return other - TensorList(update)

Randn

Bases: torchzero.core.module.Module

Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1.

Source code in torchzero/modules/ops/utility.py
class Randn(Module):
    """Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1."""
    def __init__(self):
        super().__init__({})

    @torch.no_grad
    def apply(self, objective):
        objective.updates = [torch.randn_like(p) for p in objective.params]
        return objective

RandomSample

Bases: torchzero.core.module.Module

Outputs tensors filled with random numbers from distribution depending on value of distribution.

Source code in torchzero/modules/ops/utility.py
class RandomSample(Module):
    """Outputs tensors filled with random numbers from distribution depending on value of ``distribution``."""
    def __init__(self, distribution: Distributions = 'normal', variance:float | None = None):
        defaults = dict(distribution=distribution, variance=variance)
        super().__init__(defaults)

    @torch.no_grad
    def apply(self, objective):
        distribution = self.defaults['distribution']
        variance = self.get_settings(objective.params, 'variance')
        objective.updates = TensorList(objective.params).sample_like(distribution=distribution, variance=variance)
        return objective

Reciprocal

Bases: torchzero.core.transform.TensorTransform

Returns 1 / input

Source code in torchzero/modules/ops/unary.py
class Reciprocal(TensorTransform):
    """Returns ``1 / input``"""
    def __init__(self, eps = 0):
        defaults = dict(eps = eps)
        super().__init__(defaults)
    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        eps = [s['eps'] for s in settings]
        if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
        torch._foreach_reciprocal_(tensors)
        return tensors

ReduceOperationBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override transform method to use it.

Methods:

  • transform

    applies the operation to operands

Source code in torchzero/modules/ops/reduce.py
class ReduceOperationBase(Module, ABC):
    """Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override `transform` method to use it."""
    def __init__(self, defaults: dict[str, Any] | None, *operands: Chainable | Any):
        super().__init__(defaults=defaults)

        self.operands = []
        for i, v in enumerate(operands):

            if isinstance(v, (Module, Sequence)):
                self.set_child(f'operand_{i}', v)
                self.operands.append(self.children[f'operand_{i}'])
            else:
                self.operands.append(v)

        if not self.children:
            raise ValueError('At least one operand must be a module')

    @abstractmethod
    def transform(self, objective: Objective, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
        """applies the operation to operands"""
        raise NotImplementedError

    def update(self, objective): raise RuntimeError
    def apply(self, objective): raise RuntimeError

    @torch.no_grad
    def step(self, objective: Objective) -> Objective:
        # pass cloned update to all module operands
        processed_operands: list[Any | list[torch.Tensor]] = self.operands.copy()

        for i, v in enumerate(self.operands):
            if f'operand_{i}' in self.children:
                v: Module
                updated_obj = v.step(objective.clone(clone_updates=True))
                processed_operands[i] = updated_obj.get_updates()
                objective.update_attrs_from_clone_(updated_obj) # update loss, grad, etc if this module calculated them

        transformed = self.transform(objective, *processed_operands)
        objective.updates = transformed
        return objective

transform

transform(objective: Objective, *operands: Any | list[Tensor]) -> list[Tensor]

applies the operation to operands

Source code in torchzero/modules/ops/reduce.py
@abstractmethod
def transform(self, objective: Objective, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
    """applies the operation to operands"""
    raise NotImplementedError

Sign

Bases: torchzero.core.transform.TensorTransform

Returns sign(input)

Source code in torchzero/modules/ops/unary.py
class Sign(TensorTransform):
    """Returns ``sign(input)``"""
    def __init__(self): super().__init__()
    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        torch._foreach_sign_(tensors)
        return tensors

Sqrt

Bases: torchzero.core.transform.TensorTransform

Returns sqrt(input)

Source code in torchzero/modules/ops/unary.py
class Sqrt(TensorTransform):
    """Returns ``sqrt(input)``"""
    def __init__(self): super().__init__()
    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        torch._foreach_sqrt_(tensors)
        return tensors

SqrtEMASquared

Bases: torchzero.core.transform.TensorTransform

Maintains an exponential moving average of squared updates, outputs optionally debiased square root.

Parameters:

  • beta (float, default: 0.999 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • debiased (bool, default: False ) –

    whether to multiply the output by a debiasing term from the Adam method. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Methods:

  • SQRT_EMA_SQ_FN

    Updates exp_avg_sq_ with EMA of squared tensors and calculates it's square root,

Source code in torchzero/modules/ops/higher_level.py
class SqrtEMASquared(TensorTransform):
    """Maintains an exponential moving average of squared updates, outputs optionally debiased square root.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    SQRT_EMA_SQ_FN: staticmethod = staticmethod(sqrt_ema_sq_)
    def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2,):
        defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
        super().__init__(defaults)


    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return self.SQRT_EMA_SQ_FN(
            TensorList(tensors),
            exp_avg_sq_=exp_avg_sq,
            beta=beta,
            max_exp_avg_sq_=max_exp_avg_sq,
            debiased=debiased,
            step=step,
            pow=pow,
        )

SQRT_EMA_SQ_FN

SQRT_EMA_SQ_FN(tensors: TensorList, exp_avg_sq_: TensorList, beta: float | NumberList, max_exp_avg_sq_: TensorList | None, debiased: bool, step: int, pow: float = 2, ema_sq_fn: Callable = ema_sq_)

Updates exp_avg_sq_ with EMA of squared tensors and calculates it's square root, with optional AMSGrad and debiasing.

Returns new tensors.

Source code in torchzero/modules/opt_utils.py
def sqrt_ema_sq_(
    tensors: TensorList,
    exp_avg_sq_: TensorList,
    beta: float | NumberList,
    max_exp_avg_sq_: TensorList | None,
    debiased: bool,
    step: int,
    pow: float = 2,
    ema_sq_fn: Callable = ema_sq_,
):
    """
    Updates `exp_avg_sq_` with EMA of squared `tensors` and calculates it's square root,
    with optional AMSGrad and debiasing.

    Returns new tensors.
    """
    exp_avg_sq_=ema_sq_fn(
        tensors=tensors,
        exp_avg_sq_=exp_avg_sq_,
        beta=beta,
        max_exp_avg_sq_=max_exp_avg_sq_,
        pow=pow,
    )

    sqrt_exp_avg_sq = root(exp_avg_sq_, pow, inplace=False)

    if debiased: sqrt_exp_avg_sq = debias_second_momentum(sqrt_exp_avg_sq, step=step, beta=beta, pow=pow, inplace=True)
    return sqrt_exp_avg_sq

Sub

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Subtract other from tensors. other can be a number or a module.

If other is a module, this calculates :code:tensors - other(tensors)

Source code in torchzero/modules/ops/binary.py
class Sub(BinaryOperationBase):
    """Subtract ``other`` from tensors. ``other`` can be a number or a module.

    If ``other`` is a module, this calculates :code:`tensors - other(tensors)`
    """
    def __init__(self, other: Chainable | float, alpha: float = 1):
        defaults = dict(alpha=alpha)
        super().__init__(defaults, other=other)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.defaults['alpha'])
        else: torch._foreach_sub_(update, other, alpha=self.defaults['alpha'])
        return update

SubModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates input - other. input and other can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class SubModules(MultiOperationBase):
    """Calculates ``input - other``. ``input`` and ``other`` can be numbers or modules."""
    def __init__(self, input: Chainable | float, other: Chainable | float, alpha: float = 1):
        defaults = dict(alpha=alpha)
        super().__init__(defaults, input=input, other=other)

    @torch.no_grad
    def transform(self, objective: Objective, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
        alpha = self.defaults['alpha']

        if isinstance(input, (int,float)):
            assert isinstance(other, list)
            return input - TensorList(other).mul_(alpha)

        if isinstance(other, (int, float)): torch._foreach_sub_(input, other * alpha)
        else: torch._foreach_sub_(input, other, alpha=alpha)
        return input

Sum

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs sum of inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class Sum(ReduceOperationBase):
    """Outputs sum of ``inputs`` that can be modules or numbers."""
    USE_MEAN = False
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        sum = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_add_(sum, v)

        if self.USE_MEAN and len(sorted_inputs) > 1: torch._foreach_div_(sum, len(sorted_inputs))
        return sum

USE_MEAN class-attribute

USE_MEAN = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

Threshold

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs tensors thresholded such that values above threshold are set to value.

Source code in torchzero/modules/ops/binary.py
class Threshold(BinaryOperationBase):
    """Outputs tensors thresholded such that values above ``threshold`` are set to ``value``."""
    def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
        defaults = dict(update_above=update_above)
        super().__init__(defaults, threshold=threshold, value=value)

    @torch.no_grad
    def transform(self, objective, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
        update_above = self.defaults['update_above']
        update = TensorList(update)
        if update_above:
            if isinstance(value, list): return update.where(update>threshold, value)
            return update.masked_fill_(update<=threshold, value)

        if isinstance(value, list): return update.where(update<threshold, value)
        return update.masked_fill_(update>=threshold, value)

UnaryLambda

Bases: torchzero.core.transform.TensorTransform

Applies fn to input tensors.

fn must accept and return a list of tensors.

Source code in torchzero/modules/ops/unary.py
class UnaryLambda(TensorTransform):
    """Applies ``fn`` to input tensors.

    ``fn`` must accept and return a list of tensors.
    """
    def __init__(self, fn):
        defaults = dict(fn=fn)
        super().__init__(defaults=defaults)

    @torch.no_grad
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        return settings[0]['fn'](tensors)

UnaryParameterwiseLambda

Bases: torchzero.core.transform.TensorTransform

Applies fn to each input tensor.

fn must accept and return a tensor.

Source code in torchzero/modules/ops/unary.py
class UnaryParameterwiseLambda(TensorTransform):
    """Applies ``fn`` to each input tensor.

    ``fn`` must accept and return a tensor.
    """
    def __init__(self, fn):
        defaults = dict(fn=fn)
        super().__init__(defaults=defaults)

    @torch.no_grad
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        return setting['fn'](tensor)

Uniform

Bases: torchzero.core.module.Module

Outputs tensors filled with random numbers from uniform distribution between low and high.

Source code in torchzero/modules/ops/utility.py
class Uniform(Module):
    """Outputs tensors filled with random numbers from uniform distribution between ``low`` and ``high``."""
    def __init__(self, low: float, high: float):
        defaults = dict(low=low, high=high)
        super().__init__(defaults)

    @torch.no_grad
    def apply(self, objective):
        low,high = self.get_settings(objective.params, 'low','high')
        objective.updates = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(objective.params, low, high)]
        return objective

UpdateToNone

Bases: torchzero.core.module.Module

Sets update attribute to None on var.

Source code in torchzero/modules/ops/utility.py
class UpdateToNone(Module):
    """Sets ``update`` attribute to None on ``var``."""
    def __init__(self): super().__init__()
    def apply(self, objective):
        objective.updates = None
        return objective

WeightedMean

Bases: torchzero.modules.ops.reduce.WeightedSum

Outputs weighted mean of inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class WeightedMean(WeightedSum):
    """Outputs weighted mean of ``inputs`` that can be modules or numbers."""
    USE_MEAN = True

USE_MEAN class-attribute

USE_MEAN = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

WeightedSum

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs a weighted sum of inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class WeightedSum(ReduceOperationBase):
    """Outputs a weighted sum of ``inputs`` that can be modules or numbers."""
    USE_MEAN = False
    def __init__(self, *inputs: Chainable | float, weights: Iterable[float]):
        weights = list(weights)
        if len(inputs) != len(weights):
            raise ValueError(f'Number of inputs {len(inputs)} must match number of weights {len(weights)}')
        defaults = dict(weights=weights)
        super().__init__(defaults=defaults, *inputs)

    @torch.no_grad
    def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        weights = self.defaults['weights']
        sum = cast(list, sorted_inputs[0])
        torch._foreach_mul_(sum, weights[0])
        if len(sorted_inputs) > 1:
            for v, w in zip(sorted_inputs[1:], weights[1:]):
                if isinstance(v, (int, float)): torch._foreach_add_(sum, v*w)
                else: torch._foreach_add_(sum, v, alpha=w)

        if self.USE_MEAN and len(sorted_inputs) > 1: torch._foreach_div_(sum, len(sorted_inputs))
        return sum

USE_MEAN class-attribute

USE_MEAN = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

Zeros

Bases: torchzero.core.module.Module

Outputs zeros

Source code in torchzero/modules/ops/utility.py
class Zeros(Module):
    """Outputs zeros"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def apply(self, objective):
        objective.updates = [torch.zeros_like(p) for p in objective.params]
        return objective