Skip to content

Wrappers

This subpackage contains wrappers for other libraries.

Classes:

  • Wrap

    Wraps a pytorch optimizer to use it as a module.

Wrap

Bases: torchzero.core.module.Module

Wraps a pytorch optimizer to use it as a module.

Note

Custom param groups are supported only by set_param_groups, settings passed to Optimizer will be applied to all parameters.

Parameters:

  • opt_fn (Callable[..., Optimizer] | Optimizer) –

    function that takes in parameters and returns the optimizer, for example torch.optim.Adam or lambda parameters: torch.optim.Adam(parameters, lr=1e-3)

  • *args
  • **kwargs

    Extra args to be passed to opt_fn. The function is called as opt_fn(parameters, *args, **kwargs).

  • use_param_groups (bool, default: True ) –

    Whether to pass settings passed to Optimizer to the wrapped optimizer.

    Note that settings to the first parameter are used for all parameters, so if you specified per-parameter settings, they will be ignored.

Example:

wrapping pytorch_optimizer.StableAdamW

from pytorch_optimizer import StableAdamW
opt = tz.Optimizer(
    model.parameters(),
    tz.m.Wrap(StableAdamW, lr=1),
    tz.m.Cautious(),
    tz.m.LR(1e-2)
)
Source code in torchzero/modules/wrappers/optim_wrapper.py
class Wrap(Module):
    """
    Wraps a pytorch optimizer to use it as a module.

    Note:
        Custom param groups are supported only by ``set_param_groups``, settings passed to Optimizer will be applied to all parameters.

    Args:
        opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
            function that takes in parameters and returns the optimizer, for example ``torch.optim.Adam``
            or ``lambda parameters: torch.optim.Adam(parameters, lr=1e-3)``
        *args:
        **kwargs:
            Extra args to be passed to opt_fn. The function is called as ``opt_fn(parameters, *args, **kwargs)``.
        use_param_groups:
            Whether to pass settings passed to Optimizer to the wrapped optimizer.

            Note that settings to the first parameter are used for all parameters,
            so if you specified per-parameter settings, they will be ignored.

    ### Example:
    wrapping pytorch_optimizer.StableAdamW

    ```python

    from pytorch_optimizer import StableAdamW
    opt = tz.Optimizer(
        model.parameters(),
        tz.m.Wrap(StableAdamW, lr=1),
        tz.m.Cautious(),
        tz.m.LR(1e-2)
    )
    ```

    """

    def __init__(
        self,
        opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer,
        *args,
        use_param_groups: bool = True,
        **kwargs,
    ):
        defaults = dict(use_param_groups=use_param_groups)
        super().__init__(defaults=defaults)

        self._opt_fn = opt_fn
        self._opt_args = args
        self._opt_kwargs = kwargs
        self._custom_param_groups = None

        self.optimizer: torch.optim.Optimizer | None = None
        if isinstance(self._opt_fn, torch.optim.Optimizer) or not callable(self._opt_fn):
            self.optimizer = self._opt_fn

    def set_param_groups(self, param_groups):
        self._custom_param_groups = _make_param_groups(param_groups, differentiable=False)
        return super().set_param_groups(param_groups)

    @torch.no_grad
    def apply(self, objective):
        params = objective.params

        # initialize opt on 1st step
        if self.optimizer is None:
            assert callable(self._opt_fn)
            param_groups = params if self._custom_param_groups is None else self._custom_param_groups
            self.optimizer = self._opt_fn(param_groups, *self._opt_args, **self._opt_kwargs)

        # set optimizer per-parameter settings
        if self.defaults["use_param_groups"] and objective.modular is not None:
            for group in self.optimizer.param_groups:
                first_param = group['params'][0]
                setting = self.settings[first_param]

                # settings passed in `set_param_groups` are the highest priority
                # schedulers will override defaults but not settings passed in `set_param_groups`
                # this is consistent with how Optimizer does it.
                if self._custom_param_groups is not None:
                    setting = {k:v for k,v in setting if k not in self._custom_param_groups[0]}

                group.update(setting)

        # set grad to update
        orig_grad = [p.grad for p in params]
        for p, u in zip(params, objective.get_updates()):
            p.grad = u

        # if this is last module, simply use optimizer to update parameters
        if objective.modular is not None and self is objective.modular.modules[-1]:
            self.optimizer.step()

            # restore grad
            for p, g in zip(params, orig_grad):
                p.grad = g

            objective.stop = True; objective.skip_update = True
            return objective

        # this is not the last module, meaning update is difference in parameters
        # and passed to next module
        params_before_step = [p.clone() for p in params]
        self.optimizer.step() # step and update params
        for p, g in zip(params, orig_grad):
            p.grad = g
        objective.updates = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
        for p, o in zip(params, params_before_step):
            p.set_(o) # pyright: ignore[reportArgumentType]

        return objective

    def reset(self):
        super().reset()
        assert self.optimizer is not None
        for g in self.optimizer.param_groups:
            for p in g['params']:
                state = self.optimizer.state[p]
                state.clear()