Skip to content

Operations

This subpackage contains operations like adding modules, subtracting, grafting, tracking the maximum, etc.

Classes:

  • Abs

    Returns :code:abs(input)

  • AccumulateMaximum

    Accumulates maximum of all past updates.

  • AccumulateMean

    Accumulates mean of all past updates.

  • AccumulateMinimum

    Accumulates minimum of all past updates.

  • AccumulateProduct

    Accumulates product of all past updates.

  • AccumulateSum

    Accumulates sum of all past updates.

  • Add

    Add :code:other to tensors. :code:other can be a number or a module.

  • BinaryOperationBase

    Base class for operations that use update as the first operand. This is an abstract class, subclass it and override transform method to use it.

  • CenteredEMASquared

    Maintains a centered exponential moving average of squared updates. This also maintains an additional

  • CenteredSqrtEMASquared

    Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.

  • Clip

    clip tensors to be in :code:(min, max) range. :code:min and :code:`max: can be None, numbers or modules.

  • ClipModules

    Calculates :code:input(tensors).clip(min, max). :code:min and :code:max can be numbers or modules.

  • Clone

    Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations

  • CopyMagnitude

    Returns :code:other(tensors) with sign copied from tensors.

  • CopySign

    Returns tensors with sign copied from :code:other(tensors).

  • CustomUnaryOperation

    Applies :code:getattr(tensor, name) to each tensor

  • Debias

    Multiplies the update by an Adam debiasing term based first and/or second momentum.

  • Debias2

    Multiplies the update by an Adam debiasing term based on the second momentum.

  • Div

    Divide tensors by :code:other. :code:other can be a number or a module.

  • DivModules

    Calculates :code:input / other. :code:input and :code:other can be numbers or modules.

  • EMASquared

    Maintains an exponential moving average of squared updates.

  • Exp

    Returns :code:exp(input)

  • Fill

    Outputs tensors filled with :code:value

  • Grad

    Outputs the gradient

  • GradToNone

    Sets :code:grad attribute to None on :code:var.

  • Graft

    Outputs tensors rescaled to have the same norm as :code:magnitude(tensors).

  • GraftModules

    Outputs :code:direction output rescaled to have the same norm as :code:magnitude output.

  • GraftToUpdate

    Outputs :code:magnitude(tensors) rescaled to have the same norm as tensors

  • GramSchimdt

    outputs tensors made orthogonal to other(tensors) via Gram-Schmidt.

  • Identity

    Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

  • LerpModules

    Does a linear interpolation of :code:input(tensors) and :code:end(tensors) based on a scalar :code:weight.

  • Maximum

    Outputs :code:maximum(tensors, other(tensors))

  • MaximumModules

    Outputs elementwise maximum of :code:inputs that can be modules or numbers.

  • Mean

    Outputs a mean of :code:inputs that can be modules or numbers.

  • Minimum

    Outputs :code:minimum(tensors, other(tensors))

  • MinimumModules

    Outputs elementwise minimum of :code:inputs that can be modules or numbers.

  • Mul

    Multiply tensors by :code:other. :code:other can be a number or a module.

  • MultiOperationBase

    Base class for operations that use operands. This is an abstract class, subclass it and override transform method to use it.

  • NanToNum

    Convert nan, inf and -inf to numbers.

  • Negate

    Returns :code:- input

  • Noop

    Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

  • Ones

    Outputs ones

  • Params

    Outputs parameters

  • Pow

    Take tensors to the power of :code:exponent. :code:exponent can be a number or a module.

  • PowModules

    Calculates :code:input ** exponent. :code:input and :code:other can be numbers or modules.

  • Prod

    Outputs product of :code:inputs that can be modules or numbers.

  • RCopySign

    Returns :code:other(tensors) with sign copied from tensors.

  • RDiv

    Divide :code:other by tensors. :code:other can be a number or a module.

  • RGraft

    Outputs :code:magnitude(tensors) rescaled to have the same norm as tensors

  • RPow

    Take :code:other to the power of tensors. :code:other can be a number or a module.

  • RSub

    Subtract tensors from :code:other. :code:other can be a number or a module.

  • Randn

    Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1.

  • RandomSample

    Outputs tensors filled with random numbers from distribution depending on value of :code:distribution.

  • Reciprocal

    Returns :code:1 / input

  • ReduceOperationBase

    Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override transform method to use it.

  • Sign

    Returns :code:sign(input)

  • Sqrt

    Returns :code:sqrt(input)

  • SqrtEMASquared

    Maintains an exponential moving average of squared updates, outputs optionally debiased square root.

  • Sub

    Subtract :code:other from tensors. :code:other can be a number or a module.

  • SubModules

    Calculates :code:input - other. :code:input and :code:other can be numbers or modules.

  • Sum

    Outputs sum of :code:inputs that can be modules or numbers.

  • Threshold

    Outputs tensors thresholded such that values above :code:threshold are set to :code:value.

  • UnaryLambda

    Applies :code:fn to input tensors.

  • UnaryParameterwiseLambda

    Applies :code:fn to each input tensor.

  • Uniform

    Outputs tensors filled with random numbers from uniform distribution between :code:low and :code:high.

  • UpdateToNone

    Sets :code:update attribute to None on :code:var.

  • WeightedMean

    Outputs weighted mean of :code:inputs that can be modules or numbers.

  • WeightedSum
  • Zeros

    Outputs zeros

Abs

Bases: torchzero.core.transform.Transform

Returns :code:abs(input)

Source code in torchzero/modules/ops/unary.py
class Abs(Transform):
    """Returns :code:`abs(input)`"""
    def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        torch._foreach_abs_(tensors)
        return tensors

AccumulateMaximum

Bases: torchzero.core.transform.Transform

Accumulates maximum of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateMaximum(Transform):
    """Accumulates maximum of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0, target: Target = 'update',):
        defaults = dict(decay=decay)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return maximum.maximum_(tensors).lazy_mul(decay, clone=True)

AccumulateMean

Bases: torchzero.core.transform.Transform

Accumulates mean of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateMean(Transform):
    """Accumulates mean of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0, target: Target = 'update',):
        defaults = dict(decay=decay)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1
        mean = unpack_states(states, tensors, 'mean', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return mean.add_(tensors).lazy_mul(decay, clone=True).div_(step)

AccumulateMinimum

Bases: torchzero.core.transform.Transform

Accumulates minimum of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateMinimum(Transform):
    """Accumulates minimum of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0, target: Target = 'update',):
        defaults = dict(decay=decay)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return minimum.minimum_(tensors).lazy_mul(decay, clone=True)

AccumulateProduct

Bases: torchzero.core.transform.Transform

Accumulates product of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateProduct(Transform):
    """Accumulates product of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0, target: Target = 'update',):
        defaults = dict(decay=decay)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        prod = unpack_states(states, tensors, 'prod', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return prod.mul_(tensors).lazy_mul(decay, clone=True)

AccumulateSum

Bases: torchzero.core.transform.Transform

Accumulates sum of all past updates.

Parameters:

  • decay (float, default: 0 ) –

    decays the accumulator. Defaults to 0.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/accumulate.py
class AccumulateSum(Transform):
    """Accumulates sum of all past updates.

    Args:
        decay (float, optional): decays the accumulator. Defaults to 0.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, decay: float = 0, target: Target = 'update',):
        defaults = dict(decay=decay)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        sum = unpack_states(states, tensors, 'sum', cls=TensorList)
        decay = [1-s['decay'] for s in settings]
        return sum.add_(tensors).lazy_mul(decay, clone=True)

Add

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Add :code:other to tensors. :code:other can be a number or a module.

If :code:other is a module, this calculates :code:tensors + other(tensors)

Source code in torchzero/modules/ops/binary.py
class Add(BinaryOperationBase):
    """Add :code:`other` to tensors. :code:`other` can be a number or a module.

    If :code:`other` is a module, this calculates :code:`tensors + other(tensors)`
    """
    def __init__(self, other: Chainable | float, alpha: float = 1):
        defaults = dict(alpha=alpha)
        super().__init__(defaults, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.defaults['alpha'])
        else: torch._foreach_add_(update, other, alpha=self.defaults['alpha'])
        return update

BinaryOperationBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for operations that use update as the first operand. This is an abstract class, subclass it and override transform method to use it.

Methods:

  • transform

    applies the operation to operands

Source code in torchzero/modules/ops/binary.py
class BinaryOperationBase(Module, ABC):
    """Base class for operations that use update as the first operand. This is an abstract class, subclass it and override `transform` method to use it."""
    def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
        super().__init__(defaults=defaults)

        self.operands = {}
        for k,v in operands.items():

            if isinstance(v, (Module, Sequence)):
                self.set_child(k, v)
                self.operands[k] = self.children[k]
            else:
                self.operands[k] = v

    @abstractmethod
    def transform(self, var: Var, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
        """applies the operation to operands"""
        raise NotImplementedError

    @torch.no_grad
    def step(self, var: Var) -> Var:
        # pass cloned update to all module operands
        processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()

        for k,v in self.operands.items():
            if k in self.children:
                v: Module
                updated_var = v.step(var.clone(clone_update=True))
                processed_operands[k] = updated_var.get_update()
                var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them

        transformed = self.transform(var, update=var.get_update(), **processed_operands)
        var.update = list(transformed)
        return var

transform

transform(var: Var, update: list[Tensor], **operands: Any | list[Tensor]) -> Iterable[Tensor]

applies the operation to operands

Source code in torchzero/modules/ops/binary.py
@abstractmethod
def transform(self, var: Var, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
    """applies the operation to operands"""
    raise NotImplementedError

CenteredEMASquared

Bases: torchzero.core.transform.Transform

Maintains a centered exponential moving average of squared updates. This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.

Parameters:

  • beta (float, default: 0.99 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Source code in torchzero/modules/ops/higher_level.py
class CenteredEMASquared(Transform):
    """Maintains a centered exponential moving average of squared updates. This also maintains an additional
    exponential moving average of un-squared updates, square of which is subtracted from the EMA.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2):
        defaults = dict(beta=beta, amsgrad=amsgrad, pow=pow)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        amsgrad, pow = itemgetter('amsgrad', 'pow')(settings[0])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return centered_ema_sq_(
            TensorList(tensors),
            exp_avg_=exp_avg,
            exp_avg_sq_=exp_avg_sq,
            beta=beta,
            max_exp_avg_sq_=max_exp_avg_sq,
            pow=pow,
        ).clone()

CenteredSqrtEMASquared

Bases: torchzero.core.transform.Transform

Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root. This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.

Parameters:

  • beta (float, default: 0.99 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • debiased (bool, default: False ) –

    whether to multiply the output by a debiasing term from the Adam method. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Source code in torchzero/modules/ops/higher_level.py
class CenteredSqrtEMASquared(Transform):
    """Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.
    This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2):
        defaults = dict(beta=beta, amsgrad=amsgrad, debiased=debiased, pow=pow)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return sqrt_centered_ema_sq_(
            TensorList(tensors),
            exp_avg_=exp_avg,
            exp_avg_sq_=exp_avg_sq,
            beta=beta,
            debiased=debiased,
            step=step,
            max_exp_avg_sq_=max_exp_avg_sq,
            pow=pow,
        )

Clip

Bases: torchzero.modules.ops.binary.BinaryOperationBase

clip tensors to be in :code:(min, max) range. :code:min and :code:`max: can be None, numbers or modules.

