Weight decay¶
This subpackage contains weight decay modules.
Classes:
-
DirectWeightDecay
–Directly applies weight decay to parameters.
-
RandomReinitialize
–On each step with probability
p_reinit
trigger reinitialization, -
RelativeWeightDecay
–Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of
norm_input
argument. -
WeightDecay
–Weight decay.
Functions:
-
decay_weights_
–directly decays weights in-place
DirectWeightDecay ¶
Bases: torchzero.core.module.Module
Directly applies weight decay to parameters.
Parameters:
-
weight_decay
(float
) –weight decay scale.
-
ord
(int
, default:2
) –order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
Source code in torchzero/modules/weight_decay/weight_decay.py
RandomReinitialize ¶
Bases: torchzero.core.module.Module
On each step with probability p_reinit
trigger reinitialization,
whereby p_weights
weights are reset to their initial values.
This modifies the parameters directly. Place it as the first module.
Parameters:
-
p_reinit
(float
, default:0.01
) –probability to trigger reinitialization on each step. Defaults to 0.01.
-
p_weights
(float
, default:0.1
) –probability for each weight to be set to initial value when reinitialization is triggered. Defaults to 0.1.
-
store_every
(int | None
, default:None
) –if set, stores new initial values every this many steps. Defaults to None.
-
beta
(float
, default:0
) –whenever
store_every
is triggered, uses linear interpolation with this beta. Ifstore_every=1
, this can be set to some value close to 1 such as 0.999 to reinitialize to slow parameter EMA. Defaults to 0. -
reset
(bool
, default:False
) –whether to reset states of other modules on reinitialization. Defaults to False.
-
seed
(int | None
, default:None
) –random seed.
Source code in torchzero/modules/weight_decay/reinit.py
RelativeWeightDecay ¶
Bases: torchzero.core.transform.TensorTransform
Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of norm_input
argument.
Parameters:
-
weight_decay
(float
, default:0.1
) –relative weight decay scale.
-
ord
(int
, default:2
) –order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
-
norm_input
(str
, default:'update'
) –determines what should weight decay be relative to. "update", "grad" or "params". Defaults to "update".
-
metric
(Ords
, default:'mad'
) –metric (norm, etc) that weight decay should be relative to. defaults to 'mad' (mean absolute deviation).
-
target
(Target
) –what to set on var. Defaults to 'update'.
Examples:¶
Adam with non-decoupled relative weight decay
opt = tz.Optimizer(
model.parameters(),
tz.m.RelativeWeightDecay(1e-1),
tz.m.Adam(),
tz.m.LR(1e-3)
)
Adam with decoupled relative weight decay
opt = tz.Optimizer(
model.parameters(),
tz.m.Adam(),
tz.m.RelativeWeightDecay(1e-1),
tz.m.LR(1e-3)
)
Source code in torchzero/modules/weight_decay/weight_decay.py
WeightDecay ¶
Bases: torchzero.core.transform.TensorTransform
Weight decay.
Parameters:
-
weight_decay
(float
) –weight decay scale.
-
ord
(int
, default:2
) –order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
-
target
(Target
) –what to set on var. Defaults to 'update'.
Examples:¶
Adam with non-decoupled weight decay
Adam with decoupled weight decay that still scales with learning rate
Adam with fully decoupled weight decay that doesn't scale with learning rate
Source code in torchzero/modules/weight_decay/weight_decay.py
decay_weights_ ¶
directly decays weights in-place