torchzero basics¶
import torch
torch.manual_seed(0)
from torch import nn
from torch.nn import functional as F
import torchzero as tz
Performing optimization¶
In torchzero the optimization algorithm is represented as a sequence of modules.
To construct an optimizer, pass the modules to Modular
object, it can be a drop-in replacement for any PyTorch optimizer. All modules are available within the torchzero.m
namespace.
model = nn.Sequential(nn.Linear(10, 10), nn.ELU(), nn.Linear(10, 1))
inputs = torch.randn(100,10)
targets = torch.randn(100, 1)
optimizer = tz.Modular(
model.parameters(),
tz.m.ClipValue(1),
tz.m.Adam(),
tz.m.WeightDecay(1e-2),
tz.m.LR(1e-1)
)
Here is what happens:
The gradient is passed to the
ClipValue(1)
module, which returns gradient with magnitudes clipped to be no larger than 1.Clipped gradient is passed to
Adam()
, which updates Adam momentum buffers and returns the Adam update.The Adam update is passed to
WeightDecay()
which adds a weight decay penalty to the Adam update. Since we placed it after Adam, the weight decay is decoupled. By movingWeightDecay()
beforeAdam()
, we can get coupled weight decay.Finally the update is passed to
LR(0.1)
, which multiplies it by the learning rate of 0.1.
The optimization loop is the same as with any other pytorch optimizer:
for i in range(1, 101):
preds = model(inputs)
loss = F.mse_loss(preds, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 20 == 0: print(f"step: {i}, loss: {loss.item():.4f}")
step: 20, loss: 0.5673 step: 40, loss: 0.2976 step: 60, loss: 0.1544 step: 80, loss: 0.1057 step: 100, loss: 0.0926
LR schedulers¶
An LR scheduler works like with any pytorch optimizer as long as you add an LR
module, where the scheduling will happen.
model = nn.Sequential(nn.Linear(10, 10), nn.ELU(), nn.Linear(10, 1))
inputs = torch.randn(100,10)
targets = torch.randn(100, 1)
optimizer = tz.Modular(
model.parameters(),
tz.m.ClipValue(1),
tz.m.Adam(),
tz.m.WeightDecay(1e-2),
tz.m.LR(1e-1)
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=2e-1, total_steps=100, cycle_momentum=False)
for i in range(1, 101):
preds = model(inputs)
loss = F.mse_loss(preds, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
if i % 20 == 1:
print(f"step: {i}, loss: {loss.item():.4f}, lr: {optimizer.param_groups[0]['lr']:.4f}")
step: 1, loss: 1.0595, lr: 0.0086 step: 21, loss: 0.6385, lr: 0.1661 step: 41, loss: 0.3592, lr: 0.1858 step: 61, loss: 0.1495, lr: 0.1134 step: 81, loss: 0.0904, lr: 0.0309
Per-parameter settings¶
Per-parameter settings are specified in param groups, in the same way as in pytorch optimizers. If a module has a setting, such as "beta2" in Adam, it will use the setting provided in the parameter groups, 0.95 in the example below. If a setting isn't provided, it will use the value passed on initialization, so the first linear layer will have beta2=0.99
.
model = nn.Sequential(nn.Linear(10, 10), nn.ELU(), nn.Linear(10, 1))
param_groups = [
{"params": model[0].parameters(), "lr": 1e-2}, # 1st linear
{"params": model[2].parameters(), "lr": 1e-1, "beta2": 0.95} # 2nd linear
]
optimizer = tz.Modular(
param_groups,
tz.m.ClipValue(1),
tz.m.Adam(beta2=0.99),
tz.m.WeightDecay(1e-2),
tz.m.LR(1e-1)
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=[1e-2, 1e-1], total_steps=100, cycle_momentum=False)
Advanced optimization¶
Certain modules require closure, for example line searches, trust region methods, gradient estimators and optimizers that rely on extra autograd. The closure is similar to one needed by L-BFGS in pytorch, however in torchzero it requires an additional backward
argument with the default value of True.
The closure evaluates and returns the loss. If backward=True
, it should also call optimizer.zero_grad()
and loss.backward()
.
For example, we can use Newton's method with cubic regularization to greatly speed up small scale optimization.
model = nn.Sequential(nn.Linear(10, 10), nn.ELU(), nn.Linear(10, 1))
inputs = torch.randn(100,10)
targets = torch.randn(100, 1)
optimizer = tz.Modular(
model.parameters(),
tz.m.CubicRegularization(tz.m.Newton()),
)
for i in range(1, 51):
def closure(backward=True):
preds = model(inputs)
loss = F.mse_loss(preds, targets)
if backward:
optimizer.zero_grad()
loss.backward()
return loss
loss = optimizer.step(closure)
if i % 10 == 0:
print(f"step: {i}, loss: {loss.item():.4f}")
step: 10, loss: 0.6304 step: 20, loss: 0.1783 step: 30, loss: 0.0160 step: 40, loss: 0.0004 step: 50, loss: 0.0000