import numpy as np
import matplotlib.pyplot as plt
import torch
torch.manual_seed(0)
from torch import nn
import torch.nn.functional as F
import torchzero as tz
from visualbench import FunctionDescent, test_functions
10. Variance reduction¶
10.1 Online methods¶
Many methods use differences between consecutive gradients which makes them not suitable for mini-batch optimization, for example all quasi-newton methods, conjugate gradient methods. If consecutive gradients are sampled from different mini-batches, the difference between them will include difference between mini-batches most often causing such methods to fail.
Denote $\nabla f(x_t, \xi_t)$ as gradient at parameters $x_t$ with mini-batch $\xi_t$. When a method uses difference between consecutive gradients $\Delta g_t = \nabla f(x_{t}) - \nabla f(x_{t-1})$, it expects those two gradients to be sampled from the same deterministic objective function. But with mini-batching it receives: $$\Delta g_t = \nabla f(x_{t}, \xi_t) - \nabla f(x_{t-1}, \xi_{t-1}),$$ where two gradients are sampled from different mini-batches $\xi_t$ and $\xi_{t-1}$ - those are different sub-samples of the objective function. A natural solution is, after receiving new mini-batch $\xi_{t}$, to use an extra backward pass to evaluate gradient at current mini-batch and previous parameters $\nabla f(x_{t-1}, \xi_{t})$ and use it instead of gradient at previous mini-batch and previous parameters $\nabla f(x_{t-1}, \xi_{t-1})$: $$\Delta g_t = \nabla f(x_{t}, \xi_t) - \nabla f(x_{t-1}, \xi_{t}),$$
This way both gradients used to calculated gradient difference are sampled from the same mini-batch $\xi_t$. Algorithms modified to handle mini-batching like that are often called online (e.g. Online-LBFGS). In torchzero a method needs to be wrapped into tz.m.Online module to become online. A module needs to define reset_for_online method to be compatible with tz.m.Online, which most relevant modules define.
For visualization we use booth function where function value and gradient are evaluated at current point plus a random perturbation to emulate mini-batching.
fig, ax = plt.subplots(ncols=2, figsize=(12,6))
ax = np.ravel(ax)
func = FunctionDescent('booth').set_noise(1)
optimizer = tz.Optimizer(func.parameters(), tz.m.LBFGS(), tz.m.LR(1e-1),)
func.run(optimizer, max_steps=100)
func.plot(log_contour=True, ax=ax[0])
ax[0].set_title("LBFGS")
func = FunctionDescent('booth').set_noise(1)
optimizer = tz.Optimizer(func.parameters(),tz.m.Online(tz.m.LBFGS()), tz.m.LR(1e-1),)
func.run(optimizer, max_steps=100)
func.plot(log_contour=True, ax=ax[1])
ax[1].set_title("Online-LBFGS")
plt.show()
finished in 0.3s., reached loss = 128.554 finished in 0.4s., reached loss = 0.000679
10.2 SVRG¶
Stochastic variance reduced gradient method (SVRG) is a variance-reduction method for convex optimization that uses a "snapshot" of full-batch gradient. Here is the algorithm:
Compute full-batch gradient, for example via gradient accumulation. If new samples are generated on the fly (so there is no "entire dataset"), it may be sufficient to just calculate gradient for a large number of samples. Denote the full gradient as $\nabla f_{\text{full}}(\tilde{x})$, it is called snapshot. Note that it was evaluated at parameters $\tilde{x}$, at the beginning $\tilde{x}=x_0$.
Now that we have full gradient at $\tilde{x}$, we can start performing optimization. The variance-reduced gradient at parameters $x_t$ and mini-batch $\xi_t$ is computed as:
$$ \nabla f_{\text{SVRG}}(x_{t}, \xi_{t}) = \nabla f(x_{t}, \xi_{t}) + \alpha (\nabla f_{\text{full}}(\tilde{x}) - \nabla f(\tilde{x}, \xi_{t})) $$
Here $\alpha \in (0, 1]$ is a hyperparameter that determines the amount of variance reduction and is usually set to 1.
Use any gradient-based method to optimize for some number of steps using variance-reduced gradients computed by this formula. Typically one epoch of variance-reduced optimization is performed so that all mini-batches are used.
After optimizing for specified number of steps, set $\tilde{x} \leftarrow x_t$ and start from step 1.
The formula computes difference between full and mini-batch gradient at $\tilde{x}$, which estimates difference between full and mini-batch gradient at $x_t$: $$ \nabla f_{\text{full}}(\tilde{x}) - \nabla f(\tilde{x}, \xi_{t}) \approx \nabla f_{\text{full}}(x_t) - \nabla f(x_t, \xi_{t}) $$ So with $\alpha = 1$ SVRG gradient approximates full-batch gradient: $$ \nabla f_{\text{SVRG}}(x_{t}, \xi_{t}) = \nabla f(x_{t}, \xi_{t}) + \nabla f_{\text{full}}(\tilde{x}) - \nabla f(\tilde{x}, \xi_{t}) \approx \nabla f(x_{t}, \xi_{t}) + \nabla f_{\text{full}}(x_t) - \nabla f(x_t, \xi_{t}) = \nabla f_{\text{full}}(x_t) $$
Computing variance-reduced gradient requires an extra backward pass to compute $\nabla f(\tilde{x}, \xi_{t})$.
In torchzero to use SVRG, put tz.m.SVRG as the first module. It includes optional gradient accumulation to calculate full gradient. It has two main parameters:
accum_stepsdetermines number of steps to accumulate the gradients for when calculating full-batch gradient (step 1);svrg_stepsdetermines number of optimization steps with variance-reduced gradient to perform (step 2) before restarting (step 3).
By default accum_steps is set to the same value as svrg_steps, and a good value to pass it length of train dataloader. It is also possible to pass full_closure argument to step method if you don't want to use gradient accumulation and wish to calculate full gradients manually.
For the booth function we perform 100 gradient accumulation steps and then 100 LBFGS steps with variance-reduced gradients.
fig, ax = plt.subplots(ncols=2, figsize=(12,6))
ax = np.ravel(ax)
func = FunctionDescent('booth').set_noise(1)
optimizer = tz.Optimizer(func.parameters(), tz.m.LBFGS(), tz.m.LR(1e-1),)
func.run(optimizer, max_steps=200)
func.plot(log_contour=True, ax=ax[0])
ax[0].set_title("LBFGS")
func = FunctionDescent('booth').set_noise(1)
optimizer = tz.Optimizer(func.parameters(),tz.m.SVRG(100), tz.m.LBFGS(), tz.m.LR(1e-1),)
func.run(optimizer, max_steps=200)
func.plot(log_contour=True, ax=ax[1], line_alpha=0.5)
ax[1].set_title("SVRG-LBFGS")
plt.show()
finished in 0.8s., reached loss = 127.376 finished in 0.6s., reached loss = 0.0716