torchzero basics¶
import torch
torch.manual_seed(0)
from torch import nn
from torch.nn import functional as F
import torchzero as tz
Performing optimization¶
In torchzero the optimization algorithm is represented as a sequence of modules, where each module is a distinct step in the optimization process.
To construct an optimizer, pass the modules to Modular
object, it can be a drop-in replacement for any PyTorch optimizer. All modules are available within the torchzero.m
namespace.
model = nn.Sequential(nn.Linear(10, 10), nn.ELU(), nn.Linear(10, 1))
inputs = torch.randn(100,10)
targets = torch.randn(100, 1)
optimizer = tz.Modular(
model.parameters(),
tz.m.ClipValue(1),
tz.m.Adam(),
tz.m.WeightDecay(1e-2),
tz.m.LR(1e-1)
)
Here is what happens:
The gradient is passed to the
ClipValue(1)
module, which returns gradient with magnitudes clipped to be no larger than 1.Clipped gradient is passed to
Adam()
, which updates Adam momentum buffers and returns the Adam update.The Adam update is passed to
WeightDecay()
which adds a weight decay penalty to the Adam update. Since we placed it after Adam, the weight decay is decoupled. By movingWeightDecay()
beforeAdam()
, we can get coupled weight decay.Finally the update is passed to
LR(0.1)
, which multiplies it by the learning rate of 0.1.
The optimization loop is the same as with any other pytorch optimizer:
for i in range(1, 101):
preds = model(inputs)
loss = F.mse_loss(preds, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 20 == 0: print(f"step: {i}, loss: {loss.item():.4f}")
step: 20, loss: 0.5673 step: 40, loss: 0.2976 step: 60, loss: 0.1544 step: 80, loss: 0.1057 step: 100, loss: 0.0926
LR schedulers¶
An LR scheduler works like with any pytorch optimizer as long as you add an LR
module, where the scheduling will happen.
model = nn.Sequential(nn.Linear(10, 10), nn.ELU(), nn.Linear(10, 1))
inputs = torch.randn(100,10)
targets = torch.randn(100, 1)
optimizer = tz.Modular(
model.parameters(),
tz.m.ClipValue(1),
tz.m.Adam(),
tz.m.WeightDecay(1e-2),
tz.m.LR(1e-1)
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=2e-1, total_steps=100, cycle_momentum=False)
for i in range(1, 101):
preds = model(inputs)
loss = F.mse_loss(preds, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
if i % 20 == 1:
print(f"step: {i}, loss: {loss.item():.4f}, lr: {optimizer.param_groups[0]['lr']:.4f}")
step: 1, loss: 1.0595, lr: 0.0086 step: 21, loss: 0.6385, lr: 0.1661 step: 41, loss: 0.3592, lr: 0.1858 step: 61, loss: 0.1495, lr: 0.1134 step: 81, loss: 0.0904, lr: 0.0309
Per-parameter settings¶
Per-parameter settings are specified in param groups, in the same way as in pytorch optimizers.
When a module is created, you can pass various settings to it, such as Adam(beta1=0.95, beta2=0.99)
.
Those settings can be overridden using param groups, which can be used to specify custom settings for only specific layers. Param groups should be a sequence of dictionaries, with each dictionary representing one parameter group. Each dictionary must have the "params"
key with an iterable of parameters, and other keys with custom settings for that param group.
model = nn.Sequential(nn.Linear(10, 10), nn.ELU(), nn.Linear(10, 1))
param_groups = [
{"params": model[0].parameters(), "lr": 1e-2}, # 1st linear
{"params": model[2].parameters(), "lr": 1e-1, "beta2": 0.95} # 2nd linear
]
optimizer = tz.Modular(
param_groups,
tz.m.ClipValue(1),
tz.m.Adam(beta2=0.99),
tz.m.WeightDecay(1e-2),
tz.m.LR(1e-1)
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=[1e-2, 1e-1], total_steps=100, cycle_momentum=False)
Advanced optimization¶
Certain modules require closure, for example line searches, trust region methods, gradient estimators and optimizers that rely on extra autograd. The closure is similar to one needed by L-BFGS in pytorch, however in torchzero it requires an additional backward
argument with the default value of True.
The closure evaluates and returns the loss. If backward=True
, it should also call optimizer.zero_grad()
and loss.backward()
.
For example, we can use Newton's method with cubic regularization to greatly speed up small scale optimization.
model = nn.Sequential(nn.Linear(10, 10), nn.ELU(), nn.Linear(10, 1))
inputs = torch.randn(100,10)
targets = torch.randn(100, 1)
optimizer = tz.Modular(
model.parameters(),
tz.m.CubicRegularization(tz.m.Newton()),
)
for i in range(1, 51):
def closure(backward=True):
preds = model(inputs)
loss = F.mse_loss(preds, targets)
if backward:
optimizer.zero_grad()
loss.backward()
return loss
loss = optimizer.step(closure)
if i % 10 == 0:
print(f"step: {i}, loss: {loss.item():.4f}")
step: 10, loss: 0.6304 step: 20, loss: 0.1783 step: 30, loss: 0.0160 step: 40, loss: 0.0004 step: 50, loss: 0.0000