Operations¶
This subpackage contains operations like adding modules, subtracting, grafting, tracking the maximum, etc.
Classes:
-
Abs–Returns
abs(input) -
AccumulateMaximum–Accumulates maximum of all past updates.
-
AccumulateMean–Accumulates mean of all past updates.
-
AccumulateMinimum–Accumulates minimum of all past updates.
-
AccumulateProduct–Accumulates product of all past updates.
-
AccumulateSum–Accumulates sum of all past updates.
-
Add–Add
otherto tensors.othercan be a number or a module. -
BinaryOperationBase–Base class for operations that use update as the first operand. This is an abstract class, subclass it and override
transformmethod to use it. -
CenteredEMASquared–Maintains a centered exponential moving average of squared updates. This also maintains an additional
-
CenteredSqrtEMASquared–Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.
-
Clip–clip tensors to be in
(min, max)range.minand`max: can be None, numbers or modules. -
ClipModules–Calculates
input(tensors).clip(min, max).minandmaxcan be numbers or modules. -
Clone–Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations
-
CopyMagnitude–Returns
other(tensors)with sign copied from tensors. -
CopySign–Returns tensors with sign copied from
other(tensors). -
CustomUnaryOperation–Applies
getattr(tensor, name)to each tensor -
Debias–Multiplies the update by an Adam debiasing term based first and/or second momentum.
-
Debias2–Multiplies the update by an Adam debiasing term based on the second momentum.
-
Div–Divide tensors by
other.othercan be a number or a module. -
DivModules–Calculates
input / other.inputandothercan be numbers or modules. -
EMASquared–Maintains an exponential moving average of squared updates.
-
Exp–Returns
exp(input) -
Fill–Outputs tensors filled with
value -
Grad–Outputs the gradient
-
GradToNone–Sets
gradattribute to None onobjective. -
Graft–Outputs
directionoutput rescaled to have the same norm asmagnitudeoutput. -
GraftInputToOutput–Outputs
tensorsrescaled to have the same norm asmagnitude(tensors). -
GraftOutputToInput–Outputs
magnitude(tensors)rescaled to have the same norm astensors -
GramSchimdt–outputs tensors made orthogonal to
other(tensors)via Gram-Schmidt. -
Identity–Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.
-
LerpModules–Does a linear interpolation of
input(tensors)andend(tensors)based on a scalarweight. -
Maximum–Outputs
maximum(tensors, other(tensors)) -
MaximumModules–Outputs elementwise maximum of
inputsthat can be modules or numbers. -
Mean–Outputs a mean of
inputsthat can be modules or numbers. -
Minimum–Outputs
minimum(tensors, other(tensors)) -
MinimumModules–Outputs elementwise minimum of
inputsthat can be modules or numbers. -
Mul–Multiply tensors by
other.othercan be a number or a module. -
MultiOperationBase–Base class for operations that use operands. This is an abstract class, subclass it and override
transformmethod to use it. -
NanToNum–Convert
nan,infand-inf`` to numbers. -
Negate–Returns
- input -
Noop–Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.
-
Ones–Outputs ones
-
Params–Outputs parameters
-
Pow–Take tensors to the power of
exponent.exponentcan be a number or a module. -
PowModules–Calculates
input ** exponent.inputandothercan be numbers or modules. -
Prod–Outputs product of
inputsthat can be modules or numbers. -
RCopySign–Returns
other(tensors)with sign copied from tensors. -
RDiv–Divide
otherby tensors.othercan be a number or a module. -
RPow–Take
otherto the power of tensors.othercan be a number or a module. -
RSub–Subtract tensors from
other.othercan be a number or a module. -
Randn–Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1.
-
RandomSample–Outputs tensors filled with random numbers from distribution depending on value of
distribution. -
Reciprocal–Returns
1 / input -
ReduceOperationBase–Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override
transformmethod to use it. -
Sign–Returns
sign(input) -
Sqrt–Returns
sqrt(input) -
SqrtEMASquared–Maintains an exponential moving average of squared updates, outputs optionally debiased square root.
-
Sub–Subtract
otherfrom tensors.othercan be a number or a module. -
SubModules–Calculates
input - other.inputandothercan be numbers or modules. -
Sum–Outputs sum of
inputsthat can be modules or numbers. -
Threshold–Outputs tensors thresholded such that values above
thresholdare set tovalue. -
UnaryLambda–Applies
fnto input tensors. -
UnaryParameterwiseLambda–Applies
fnto each input tensor. -
Uniform–Outputs tensors filled with random numbers from uniform distribution between
lowandhigh. -
UpdateToNone–Sets
updateattribute to None onvar. -
WeightedMean–Outputs weighted mean of
inputsthat can be modules or numbers. -
WeightedSum–Outputs a weighted sum of
inputsthat can be modules or numbers. -
Zeros–Outputs zeros
Abs ¶
Bases: torchzero.core.transform.TensorTransform
Returns abs(input)
Source code in torchzero/modules/ops/unary.py
AccumulateMaximum ¶
Bases: torchzero.core.transform.TensorTransform
Accumulates maximum of all past updates.
Parameters:
-
decay(float, default:0) –decays the accumulator. Defaults to 0.
-
target(Target) –target. Defaults to 'update'.
Source code in torchzero/modules/ops/accumulate.py
AccumulateMean ¶
Bases: torchzero.core.transform.TensorTransform
Accumulates mean of all past updates.
Parameters:
-
decay(float, default:0) –decays the accumulator. Defaults to 0.
-
target(Target) –target. Defaults to 'update'.
Source code in torchzero/modules/ops/accumulate.py
AccumulateMinimum ¶
Bases: torchzero.core.transform.TensorTransform
Accumulates minimum of all past updates.
Parameters:
-
decay(float, default:0) –decays the accumulator. Defaults to 0.
-
target(Target) –target. Defaults to 'update'.
Source code in torchzero/modules/ops/accumulate.py
AccumulateProduct ¶
Bases: torchzero.core.transform.TensorTransform
Accumulates product of all past updates.
Parameters:
-
decay(float, default:0) –decays the accumulator. Defaults to 0.
-
target(Target, default:'update') –target. Defaults to 'update'.
Source code in torchzero/modules/ops/accumulate.py
AccumulateSum ¶
Bases: torchzero.core.transform.TensorTransform
Accumulates sum of all past updates.
Parameters:
-
decay(float, default:0) –decays the accumulator. Defaults to 0.
-
target(Target) –target. Defaults to 'update'.
Source code in torchzero/modules/ops/accumulate.py
Add ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Add other to tensors. other can be a number or a module.
If other is a module, this calculates tensors + other(tensors)
Source code in torchzero/modules/ops/binary.py
BinaryOperationBase ¶
Bases: torchzero.core.module.Module, abc.ABC
Base class for operations that use update as the first operand. This is an abstract class, subclass it and override transform method to use it.
Methods:
-
transform–applies the operation to operands
Source code in torchzero/modules/ops/binary.py
transform ¶
transform(objective: Objective, update: list[Tensor], **operands: Any | list[Tensor]) -> Iterable[Tensor]
applies the operation to operands
CenteredEMASquared ¶
Bases: torchzero.core.transform.TensorTransform
Maintains a centered exponential moving average of squared updates. This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.
Parameters:
-
beta(float, default:0.99) –momentum value. Defaults to 0.999.
-
amsgrad(bool, default:False) –whether to maintain maximum of the exponential moving average. Defaults to False.
-
pow(float, default:2) –power, absolute value is always used. Defaults to 2.
Source code in torchzero/modules/ops/higher_level.py
CenteredSqrtEMASquared ¶
Bases: torchzero.core.transform.TensorTransform
Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root. This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.
Parameters:
-
beta(float, default:0.99) –momentum value. Defaults to 0.999.
-
amsgrad(bool, default:False) –whether to maintain maximum of the exponential moving average. Defaults to False.
-
debiased(bool, default:False) –whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
-
pow(float, default:2) –power, absolute value is always used. Defaults to 2.
Source code in torchzero/modules/ops/higher_level.py
Clip ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
clip tensors to be in (min, max) range. min and `max: can be None, numbers or modules.
If min and max are modules, this calculates tensors.clip(min(tensors), max(tensors)).
Source code in torchzero/modules/ops/binary.py
ClipModules ¶
Bases: torchzero.modules.ops.multi.MultiOperationBase
Calculates input(tensors).clip(min, max). min and max can be numbers or modules.
Source code in torchzero/modules/ops/multi.py
Clone ¶
Bases: torchzero.core.module.Module
Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations
Source code in torchzero/modules/ops/utility.py
CopyMagnitude ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Returns other(tensors) with sign copied from tensors.
Source code in torchzero/modules/ops/binary.py
CopySign ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Returns tensors with sign copied from other(tensors).
Source code in torchzero/modules/ops/binary.py
CustomUnaryOperation ¶
Bases: torchzero.core.transform.TensorTransform
Applies getattr(tensor, name) to each tensor
Source code in torchzero/modules/ops/unary.py
Debias ¶
Bases: torchzero.core.transform.TensorTransform
Multiplies the update by an Adam debiasing term based first and/or second momentum.
Parameters:
-
beta1(float | None, default:None) –first momentum, should be the same as first momentum used in modules before. Defaults to None.
-
beta2(float | None, default:None) –second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
-
alpha(float, default:1) –learning rate. Defaults to 1.
-
pow(float, default:2) –power, assumes absolute value is used. Defaults to 2.
-
target(Target) –target. Defaults to 'update'.
Source code in torchzero/modules/ops/higher_level.py
Debias2 ¶
Bases: torchzero.core.transform.TensorTransform
Multiplies the update by an Adam debiasing term based on the second momentum.
Parameters:
-
beta(float | None, default:0.999) –second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
-
pow(float, default:2) –power, assumes absolute value is used. Defaults to 2.
-
target(Target) –target. Defaults to 'update'.
Source code in torchzero/modules/ops/higher_level.py
Div ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Divide tensors by other. other can be a number or a module.
If other is a module, this calculates tensors / other(tensors)
Source code in torchzero/modules/ops/binary.py
DivModules ¶
Bases: torchzero.modules.ops.multi.MultiOperationBase
Calculates input / other. input and other can be numbers or modules.
Source code in torchzero/modules/ops/multi.py
EMASquared ¶
Bases: torchzero.core.transform.TensorTransform
Maintains an exponential moving average of squared updates.
Parameters:
-
beta(float, default:0.999) –momentum value. Defaults to 0.999.
-
amsgrad(bool, default:False) –whether to maintain maximum of the exponential moving average. Defaults to False.
-
pow(float, default:2) –power, absolute value is always used. Defaults to 2.
Methods:
-
EMA_SQ_FN–Updates
exp_avg_sq_with EMA of squaredtensors, ifmax_exp_avg_sq_is not None, updates it with maximum of EMA.
Source code in torchzero/modules/ops/higher_level.py
EMA_SQ_FN ¶
EMA_SQ_FN(tensors: TensorList, exp_avg_sq_: TensorList, beta: float | NumberList, max_exp_avg_sq_: TensorList | None, pow: float = 2)
Updates exp_avg_sq_ with EMA of squared tensors, if max_exp_avg_sq_ is not None, updates it with maximum of EMA.
Returns exp_avg_sq_ or max_exp_avg_sq_.
Source code in torchzero/modules/opt_utils.py
Exp ¶
Bases: torchzero.core.transform.TensorTransform
Returns exp(input)
Source code in torchzero/modules/ops/unary.py
Fill ¶
Bases: torchzero.core.module.Module
Outputs tensors filled with value
Source code in torchzero/modules/ops/utility.py
Grad ¶
Bases: torchzero.core.module.Module
Outputs the gradient
Source code in torchzero/modules/ops/utility.py
GradToNone ¶
Bases: torchzero.core.module.Module
Sets grad attribute to None on objective.
Source code in torchzero/modules/ops/utility.py
Graft ¶
Bases: torchzero.modules.ops.multi.MultiOperationBase
Outputs direction output rescaled to have the same norm as magnitude output.
Parameters:
-
direction(Chainable) –module to use the direction from
-
magnitude(Chainable) –module to use the magnitude from
-
tensorwise(bool, default:True) –whether to calculate norm per-tensor or globally. Defaults to True.
-
ord(float, default:2) –norm order. Defaults to 2.
-
eps(float, default:1e-06) –clips denominator to be no less than this value. Defaults to 1e-6.
-
strength(float, default:1) –strength of grafting. Defaults to 1.
Example:¶
Shampoo grafted to Adam
opt = tz.Optimizer(
model.parameters(),
tz.m.GraftModules(
direction = tz.m.Shampoo(),
magnitude = tz.m.Adam(),
),
tz.m.LR(1e-3)
)
Source code in torchzero/modules/ops/multi.py
GraftInputToOutput ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Outputs tensors rescaled to have the same norm as magnitude(tensors).
Source code in torchzero/modules/ops/binary.py
GraftOutputToInput ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Outputs magnitude(tensors) rescaled to have the same norm as tensors
Source code in torchzero/modules/ops/binary.py
GramSchimdt ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
outputs tensors made orthogonal to other(tensors) via Gram-Schmidt.
Source code in torchzero/modules/ops/binary.py
Identity ¶
Bases: torchzero.core.module.Module
Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.
Source code in torchzero/modules/ops/utility.py
LerpModules ¶
Bases: torchzero.modules.ops.multi.MultiOperationBase
Does a linear interpolation of input(tensors) and end(tensors) based on a scalar weight.
The output is given by output = input(tensors) + weight * (end(tensors) - input(tensors))
Source code in torchzero/modules/ops/multi.py
Maximum ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Outputs maximum(tensors, other(tensors))
Source code in torchzero/modules/ops/binary.py
MaximumModules ¶
Bases: torchzero.modules.ops.reduce.ReduceOperationBase
Outputs elementwise maximum of inputs that can be modules or numbers.
Source code in torchzero/modules/ops/reduce.py
Mean ¶
Bases: torchzero.modules.ops.reduce.Sum
Outputs a mean of inputs that can be modules or numbers.
Source code in torchzero/modules/ops/reduce.py
USE_MEAN
class-attribute
¶
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
Minimum ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Outputs minimum(tensors, other(tensors))
Source code in torchzero/modules/ops/binary.py
MinimumModules ¶
Bases: torchzero.modules.ops.reduce.ReduceOperationBase
Outputs elementwise minimum of inputs that can be modules or numbers.
Source code in torchzero/modules/ops/reduce.py
Mul ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Multiply tensors by other. other can be a number or a module.
If other is a module, this calculates tensors * other(tensors)
Source code in torchzero/modules/ops/binary.py
MultiOperationBase ¶
Bases: torchzero.core.module.Module, abc.ABC
Base class for operations that use operands. This is an abstract class, subclass it and override transform method to use it.
Methods:
-
transform–applies the operation to operands
Source code in torchzero/modules/ops/multi.py
transform ¶
NanToNum ¶
Bases: torchzero.core.transform.TensorTransform
Convert nan, inf and -inf`` to numbers.
Parameters:
-
nan(optional, default:None) –the value to replace NaNs with. Default is zero.
-
posinf(optional, default:None) –if a Number, the value to replace positive infinity values with. If None, positive infinity values are replaced with the greatest finite value representable by input's dtype. Default is None.
-
neginf(optional, default:None) –if a Number, the value to replace negative infinity values with. If None, negative infinity values are replaced with the lowest finite value representable by input's dtype. Default is None.
Source code in torchzero/modules/ops/unary.py
Negate ¶
Bases: torchzero.core.transform.TensorTransform
Returns - input
Source code in torchzero/modules/ops/unary.py
Noop ¶
Bases: torchzero.core.module.Module
Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods.
Source code in torchzero/modules/ops/utility.py
Ones ¶
Bases: torchzero.core.module.Module
Outputs ones
Source code in torchzero/modules/ops/utility.py
Params ¶
Bases: torchzero.core.module.Module
Outputs parameters
Source code in torchzero/modules/ops/utility.py
Pow ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Take tensors to the power of exponent. exponent can be a number or a module.
If exponent is a module, this calculates tensors ^ exponent(tensors)
Source code in torchzero/modules/ops/binary.py
PowModules ¶
Bases: torchzero.modules.ops.multi.MultiOperationBase
Calculates input ** exponent. input and other can be numbers or modules.
Source code in torchzero/modules/ops/multi.py
Prod ¶
Bases: torchzero.modules.ops.reduce.ReduceOperationBase
Outputs product of inputs that can be modules or numbers.
Source code in torchzero/modules/ops/reduce.py
RCopySign ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Returns other(tensors) with sign copied from tensors.
Source code in torchzero/modules/ops/binary.py
RDiv ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Divide other by tensors. other can be a number or a module.
If other is a module, this calculates other(tensors) / tensors
Source code in torchzero/modules/ops/binary.py
RPow ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Take other to the power of tensors. other can be a number or a module.
If other is a module, this calculates other(tensors) ^ tensors
Source code in torchzero/modules/ops/binary.py
RSub ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Subtract tensors from other. other can be a number or a module.
If other is a module, this calculates other(tensors) - tensors
Source code in torchzero/modules/ops/binary.py
Randn ¶
Bases: torchzero.core.module.Module
Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1.
Source code in torchzero/modules/ops/utility.py
RandomSample ¶
Bases: torchzero.core.module.Module
Outputs tensors filled with random numbers from distribution depending on value of distribution.
Source code in torchzero/modules/ops/utility.py
Reciprocal ¶
Bases: torchzero.core.transform.TensorTransform
Returns 1 / input
Source code in torchzero/modules/ops/unary.py
ReduceOperationBase ¶
Bases: torchzero.core.module.Module, abc.ABC
Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override transform method to use it.
Methods:
-
transform–applies the operation to operands
Source code in torchzero/modules/ops/reduce.py
transform ¶
Sign ¶
Bases: torchzero.core.transform.TensorTransform
Returns sign(input)
Source code in torchzero/modules/ops/unary.py
Sqrt ¶
Bases: torchzero.core.transform.TensorTransform
Returns sqrt(input)
Source code in torchzero/modules/ops/unary.py
SqrtEMASquared ¶
Bases: torchzero.core.transform.TensorTransform
Maintains an exponential moving average of squared updates, outputs optionally debiased square root.
Parameters:
-
beta(float, default:0.999) –momentum value. Defaults to 0.999.
-
amsgrad(bool, default:False) –whether to maintain maximum of the exponential moving average. Defaults to False.
-
debiased(bool, default:False) –whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
-
pow(float, default:2) –power, absolute value is always used. Defaults to 2.
Methods:
-
SQRT_EMA_SQ_FN–Updates
exp_avg_sq_with EMA of squaredtensorsand calculates it's square root,
Source code in torchzero/modules/ops/higher_level.py
SQRT_EMA_SQ_FN ¶
SQRT_EMA_SQ_FN(tensors: TensorList, exp_avg_sq_: TensorList, beta: float | NumberList, max_exp_avg_sq_: TensorList | None, debiased: bool, step: int, pow: float = 2, ema_sq_fn: Callable = ema_sq_)
Updates exp_avg_sq_ with EMA of squared tensors and calculates it's square root,
with optional AMSGrad and debiasing.
Returns new tensors.
Source code in torchzero/modules/opt_utils.py
Sub ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Subtract other from tensors. other can be a number or a module.
If other is a module, this calculates :code:tensors - other(tensors)
Source code in torchzero/modules/ops/binary.py
SubModules ¶
Bases: torchzero.modules.ops.multi.MultiOperationBase
Calculates input - other. input and other can be numbers or modules.
Source code in torchzero/modules/ops/multi.py
Sum ¶
Bases: torchzero.modules.ops.reduce.ReduceOperationBase
Outputs sum of inputs that can be modules or numbers.
Source code in torchzero/modules/ops/reduce.py
USE_MEAN
class-attribute
¶
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
Threshold ¶
Bases: torchzero.modules.ops.binary.BinaryOperationBase
Outputs tensors thresholded such that values above threshold are set to value.
Source code in torchzero/modules/ops/binary.py
UnaryLambda ¶
Bases: torchzero.core.transform.TensorTransform
Applies fn to input tensors.
fn must accept and return a list of tensors.
Source code in torchzero/modules/ops/unary.py
UnaryParameterwiseLambda ¶
Bases: torchzero.core.transform.TensorTransform
Applies fn to each input tensor.
fn must accept and return a tensor.
Source code in torchzero/modules/ops/unary.py
Uniform ¶
Bases: torchzero.core.module.Module
Outputs tensors filled with random numbers from uniform distribution between low and high.
Source code in torchzero/modules/ops/utility.py
UpdateToNone ¶
Bases: torchzero.core.module.Module
Sets update attribute to None on var.
Source code in torchzero/modules/ops/utility.py
WeightedMean ¶
Bases: torchzero.modules.ops.reduce.WeightedSum
Outputs weighted mean of inputs that can be modules or numbers.
Source code in torchzero/modules/ops/reduce.py
USE_MEAN
class-attribute
¶
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
WeightedSum ¶
Bases: torchzero.modules.ops.reduce.ReduceOperationBase
Outputs a weighted sum of inputs that can be modules or numbers.
Source code in torchzero/modules/ops/reduce.py
USE_MEAN
class-attribute
¶
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
Zeros ¶
Bases: torchzero.core.module.Module
Outputs zeros