Weight decay¶
This subpackage contains weight decay modules.
Classes:
-
CautiousWeightDecay–Cautious weight decay (https://arxiv.org/pdf/2510.12402).
-
DirectWeightDecay–Directly applies weight decay to parameters.
-
RandomReinitialize–On each step with probability
p_reinittrigger reinitialization, -
RelativeWeightDecay–Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of
norm_inputargument. -
WeightDecay–Weight decay.
Functions:
-
decay_weights_–directly decays weights in-place
CautiousWeightDecay ¶
Bases: torchzero.core.transform.TensorTransform
Cautious weight decay (https://arxiv.org/pdf/2510.12402).
Weight decay but only applied to updates where update sign matches weight decay sign.
Parameters:
-
weight_decay(float) –weight decay scale.
-
ord(int, default:2) –order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
-
target(Target) –what to set on var. Defaults to 'update'.
Examples:¶
Adam with non-decoupled cautious weight decay
opt = tz.Optimizer(
model.parameters(),
tz.m.CautiousWeightDecay(1e-3),
tz.m.Adam(),
tz.m.LR(1e-3)
)
Adam with decoupled cautious weight decay that still scales with learning rate
opt = tz.Optimizer(
model.parameters(),
tz.m.Adam(),
tz.m.CautiousWeightDecay(1e-3),
tz.m.LR(1e-3)
)
Adam with fully decoupled cautious weight decay that doesn't scale with learning rate
opt = tz.Optimizer(
model.parameters(),
tz.m.Adam(),
tz.m.LR(1e-3),
tz.m.CautiousWeightDecay(1e-6)
)
Source code in torchzero/modules/weight_decay/weight_decay.py
DirectWeightDecay ¶
Bases: torchzero.core.module.Module
Directly applies weight decay to parameters.
Parameters:
-
weight_decay(float) –weight decay scale.
-
ord(int, default:2) –order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
Source code in torchzero/modules/weight_decay/weight_decay.py
RandomReinitialize ¶
Bases: torchzero.core.module.Module
On each step with probability p_reinit trigger reinitialization,
whereby p_weights weights are reset to their initial values.
This modifies the parameters directly. Place it as the first module.
Parameters:
-
p_reinit(float, default:0.01) –probability to trigger reinitialization on each step. Defaults to 0.01.
-
p_weights(float, default:0.1) –probability for each weight to be set to initial value when reinitialization is triggered. Defaults to 0.1.
-
store_every(int | None, default:None) –if set, stores new initial values every this many steps. Defaults to None.
-
beta(float, default:0) –whenever
store_everyis triggered, uses linear interpolation with this beta. Ifstore_every=1, this can be set to some value close to 1 such as 0.999 to reinitialize to slow parameter EMA. Defaults to 0. -
reset(bool, default:False) –whether to reset states of other modules on reinitialization. Defaults to False.
-
seed(int | None, default:None) –random seed.
Source code in torchzero/modules/weight_decay/reinit.py
RelativeWeightDecay ¶
Bases: torchzero.core.transform.TensorTransform
Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of norm_input argument.
Parameters:
-
weight_decay(float, default:0.1) –relative weight decay scale.
-
ord(int, default:2) –order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
-
norm_input(str, default:'update') –determines what should weight decay be relative to. "update", "grad" or "params". Defaults to "update".
-
metric(Ords, default:'mad') –metric (norm, etc) that weight decay should be relative to. defaults to 'mad' (mean absolute deviation).
-
target(Target) –what to set on var. Defaults to 'update'.
Examples:¶
Adam with non-decoupled relative weight decay
opt = tz.Optimizer(
model.parameters(),
tz.m.RelativeWeightDecay(1e-1),
tz.m.Adam(),
tz.m.LR(1e-3)
)
Adam with decoupled relative weight decay
opt = tz.Optimizer(
model.parameters(),
tz.m.Adam(),
tz.m.RelativeWeightDecay(1e-1),
tz.m.LR(1e-3)
)
Source code in torchzero/modules/weight_decay/weight_decay.py
WeightDecay ¶
Bases: torchzero.core.transform.TensorTransform
Weight decay.
Parameters:
-
weight_decay(float) –weight decay scale.
-
ord(int, default:2) –order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
-
target(Target) –what to set on var. Defaults to 'update'.
Examples:¶
Adam with non-decoupled weight decay
Adam with decoupled weight decay that still scales with learning rate
Adam with fully decoupled weight decay that doesn't scale with learning rate
Source code in torchzero/modules/weight_decay/weight_decay.py
decay_weights_ ¶
directly decays weights in-place