If code:min and :code:max: are modules, this calculates :code:tensors.clip(min(tensors), max(tensors)).

Source code in torchzero/modules/ops/binary.py
class Clip(BinaryOperationBase):
    """clip tensors to be in  :code:`(min, max)` range. :code:`min` and :code:`max: can be None, numbers or modules.

    If code:`min` and :code:`max`:  are modules, this calculates :code:`tensors.clip(min(tensors), max(tensors))`.
    """
    def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
        super().__init__({}, min=min, max=max)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
        return TensorList(update).clamp_(min=min,  max=max)

ClipModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates :code:input(tensors).clip(min, max). :code:min and :code:max can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class ClipModules(MultiOperationBase):
    """Calculates :code:`input(tensors).clip(min, max)`. :code:`min` and :code:`max` can be numbers or modules."""
    def __init__(self, input: Chainable, min: float | Chainable | None = None, max: float | Chainable | None = None):
        defaults = {}
        super().__init__(defaults, input=input, min=min, max=max)

    @torch.no_grad
    def transform(self, var: Var, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
        return TensorList(input).clamp_(min=min, max=max)

Clone

Bases: torchzero.core.module.Module

Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations

Source code in torchzero/modules/ops/utility.py
class Clone(Module):
    """Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def step(self, var):
        var.update = [u.clone() for u in var.get_update()]
        return var

CopyMagnitude

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Returns :code:other(tensors) with sign copied from tensors.

Source code in torchzero/modules/ops/binary.py
class RCopySign(BinaryOperationBase):
    """Returns :code:`other(tensors)` with sign copied from tensors."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
        return [o.copysign_(u) for u, o in zip(update, other)]

CopySign

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Returns tensors with sign copied from :code:other(tensors).

Source code in torchzero/modules/ops/binary.py
class CopySign(BinaryOperationBase):
    """Returns tensors with sign copied from :code:`other(tensors)`."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
        return [u.copysign_(o) for u, o in zip(update, other)]

CustomUnaryOperation

Bases: torchzero.core.transform.Transform

Applies :code:getattr(tensor, name) to each tensor

Source code in torchzero/modules/ops/unary.py
class CustomUnaryOperation(Transform):
    """Applies :code:`getattr(tensor, name)` to each tensor
    """
    def __init__(self, name: str, target: "Target" = 'update'):
        defaults = dict(name=name)
        super().__init__(defaults=defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        return getattr(tensors, settings[0]['name'])()

Debias

Bases: torchzero.core.transform.Transform

Multiplies the update by an Adam debiasing term based first and/or second momentum.

Parameters:

  • beta1 (float | None, default: None ) –

    first momentum, should be the same as first momentum used in modules before. Defaults to None.

  • beta2 (float | None, default: None ) –

    second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.

  • alpha (float, default: 1 ) –

    learning rate. Defaults to 1.

  • pow (float, default: 2 ) –

    power, assumes absolute value is used. Defaults to 2.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/higher_level.py
class Debias(Transform):
    """Multiplies the update by an Adam debiasing term based first and/or second momentum.

    Args:
        beta1 (float | None, optional):
            first momentum, should be the same as first momentum used in modules before. Defaults to None.
        beta2 (float | None, optional):
            second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
        alpha (float, optional): learning rate. Defaults to 1.
        pow (float, optional): power, assumes absolute value is used. Defaults to 2.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, beta1: float | None = None, beta2: float | None = None, alpha: float = 1, pow:float=2, target: Target = 'update',):
        defaults = dict(beta1=beta1, beta2=beta2, alpha=alpha, pow=pow)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        pow = settings[0]['pow']
        alpha, beta1, beta2 = unpack_dicts(settings, 'alpha', 'beta1', 'beta2', cls=NumberList)

        return debias(TensorList(tensors), step=step, beta1=beta1, beta2=beta2, alpha=alpha, pow=pow, inplace=True)

