Skip to content

API reference for core module

Modules:

  • chain
  • functional
  • modular
  • module
  • reformulation
  • transform

Classes:

  • Chain

    Chain modules, mostly used internally

  • Module

    Abstract base class for an optimizer modules.

  • Optimizer

    Chains multiple modules into an optimizer.

  • TensorTransform

    TensorTransform is a Transform that doesn't use Objective, instead it operates

  • Transform

    Transform is a Module with only optional children.

Functions:

  • maybe_chain

    Returns a single module directly if only one is provided, otherwise wraps them in a Chain.

  • step

    doesn't apply hooks!

Attributes:

  • Chainable

    Represent a PEP 604 union type

Chainable module-attribute

Chainable = torchzero.core.module.Module | collections.abc.Sequence[torchzero.core.module.Module]

Represent a PEP 604 union type

E.g. for int | str

Chain

Bases: torchzero.core.module.Module

Chain modules, mostly used internally

Source code in torchzero/core/chain.py
class Chain(Module):
    """Chain modules, mostly used internally"""
    def __init__(self, *modules: Module | Iterable[Module]):
        super().__init__()
        flat_modules: list[Module] = flatten(modules)
        for i, module in enumerate(flat_modules):
            self.set_child(f'module_{i}', module)

    def update(self, objective):
        if len(self.children) > 1:
            raise RuntimeError("can't call `update` on Chain with more than one child, as `update` and `apply` have to be called sequentially. Use the `step` method instead of update-apply.")

        if len(self.children) == 0: return
        return self.children['module_0'].update(objective)

    def apply(self, objective):
        if len(self.children) > 1:
            raise RuntimeError("can't call `update` on Chain with more than one child, as `update` and `apply` have to be called sequentially. Use the `step` method instead of update-apply.")

        if len(self.children) == 0: return objective
        return self.children['module_0'].apply(objective)

    def step(self, objective):
        children = [self.children[f'module_{i}'] for i in range(len(self.children))]
        return _chain_step(objective, children)

    def __repr__(self):
        s = self.__class__.__name__
        if self.children:
            if s == 'Chain': s = 'C' # to shorten it
            s = f'{s}({", ".join(str(m) for m in self.children.values())})'
        return s

Module

Bases: abc.ABC

Abstract base class for an optimizer modules.

Modules represent distinct steps or transformations within the optimization process (e.g., momentum, line search, gradient accumulation).

A module does not store parameters, but it maintains per-parameter state and per-parameter settings where tensors are used as keys (same as torch.optim.Optimizer state.)

