Momentum¶
This subpackage contains momentums and exponential moving averages.
Classes:
-
Averaging
–Average of past
history_size
updates. -
Cautious
–Negates update for parameters where update and gradient sign is inconsistent.
-
EMA
–Maintains an exponential moving average of update.
-
HeavyBall
–Polyak's momentum (heavy-ball method).
-
IntermoduleCautious
–Negaties update on :code:
main
module where it's sign doesn't match with output of :code:compare
module. -
MedianAveraging
–Median of past
history_size
updates. -
NAG
–Nesterov accelerated gradient method (nesterov momentum).
-
ScaleByGradCosineSimilarity
–Multiplies the update by cosine similarity with gradient.
-
ScaleModulesByCosineSimilarity
–Scales the output of :code:
main
module by it's cosine similarity to the output -
UpdateGradientSignConsistency
–Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.
-
WeightedAveraging
–Weighted average of past
len(weights)
updates.
Averaging ¶
Bases: torchzero.core.transform.TensorwiseTransform
Average of past history_size
updates.
Parameters:
-
history_size
(int
) –Number of past updates to average
-
target
(Literal
, default:'update'
) –target. Defaults to 'update'.
Source code in torchzero/modules/momentum/averaging.py
Cautious ¶
Bases: torchzero.core.transform.Transform
Negates update for parameters where update and gradient sign is inconsistent. Optionally normalizes the update by the number of parameters that are not masked. This is meant to be used after any momentum-based modules.
Parameters:
-
normalize
(bool
, default:False
) –renormalize update after masking. only has effect when mode is 'zero'. Defaults to False.
-
eps
(float
, default:1e-06
) –epsilon for normalization. Defaults to 1e-6.
-
mode
(str
, default:'zero'
) –what to do with updates with inconsistent signs. - "zero" - set them to zero (as in paper) - "grad" - set them to the gradient (same as using update magnitude and gradient sign) - "backtrack" - negate them
Examples:¶
Cautious Adam
References
Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu
Source code in torchzero/modules/momentum/cautious.py
EMA ¶
Bases: torchzero.core.transform.Transform
Maintains an exponential moving average of update.
Parameters:
-
momentum
(float
, default:0.9
) –momentum (beta). Defaults to 0.9.
-
dampening
(float
, default:0
) –momentum dampening. Defaults to 0.
-
debiased
(bool
, default:False
) –whether to debias the EMA like in Adam. Defaults to False.
-
lerp
(bool
, default:True
) –whether to use linear interpolation. Defaults to True.
-
ema_init
(str
, default:'zeros'
) –initial values for the EMA, "zeros" or "update".
-
target
(Literal
, default:'update'
) –target to apply EMA to. Defaults to 'update'.
Source code in torchzero/modules/momentum/momentum.py
HeavyBall ¶
Bases: torchzero.modules.momentum.momentum.EMA
Polyak's momentum (heavy-ball method).
Parameters:
-
momentum
(float
, default:0.9
) –momentum (beta). Defaults to 0.9.
-
dampening
(float
, default:0
) –momentum dampening. Defaults to 0.
-
debiased
(bool
, default:False
) –whether to debias the EMA like in Adam. Defaults to False.
-
lerp
(bool
, default:False
) –whether to use linear interpolation, if True, this becomes exponential moving average. Defaults to False.
-
ema_init
(str
, default:'update'
) –initial values for the EMA, "zeros" or "update".
-
target
(Literal
, default:'update'
) –target to apply EMA to. Defaults to 'update'.
Source code in torchzero/modules/momentum/momentum.py
IntermoduleCautious ¶
Bases: torchzero.core.module.Module
Negaties update on :code:main
module where it's sign doesn't match with output of :code:compare
module.
Parameters:
-
main
(Chainable
) –main module or sequence of modules whose update will be cautioned.
-
compare
(Chainable
) –modules or sequence of modules to compare the sign to.
-
normalize
(bool
, default:False
) –renormalize update after masking. Defaults to False.
-
eps
(float
, default:1e-06
) –epsilon for normalization. Defaults to 1e-6.
-
mode
(str
, default:'zero'
) –what to do with updates with inconsistent signs. - "zero" - set them to zero (as in paper) - "grad" - set them to the gradient (same as using update magnitude and gradient sign) - "backtrack" - negate them
Source code in torchzero/modules/momentum/cautious.py
MedianAveraging ¶
Bases: torchzero.core.transform.TensorwiseTransform
Median of past history_size
updates.
Parameters:
-
history_size
(int
) –Number of past updates to average
-
target
(Literal
, default:'update'
) –target. Defaults to 'update'.
Source code in torchzero/modules/momentum/averaging.py
NAG ¶
Bases: torchzero.core.transform.Transform
Nesterov accelerated gradient method (nesterov momentum).
Parameters:
-
momentum
(float
, default:0.9
) –momentum (beta). Defaults to 0.9.
-
dampening
(float
, default:0
) –momentum dampening. Defaults to 0.
-
lerp
(bool
, default:False
) –whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
-
target
(Literal
, default:'update'
) –target to apply EMA to. Defaults to 'update'.
Source code in torchzero/modules/momentum/momentum.py
ScaleByGradCosineSimilarity ¶
Bases: torchzero.core.transform.Transform
Multiplies the update by cosine similarity with gradient. If cosine similarity is negative, naturally the update will be negated as well.
Parameters:
-
eps
(float
, default:1e-06
) –epsilon for division. Defaults to 1e-6.
Examples:¶
Scaled Adam
opt = tz.Modular(
bench.parameters(),
tz.m.Adam(),
tz.m.ScaleByGradCosineSimilarity(),
tz.m.LR(1e-2)
)
Source code in torchzero/modules/momentum/cautious.py
ScaleModulesByCosineSimilarity ¶
Bases: torchzero.core.module.Module
Scales the output of :code:main
module by it's cosine similarity to the output
of :code:compare
module.
Parameters:
-
main
(Chainable
) –main module or sequence of modules whose update will be scaled.
-
compare
(Chainable
) –module or sequence of modules to compare to
-
eps
(float
, default:1e-06
) –epsilon for division. Defaults to 1e-6.
Examples:¶
Adam scaled by similarity to RMSprop
opt = tz.Modular(
bench.parameters(),
tz.m.ScaleModulesByCosineSimilarity(
main = tz.m.Adam(),
compare = tz.m.RMSprop(0.999, debiased=True),
),
tz.m.LR(1e-2)
)
Source code in torchzero/modules/momentum/cautious.py
UpdateGradientSignConsistency ¶
Bases: torchzero.core.transform.Transform
Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.
Parameters:
-
normalize
(bool
, default:False
) –renormalize update after masking. Defaults to False.
-
eps
(float
, default:1e-06
) –epsilon for normalization. Defaults to 1e-6.
Source code in torchzero/modules/momentum/cautious.py
WeightedAveraging ¶
Bases: torchzero.core.transform.TensorwiseTransform
Weighted average of past len(weights)
updates.
Parameters:
-
weights
(Sequence[float]
) –a sequence of weights from oldest to newest.
-
target
(Literal
, default:'update'
) –target. Defaults to 'update'.