Debias2

Bases: torchzero.core.transform.Transform

Multiplies the update by an Adam debiasing term based on the second momentum.

Parameters:

  • beta (float | None, default: 0.999 ) –

    second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.

  • pow (float, default: 2 ) –

    power, assumes absolute value is used. Defaults to 2.

  • target (Literal, default: 'update' ) –

    target. Defaults to 'update'.

Source code in torchzero/modules/ops/higher_level.py
class Debias2(Transform):
    """Multiplies the update by an Adam debiasing term based on the second momentum.

    Args:
        beta (float | None, optional):
            second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
        pow (float, optional): power, assumes absolute value is used. Defaults to 2.
        target (Target, optional): target. Defaults to 'update'.
    """
    def __init__(self, beta: float = 0.999, pow: float = 2, target: Target = 'update',):
        defaults = dict(beta=beta, pow=pow)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        pow = settings[0]['pow']
        beta = NumberList(s['beta'] for s in settings)
        return debias_second_momentum(TensorList(tensors), step=step, beta=beta, pow=pow, inplace=True)

Div

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Divide tensors by :code:other. :code:other can be a number or a module.

If :code:other is a module, this calculates :code:tensors / other(tensors)

Source code in torchzero/modules/ops/binary.py
class Div(BinaryOperationBase):
    """Divide tensors by :code:`other`. :code:`other` can be a number or a module.

    If :code:`other` is a module, this calculates :code:`tensors / other(tensors)`
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        torch._foreach_div_(update, other)
        return update

DivModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates :code:input / other. :code:input and :code:other can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class DivModules(MultiOperationBase):
    """Calculates :code:`input / other`. :code:`input` and :code:`other` can be numbers or modules."""
    def __init__(self, input: Chainable | float, other: Chainable | float, other_first:bool=False):
        defaults = {}
        if other_first: super().__init__(defaults, other=other, input=input)
        else: super().__init__(defaults, input=input, other=other)

    @torch.no_grad
    def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
        if isinstance(input, (int,float)):
            assert isinstance(other, list)
            return input / TensorList(other)

        torch._foreach_div_(input, other)
        return input

EMASquared

Bases: torchzero.core.transform.Transform

Maintains an exponential moving average of squared updates.

Parameters:

  • beta (float, default: 0.999 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Methods:

  • EMA_SQ_FN

    Updates exp_avg_sq_ with EMA of squared tensors, if max_exp_avg_sq_ is not None, updates it with maximum of EMA.

Source code in torchzero/modules/ops/higher_level.py
class EMASquared(Transform):
    """Maintains an exponential moving average of squared updates.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    EMA_SQ_FN: staticmethod = staticmethod(ema_sq_)

    def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2):
        defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
        super().__init__(defaults, uses_grad=False)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return self.EMA_SQ_FN(TensorList(tensors), exp_avg_sq_=exp_avg_sq, beta=beta, max_exp_avg_sq_=max_exp_avg_sq, pow=pow).clone()

EMA_SQ_FN