Parameters:

  • defaults (dict[str, Any] | None, default: None ) –

    a dict containing default values of optimization options (used when a parameter group doesn't specify them).

Methods:

  • apply

    Updates objective using the internal state of this module.

  • get_H

    returns a LinearOperator corresponding to hessian or hessian approximation.

  • get_generator

    If seed=None, returns None.

  • get_state

    Returns values of per-parameter state for a given key.

  • increment_counter

    first value is start

  • inner_step

    Passes objective to child and returns it.

  • inner_step_tensors

    Steps with child module. Can be used to apply transforms to any internal buffers.

  • reset

    Resets the internal state of the module (e.g. momentum) and all children. By default clears state and global state.

  • reset_for_online

    Resets buffers that depend on previous evaluation, such as previous gradient and loss,

  • set_param_groups

    Set custom parameter groups with per-parameter settings that this module will use.

  • state_dict

    state dict

  • step

    Perform a step with this module. Calls update, then apply.

  • update

    Updates internal state of this module. This should not modify objective.update.

Source code in torchzero/core/module.py
class Module(ABC):
    """Abstract base class for an optimizer modules.

    Modules represent distinct steps or transformations within the optimization
    process (e.g., momentum, line search, gradient accumulation).

    A module does not store parameters, but it maintains per-parameter state and per-parameter settings
    where tensors are used as keys (same as torch.optim.Optimizer state.)

    Args:
        defaults (dict[str, Any] | None):
            a dict containing default values of optimization options (used when a parameter group doesn't specify them).
"""
    def __init__(self, defaults: dict[str, Any] | None = None):
        if defaults is None: defaults = {}
        if any(isinstance(v, Module) for v in defaults.values()): raise RuntimeError("Passed a module to defaults")
        self.defaults: dict[str, Any] = defaults

        # settings are stored like state in per-tensor defaultdict, with per-parameter overrides possible
        # 0 - this module specific per-parameter setting overrides set via `set_param_groups` - highest priority
        # 1 - global per-parameter setting overrides in param_groups passed to Optimizer - medium priority
        # 2 - `defaults` - lowest priority
        self.settings: defaultdict[torch.Tensor, ChainMap[str, Any]] = defaultdict(lambda: ChainMap({}, {}, self.defaults))
        """per-parameter settings."""

        self.state: defaultdict[torch.Tensor, dict[str, Any]] = defaultdict(dict)
        """Per-parameter state (e.g., momentum buffers)."""

        self.global_state: dict[str, Any] = {}
        """Global state for things that are not per-parameter."""

        self.children: dict[str, Module] = {}
        """A dictionary of child modules."""

        self._overridden_keys = set()
        """tracks keys overridden with ``set_param_groups``, only used to not give a warning"""


    def set_param_groups(self, param_groups: Params):
        """Set custom parameter groups with per-parameter settings that this module will use."""
        param_groups = _make_param_groups(param_groups, differentiable=False)
        for group in param_groups:
            settings = group.copy()
            params = settings.pop('params')
            if not settings: continue
            self._overridden_keys.update(*settings.keys())

            for param in params:
                self.settings[param].maps[0].update(settings) # set module-specific per-parameter settings
        return self

    def set_child(self, key: str, module: "Module | Sequence[Module] | None"):
        if key in self.children:
            warnings.warn(f"set_child overwriting child `{key}`")

        if module is None: return

        from .chain import maybe_chain
        self.children[key] = maybe_chain(module)

    def set_children_sequence(self, modules: "Iterable[Module | Sequence[Module]]", prefix = 'module_'):
        from .chain import maybe_chain

        modules = list(modules)
        for i, m in enumerate(modules):
            self.set_child(f'{prefix}{i}', maybe_chain(m))

    def get_children_sequence(self, prefix = 'module_'):
        return [self.children[f'{prefix}{i}'] for i in range(len(self.children)) if f'{prefix}{i}' in self.children]

    def inner_step(
        self,
        key: str,
        objective: "Objective",
        must_exist: bool = True,
    ) -> "Objective":
        """Passes ``objective`` to child and returns it."""
        child = self.children.get(key, None)

        if child is None:
            if must_exist: raise KeyError(f"child `{key}` doesn't exist")
            return objective

        return child.step(objective)


    def inner_step_tensors(
        self,
        key: str,
        tensors: list[torch.Tensor],
        clone: bool,
        params: Iterable[torch.Tensor] | None = None,
        grads: Sequence[torch.Tensor] | None = None,
        loss: torch.Tensor | None = None,
        closure: Callable | None = None,
        objective: "Objective | None" = None,
        must_exist: bool = True
    ) -> list[torch.Tensor]:
        """Steps with child module. Can be used to apply transforms to any internal buffers.

        If ``objective`` is specified, other attributes shouldn't to be specified.

        Args:
            key (str): Child module key.
            tensors (Sequence[torch.Tensor]): tensors to pass to child module.
            clone (bool):
                If ``key`` exists, whether to clone ``tensors`` to avoid modifying buffers in-place.
                If ``key`` doesn't exist, ``tensors`` are always returned without cloning
            params (Iterable[torch.Tensor] | None, optional): pass None if ``tensors`` have different shape. Defaults to None.
            grads (Sequence[torch.Tensor] | None, optional): grads. Defaults to None.
            loss (torch.Tensor | None, optional): loss. Defaults to None.
            closure (Callable | None, optional): closure. Defaults to None.
            must_exist (bool, optional): if True, if ``key`` doesn't exist, raises ``KeyError``. Defaults to True.
        """

        child = self.children.get(key, None)

        if child is None:
            if must_exist: raise KeyError(f"child `{key}` doesn't exist")
            return tensors

        if clone: tensors = [t.clone() for t in tensors]
        return step_tensors(modules=child, tensors=tensors, params=params, grads=grads,
                            loss=loss, closure=closure, objective=objective)


    def __repr__(self):
        s = self.__class__.__name__
        if self.children:
            s = f'{s}('
            for k,v in self.children.items():
                s = f'{s}{k}={v}, '
            s = f'{s[:-2]})'
        return s

    @overload
    def get_settings(self, params: Sequence[torch.Tensor], key: str, *,
                     cls: type[ListLike] = list) -> ListLike: ...
    @overload
    def get_settings(self, params: Sequence[torch.Tensor], key: list[str] | tuple[str,...], *,
                     cls: type[ListLike] = list) -> list[ListLike]: ...
    @overload
    def get_settings(self, params: Sequence[torch.Tensor], key: str, key2: str, *keys: str,
                     cls: type[ListLike] = list) -> list[ListLike]: ...

    def get_settings(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None,
                     *keys: str, cls: type[ListLike] = list) -> ListLike | list[ListLike]:
        return get_state_vals(self.settings, params, key, key2, *keys, must_exist=True, cls=cls) # pyright:ignore[reportArgumentType]


    @overload
    def get_state(self, params: Sequence[torch.Tensor], key: str, *,
                   must_exist: bool = False, init: Init = torch.zeros_like,
                   cls: type[ListLike] = list) -> ListLike: ...
    @overload
    def get_state(self, params: Sequence[torch.Tensor], key: list[str] | tuple[str,...], *,
                   must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
                   cls: type[ListLike] = list) -> list[ListLike]: ...
    @overload
    def get_state(self, params: Sequence[torch.Tensor], key: str, key2: str, *keys: str,
                   must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
                   cls: type[ListLike] = list) -> list[ListLike]: ...

    def get_state(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
                   must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
                   cls: type[ListLike] = list) -> ListLike | list[ListLike]:
        """Returns values of per-parameter state for a given key.
        If key doesn't exist, create it with inits.

        This functions like `operator.itemgetter`, returning a single value if called with a single key,
        or tuple of called with multiple keys.

        If you want to force it to return a tuple even with a single key, pass a list/tuple of 1 or more keys.

        ```python
        exp_avg = self.state_vals("exp_avg")
        # returns cls (by default TensorList)

        exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
        # returns list of cls

        exp_avg = self.state_vals(["exp_avg"])
        # always returns a list of cls, even if got a single key
        ```

        Args:
            *keys (str):
                the keys to look for in each parameters state.
                if a single key is specified, this returns a single value or cls,
                otherwise this returns a list of values or cls per each key.
            params (Iterable[torch.Tensor]): parameters to return the states for.
            must_exist (bool, optional):
                If a key doesn't exist in state, if True, raises a KeyError, if False, creates the value
                using `init` argument (default = False).
            init (Init | Sequence[Init], optional):
                how to initialize a key if it doesn't exist.

                can be
                - Callable like torch.zeros_like
                - string - "param" or "grad" to use cloned params or cloned grads.
                - anything else other than list/tuples will be used as-is, tensors will be cloned.
                - list/tuple of values per each parameter, only if got a single key.
                - list/tuple of values per each key, only if got multiple keys.

                if multiple `keys` are specified, inits is per-key!

                Defaults to torch.zeros_like.
            cls (type[ListLike], optional):
                MutableSequence class to return, this only has effect when state_keys is a list/tuple. Defaults to list.

        Returns:
            - if state_keys has a single key and keys has a single key, return a single value.
            - if state_keys has a single key and keys has multiple keys, return a list of values.
            - if state_keys has multiple keys and keys has a single key, return cls.
            - if state_keys has multiple keys and keys has multiple keys, return list of cls.
        """
        return get_state_vals(self.state, params, key, key2, *keys, must_exist=must_exist, init=init, cls=cls) # pyright:ignore[reportArgumentType]

    def clear_state_keys(self, *keys:str):
        for s in self.state.values():
            for k in keys:
                if k in s: del s[k]

    @overload
    def store(self, params: Sequence[torch.Tensor], keys: str, values: Sequence): ...
    @overload
    def store(self, params: Sequence[torch.Tensor], keys: Sequence[str], values: Sequence[Sequence]): ...
    def store(self, params: Sequence[torch.Tensor], keys: str | Sequence[str], values: Sequence):
        if isinstance(keys, str):
            for p,v in zip(params, values):
                state = self.state[p]
                state[keys] = v
            return

        for p, *p_v in zip(params, *values):
            state = self.state[p]
            for k,v in zip(keys, p_v): state[k] = v

    def state_dict(self):
        """state dict"""
        packed_state = {id(k):v for k,v in self.state.items()}
        packed_settings = {id(k):v for k,v in self.settings.items()}

        state_dict = {
            "state": packed_state,
            "settings":
                {
                    "local": {k:v.maps[0] for k,v in packed_settings.items()},
                    "global": {k:v.maps[1] for k,v in packed_settings.items()},
                    "defaults": {k:v.maps[2] for k,v in packed_settings.items()},
                },
            "global_state": self.global_state,
            "extra": self._extra_pack(),
            "children": {k: v.state_dict() for k, v in self.children.items()}
        }
        return state_dict

    def _load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
        """loads state_dict, ``id_to_tensor`` is passed by ``Optimizer``"""
        # load state
        state = state_dict['state']
        self.state.clear()
        self.state.update({id_to_tensor[k]:v for k,v in state.items()})

        # load settings
        settings = state_dict['settings']
        self.settings.clear()
        for k, v in settings['local'].items(): self.settings[id_to_tensor[k]].maps[0].update(v)
        for k, v in settings['global'].items(): self.settings[id_to_tensor[k]].maps[1].update(v)
        for k, v in settings['defaults'].items(): self.settings[id_to_tensor[k]].maps[2].update(v)

        # load global state
        self.global_state.clear()
        self.global_state.update(state_dict['global_state'])

        # children
        for k, v in state_dict['children']:
            if k in self.children: self.children[k]._load_state_dict(v, id_to_tensor)
            else: warnings.warn(f'State dict for {self} has child {k}, which is missing in {self}')

        # extra info
        self._extra_unpack(state_dict['extra'])

    def get_generator(self, device: torch.types.Device, seed: int | None):
        """If ``seed=None``, returns ``None``.

        Otherwise, if generator on this device and with this seed hasn't been created,
        creates it and stores in global state.

        Returns ``torch.Generator``."""
        if seed is None: return None

        if device is None: device_obj = torch.get_default_device()
        else: device_obj = torch.device(device)
        key = f"__generator-{seed}-{device_obj.type}:{device_obj.index}"

        if key not in self.global_state:
            self.global_state[key] = torch.Generator(device).manual_seed(seed)

        return self.global_state[key]

    def increment_counter(self, key: str, start: int):
        """first value is ``start``"""
        value = self.global_state.get(key, start - 1) + 1
        self.global_state[key] = value
        return value

    # ---------------------------- OVERRIDABLE METHODS --------------------------- #
    def update(self, objective:"Objective") -> None:
        """Updates internal state of this module. This should not modify ``objective.update``.

        Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
        such as ``tz.m.Online`` or trust regions. Alternatively, define all logic within the ``apply`` method.

        ``update`` is guaranteed to be called at least once before ``apply``.

        Args:
            objective (Objective): ``Objective`` object
        """

    @abstractmethod
    def apply(self, objective: "Objective") -> "Objective":
        """Updates ``objective`` using the internal state of this module.

        If ``update`` method is defined, ``apply`` shouldn't modify the internal state of this module if possible.

        Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
        such as ``tz.m.Online`` or trust regions. Alternatively, define all logic within the ``apply`` method.

        ``update`` is guaranteed to be called at least once before ``apply``.

        Args:
            objective (Objective): ``Objective`` object
        """
        # if apply is empty, it should be defined explicitly.
        raise NotImplementedError(f"{self.__class__.__name__} doesn't implement `apply`.")

    def step(self, objective: "Objective") -> "Objective":
        """Perform a step with this module. Calls ``update``, then ``apply``."""
        self.update(objective)
        return self.apply(objective)

    def get_H(self, objective: "Objective") -> LinearOperator | None:
        """returns a ``LinearOperator`` corresponding to hessian or hessian approximation.
        The hessian approximation is assumed to be for all parameters concatenated to a vector."""
        # if this method is not defined it searches in children
        # this should be overwritten to return None if child params are different from this modules params
        H = None
        for k,v in self.children.items():
            H_v = v.get_H(objective)

            if (H is not None) and (H_v is not None):
                raise RuntimeError(f"Two children of {self} have a hessian, second one is {k}={v}")

            if H_v is not None: H = H_v

        return H

    def reset(self):
        """Resets the internal state of the module (e.g. momentum) and all children. By default clears state and global state."""
        self.state.clear()

        generator = self.global_state.get("generator", None)
        self.global_state.clear()
        if generator is not None: self.global_state["generator"] = generator

        for c in self.children.values(): c.reset()

    def reset_for_online(self):
        """Resets buffers that depend on previous evaluation, such as previous gradient and loss,
        which may become inaccurate due to mini-batching.

        ``Online`` module calls ``reset_for_online``,
        then it calls ``update`` with previous parameters,
        then it calls ``update`` with current parameters,
        and then ``apply``.
        """
        for c in self.children.values(): c.reset_for_online()

    def _extra_pack(self) -> dict:
        """extra information to store in ``state_dict`` of this optimizer.
        Will be passed to ``_extra_unpack`` when loading the ``state_dict``."""
        return {}

    def _extra_unpack(self, d: dict):
        """``_extra_pack`` return will be passed to this method when loading ``state_dict``.
        This method is called after loading the rest of the state dict"""

apply

apply(objective: Objective) -> Objective

Updates objective using the internal state of this module.

If update method is defined, apply shouldn't modify the internal state of this module if possible.

Specifying update and apply methods is optional and allows certain meta-modules to be used, such as tz.m.Online or trust regions. Alternatively, define all logic within the apply method.

update is guaranteed to be called at least once before apply.

Parameters:

  • objective (Objective) –

    Objective object

Source code in torchzero/core/module.py
@abstractmethod
def apply(self, objective: "Objective") -> "Objective":
    """Updates ``objective`` using the internal state of this module.

    If ``update`` method is defined, ``apply`` shouldn't modify the internal state of this module if possible.

    Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
    such as ``tz.m.Online`` or trust regions. Alternatively, define all logic within the ``apply`` method.

    ``update`` is guaranteed to be called at least once before ``apply``.

    Args:
        objective (Objective): ``Objective`` object
    """
    # if apply is empty, it should be defined explicitly.
    raise NotImplementedError(f"{self.__class__.__name__} doesn't implement `apply`.")

get_H

get_H(objective: Objective) -> LinearOperator | None

returns a LinearOperator corresponding to hessian or hessian approximation. The hessian approximation is assumed to be for all parameters concatenated to a vector.

Source code in torchzero/core/module.py
def get_H(self, objective: "Objective") -> LinearOperator | None:
    """returns a ``LinearOperator`` corresponding to hessian or hessian approximation.
    The hessian approximation is assumed to be for all parameters concatenated to a vector."""
    # if this method is not defined it searches in children
    # this should be overwritten to return None if child params are different from this modules params
    H = None
    for k,v in self.children.items():
        H_v = v.get_H(objective)

        if (H is not None) and (H_v is not None):
            raise RuntimeError(f"Two children of {self} have a hessian, second one is {k}={v}")

        if H_v is not None: H = H_v

    return H

get_generator

get_generator(device: Union[device, str, int, NoneType], seed: int | None)

If seed=None, returns None.

Otherwise, if generator on this device and with this seed hasn't been created, creates it and stores in global state.

Returns torch.Generator.

Source code in torchzero/core/module.py
def get_generator(self, device: torch.types.Device, seed: int | None):
    """If ``seed=None``, returns ``None``.

    Otherwise, if generator on this device and with this seed hasn't been created,
    creates it and stores in global state.

    Returns ``torch.Generator``."""
    if seed is None: return None

    if device is None: device_obj = torch.get_default_device()
    else: device_obj = torch.device(device)
    key = f"__generator-{seed}-{device_obj.type}:{device_obj.index}"

    if key not in self.global_state:
        self.global_state[key] = torch.Generator(device).manual_seed(seed)

    return self.global_state[key]

get_state

get_state(params: Sequence[Tensor], key: str | list[str] | tuple[str, ...], key2: str | None = None, *keys: str, must_exist: bool = False, init: Any | Sequence[Any] = zeros_like, cls: type[~ListLike] = list) -> Union[~ListLike, list[~ListLike]]

Returns values of per-parameter state for a given key. If key doesn't exist, create it with inits.

This functions like operator.itemgetter, returning a single value if called with a single key, or tuple of called with multiple keys.

If you want to force it to return a tuple even with a single key, pass a list/tuple of 1 or more keys.

exp_avg = self.state_vals("exp_avg")
# returns cls (by default TensorList)

exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
# returns list of cls

exp_avg = self.state_vals(["exp_avg"])
# always returns a list of cls, even if got a single key

Parameters:

  • *keys (str) –

    the keys to look for in each parameters state. if a single key is specified, this returns a single value or cls, otherwise this returns a list of values or cls per each key.

  • params (Iterable[Tensor]) –

    parameters to return the states for.

  • must_exist (bool, default: False ) –

    If a key doesn't exist in state, if True, raises a KeyError, if False, creates the value using init argument (default = False).

  • init (Any | Sequence[Any], default: zeros_like ) –

    how to initialize a key if it doesn't exist.

    can be - Callable like torch.zeros_like - string - "param" or "grad" to use cloned params or cloned grads. - anything else other than list/tuples will be used as-is, tensors will be cloned. - list/tuple of values per each parameter, only if got a single key. - list/tuple of values per each key, only if got multiple keys.

    if multiple keys are specified, inits is per-key!

    Defaults to torch.zeros_like.

  • cls (type[ListLike], default: list ) –

    MutableSequence class to return, this only has effect when state_keys is a list/tuple. Defaults to list.

Returns:

  • Union[~ListLike, list[~ListLike]]
    • if state_keys has a single key and keys has a single key, return a single value.
  • Union[~ListLike, list[~ListLike]]
    • if state_keys has a single key and keys has multiple keys, return a list of values.
  • Union[~ListLike, list[~ListLike]]
    • if state_keys has multiple keys and keys has a single key, return cls.
  • Union[~ListLike, list[~ListLike]]
    • if state_keys has multiple keys and keys has multiple keys, return list of cls.
Source code in torchzero/core/module.py
def get_state(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
               must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
               cls: type[ListLike] = list) -> ListLike | list[ListLike]:
    """Returns values of per-parameter state for a given key.
    If key doesn't exist, create it with inits.

    This functions like `operator.itemgetter`, returning a single value if called with a single key,
    or tuple of called with multiple keys.

    If you want to force it to return a tuple even with a single key, pass a list/tuple of 1 or more keys.

    ```python
    exp_avg = self.state_vals("exp_avg")
    # returns cls (by default TensorList)

    exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
    # returns list of cls

    exp_avg = self.state_vals(["exp_avg"])
    # always returns a list of cls, even if got a single key
    ```

    Args:
        *keys (str):
            the keys to look for in each parameters state.
            if a single key is specified, this returns a single value or cls,
            otherwise this returns a list of values or cls per each key.
        params (Iterable[torch.Tensor]): parameters to return the states for.
        must_exist (bool, optional):
            If a key doesn't exist in state, if True, raises a KeyError, if False, creates the value
            using `init` argument (default = False).
        init (Init | Sequence[Init], optional):
            how to initialize a key if it doesn't exist.

            can be
            - Callable like torch.zeros_like
            - string - "param" or "grad" to use cloned params or cloned grads.
            - anything else other than list/tuples will be used as-is, tensors will be cloned.
            - list/tuple of values per each parameter, only if got a single key.
            - list/tuple of values per each key, only if got multiple keys.

            if multiple `keys` are specified, inits is per-key!

            Defaults to torch.zeros_like.
        cls (type[ListLike], optional):
            MutableSequence class to return, this only has effect when state_keys is a list/tuple. Defaults to list.

    Returns:
        - if state_keys has a single key and keys has a single key, return a single value.
        - if state_keys has a single key and keys has multiple keys, return a list of values.
        - if state_keys has multiple keys and keys has a single key, return cls.
        - if state_keys has multiple keys and keys has multiple keys, return list of cls.
    """
    return get_state_vals(self.state, params, key, key2, *keys, must_exist=must_exist, init=init, cls=cls) # pyright:ignore[reportArgumentType]

increment_counter

increment_counter(key: str, start: int)

first value is start

Source code in torchzero/core/module.py
def increment_counter(self, key: str, start: int):
    """first value is ``start``"""
    value = self.global_state.get(key, start - 1) + 1
    self.global_state[key] = value
    return value

inner_step

inner_step(key: str, objective: Objective, must_exist: bool = True) -> Objective

Passes objective to child and returns it.

Source code in torchzero/core/module.py
def inner_step(
    self,
    key: str,
    objective: "Objective",
    must_exist: bool = True,
) -> "Objective":
    """Passes ``objective`` to child and returns it."""
    child = self.children.get(key, None)

    if child is None:
        if must_exist: raise KeyError(f"child `{key}` doesn't exist")
        return objective

    return child.step(objective)

inner_step_tensors

inner_step_tensors(key: str, tensors: list[Tensor], clone: bool, params: Iterable[Tensor] | None = None, grads: Sequence[Tensor] | None = None, loss: Tensor | None = None, closure: Callable | None = None, objective: Objective | None = None, must_exist: bool = True) -> list[Tensor]

Steps with child module. Can be used to apply transforms to any internal buffers.

If objective is specified, other attributes shouldn't to be specified.

Parameters:

  • key (str) –

    Child module key.

  • tensors (Sequence[Tensor]) –

    tensors to pass to child module.

  • clone (bool) –

    If key exists, whether to clone tensors to avoid modifying buffers in-place. If key doesn't exist, tensors are always returned without cloning

  • params (Iterable[Tensor] | None, default: None ) –

    pass None if tensors have different shape. Defaults to None.

  • grads (Sequence[Tensor] | None, default: None ) –

    grads. Defaults to None.

  • loss (Tensor | None, default: None ) –

    loss. Defaults to None.

  • closure (Callable | None, default: None ) –

    closure. Defaults to None.

  • must_exist (bool, default: True ) –

    if True, if key doesn't exist, raises KeyError. Defaults to True.

Source code in torchzero/core/module.py
def inner_step_tensors(
    self,
    key: str,
    tensors: list[torch.Tensor],
    clone: bool,
    params: Iterable[torch.Tensor] | None = None,
    grads: Sequence[torch.Tensor] | None = None,
    loss: torch.Tensor | None = None,
    closure: Callable | None = None,
    objective: "Objective | None" = None,
    must_exist: bool = True
) -> list[torch.Tensor]:
    """Steps with child module. Can be used to apply transforms to any internal buffers.

    If ``objective`` is specified, other attributes shouldn't to be specified.

    Args:
        key (str): Child module key.
        tensors (Sequence[torch.Tensor]): tensors to pass to child module.
        clone (bool):
            If ``key`` exists, whether to clone ``tensors`` to avoid modifying buffers in-place.
            If ``key`` doesn't exist, ``tensors`` are always returned without cloning
        params (Iterable[torch.Tensor] | None, optional): pass None if ``tensors`` have different shape. Defaults to None.
        grads (Sequence[torch.Tensor] | None, optional): grads. Defaults to None.
        loss (torch.Tensor | None, optional): loss. Defaults to None.
        closure (Callable | None, optional): closure. Defaults to None.
        must_exist (bool, optional): if True, if ``key`` doesn't exist, raises ``KeyError``. Defaults to True.
    """

    child = self.children.get(key, None)

    if child is None:
        if must_exist: raise KeyError(f"child `{key}` doesn't exist")
        return tensors

    if clone: tensors = [t.clone() for t in tensors]
    return step_tensors(modules=child, tensors=tensors, params=params, grads=grads,
                        loss=loss, closure=closure, objective=objective)

reset

reset()

Resets the internal state of the module (e.g. momentum) and all children. By default clears state and global state.

Source code in torchzero/core/module.py
def reset(self):
    """Resets the internal state of the module (e.g. momentum) and all children. By default clears state and global state."""
    self.state.clear()

    generator = self.global_state.get("generator", None)
    self.global_state.clear()
    if generator is not None: self.global_state["generator"] = generator

    for c in self.children.values(): c.reset()

reset_for_online

reset_for_online()

Resets buffers that depend on previous evaluation, such as previous gradient and loss, which may become inaccurate due to mini-batching.

Online module calls reset_for_online, then it calls update with previous parameters, then it calls update with current parameters, and then apply.

Source code in torchzero/core/module.py
def reset_for_online(self):
    """Resets buffers that depend on previous evaluation, such as previous gradient and loss,
    which may become inaccurate due to mini-batching.

    ``Online`` module calls ``reset_for_online``,
    then it calls ``update`` with previous parameters,
    then it calls ``update`` with current parameters,
    and then ``apply``.
    """
    for c in self.children.values(): c.reset_for_online()

set_param_groups

set_param_groups(param_groups: Iterable[Tensor | tuple[str, Tensor] | Mapping[str, Any]])

Set custom parameter groups with per-parameter settings that this module will use.

Source code in torchzero/core/module.py
def set_param_groups(self, param_groups: Params):
    """Set custom parameter groups with per-parameter settings that this module will use."""
    param_groups = _make_param_groups(param_groups, differentiable=False)
    for group in param_groups:
        settings = group.copy()
        params = settings.pop('params')
        if not settings: continue
        self._overridden_keys.update(*settings.keys())

        for param in params:
            self.settings[param].maps[0].update(settings) # set module-specific per-parameter settings
    return self

state_dict

state_dict()

state dict

Source code in torchzero/core/module.py
def state_dict(self):
    """state dict"""
    packed_state = {id(k):v for k,v in self.state.items()}
    packed_settings = {id(k):v for k,v in self.settings.items()}

    state_dict = {
        "state": packed_state,
        "settings":
            {
                "local": {k:v.maps[0] for k,v in packed_settings.items()},
                "global": {k:v.maps[1] for k,v in packed_settings.items()},
                "defaults": {k:v.maps[2] for k,v in packed_settings.items()},
            },
        "global_state": self.global_state,
        "extra": self._extra_pack(),
        "children": {k: v.state_dict() for k, v in self.children.items()}
    }
    return state_dict

step

step(objective: Objective) -> Objective

Perform a step with this module. Calls update, then apply.

Source code in torchzero/core/module.py
def step(self, objective: "Objective") -> "Objective":
    """Perform a step with this module. Calls ``update``, then ``apply``."""
    self.update(objective)
    return self.apply(objective)

update

update(objective: Objective) -> None

Updates internal state of this module. This should not modify objective.update.

Specifying update and apply methods is optional and allows certain meta-modules to be used, such as tz.m.Online or trust regions. Alternatively, define all logic within the apply method.

update is guaranteed to be called at least once before apply.

Parameters:

  • objective (Objective) –

    Objective object

Source code in torchzero/core/module.py
def update(self, objective:"Objective") -> None:
    """Updates internal state of this module. This should not modify ``objective.update``.

    Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
    such as ``tz.m.Online`` or trust regions. Alternatively, define all logic within the ``apply`` method.

    ``update`` is guaranteed to be called at least once before ``apply``.

    Args:
        objective (Objective): ``Objective`` object
    """

Optimizer

Bases: torch.optim.optimizer.Optimizer

Chains multiple modules into an optimizer.

Parameters:

  • params (Iterable | Module) –

    An iterable of parameters to optimize (typically model.parameters()), an iterable of parameter group dicts, or a torch.nn.Module instance.

  • *modules (Module) –

    A sequence of Module instances that define the optimization algorithm steps.

Source code in torchzero/core/modular.py
class Optimizer(torch.optim.Optimizer):
    """Chains multiple modules into an optimizer.

    Args:
        params (Params | torch.nn.Module): An iterable of parameters to optimize
            (typically `model.parameters()`), an iterable of parameter group dicts,
            or a `torch.nn.Module` instance.
        *modules (Module): A sequence of `Module` instances that define the
            optimization algorithm steps.
    """
    # this is specifically for lr schedulers
    param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]

    def __init__(self, params: Params | torch.nn.Module, *modules: Module):
        if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Optimizer`")
        self.model: torch.nn.Module | None = None
        """The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
        if isinstance(params, torch.nn.Module):
            self.model = params
            params = params.parameters()

        self.modules = modules
        """Top-level modules providedduring initialization."""

        self.flat_modules = flatten_modules(self.modules)
        """A flattened list of all modules including all children."""

        param_groups = _make_param_groups(params, differentiable=False)
        self._per_parameter_global_settings: dict[torch.Tensor, list[MutableMapping[str, Any]]] = {}
        """Maps each parameter tensor to a list of per-module global settings.
        Each element in the list is ChainDict's 2nd map of a module."""

        # make sure there is no more than a single learning rate module
        lr_modules = [m for m in self.flat_modules if 'lr' in m.defaults]
        if len(lr_modules) > 1:
            warnings.warn(f'multiple learning rate modules detected: {lr_modules}. This may lead to componding of learning rate multiplication with per-parameter learning rates and schedulers.')

        # iterate over all per-parameter settings overrides and check if they are applied at most once
        for group in param_groups:
            for k in group:
                if k in ('params', 'lr'): continue
                modules_with_k = [m for m in self.flat_modules if k in m.defaults and k not in m._overridden_keys]
                if len(modules_with_k) > 1:
                    warnings.warn(f'`params` has a `{k}` key, and multiple modules have that key: {modules_with_k}. If you intended to only set `{k}` to one of them, use `module.set_param_groups(params)`')

        # defaults for schedulers
        defaults = {}
        for m in self.flat_modules: defaults.update(m.defaults)
        super().__init__(param_groups, defaults=defaults)

        # note - this is what super().__init__(param_groups, defaults=defaults) does:

        # self.defaults = defaults
        # for param_group in param_groups:
        #     self.add_param_group(param_group)

        # add_param_group adds a ChainMap where defaults are lowest priority,
        # and entries specifed in param_groups or scheduler are higher priority.
        # pytorch schedulers do group["lr"] = new_lr, which sets higher priority key.
        # in each module, settings passed to that module by calling set_param_groups are highest priority

        self.current_step = 0
        """global step counter for the optimizer."""

        self.num_evaluations = 0
        """number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""

        # reformulations will change the closure to return a different loss (e.g. a sqrt homotopy, gaussian homotopy)
        # we want to return original loss so this attribute is used
        self._closure_return = None
        """on each step, first time a closure is evaluated, this attribute is set to the returned value. `step` method returns this."""

        self.attrs = {}
        """custom attributes that can be set by modules, for example EMA of weights or best so far"""

        self.should_terminate = False
        """is set to True by termination criteria modules."""

    def add_param_group(self, param_group: dict[str, Any]):
        proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
        self.param_groups.append(ChainMap(proc_param_group, self.defaults))
        # setting param_group[key] = value sets it to first map (the `proc_param_group`).
        # therefore lr schedulers override defaults, but not settings passed to individual modules
        # by `set_param_groups` .

        for p in proc_param_group['params']:
            # updates global per-parameter setting overrides (medium priority)
            self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.flat_modules]

    def state_dict(self):
        all_params = [p for g in self.param_groups for p in g['params']]
        id_to_idx = {id(p): i for i,p in enumerate(all_params)}

        groups = []
        for g in self.param_groups:
            g = g.copy()
            g['params'] = [id_to_idx[id(p)] for p in g['params']]
            groups.append(g)

        state_dict = {
            "idx_to_id": {v:k for k,v in id_to_idx.items()},
            "params": all_params,
            "groups": groups,
            "defaults": self.defaults,
            "modules": {i: m.state_dict() for i, m in enumerate(self.flat_modules)}
        }
        return state_dict

    def load_state_dict(self, state_dict: dict):
        self.defaults.clear()
        self.defaults.update(state_dict['defaults'])

        idx_to_param = dict(enumerate(state_dict['params']))
        groups = []
        for g in state_dict['groups']:
            g = g.copy()
            g['params'] = [idx_to_param[p] for p in g['params']]
            groups.append(g)

        self.param_groups.clear()
        for group in groups:
            self.add_param_group(group)

        id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
        for m, sd in zip(self.flat_modules, state_dict['modules'].values()):
            m._load_state_dict(sd, id_to_tensor)


    def step(self, closure=None, loss=None, **kwargs): # pyright: ignore[reportIncompatibleMethodOverride]
        # clear closure return from previous step
        self._closure_return = None

        # propagate global per-parameter setting overrides
        for g in self.param_groups:
            settings = dict(g.maps[0]) # ignore defaults
            params = settings.pop('params')
            if not settings: continue

            for p in params:
                if not p.requires_grad: continue
                for map in self._per_parameter_global_settings[p]: map.update(settings)

        # create Objective
        params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]

        counter_closure = None
        if closure is not None:
            counter_closure = _EvalCounterClosure(self, closure)

        objective = Objective(
            params=params, closure=counter_closure, model=self.model,
            current_step=self.current_step, modular=self, loss=loss, storage=kwargs
        )

        # step with all modules
        objective = step(objective, self.modules)

        # apply update to parameters unless `objective.skip_update = True`
        # this does:
        # if not objective.skip_update:
        #   torch._foreach_sub_(objective.params, objective.get_updates())
        objective.update_parameters()

        # update attributes
        self.attrs.update(objective.attrs)
        if objective.should_terminate is not None:
            self.should_terminate = objective.should_terminate

        self.current_step += 1

        # apply hooks
        # this does:
        # for hook in objective.post_step_hooks:
        #     hook(objective, modules)
        objective.apply_post_step_hooks(self.modules)

        # return the first closure evaluation return
        # could return loss if it was passed but that's pointless
        return self._closure_return

    def __repr__(self):
        return f'Optimizer({", ".join(str(m) for m in self.modules)})'

TensorTransform

Bases: torchzero.core.transform.Transform

TensorTransform is a Transform that doesn't use Objective, instead it operates on lists of tensors directly.

This has a concat_params setting which is used in quite a few modules, for example it is optional in all full-matrix method like Quasi-Newton or full-matrix Adagrad.

To use, subclass this and override one of single_tensor_update or multi_tensor_update, and one of single_tensor_apply or multi_tensor_apply.

For copying:

multi tensor:

def multi_tensor_initialize(self, tensors, params, grads, loss, states, settings):
    ...
def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
    ...
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
    ...

single tensor:

def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
    ...
def single_tensor_update(self, tensor, param, grad, loss, state, setting):
    ...
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
    ...

Methods:

Source code in torchzero/core/transform.py
class TensorTransform(Transform):
    """``TensorTransform`` is a ``Transform`` that doesn't use ``Objective``, instead it operates
    on lists of tensors directly.

    This has a ``concat_params`` setting which is used in quite a few modules, for example it is optional
    in all full-matrix method like Quasi-Newton or full-matrix Adagrad.

    To use, subclass this and override one of ``single_tensor_update`` or ``multi_tensor_update``,
    and one of ``single_tensor_apply`` or ``multi_tensor_apply``.

    For copying:

    multi tensor:
    ```
    def multi_tensor_initialize(self, tensors, params, grads, loss, states, settings):
        ...
    def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
        ...
    def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
        ...
    ```

    single tensor:

    ```
    def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
        ...
    def single_tensor_update(self, tensor, param, grad, loss, state, setting):
        ...
    def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
        ...
    ```
    """
    def __init__(
        self,
        defaults: dict[str, Any] | None = None,
        update_freq: int = 1,
        concat_params: bool = False,
        uses_grad: bool = False,
        uses_loss: bool = False,
        inner: "Chainable | None" = None,
    ):
        super().__init__(defaults, update_freq=update_freq, inner=inner)

        self._concat_params = concat_params
        self._uses_grad = uses_grad
        self._uses_loss = uses_loss

    # ------------------------------- single tensor ------------------------------ #
    def single_tensor_initialize(
        self,
        tensor: torch.Tensor,
        param: torch.Tensor,
        grad: torch.Tensor | None,
        loss: torch.Tensor | None,
        state: dict[str, Any],
        setting: Mapping[str, Any],
    ) -> None:
        """initialize ``state`` before first ``update``.
        """

    def single_tensor_update(
        self,
        tensor: torch.Tensor,
        param: torch.Tensor,
        grad: torch.Tensor | None,
        loss: torch.Tensor | None,
        state: dict[str, Any],
        setting: Mapping[str, Any],
    ) -> None:
        """Updates ``state``. This should not modify ``tensor``.
        """

    def single_tensor_apply(
        self,
        tensor: torch.Tensor,
        param: torch.Tensor,
        grad: torch.Tensor | None,
        loss: torch.Tensor | None,
        state: dict[str, Any],
        setting: Mapping[str, Any],
    ) -> torch.Tensor:
        """Updates ``tensor`` and returns it. This shouldn't modify ``state`` if possible.
        """
        raise NotImplementedError(f"{self.__class__.__name__} doesn't implement `single_tensor_apply`.")

    # ------------------------------- multi tensor ------------------------------- #
    def multi_tensor_initialize(
        self,
        tensors: list[torch.Tensor],
        params: list[torch.Tensor],
        grads: list[torch.Tensor] | None,
        loss: torch.Tensor | None,
        states: list[dict[str, Any]],
        settings: Sequence[Mapping[str, Any]],
    ) -> None:
        """initialize ``states`` before first ``update``.
        By default calls ``single_tensor_initialize`` on all tensors.
        """
        if grads is None:
            grads = cast(list, [None] * len(tensors))

        for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
            self.single_tensor_initialize(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)

    def multi_tensor_update(
        self,
        tensors: list[torch.Tensor],
        params: list[torch.Tensor],
        grads: list[torch.Tensor] | None,
        loss: torch.Tensor | None,
        states: list[dict[str, Any]],
        settings: Sequence[Mapping[str, Any]],
    ) -> None:
        """Updates ``states``. This should not modify ``tensor``.
        By default calls ``single_tensor_update`` on all tensors.
        """

        if grads is None:
            grads = cast(list, [None] * len(tensors))

        for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
            self.single_tensor_update(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)

    def multi_tensor_apply(
        self,
        tensors: list[torch.Tensor],
        params: list[torch.Tensor],
        grads: list[torch.Tensor] | None,
        loss: torch.Tensor | None,
        states: list[dict[str, Any]],
        settings: Sequence[Mapping[str, Any]],
    ) -> Sequence[torch.Tensor]:
        """Updates ``tensors`` and returns it. This shouldn't modify ``state`` if possible.
         By default calls ``single_tensor_apply`` on all tensors.
         """

        if grads is None:
            grads = cast(list, [None] * len(tensors))

        ret = []
        for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
            u = self.single_tensor_apply(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)
            ret.append(u)

        return ret

    def _get_grads_loss(self, objective: "Objective"):
        """evaluates grads and loss only if needed"""

        if self._uses_grad: grads = objective.get_grads()
        else: grads = None # better explicitly set to None rather than objective.grads because it shouldn't be used

        if self._uses_loss: loss = objective.get_loss(backward=True)
        else: loss = None

        return grads, loss

    @torch.no_grad
    def _get_cat_updates_params_grads(self, objective: "Objective", grads: list[torch.Tensor] | None):
        assert self._concat_params

        cat_updates = [torch.cat([u.ravel() for u in objective.get_updates()])]
        cat_params = [torch.cat([p.ravel() for p in objective.params])]

        if grads is None: cat_grads = None
        else: cat_grads = [torch.cat([g.ravel() for g in grads])]

        return cat_updates, cat_params, cat_grads

    def _gather_tensors(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]):
        """returns everything for ``multi_tensor_*``. Concatenates if ```self._concat_params``.
        evaluates grads and loss if ``self._uses_grad`` and ``self._uses_loss``"""

        # evaluate grads and loss if `self._uses_grad` and `self._uses_loss`
        grads, loss = self._get_grads_loss(objective)

        # gather all things
        # concatenate everything to a vec if `self._concat_params`
        if self._concat_params:
            tensors, params, grads = self._get_cat_updates_params_grads(objective, grads)
            states = [states[0]]; settings = [settings[0]]

        # or take original values
        else:
            tensors=objective.get_updates()
            params = objective.params

        return tensors, params, grads, loss, states, settings

    @final
    def update_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> None:
        tensors, params, grads, loss, states, settings = self._gather_tensors(objective, states, settings)

        # initialize before the first update
        num_updates = self.increment_counter("__num_updates", 0)
        if num_updates == 0:
            self.multi_tensor_initialize(
                tensors=tensors,
                params=params,
                grads=grads,
                loss=loss,
                states=states,
                settings=settings
            )

        # update
        self.multi_tensor_update(
            tensors=tensors,
            params=params,
            grads=grads,
            loss=loss,
            states=states,
            settings=settings
        )

    @final
    def apply_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> "Objective":
        tensors, params, grads, loss, states, settings = self._gather_tensors(objective, states, settings)
        # note: _gather tensors will re-cat again if `_concat_params`, this is necessary because objective
        # may have been modified in functional logic, there is no way to know if that happened

        # apply
        ret = self.multi_tensor_apply(
            tensors=tensors,
            params=params,
            grads=grads,
            loss=loss,
            states=states,
            settings=settings
        )

        # uncat if needed and set objective.updates and return objective
        if self._concat_params:
            objective.updates = vec_to_tensors(ret[0], objective.params)

        else:
            objective.updates = list(ret)

        return objective


    # make sure _concat_params, _uses_grad and _uses_loss are saved in `state_dict`
    def _extra_pack(self):
        return {
            "__concat_params": self._concat_params,
            "__uses_grad": self._uses_grad,
            "__uses_loss": self._uses_loss,
        }

    def _extra_unpack(self, d):
        self._concat_params = d["__concat_params"]
        self._uses_grad = d["__uses_grad"]
        self._uses_loss = d["__uses_loss"]

multi_tensor_apply

multi_tensor_apply(tensors: list[Tensor], params: list[Tensor], grads: list[Tensor] | None, loss: Tensor | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> Sequence[Tensor]

Updates tensors and returns it. This shouldn't modify state if possible. By default calls single_tensor_apply on all tensors.

Source code in torchzero/core/transform.py
def multi_tensor_apply(
    self,
    tensors: list[torch.Tensor],
    params: list[torch.Tensor],
    grads: list[torch.Tensor] | None,
    loss: torch.Tensor | None,
    states: list[dict[str, Any]],
    settings: Sequence[Mapping[str, Any]],
) -> Sequence[torch.Tensor]:
    """Updates ``tensors`` and returns it. This shouldn't modify ``state`` if possible.
     By default calls ``single_tensor_apply`` on all tensors.
     """

    if grads is None:
        grads = cast(list, [None] * len(tensors))

    ret = []
    for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
        u = self.single_tensor_apply(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)
        ret.append(u)

    return ret

multi_tensor_initialize

multi_tensor_initialize(tensors: list[Tensor], params: list[Tensor], grads: list[Tensor] | None, loss: Tensor | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> None

initialize states before first update. By default calls single_tensor_initialize on all tensors.

Source code in torchzero/core/transform.py
def multi_tensor_initialize(
    self,
    tensors: list[torch.Tensor],
    params: list[torch.Tensor],
    grads: list[torch.Tensor] | None,
    loss: torch.Tensor | None,
    states: list[dict[str, Any]],
    settings: Sequence[Mapping[str, Any]],
) -> None:
    """initialize ``states`` before first ``update``.
    By default calls ``single_tensor_initialize`` on all tensors.
    """
    if grads is None:
        grads = cast(list, [None] * len(tensors))

    for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
        self.single_tensor_initialize(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)

multi_tensor_update

multi_tensor_update(tensors: list[Tensor], params: list[Tensor], grads: list[Tensor] | None, loss: Tensor | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> None

Updates states. This should not modify tensor. By default calls single_tensor_update on all tensors.

Source code in torchzero/core/transform.py
def multi_tensor_update(
    self,
    tensors: list[torch.Tensor],
    params: list[torch.Tensor],
    grads: list[torch.Tensor] | None,
    loss: torch.Tensor | None,
    states: list[dict[str, Any]],
    settings: Sequence[Mapping[str, Any]],
) -> None:
    """Updates ``states``. This should not modify ``tensor``.
    By default calls ``single_tensor_update`` on all tensors.
    """

    if grads is None:
        grads = cast(list, [None] * len(tensors))

    for tensor, param, grad, state, setting in zip(tensors, params, grads, states, settings):
        self.single_tensor_update(tensor=tensor, param=param, grad=grad, loss=loss, state=state, setting=setting)

single_tensor_apply

single_tensor_apply(tensor: Tensor, param: Tensor, grad: Tensor | None, loss: Tensor | None, state: dict[str, Any], setting: Mapping[str, Any]) -> Tensor

Updates tensor and returns it. This shouldn't modify state if possible.

Source code in torchzero/core/transform.py
def single_tensor_apply(
    self,
    tensor: torch.Tensor,
    param: torch.Tensor,
    grad: torch.Tensor | None,
    loss: torch.Tensor | None,
    state: dict[str, Any],
    setting: Mapping[str, Any],
) -> torch.Tensor:
    """Updates ``tensor`` and returns it. This shouldn't modify ``state`` if possible.
    """
    raise NotImplementedError(f"{self.__class__.__name__} doesn't implement `single_tensor_apply`.")

single_tensor_initialize

single_tensor_initialize(tensor: Tensor, param: Tensor, grad: Tensor | None, loss: Tensor | None, state: dict[str, Any], setting: Mapping[str, Any]) -> None

initialize state before first update.

Source code in torchzero/core/transform.py
def single_tensor_initialize(
    self,
    tensor: torch.Tensor,
    param: torch.Tensor,
    grad: torch.Tensor | None,
    loss: torch.Tensor | None,
    state: dict[str, Any],
    setting: Mapping[str, Any],
) -> None:
    """initialize ``state`` before first ``update``.
    """

single_tensor_update

single_tensor_update(tensor: Tensor, param: Tensor, grad: Tensor | None, loss: Tensor | None, state: dict[str, Any], setting: Mapping[str, Any]) -> None

Updates state. This should not modify tensor.

Source code in torchzero/core/transform.py
def single_tensor_update(
    self,
    tensor: torch.Tensor,
    param: torch.Tensor,
    grad: torch.Tensor | None,
    loss: torch.Tensor | None,
    state: dict[str, Any],
    setting: Mapping[str, Any],
) -> None:
    """Updates ``state``. This should not modify ``tensor``.
    """

Transform

Bases: torchzero.core.module.Module

Transform is a Module with only optional children.

Transform if more flexible in that as long as there are no children, it can use a custom list of states and settings instead of self.state and self.setting.

To use, subclass this and override update_states and apply_states.

Methods:

  • apply_states

    Updates objective using states.

  • update_states

    Updates states. This should not modify objective.update.

Source code in torchzero/core/transform.py
class Transform(Module):
    """``Transform`` is a ``Module`` with only optional children.

    ``Transform`` if more flexible in that as long as there are no children, it can use a custom list of states
    and settings instead of ``self.state`` and ``self.setting``.

    To use, subclass this and override ``update_states`` and ``apply_states``.
    """
    def __init__(self, defaults: dict[str, Any] | None = None, update_freq: int = 1, inner: "Chainable | None" = None):

        # store update_freq in defaults so that it is scheduleable
        if defaults is None: defaults = {}
        safe_dict_update_(defaults, {"__update_freq": update_freq})

        super().__init__(defaults)

        self._objective = None
        if inner is not None:
            self.set_child("inner", inner)

    # settings shouldn't mutate, so they are typed as Sequence[Mapping]
    def update_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> None:
        """Updates ``states``. This should not modify ``objective.update``."""

    @abstractmethod
    def apply_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> "Objective":
        """Updates ``objective`` using ``states``."""

    def _get_states_settings(self, objective: "Objective") -> tuple[list, tuple]:
        # itemgetter is faster
        # but need to make sure it returns a tuple, as if there is a single param, it returns the value
        getter = itemgetter(*objective.params)
        is_single = len(objective.params) == 1
        states = getter(self.state)
        settings = getter(self.settings)

        if is_single:
            states = [states, ]
            settings = (settings, )

        else:
            states = list(states) # itemgetter returns tuple

        return states, settings

    @final
    def update(self, objective:"Objective"):
        step = self.increment_counter("__step", 0)

        if step % self.settings[objective.params[0]]["__update_freq"] == 0:
            states, settings = self._get_states_settings(objective)
            self.update_states(objective=objective, states=states, settings=settings)

    @final
    def apply(self, objective: "Objective"):

        # inner step
        if "inner" in self.children:
            inner = self.children["inner"]
            objective = inner.step(objective)

        # apply and return
        states, settings = self._get_states_settings(objective)
        return self.apply_states(objective=objective, states=states, settings=settings)

apply_states

apply_states(objective: Objective, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> Objective

Updates objective using states.

Source code in torchzero/core/transform.py
@abstractmethod
def apply_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> "Objective":
    """Updates ``objective`` using ``states``."""

update_states

update_states(objective: Objective, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> None

Updates states. This should not modify objective.update.

Source code in torchzero/core/transform.py
def update_states(self, objective: "Objective", states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> None:
    """Updates ``states``. This should not modify ``objective.update``."""

maybe_chain

maybe_chain(*modules: Module | Sequence[Module]) -> Module

Returns a single module directly if only one is provided, otherwise wraps them in a Chain.

Source code in torchzero/core/chain.py
def maybe_chain(*modules: Chainable) -> Module:
    """Returns a single module directly if only one is provided, otherwise wraps them in a ``Chain``."""
    flat_modules: list[Module] = flatten(modules)
    if len(flat_modules) == 1:
        return flat_modules[0]
    return Chain(*flat_modules)

step

step(objective: Objective, modules: Module | Sequence[Module])

doesn't apply hooks!

Source code in torchzero/core/functional.py
def step(objective: "Objective", modules: "Module | Sequence[Module]"):
    """doesn't apply hooks!"""
    if not isinstance(modules, Sequence):
        modules = (modules, )

    if len(modules) == 0:
        raise RuntimeError("`modules` is an empty sequence")

    # if closure is None, assume backward has been called and gather grads
    if objective.closure is None:
        objective.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in objective.params]

    # step and return
    return _chain_step(objective, modules)