import numpy as np
import matplotlib.pyplot as plt
import torch
torch.manual_seed(0)
from torch import nn
import torch.nn.functional as F
import torchzero as tz
from visualbench import FunctionDescent, test_functions
12. Gradient free methods¶
12.1 Introduction¶
Gradient-free methods use only function values for minimization, therefore they are suitable for problems where gradients are not known, aren't useful or are too expensive to calculate.
There is an enormous number of various gradient-free methods and one notebook is definitely not enough to cover them all. Torchzero only implements a few (for now), but also provides wrappers for some other gradient-free optimization libraries with a lot of methods implemented, here I will mostly focus on methods available in torchzero.
12.2 Gradient approximations¶
When gradients are not available, one strategy is to estimate them using function values and use any of the gradient-based methods using the approximated gradients.
12.2.1 Finite difference estimator¶
The finite difference estimator (FDM) loops over each parameter, adds a small perturbation to it and evaluates the function value, therefore it requires at least $n$ evaluations to estimate the gradient where $n$ is number of parameters. There are various finite difference formulas - 2-point forward/backward, 3-point central, 3-point forward/backward, 4 point central, etc.
The 3-point central is widely used. To estimate gradient of $i$-th parameter $x_i$, it evaluates function at $f(x_i - h)$ and $f(x_i + h)$, where $h$ is a hyperparameter controlling accuracy of the estimation. If precision was infinite, smaller $h$ means more accurate estimation, however due to finite precision $h$ can't be too small. Also if $h$ is large, it can have an effect of smoothing the function which is called implicit filtering, and can also be useful for functions with very rough surface. The formula for gradient of $i$-th parameter $g_i$ is the following:
$$ \hat{g_i} = \frac{f(x_i + h) - f(x_i - h)}{2 \cdot h} $$
This formula has to be ran for each parameter, therefore $2n$ evaluations are required.
2-point forward formula is the following: $$ \hat{g_i} = \frac{f(x_i + h) - f(x_i)}{h} $$
This formula is less accurate, but by evaluating $f(x_i)$ once and reusing when estimating gradient of each parameter, it requires $n+1$ evaluations.
In torchzero to use FDM-approximated gradients, add tz.m.FDM
as the first module.
fig, ax = plt.subplots(ncols=3, figsize=(15,5))
ax = np.ravel(ax)
func = FunctionDescent('rosen')
optimizer = tz.Optimizer(func.parameters(), tz.m.FDM(formula='forward2'), tz.m.BFGS(), tz.m.Backtracking())
func.run(optimizer, max_steps=50)
func.plot(log_contour=True, ax=ax[0])
ax[0].set_title("2-point forward - BFGS")
func = FunctionDescent('rosen')
optimizer = tz.Optimizer(func.parameters(), tz.m.FDM(formula='central3'), tz.m.BFGS(), tz.m.Backtracking())
func.run(optimizer, max_steps=50)
func.plot(log_contour=True, ax=ax[1])
ax[1].set_title("3-point central - BFGS")
func = FunctionDescent('rosen')
optimizer = tz.Optimizer(func.parameters(), tz.m.BFGS(), tz.m.Backtracking())
func.run(optimizer, max_steps=50)
func.plot(log_contour=True, ax=ax[2])
ax[2].set_title("true gradient - BFGS")
plt.show()
finished in 0.2s., reached loss = 0.0424 finished in 0.2s., reached loss = 3e-08 finished in 0.2s., reached loss = 0
12.2.2 Randomized finite difference estimator¶
Randomized finite difference estimator (RFDM) perturbs all parameters at once in a random direction.
$$ \hat{g} = p\frac{(f(x + h p) - f(x - h p))}{2 h} $$
Here $h$ controls accuracy of the approximation, $p$ is a random perturbation with zero mean and variance of 1. This formula is direct equivalent of 3-point central formula, and other formulas can be used too.
This formula essentially estimates directional derivative in the direction $p$ and multiplies $p$ by it, which is a very rough estimate for full gradient.
Often $p$ is sample from the Rademacher distribution, so every value has 50% chance of being 1 and 50% of being -1, leading to the very popular Simultaneous perturbation stochastic approximation (SPSA) method. SPSA formula is often written with $p$ in the denominator (since multiplying and dividing by 1 and -1 is equivalent).
If $p$ isn' sampled from Rademacher distribution, we get Random direction stochastic approximation (RDSA) method.
It is possible to calculate $\hat{g}$ multiple times with different random perturbations $p$, and then take the average. Then that average is an estimate of gradient in a subspace spanned by perturbations $p$, which is a slightly better estimate. The gaussian smoothing method averages multiple $p$ sampled from gaussian distribution, and $h$ can be made larger leading to the effect of smoothing the function.
RFDM doesn't suffer from having to perform $n$ evaluations per step just to approximate the gradient like in FDM. However the approximation is very rough and won't work with methods that rely on gradient differences, such as conjugate gradient and quasi-newton methods. It works with momentum and adaptive methods.
In torchzero to use RFDM-approximated gradients, add tz.m.RandomizedFDM
as the first module. Or use one of RandomizedFDM
subclasses - tz.m.SPSA
, tz.m.RDSA
and tz.m.GaussianSmoothing
.
fig, ax = plt.subplots(ncols=3, figsize=(15,5))
ax = np.ravel(ax)
func = FunctionDescent('rosen')
optimizer = tz.Optimizer(func.parameters(), tz.m.SPSA(seed=0), tz.m.Adam(0.9, 0.95), tz.m.LR(2e-1))
func.run(optimizer, max_steps=500)
func.plot(log_contour=True, ax=ax[0])
ax[0].set_title("SPSA-Adam")
func = FunctionDescent('rosen')
optimizer = tz.Optimizer(func.parameters(), tz.m.RDSA(seed=1), tz.m.Adam(0.9, 0.95), tz.m.LR(2e-1))
func.run(optimizer, max_steps=500)
func.plot(log_contour=True, ax=ax[1])
ax[1].set_title("RDSA-Adam")
func = FunctionDescent('rosen')
optimizer = tz.Optimizer(func.parameters(), tz.m.GaussianSmoothing(n_samples=10, seed=0), tz.m.Adam(0.9, 0.95), tz.m.LR(2e-1))
func.run(optimizer, max_steps=500)
func.plot(log_contour=True, ax=ax[2])
ax[2].set_title("GaussianSmoothing-Adam")
plt.show()
finished in 0.9s., reached loss = 0.112 finished in 0.9s., reached loss = 0.00129 finished in 2.7s., reached loss = 0.000152
12.2.3 MeZO¶
MeZO is a version of SPSA which uses the same formula, but the random perturbation $p$ is never stored in memory, instead only a seed used to generate it is stored and it is generated from the seed whenever needed. It has been proposed to fine-tune large language models when very limited memory is available. SPSA requires $2n$ extra memory, MeZO requires $n$ extra memory, where $n$ is number of parameters. Theoretically it could use almost no extra memory by generating and subtracting $p$ gradually weight by weight, not at once, but I think that would be hard to implement in a way where it isn't slow.
func = FunctionDescent('rosen')
optimizer = tz.Optimizer(
func.parameters(),
tz.m.MeZO(),
tz.m.LR(1e-3),
)
func.run(optimizer, max_steps=2000)
func.plot()
finished in 3.5s., reached loss = 0.0593
<Axes: >
12.2.4 Forward gradient¶
Instead of having to tune finite difference parameter $h$, exact directional derivative in the direction $p$ can be calculated using forward mode autograd via a jacobian-vector product (Jvp). Jvps can be cheap, in PyTorch they are experimental but still use way less memory than backward passes.
The formula is $$ \hat{g} = p \cdot \nabla f(x)^T p $$
Here $p$ is a random vector, $\nabla f(x)^T p$ is directional derivative in direction $p$ which can be calculated without calculating full gradient $\nabla f(x)$.
This approximation has been called "Forward gradient".
func = FunctionDescent('rosen')
optimizer = tz.Optimizer(
func.parameters(),
tz.m.ForwardGradient(seed=1),
tz.m.Adam(0.9, 0.95),
tz.m.LR(2e-1),
)
func.run(optimizer, max_steps=500)
func.plot()
finished in 1.0s., reached loss = 0.0937
<Axes: >
12.2.5 2nd order SPSA¶
2SG is a quasi-newton method that is able to work with stochastic gradients, including SPSA gradient estimates - resulting method is called 2SPSA, or second-order SPSA. With true stochastic gradients it requires three gradients per step; with SPSA gradients it uses four function evaluations per step.
func = FunctionDescent('booth')
optimizer = tz.Optimizer(
func.parameters(),
tz.m.SPSA(),
tz.m.SG2(seed=0),
tz.m.Warmup(10),
tz.m.LR(1e-1),
)
func.run(optimizer, max_steps=200)
func.plot(log_contour=True)
finished in 0.9s., reached loss = 2.27e-13
<Axes: >
12.2.6 Finite difference hessian-vector products¶
Certain second order optimizers such as NewtonCG and sketched Newton do not require the full hessian, they require just the hessian-vector products (Hvps) which can be computed efficiently via autograd. But in many case autograd fails because some rules aren't implemented, or maybe gradients are computed without autograd, then hessian-vector products can be estimated in one or two extra gradient computations using finite difference formulas. And, of course, we can use finite difference gradient approximations too, leading to a zeroth-order approximation to Newton's method.
The forward formula for estimating a hessian-vector product $Hv$ with vector $v$ is this: $$ \widehat{Hv} = \frac{\nabla f(x + hv) - \nabla f(x)}{h} $$
Here $h$ controls accuracy of the approximation. By pre-computing $\nabla f(x)$, it can be re-used for each subsequent hessian-vector product, leading to requiring one extra gradient computation per Hvp.
A more accurate central formula requires two extra gradient computations per Hvp: $$ \widehat{Hv} = \frac{\nabla f(x + hv) - \nabla f(x - hv)}{2h} $$
When using those formulas we can use finite difference approximation to $\nabla f(x)$.
In torchzero modules that use hessian-vector products have a hvp_method
argument which can be set to "forward"
to use forward formula or "central"
to use central formula. By default it is usually set to "autograd"
and uses automatic differentiation.
func = FunctionDescent('rosen')
optimizer = tz.Optimizer(
func.parameters(),
tz.m.FDM(),
tz.m.NewtonCG(hvp_method='fd_central'),
tz.m.Backtracking(),
)
func.run(optimizer, max_steps=20)
func.plot()
finished in 0.3s., reached loss = 3.99e-08
<Axes: >
12.3 Other libraries¶
There are many python libraries that implement various zeroth-order methods including evolutionary algorithms, bayesian optimization, direct search, etc. Torchzero implements wrappers for a few of them, allowing to use them as pytorch optimizers.
Since I haven't studied those methods, I won't give a detailed description unless I implement them as torchzero modules.
12.3.1 scipy.optimize.minimize¶
scipy.optimize.minimize
implements the following local optimization methods: 'nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp', 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact', 'trust-krylov'. Most first and second order methods are also implemented in torchzero, but zeroth-order methods are not (as of yet) - Nelder-Mead, Powell's method, COBYLA and COBYQA.
Those methods are all local search methods and suitable for functions with a single global minima.
The wrapper for it is torchzero.optim.wrappers.scipy.ScipyMinimize
. Note that scipy doesn't support performing a single step - it optimizes until stopping criterion is reached. Therefore a single step with ScipyMinimize
will perform a full minimization. The nevergrad wrapper (described later) has some of those methods and supports performing a single step.
from torchzero.optim.wrappers.scipy import ScipyMinimize
fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(14,14))
ax = np.ravel(ax)
func = FunctionDescent('rosen')
optimizer = ScipyMinimize(func.parameters(), method='nelder-mead')
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[0])
ax[0].set_title("Nelder-Mead")
func = FunctionDescent('rosen')
optimizer = ScipyMinimize(func.parameters(), method='powell')
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[1])
ax[1].set_title("Powell's method")
func = FunctionDescent('rosen')
optimizer = ScipyMinimize(func.parameters(), method='cobyla')
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[2])
ax[2].set_title("COBYLA")
func = FunctionDescent('rosen')
optimizer = ScipyMinimize(func.parameters(), method='cobyqa')
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[3])
ax[3].set_title("COBYQA")
plt.show()
finished in 0.1s., reached loss = 6.88e-10 finished in 0.3s., reached loss = 1.42e-14 finished in 2.8s., reached loss = 0.124 finished in 0.5s., reached loss = 1.78e-12
12.3.2 Other scipy.optimize methods¶
scipy.optimize
also implements differential evolution, dual annealing, SHGO, DIRECT, basin-hopping and brute search, all of them are global optimization methods and are suitable for optimizing functions with many local minima, and all of them except basin-hoping require box bounds to be specified.
All of those methods also optimize until stopping criterion is specified, so usually a single step should be performed, and if you want a better control over steps, use the nevergrad wrapper.
from torchzero.optim.wrappers.scipy import ScipyBrute, ScipyDE, ScipyDIRECT, ScipySHGO, ScipyDualAnnealing, ScipyBasinHopping
fig, ax = plt.subplots(ncols=3, nrows=2, figsize=(18,13))
ax = np.ravel(ax)
func = FunctionDescent('rosen')
optimizer = ScipyDE(func.parameters(), lb=-3, ub=3, seed=0)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[0], line_alpha=0.2)
ax[0].set_title("ScipyDE")
func = FunctionDescent('rosen')
optimizer = ScipyDualAnnealing(func.parameters(), lb=-3, ub=3, rng=0)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[1], line_alpha=0.2)
ax[1].set_title("ScipyDualAnnealing")
func = FunctionDescent('rosen')
optimizer = ScipySHGO(func.parameters(), lb=-3, ub=3, iters=100)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[2], line_alpha=0.2)
ax[2].set_title("ScipySHGO")
func = FunctionDescent('rosen')
optimizer = ScipyDIRECT(func.parameters(), lb=-3, ub=3)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[3], line_alpha=0.2)
ax[3].set_title("ScipyDIRECT")
func = FunctionDescent('rosen')
optimizer = ScipyBasinHopping(func.parameters(), niter=1000, rng=0)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[4], line_alpha=0.2)
ax[4].set_title("ScipyBasinHopping")
func = FunctionDescent('rosen')
optimizer = ScipyBrute(func.parameters(), lb=-3, ub=3)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[5], line_alpha=0.2)
ax[5].set_title("ScipyBrute")
plt.show()
finished in 1.0s., reached loss = 0 finished in 1.4s., reached loss = 3.59e-13 finished in 2.0s., reached loss = 3.55e-15 finished in 0.2s., reached loss = 3.23e-10 finished in 27.1s., reached loss = 0 finished in 0.1s., reached loss = 5.13e-10
12.3.3 NLopt¶
NLopt is another optimization library with many gradient based and gradient free methods. The algorithms are listed here https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/.
The wrapper for NLOpt is torchzero.optim.wrappers.nlopt.NLOptWrapper
. Like scipy, a single step performs full minimization. Make sure to pass some stopping criterion to NLOptWrapper
such as maxeval
, and some methods require bounds to be specified and without bounds they return the initial point.
from torchzero.optim.wrappers.nlopt import NLOptWrapper
fig, ax = plt.subplots(ncols=3, nrows=4, figsize=(18,22))
ax = np.ravel(ax)
func = FunctionDescent('rosen')
optimizer = NLOptWrapper(func.parameters(), 'GN_DIRECT_L', lb=-3, ub=3, maxeval=1000)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[0], line_alpha=0.2)
ax[0].set_title("GN_DIRECT_L")
func = FunctionDescent('rosen')
optimizer = NLOptWrapper(func.parameters(), 'GN_CRS2_LM', lb=-3, ub=3, maxeval=1000)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[1], line_alpha=0.2)
ax[1].set_title("GN_CRS2_LM")
func = FunctionDescent('rosen')
optimizer = NLOptWrapper(func.parameters(), 'GN_AGS', lb=-3, ub=3, maxeval=5000)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[2], line_alpha=0.2)
ax[2].set_title("GN_AGS")
func = FunctionDescent('rosen')
optimizer = NLOptWrapper(func.parameters(), 'GN_ISRES', lb=-3, ub=3, maxeval=5000)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[3], line_alpha=0.2)
ax[3].set_title("GN_ISRES")
func = FunctionDescent('rosen')
optimizer = NLOptWrapper(func.parameters(), 'GN_ESCH', lb=-3, ub=3, maxeval=1000)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[4], line_alpha=0.2)
ax[4].set_title("GN_ESCH")
func = FunctionDescent('rosen')
optimizer = NLOptWrapper(func.parameters(), 'LN_COBYLA', maxeval=1000)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[5], line_alpha=0.2)
ax[5].set_title("LN_COBYLA")
func = FunctionDescent('rosen')
optimizer = NLOptWrapper(func.parameters(), 'LN_BOBYQA', maxeval=1000)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[6], line_alpha=0.2)
ax[6].set_title("LN_BOBYQA")
func = FunctionDescent('rosen')
optimizer = NLOptWrapper(func.parameters(), 'LN_NEWUOA', maxeval=1000)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[7], line_alpha=0.2)
ax[7].set_title("LN_NEWUOA")
func = FunctionDescent('rosen')
optimizer = NLOptWrapper(func.parameters(), 'LN_PRAXIS', maxeval=1000)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[8], line_alpha=0.2)
ax[8].set_title("LN_PRAXIS")
func = FunctionDescent('rosen')
optimizer = NLOptWrapper(func.parameters(), 'LN_NELDERMEAD', maxeval=1000)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[9], line_alpha=0.2)
ax[9].set_title("LN_NELDERMEAD")
func = FunctionDescent('rosen')
optimizer = NLOptWrapper(func.parameters(), 'LN_SBPLX', maxeval=1000)
func.run(optimizer, max_steps=1)
func.plot(log_contour=True, ax=ax[10], line_alpha=0.2)
ax[10].set_title("LN_SBPLX")
plt.show()
finished in 0.3s., reached loss = 3.43e-09 finished in 0.3s., reached loss = 3.25e-09 finished in 1.2s., reached loss = 0.00429 finished in 1.2s., reached loss = 0.0047 finished in 0.3s., reached loss = 0.005 finished in 0.2s., reached loss = 0.0234 finished in 0.1s., reached loss = 0 finished in 0.1s., reached loss = 8.88e-14 finished in 0.2s., reached loss = 0 finished in 0.1s., reached loss = 0 finished in 0.1s., reached loss = 2.88e-13
12.3.4 Nevergrad¶
Nevergrad implements a large number of zeroth order algorithms including wrappers for NLOpt and Scipy. All algorithms are listed here https://facebookresearch.github.io/nevergrad/optimizers_ref.html#optimizer-api.
The wrapper is torchzero.optim.wrappers.nevergrad.NevergradWrapper
. Here the step method actually performs a single step, i.e. a single objective function evaluation, so it may be more convenient to use as pytorch optimizer. Some methods in nevergrad
require a budget to be specified - maximum number of objective function evaluations, and will raise an exception if it is not given.
import nevergrad as ng
from torchzero.optim.wrappers.nevergrad import NevergradWrapper
fig, ax = plt.subplots(ncols=3, nrows=3, figsize=(18,18))
ax = np.ravel(ax)
func = FunctionDescent('rosen')
optimizer = NevergradWrapper(func.parameters(), ng.optimizers.CMA, budget=1001)
func.run(optimizer, max_steps=1000)
func.plot(log_contour=True, ax=ax[0], line_alpha=0.2)
ax[0].set_title("CMA")
func = FunctionDescent('rosen')
optimizer = NevergradWrapper(func.parameters(), ng.optimizers.DiagonalCMA, budget=1001)
func.run(optimizer, max_steps=1000)
func.plot(log_contour=True, ax=ax[1], line_alpha=0.2)
ax[1].set_title("DiagonalCMA")
func = FunctionDescent('rosen')
optimizer = NevergradWrapper(func.parameters(), ng.optimizers.cGA, budget=10001)
func.run(optimizer, max_steps=10000)
func.plot(log_contour=True, ax=ax[2], line_alpha=0.2)
ax[2].set_title("CGA")
func = FunctionDescent('rosen')
optimizer = NevergradWrapper(func.parameters(), ng.optimizers.ES, budget=10001)
func.run(optimizer, max_steps=10000)
func.plot(log_contour=True, ax=ax[3], line_alpha=0.2)
ax[3].set_title("ES")
func = FunctionDescent('rosen')
optimizer = NevergradWrapper(func.parameters(), ng.optimizers.PSO, budget=2001)
func.run(optimizer, max_steps=2000)
func.plot(log_contour=True, ax=ax[4], line_alpha=0.2)
ax[4].set_title("PSO")
func = FunctionDescent('rosen')
optimizer = NevergradWrapper(func.parameters(), ng.optimizers.EDA, budget=2001)
func.run(optimizer, max_steps=2000)
func.plot(log_contour=True, ax=ax[5], line_alpha=0.2)
ax[5].set_title("EDA")
func = FunctionDescent('rosen')
optimizer = NevergradWrapper(func.parameters(), ng.optimizers.HammersleySearch, budget=1001, lb=-3, ub=3)
func.run(optimizer, max_steps=1000)
func.plot(log_contour=True, ax=ax[6], line_alpha=0.2)
ax[6].set_title("HammersleySearch")
func = FunctionDescent('rosen')
optimizer = NevergradWrapper(func.parameters(), ng.optimizers.DiscreteLenglerOnePlusOne, budget=2001)
func.run(optimizer, max_steps=2000)
func.plot(log_contour=True, ax=ax[7], line_alpha=0.2)
ax[7].set_title("DiscreteLenglerOnePlusOne")
func = FunctionDescent('rosen')
optimizer = NevergradWrapper(func.parameters(), ng.optimizers.Portfolio, budget=2001)
func.run(optimizer, max_steps=2000)
func.plot(log_contour=True, ax=ax[8], line_alpha=0.2)
ax[8].set_title("Portfolio")
plt.show()
finished in 2.0s., reached loss = 0 finished in 2.6s., reached loss = 3.189 finished in 19.9s., reached loss = 0.0562 finished in 10.8s., reached loss = 0.000435 f175 p175 b174/2000 e174; train loss = inf
/var/mnt/issd/dev/miniconda3/envs/pytorch312/lib/python3.12/site-packages/nevergrad/optimization/base.py:149: LossTooLargeWarning: Clipping very high value inf in tell (rescale the cost function?). warnings.warn(msg, e)
finished in 2.4s., reached loss = 0.00188 finished in 3.2s., reached loss = 3.689 finished in 1.8s., reached loss = 0.0398 finished in 2.2s., reached loss = 0.0335 finished in 8.5s., reached loss = 7.61e-05