EMA_SQ_FN(tensors: TensorList, exp_avg_sq_: TensorList, beta: float | NumberList, max_exp_avg_sq_: TensorList | None, pow: float = 2)

Updates exp_avg_sq_ with EMA of squared tensors, if max_exp_avg_sq_ is not None, updates it with maximum of EMA.

Returns exp_avg_sq_ or max_exp_avg_sq_.

Source code in torchzero/modules/functional.py
def ema_sq_(
    tensors: TensorList,
    exp_avg_sq_: TensorList,
    beta: float | NumberList,
    max_exp_avg_sq_: TensorList | None,
    pow: float = 2,
):
    """
    Updates `exp_avg_sq_` with EMA of squared `tensors`, if `max_exp_avg_sq_` is not None, updates it with maximum of EMA.

    Returns `exp_avg_sq_` or `max_exp_avg_sq_`.
    """
    lerp_power_(tensors=tensors, exp_avg_pow_=exp_avg_sq_,beta=beta,pow=pow)

    # AMSGrad
    if max_exp_avg_sq_ is not None:
        max_exp_avg_sq_.maximum_(exp_avg_sq_)
        exp_avg_sq_ = max_exp_avg_sq_

    return exp_avg_sq_

Exp

Bases: torchzero.core.transform.Transform

Returns :code:exp(input)

Source code in torchzero/modules/ops/unary.py
class Exp(Transform):
    """Returns :code:`exp(input)`"""
    def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        torch._foreach_exp_(tensors)
        return tensors

Fill

Bases: torchzero.core.module.Module

Outputs tensors filled with :code:value

Source code in torchzero/modules/ops/utility.py
class Fill(Module):
    """Outputs tensors filled with :code:`value`"""
    def __init__(self, value: float):
        defaults = dict(value=value)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        var.update = [torch.full_like(p, self.settings[p]['value']) for p in var.params]
        return var

Grad

Bases: torchzero.core.module.Module

Outputs the gradient

Source code in torchzero/modules/ops/utility.py
class Grad(Module):
    """Outputs the gradient"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def step(self, var):
        var.update = [g.clone() for g in var.get_grad()]
        return var

GradToNone

Bases: torchzero.core.module.Module

Sets :code:grad attribute to None on :code:var.

Source code in torchzero/modules/ops/utility.py
class GradToNone(Module):
    """Sets :code:`grad` attribute to None on :code:`var`."""
    def __init__(self): super().__init__()
    def step(self, var):
        var.grad = None
        return var

Graft

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs tensors rescaled to have the same norm as :code:magnitude(tensors).

Source code in torchzero/modules/ops/binary.py
class Graft(BinaryOperationBase):
    """Outputs tensors rescaled to have the same norm as :code:`magnitude(tensors)`."""
    def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, magnitude=magnitude)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
        return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)

GraftModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Outputs :code:direction output rescaled to have the same norm as :code:magnitude output.

Parameters:

  • direction (Chainable) –

    module to use the direction from

  • magnitude (Chainable) –

    module to use the magnitude from

  • tensorwise (bool, default: True ) –

    whether to calculate norm per-tensor or globally. Defaults to True.

  • ord (float, default: 2 ) –

    norm order. Defaults to 2.

  • eps (float, default: 1e-06 ) –

    clips denominator to be no less than this value. Defaults to 1e-6.

  • strength (float, default: 1 ) –

    strength of grafting. Defaults to 1.

Example

Shampoo grafted to Adam

.. code-block:: python

opt = tz.Modular(
    model.parameters(),
    tz.m.GraftModules(
        direction = tz.m.Shampoo(),
        magnitude = tz.m.Adam(),
    ),
    tz.m.LR(1e-3)
)
Reference

Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803. https://arxiv.org/pdf/2002.11803

Source code in torchzero/modules/ops/multi.py
class GraftModules(MultiOperationBase):
    """Outputs :code:`direction` output rescaled to have the same norm as :code:`magnitude` output.

    Args:
        direction (Chainable): module to use the direction from
        magnitude (Chainable): module to use the magnitude from
        tensorwise (bool, optional): whether to calculate norm per-tensor or globally. Defaults to True.
        ord (float, optional): norm order. Defaults to 2.
        eps (float, optional): clips denominator to be no less than this value. Defaults to 1e-6.
        strength (float, optional): strength of grafting. Defaults to 1.

    Example:
        Shampoo grafted to Adam

        .. code-block:: python

            opt = tz.Modular(
                model.parameters(),
                tz.m.GraftModules(
                    direction = tz.m.Shampoo(),
                    magnitude = tz.m.Adam(),
                ),
                tz.m.LR(1e-3)
            )

    Reference:
        Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803. https://arxiv.org/pdf/2002.11803
    """
    def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:Metrics=2, eps:float = 1e-6, strength:float=1):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
        super().__init__(defaults, direction=direction, magnitude=magnitude)

    @torch.no_grad
    def transform(self, var, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
        tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.defaults)
        return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)

GraftToUpdate

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs :code:magnitude(tensors) rescaled to have the same norm as tensors

Source code in torchzero/modules/ops/binary.py
class RGraft(BinaryOperationBase):
    """Outputs :code:`magnitude(tensors)` rescaled to have the same norm as tensors"""

    def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, direction=direction)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tensor]):
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
        return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)

GramSchimdt

Bases: torchzero.modules.ops.binary.BinaryOperationBase

outputs tensors made orthogonal to other(tensors) via Gram-Schmidt.

Source code in torchzero/modules/ops/binary.py
class GramSchimdt(BinaryOperationBase):
    """outputs tensors made orthogonal to `other(tensors)` via Gram-Schmidt."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
        update = TensorList(update); other = TensorList(other)
        min = torch.finfo(update[0].dtype).tiny * 2
        return update - (other*update) / (other*other).clip(min=min)

Identity

Bases: torchzero.core.module.Module

Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

Source code in torchzero/modules/ops/utility.py
class Identity(Module):
    """Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
    def __init__(self, *args, **kwargs): super().__init__()
    def step(self, var): return var
    def get_H(self, var):
        n = sum(p.numel() for p in var.params)
        p = var.params[0]
        return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)

LerpModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Does a linear interpolation of :code:input(tensors) and :code:end(tensors) based on a scalar :code:weight.

The output is given by :code:output = input(tensors) + weight * (end(tensors) - input(tensors))

Source code in torchzero/modules/ops/multi.py
class LerpModules(MultiOperationBase):
    """Does a linear interpolation of :code:`input(tensors)` and :code:`end(tensors)` based on a scalar :code:`weight`.

    The output is given by :code:`output = input(tensors) + weight * (end(tensors) - input(tensors))`
    """
    def __init__(self, input: Chainable, end: Chainable, weight: float):
        defaults = dict(weight=weight)
        super().__init__(defaults, input=input, end=end)

    @torch.no_grad
    def transform(self, var: Var, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
        torch._foreach_lerp_(input, end, weight=self.defaults['weight'])
        return input

Maximum

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs :code:maximum(tensors, other(tensors))

Source code in torchzero/modules/ops/binary.py
class Maximum(BinaryOperationBase):
    """Outputs :code:`maximum(tensors, other(tensors))`"""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
        torch._foreach_maximum_(update, other)
        return update

MaximumModules

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs elementwise maximum of :code:inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class MaximumModules(ReduceOperationBase):
    """Outputs elementwise maximum of :code:`inputs` that can be modules or numbers."""
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        maximum = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_maximum_(maximum, v)

        return maximum

Mean

Bases: torchzero.modules.ops.reduce.Sum

Outputs a mean of :code:inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class Mean(Sum):
    """Outputs a mean of :code:`inputs` that can be modules or numbers."""
    USE_MEAN = True

USE_MEAN class-attribute

USE_MEAN = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

Minimum

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs :code:minimum(tensors, other(tensors))

Source code in torchzero/modules/ops/binary.py
class Minimum(BinaryOperationBase):
    """Outputs :code:`minimum(tensors, other(tensors))`"""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
        torch._foreach_minimum_(update, other)
        return update

MinimumModules

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs elementwise minimum of :code:inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class MinimumModules(ReduceOperationBase):
    """Outputs elementwise minimum of :code:`inputs` that can be modules or numbers."""
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        minimum = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_minimum_(minimum, v)

        return minimum

Mul

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Multiply tensors by :code:other. :code:other can be a number or a module.

If :code:other is a module, this calculates :code:tensors * other(tensors)

Source code in torchzero/modules/ops/binary.py
class Mul(BinaryOperationBase):
    """Multiply tensors by :code:`other`. :code:`other` can be a number or a module.

    If :code:`other` is a module, this calculates :code:`tensors * other(tensors)`
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        torch._foreach_mul_(update, other)
        return update

MultiOperationBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for operations that use operands. This is an abstract class, subclass it and override transform method to use it.

Methods:

  • transform

    applies the operation to operands

Source code in torchzero/modules/ops/multi.py
class MultiOperationBase(Module, ABC):
    """Base class for operations that use operands. This is an abstract class, subclass it and override `transform` method to use it."""
    def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
        super().__init__(defaults=defaults)

        self.operands = {}
        for k,v in operands.items():

            if isinstance(v, (Module, Sequence)):
                self.set_child(k, v)
                self.operands[k] = self.children[k]
            else:
                self.operands[k] = v

        if not self.children:
            raise ValueError('At least one operand must be a module')

    @abstractmethod
    def transform(self, var: Var, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
        """applies the operation to operands"""
        raise NotImplementedError

    @torch.no_grad
    def step(self, var: Var) -> Var:
        # pass cloned update to all module operands
        processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()

        for k,v in self.operands.items():
            if k in self.children:
                v: Module
                updated_var = v.step(var.clone(clone_update=True))
                processed_operands[k] = updated_var.get_update()
                var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them

        transformed = self.transform(var, **processed_operands)
        var.update = transformed
        return var

transform

transform(var: Var, **operands: Any | list[Tensor]) -> list[Tensor]

applies the operation to operands

Source code in torchzero/modules/ops/multi.py
@abstractmethod
def transform(self, var: Var, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
    """applies the operation to operands"""
    raise NotImplementedError

NanToNum

Bases: torchzero.core.transform.Transform

Convert nan, inf and -inf to numbers.

Parameters:

  • nan (optional, default: None ) –

    the value to replace NaNs with. Default is zero.

  • posinf (optional, default: None ) –

    if a Number, the value to replace positive infinity values with. If None, positive infinity values are replaced with the greatest finite value representable by input's dtype. Default is None.

  • neginf (optional, default: None ) –

    if a Number, the value to replace negative infinity values with. If None, negative infinity values are replaced with the lowest finite value representable by input's dtype. Default is None.

Source code in torchzero/modules/ops/unary.py
class NanToNum(Transform):
    """Convert `nan`, `inf` and `-inf` to numbers.

    Args:
        nan (optional): the value to replace NaNs with. Default is zero.
        posinf (optional): if a Number, the value to replace positive infinity values with.
            If None, positive infinity values are replaced with the greatest finite value
            representable by input's dtype. Default is None.
        neginf (optional): if a Number, the value to replace negative infinity values with.
            If None, negative infinity values are replaced with the lowest finite value
            representable by input's dtype. Default is None.
    """
    def __init__(self, nan=None, posinf=None, neginf=None, target: "Target" = 'update'):
        defaults = dict(nan=nan, posinf=posinf, neginf=neginf)
        super().__init__(defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        nan, posinf, neginf = unpack_dicts(settings, 'nan', 'posinf', 'neginf')
        return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]

Negate

Bases: torchzero.core.transform.Transform

Returns :code:- input

Source code in torchzero/modules/ops/unary.py
class Negate(Transform):
    """Returns :code:`- input`"""
    def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        torch._foreach_neg_(tensors)
        return tensors

Noop

Bases: torchzero.core.module.Module

Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.

Source code in torchzero/modules/ops/utility.py
class Identity(Module):
    """Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
    def __init__(self, *args, **kwargs): super().__init__()
    def step(self, var): return var
    def get_H(self, var):
        n = sum(p.numel() for p in var.params)
        p = var.params[0]
        return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)

Ones

Bases: torchzero.core.module.Module

Outputs ones

Source code in torchzero/modules/ops/utility.py
class Ones(Module):
    """Outputs ones"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def step(self, var):
        var.update = [torch.ones_like(p) for p in var.params]
        return var

Params

Bases: torchzero.core.module.Module

Outputs parameters

Source code in torchzero/modules/ops/utility.py
class Params(Module):
    """Outputs parameters"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def step(self, var):
        var.update = [p.clone() for p in var.params]
        return var

Pow

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Take tensors to the power of :code:exponent. :code:exponent can be a number or a module.

If :code:exponent is a module, this calculates :code:tensors ^ exponent(tensors)

Source code in torchzero/modules/ops/binary.py
class Pow(BinaryOperationBase):
    """Take tensors to the power of :code:`exponent`. :code:`exponent` can be a number or a module.

    If :code:`exponent` is a module, this calculates :code:`tensors ^ exponent(tensors)`
    """
    def __init__(self, exponent: Chainable | float):
        super().__init__({}, exponent=exponent)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
        torch._foreach_pow_(update, exponent)
        return update

PowModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates :code:input ** exponent. :code:input and :code:other can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class PowModules(MultiOperationBase):
    """Calculates :code:`input ** exponent`. :code:`input` and :code:`other` can be numbers or modules."""
    def __init__(self, input: Chainable | float, exponent: Chainable | float):
        defaults = {}
        super().__init__(defaults, input=input, exponent=exponent)

    @torch.no_grad
    def transform(self, var: Var, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
        if isinstance(input, (int,float)):
            assert isinstance(exponent, list)
            return input ** TensorList(exponent)

        torch._foreach_div_(input, exponent)
        return input

Prod

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs product of :code:inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class Prod(ReduceOperationBase):
    """Outputs product of :code:`inputs` that can be modules or numbers."""
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        prod = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_mul_(prod, v)

        return prod

RCopySign

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Returns :code:other(tensors) with sign copied from tensors.

Source code in torchzero/modules/ops/binary.py
class RCopySign(BinaryOperationBase):
    """Returns :code:`other(tensors)` with sign copied from tensors."""
    def __init__(self, other: Chainable):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
        return [o.copysign_(u) for u, o in zip(update, other)]

RDiv

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Divide :code:other by tensors. :code:other can be a number or a module.

If :code:other is a module, this calculates :code:other(tensors) / tensors

Source code in torchzero/modules/ops/binary.py
class RDiv(BinaryOperationBase):
    """Divide :code:`other` by tensors. :code:`other` can be a number or a module.

    If :code:`other` is a module, this calculates :code:`other(tensors) / tensors`
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        return other / TensorList(update)

RGraft

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs :code:magnitude(tensors) rescaled to have the same norm as tensors

Source code in torchzero/modules/ops/binary.py
class RGraft(BinaryOperationBase):
    """Outputs :code:`magnitude(tensors)` rescaled to have the same norm as tensors"""

    def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
        defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
        super().__init__(defaults, direction=direction)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tensor]):
        tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
        return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)

RPow

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Take :code:other to the power of tensors. :code:other can be a number or a module.

If :code:other is a module, this calculates :code:other(tensors) ^ tensors

Source code in torchzero/modules/ops/binary.py
class RPow(BinaryOperationBase):
    """Take :code:`other` to the power of tensors. :code:`other` can be a number or a module.

    If :code:`other` is a module, this calculates :code:`other(tensors) ^ tensors`
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        if isinstance(other, (int, float)): return torch._foreach_pow(other, update) # no in-place
        torch._foreach_pow_(other, update)
        return other

RSub

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Subtract tensors from :code:other. :code:other can be a number or a module.

If :code:other is a module, this calculates :code:other(tensors) - tensors

Source code in torchzero/modules/ops/binary.py
class RSub(BinaryOperationBase):
    """Subtract tensors from :code:`other`. :code:`other` can be a number or a module.

    If :code:`other` is a module, this calculates :code:`other(tensors) - tensors`
    """
    def __init__(self, other: Chainable | float):
        super().__init__({}, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        return other - TensorList(update)

Randn

Bases: torchzero.core.module.Module

Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1.

Source code in torchzero/modules/ops/utility.py
class Randn(Module):
    """Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1."""
    def __init__(self):
        super().__init__({})

    @torch.no_grad
    def step(self, var):
        var.update = [torch.randn_like(p) for p in var.params]
        return var

RandomSample

Bases: torchzero.core.module.Module

Outputs tensors filled with random numbers from distribution depending on value of :code:distribution.

Source code in torchzero/modules/ops/utility.py
class RandomSample(Module):
    """Outputs tensors filled with random numbers from distribution depending on value of :code:`distribution`."""
    def __init__(self, distribution: Distributions = 'normal', variance:float | None = None):
        defaults = dict(distribution=distribution, variance=variance)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        distribution = self.defaults['distribution']
        variance = self.get_settings(var.params, 'variance')
        var.update = TensorList(var.params).sample_like(distribution=distribution, variance=variance)
        return var

Reciprocal

Bases: torchzero.core.transform.Transform

Returns :code:1 / input

Source code in torchzero/modules/ops/unary.py
class Reciprocal(Transform):
    """Returns :code:`1 / input`"""
    def __init__(self, eps = 0, target: "Target" = 'update'):
        defaults = dict(eps = eps)
        super().__init__(defaults, uses_grad=False, target=target)
    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        eps = [s['eps'] for s in settings]
        if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
        torch._foreach_reciprocal_(tensors)
        return tensors

ReduceOperationBase

Bases: torchzero.core.module.Module, abc.ABC

Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override transform method to use it.

Methods:

  • transform

    applies the operation to operands

Source code in torchzero/modules/ops/reduce.py
class ReduceOperationBase(Module, ABC):
    """Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override `transform` method to use it."""
    def __init__(self, defaults: dict[str, Any] | None, *operands: Chainable | Any):
        super().__init__(defaults=defaults)

        self.operands = []
        for i, v in enumerate(operands):

            if isinstance(v, (Module, Sequence)):
                self.set_child(f'operand_{i}', v)
                self.operands.append(self.children[f'operand_{i}'])
            else:
                self.operands.append(v)

        if not self.children:
            raise ValueError('At least one operand must be a module')

    @abstractmethod
    def transform(self, var: Var, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
        """applies the operation to operands"""
        raise NotImplementedError

    @torch.no_grad
    def step(self, var: Var) -> Var:
        # pass cloned update to all module operands
        processed_operands: list[Any | list[torch.Tensor]] = self.operands.copy()

        for i, v in enumerate(self.operands):
            if f'operand_{i}' in self.children:
                v: Module
                updated_var = v.step(var.clone(clone_update=True))
                processed_operands[i] = updated_var.get_update()
                var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them

        transformed = self.transform(var, *processed_operands)
        var.update = transformed
        return var

transform

transform(var: Var, *operands: Any | list[Tensor]) -> list[Tensor]

applies the operation to operands

Source code in torchzero/modules/ops/reduce.py
@abstractmethod
def transform(self, var: Var, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
    """applies the operation to operands"""
    raise NotImplementedError

Sign

Bases: torchzero.core.transform.Transform

Returns :code:sign(input)

Source code in torchzero/modules/ops/unary.py
class Sign(Transform):
    """Returns :code:`sign(input)`"""
    def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        torch._foreach_sign_(tensors)
        return tensors

Sqrt

Bases: torchzero.core.transform.Transform

Returns :code:sqrt(input)

Source code in torchzero/modules/ops/unary.py
class Sqrt(Transform):
    """Returns :code:`sqrt(input)`"""
    def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        torch._foreach_sqrt_(tensors)
        return tensors

SqrtEMASquared

Bases: torchzero.core.transform.Transform

Maintains an exponential moving average of squared updates, outputs optionally debiased square root.

Parameters:

  • beta (float, default: 0.999 ) –

    momentum value. Defaults to 0.999.

  • amsgrad (bool, default: False ) –

    whether to maintain maximum of the exponential moving average. Defaults to False.

  • debiased (bool, default: False ) –

    whether to multiply the output by a debiasing term from the Adam method. Defaults to False.

  • pow (float, default: 2 ) –

    power, absolute value is always used. Defaults to 2.

Methods:

  • SQRT_EMA_SQ_FN

    Updates exp_avg_sq_ with EMA of squared tensors and calculates it's square root,

Source code in torchzero/modules/ops/higher_level.py
class SqrtEMASquared(Transform):
    """Maintains an exponential moving average of squared updates, outputs optionally debiased square root.

    Args:
        beta (float, optional): momentum value. Defaults to 0.999.
        amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
        debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
        pow (float, optional): power, absolute value is always used. Defaults to 2.
    """
    SQRT_EMA_SQ_FN: staticmethod = staticmethod(sqrt_ema_sq_)
    def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2,):
        defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
        super().__init__(defaults, uses_grad=False)


    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        step = self.global_state['step'] = self.global_state.get('step', 0) + 1

        amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
        beta = NumberList(s['beta'] for s in settings)

        if amsgrad:
            exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
        else:
            exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
            max_exp_avg_sq = None

        return self.SQRT_EMA_SQ_FN(
            TensorList(tensors),
            exp_avg_sq_=exp_avg_sq,
            beta=beta,
            max_exp_avg_sq_=max_exp_avg_sq,
            debiased=debiased,
            step=step,
            pow=pow,
        )

SQRT_EMA_SQ_FN

SQRT_EMA_SQ_FN(tensors: TensorList, exp_avg_sq_: TensorList, beta: float | NumberList, max_exp_avg_sq_: TensorList | None, debiased: bool, step: int, pow: float = 2, ema_sq_fn: Callable = ema_sq_)

Updates exp_avg_sq_ with EMA of squared tensors and calculates it's square root, with optional AMSGrad and debiasing.

Returns new tensors.

Source code in torchzero/modules/functional.py
def sqrt_ema_sq_(
    tensors: TensorList,
    exp_avg_sq_: TensorList,
    beta: float | NumberList,
    max_exp_avg_sq_: TensorList | None,
    debiased: bool,
    step: int,
    pow: float = 2,
    ema_sq_fn: Callable = ema_sq_,
):
    """
    Updates `exp_avg_sq_` with EMA of squared `tensors` and calculates it's square root,
    with optional AMSGrad and debiasing.

    Returns new tensors.
    """
    exp_avg_sq_=ema_sq_fn(
        tensors=tensors,
        exp_avg_sq_=exp_avg_sq_,
        beta=beta,
        max_exp_avg_sq_=max_exp_avg_sq_,
        pow=pow,
    )

    sqrt_exp_avg_sq = root(exp_avg_sq_, pow, inplace=False)

    if debiased: sqrt_exp_avg_sq = debias_second_momentum(sqrt_exp_avg_sq, step=step, beta=beta, pow=pow, inplace=True)
    return sqrt_exp_avg_sq

Sub

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Subtract :code:other from tensors. :code:other can be a number or a module.

If :code:other is a module, this calculates :code:tensors - other(tensors)

Source code in torchzero/modules/ops/binary.py
class Sub(BinaryOperationBase):
    """Subtract :code:`other` from tensors. :code:`other` can be a number or a module.

    If :code:`other` is a module, this calculates :code:`tensors - other(tensors)`
    """
    def __init__(self, other: Chainable | float, alpha: float = 1):
        defaults = dict(alpha=alpha)
        super().__init__(defaults, other=other)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
        if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.defaults['alpha'])
        else: torch._foreach_sub_(update, other, alpha=self.defaults['alpha'])
        return update

SubModules

Bases: torchzero.modules.ops.multi.MultiOperationBase

Calculates :code:input - other. :code:input and :code:other can be numbers or modules.

Source code in torchzero/modules/ops/multi.py
class SubModules(MultiOperationBase):
    """Calculates :code:`input - other`. :code:`input` and :code:`other` can be numbers or modules."""
    def __init__(self, input: Chainable | float, other: Chainable | float, alpha: float = 1):
        defaults = dict(alpha=alpha)
        super().__init__(defaults, input=input, other=other)

    @torch.no_grad
    def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
        alpha = self.defaults['alpha']

        if isinstance(input, (int,float)):
            assert isinstance(other, list)
            return input - TensorList(other).mul_(alpha)

        if isinstance(other, (int, float)): torch._foreach_sub_(input, other * alpha)
        else: torch._foreach_sub_(input, other, alpha=alpha)
        return input

Sum

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Outputs sum of :code:inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class Sum(ReduceOperationBase):
    """Outputs sum of :code:`inputs` that can be modules or numbers."""
    USE_MEAN = False
    def __init__(self, *inputs: Chainable | float):
        super().__init__({}, *inputs)

    @torch.no_grad
    def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        sum = cast(list, sorted_inputs[0])
        if len(sorted_inputs) > 1:
            for v in sorted_inputs[1:]:
                torch._foreach_add_(sum, v)

        if self.USE_MEAN and len(sorted_inputs) > 1: torch._foreach_div_(sum, len(sorted_inputs))
        return sum

USE_MEAN class-attribute

USE_MEAN = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

Threshold

Bases: torchzero.modules.ops.binary.BinaryOperationBase

Outputs tensors thresholded such that values above :code:threshold are set to :code:value.

Source code in torchzero/modules/ops/binary.py
class Threshold(BinaryOperationBase):
    """Outputs tensors thresholded such that values above :code:`threshold` are set to :code:`value`."""
    def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
        defaults = dict(update_above=update_above)
        super().__init__(defaults, threshold=threshold, value=value)

    @torch.no_grad
    def transform(self, var, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
        update_above = self.defaults['update_above']
        update = TensorList(update)
        if update_above:
            if isinstance(value, list): return update.where_(update>threshold, value)
            return update.masked_fill_(update<=threshold, value)

        if isinstance(value, list): return update.where_(update<threshold, value)
        return update.masked_fill_(update>=threshold, value)

UnaryLambda

Bases: torchzero.core.transform.Transform

Applies :code:fn to input tensors.

:code:fn must accept and return a list of tensors.

Source code in torchzero/modules/ops/unary.py
class UnaryLambda(Transform):
    """Applies :code:`fn` to input tensors.

    :code:`fn` must accept and return a list of tensors.
    """
    def __init__(self, fn, target: "Target" = 'update'):
        defaults = dict(fn=fn)
        super().__init__(defaults=defaults, uses_grad=False, target=target)

    @torch.no_grad
    def apply_tensors(self, tensors, params, grads, loss, states, settings):
        return settings[0]['fn'](tensors)

UnaryParameterwiseLambda

Bases: torchzero.core.transform.TensorwiseTransform

Applies :code:fn to each input tensor.

:code:fn must accept and return a tensor.

Source code in torchzero/modules/ops/unary.py
class UnaryParameterwiseLambda(TensorwiseTransform):
    """Applies :code:`fn` to each input tensor.

    :code:`fn` must accept and return a tensor.
    """
    def __init__(self, fn, target: "Target" = 'update'):
        defaults = dict(fn=fn)
        super().__init__(uses_grad=False, defaults=defaults, target=target)

    @torch.no_grad
    def apply_tensor(self, tensor, param, grad, loss, state, setting):
        return setting['fn'](tensor)

Uniform

Bases: torchzero.core.module.Module

Outputs tensors filled with random numbers from uniform distribution between :code:low and :code:high.

Source code in torchzero/modules/ops/utility.py
class Uniform(Module):
    """Outputs tensors filled with random numbers from uniform distribution between :code:`low` and :code:`high`."""
    def __init__(self, low: float, high: float):
        defaults = dict(low=low, high=high)
        super().__init__(defaults)

    @torch.no_grad
    def step(self, var):
        low,high = self.get_settings(var.params, 'low','high')
        var.update = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(var.params, low, high)]
        return var

UpdateToNone

Bases: torchzero.core.module.Module

Sets :code:update attribute to None on :code:var.

Source code in torchzero/modules/ops/utility.py
class UpdateToNone(Module):
    """Sets :code:`update` attribute to None on :code:`var`."""
    def __init__(self): super().__init__()
    def step(self, var):
        var.update = None
        return var

WeightedMean

Bases: torchzero.modules.ops.reduce.WeightedSum

Outputs weighted mean of :code:inputs that can be modules or numbers.

Source code in torchzero/modules/ops/reduce.py
class WeightedMean(WeightedSum):
    """Outputs weighted mean of :code:`inputs` that can be modules or numbers."""
    USE_MEAN = True

USE_MEAN class-attribute

USE_MEAN = True

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

WeightedSum

Bases: torchzero.modules.ops.reduce.ReduceOperationBase

Source code in torchzero/modules/ops/reduce.py
class WeightedSum(ReduceOperationBase):
    USE_MEAN = False
    def __init__(self, *inputs: Chainable | float, weights: Iterable[float]):
        """Outputs a weighted sum of :code:`inputs` that can be modules or numbers."""
        weights = list(weights)
        if len(inputs) != len(weights):
            raise ValueError(f'Number of inputs {len(inputs)} must match number of weights {len(weights)}')
        defaults = dict(weights=weights)
        super().__init__(defaults=defaults, *inputs)

    @torch.no_grad
    def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
        sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
        weights = self.defaults['weights']
        sum = cast(list, sorted_inputs[0])
        torch._foreach_mul_(sum, weights[0])
        if len(sorted_inputs) > 1:
            for v, w in zip(sorted_inputs[1:], weights[1:]):
                if isinstance(v, (int, float)): torch._foreach_add_(sum, v*w)
                else: torch._foreach_add_(sum, v, alpha=w)

        if self.USE_MEAN and len(sorted_inputs) > 1: torch._foreach_div_(sum, len(sorted_inputs))
        return sum

USE_MEAN class-attribute

USE_MEAN = False

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

Zeros

Bases: torchzero.core.module.Module

Outputs zeros

Source code in torchzero/modules/ops/utility.py
class Zeros(Module):
    """Outputs zeros"""
    def __init__(self):
        super().__init__({})
    @torch.no_grad
    def step(self, var):
        var.update = [torch.zeros_like(p) for p in var.params]
        return var