speechbrain.utils.checkpoints 模块

此模块实现了检查点保存器和加载器。

实验中的检查点通常需要保存许多不同内容的状态:模型参数、优化器参数、当前 epoch 等。检查点的保存格式是一个目录,其中每个独立的可保存项都有自己的文件。此外,一个特殊文件保存了关于检查点的元信息(默认仅为创建时间,但你可以指定你想要的任何其他信息,例如验证损失)。

检查点系统的接口要求你指定要保存的内容。这种方法很灵活,并且与你的实验实际运行方式无关。

接口要求你为每个要保存的内容指定名称。该名称用于在恢复时将正确的参数文件分配给正确的对象。

默认的保存和加载方法仅针对 torch.nn.Modules(及其子类)和 torch.optim.Optimizers 添加。如果这些方法不适用于你的对象,你可以为特定实例或类指定自己的保存和/或加载方法。

示例

>>> # Toy example Module:
>>> class Recoverable(torch.nn.Module):
...     def __init__(self, param):
...         super().__init__()
...         self.param = torch.nn.Parameter(torch.tensor([param]))
...     def forward(self, x):
...         return x * self.param
>>> model = Recoverable(1.)
>>> tempdir = getfixture('tmpdir')
>>> # In simple cases, the module aims to have a terse syntax,
>>> # consisting of three steps.
>>> # 1. Specifying where to save checkpoints and what is included in a
>>> # checkpoint:
>>> checkpointer = Checkpointer(tempdir, {"network": model})
>>> # 2. Recover from the latest checkpoint, if one is found:
>>> checkpointer.recover_if_possible()
>>> # Run your experiment:
>>> data = [(0.1, 0.9), (0.3, 0.8)]
>>> for example, target in data:
...     loss = (model(example) - target)**2
...     # 3. Save checkpoints, and keep by default just one, the newest:
...     ckpt = checkpointer.save_and_keep_only()
作者
  • Aku Rouhe 2020

  • Adel Moumen 2024

摘要

Checkpoint

描述一个已保存检查点的 NamedTuple

Checkpointer

保存检查点并从中恢复。

函数

average_checkpoints

平均多个检查点中的参数。

average_state_dicts

从 state_dicts 的迭代器生成平均 state_dict。

ckpt_recency

检查点重要性指标:新近度。

get_default_hook

查找用于给定对象的默认保存/加载钩子。

hook_on_loading_state_dict_checkpoint

加载 state_dict 检查点时调用的钩子。

map_old_state_dict_weights

根据提供的映射,映射旧状态字典中的键。

mark_as_loader

方法装饰器,标记给定方法为检查点加载钩子。

mark_as_saver

方法装饰器,标记给定方法为检查点保存钩子。

mark_as_transfer

方法装饰器,标记给定方法为参数迁移钩子。

register_checkpoint_hooks

类装饰器,注册加载、保存和迁移钩子。

torch_parameter_transfer

非严格的 Torch Module state_dict 加载。

torch_patched_state_dict_load

使用 torch.load() 从给定路径加载 state_dict,并调用 SpeechBrain state_dict 加载钩子,例如应用键名修补规则以实现兼容性。

torch_recovery

立即从给定路径加载 torch.nn.Module state_dict。

torch_save

将对象的参数保存到路径。

参考

speechbrain.utils.checkpoints.map_old_state_dict_weights(state_dict: Dict[str, Tensor], mapping: Dict[str, str]) Dict[str, Tensor][源码]

根据提供的映射,映射旧状态字典中的键。

注意:此函数将重新映射包含旧键的所有 state_dict 键。例如,如果 state_dict 是 {‘model.encoder.layer.0.atn.self.query.weight’: …} 并且映射是 {‘.atn’: ‘.attn’},则结果 state_dict 将是 {‘model.encoder.layer.0.attn.self.query.weight’: …}。

由于这实际上是批量子串替换,因此部分键匹配(例如在某个层名称中间)也会生效,所以要小心避免误报。

参数:
  • state_dict (dict) – 要映射的旧状态字典。

  • mapping (dict) – 指定旧键和新键之间映射的字典。

返回:

修改后的带有映射键的状态字典。

返回类型:

dict

speechbrain.utils.checkpoints.hook_on_loading_state_dict_checkpoint(state_dict: Dict[str, Tensor]) Dict[str, Tensor][源码]

加载 state_dict 检查点时调用的钩子。

此钩子在加载 state_dict 检查点时调用。可用于在加载到模型之前修改 state_dict。

默认情况下,此钩子会将旧的 state_dict 键映射到新的键。

参数:

state_dict (dict) – 要加载的 state_dict。

返回:

修改后的 state_dict。

返回类型:

dict

speechbrain.utils.checkpoints.torch_recovery(obj, path, end_of_epoch)[源码]

立即从给定路径加载 torch.nn.Module state_dict。

这可以设置为 torch.nn.Modules 的默认值:>>> DEFAULT_LOAD_HOOKS[torch.nn.Module] = torch_recovery

参数:
  • obj (torch.nn.Module) – 要加载参数的实例。

  • path (str, pathlib.Path) – 加载路径。

  • end_of_epoch (bool) – 恢复是否来自 epoch 结束检查点。

speechbrain.utils.checkpoints.torch_patched_state_dict_load(path, device='cpu')[源码]

使用 torch.load() 从给定路径加载 state_dict,并调用 SpeechBrain state_dict 加载钩子,例如应用键名修补规则以实现兼容性。

state_dict 不会进一步预处理,也不会应用到模型中,请参阅 torch_recovery()torch_parameter_transfer()

参数:
  • path (str, pathlib.Path) – 加载路径。

  • device (str) – 加载的 state_dict 张量应驻留的设备。这被转发到 torch.load();详情请参阅其文档。

返回类型:

加载的状态字典。

speechbrain.utils.checkpoints.torch_save(obj, path)[源码]

将对象的参数保存到路径。

torch.nn.Modules 的默认保存钩子。用于保存 torch.nn.Module state_dicts。

参数:
  • obj (torch.nn.Module) – 要保存的实例。

  • path (str, pathlib.Path) – 保存路径。

speechbrain.utils.checkpoints.torch_parameter_transfer(obj, path)[源码]

非严格的 Torch Module state_dict 加载。

从 path 加载一组参数到 obj。如果 obj 中某些层的参数未找到,则仅记录警告。path 中某些参数在 obj 中未找到对应层,也同样记录警告。

参数:
  • obj (torch.nn.Module) – 要加载参数的实例。

  • path (str) – 加载路径。

speechbrain.utils.checkpoints.mark_as_saver(method)[源码]

方法装饰器,标记给定方法为检查点保存钩子。

示例请参阅 register_checkpoint_hooks。

参数:

method (callable) – 要装饰的类方法。必须可以使用位置参数签名 (instance, path) 调用。例如:def saver(self, path) 满足此条件。

返回类型:

已装饰的方法,标记为检查点保存器。

注意

这不会添加钩子(通过方法装饰器不可能),你还必须使用 @register_checkpoint_hooks 装饰类。只能添加一个方法作为钩子。

speechbrain.utils.checkpoints.mark_as_loader(method)[源码]

方法装饰器,标记给定方法为检查点加载钩子。

参数:

method (callable) – 要装饰的类方法。必须可以使用位置参数签名 (instance, path, end_of_epoch) 调用。例如:def loader(self, path, end_of_epoch): 满足此条件。

返回类型:

已装饰的方法,注册为检查点加载器。

注意

这不会添加钩子(通过方法装饰器不可能),你还必须使用 @register_checkpoint_hooks 装饰类。只能添加一个方法作为钩子。

speechbrain.utils.checkpoints.mark_as_transfer(method)[源码]

方法装饰器,标记给定方法为参数迁移钩子。

参数:

method (callable) – 要装饰的类方法。必须可以使用位置参数签名 (instance, path) 调用。例如:def loader(self, path): 满足此条件。

返回类型:

已装饰的方法,注册为迁移方法。

注意

这不会添加钩子(通过方法装饰器不可能),你还必须使用 @register_checkpoint_hooks 装饰类。只能添加一个方法作为钩子。

注意

Pretrainer 会优先使用迁移钩子而不是加载钩子。然而,如果没有注册迁移钩子,Pretrainer 将使用加载钩子。

speechbrain.utils.checkpoints.register_checkpoint_hooks(cls, save_on_main_only=True)[源码]

类装饰器,注册加载、保存和迁移钩子。

钩子必须已用 mark_as_loader 和 mark_as_saver 标记,并可能已用 mark_as_transfer 标记。

参数:
  • cls (class) – 要装饰的类

  • save_on_main_only (bool) – 默认情况下,保存器仅在单个进程上运行。此参数提供了在所有进程上运行保存器的选项,这对于一些保存器是必需的,这些保存器需要在保存之前收集数据。

返回类型:

已注册钩子的装饰类

示例

>>> @register_checkpoint_hooks
... class CustomRecoverable:
...     def __init__(self, param):
...         self.param = int(param)
...
...     @mark_as_saver
...     def save(self, path):
...         with open(path, "w", encoding="utf-8") as fo:
...             fo.write(str(self.param))
...
...     @mark_as_loader
...     def load(self, path, end_of_epoch):
...         del end_of_epoch  # Unused here
...         with open(path, encoding="utf-8") as fi:
...             self.param = int(fi.read())
speechbrain.utils.checkpoints.get_default_hook(obj, default_hooks)[源码]

查找用于给定对象的默认保存/加载钩子。

遵循方法解析顺序(Method Resolution Order),即如果对象的类本身没有注册钩子,也会搜索对象继承的类。

参数:
  • obj (instance) – 类的实例。

  • default_hooks (dict) – 从类到(检查点钩子)函数的映射。

返回类型:

正确的方法,如果未注册任何方法则返回 None。

示例

>>> a = torch.nn.Module()
>>> get_default_hook(a, DEFAULT_SAVE_HOOKS) == torch_save
True
class speechbrain.utils.checkpoints.Checkpoint(path, meta, paramfiles)

基类:tuple

描述一个已保存检查点的 NamedTuple

要从众多检查点中选择一个加载,首先根据此命名元组过滤和排序检查点。Checkpointers 在 path 中放置 pathlib.Path,在 meta 中放置 dict。你可以在保存检查点时向 meta 添加任何你想要的任何信息,例如验证损失。meta 中唯一的默认键是 "unixtime"。Checkpoint.paramfiles 是从可恢复名称到参数文件路径的字典。

meta

字段编号 1 的别名

paramfiles

字段编号 2 的别名

path

字段编号 0 的别名

speechbrain.utils.checkpoints.ckpt_recency(ckpt)[源码]

检查点重要性指标:新近度。

此函数也可以作为一个示例,说明如何创建检查点重要性键函数。这是一个命名函数,但如你所见,在紧急情况下它也可以轻松地实现为 lambda 函数。

class speechbrain.utils.checkpoints.Checkpointer(checkpoints_dir, recoverables=None, custom_load_hooks=None, custom_save_hooks=None, allow_partial_load=False)[源码]

基类:object

保存检查点并从中恢复。

参数:
  • checkpoints_dir (str, pathlib.Path) – 检查点保存目录的路径。

  • recoverables (mapping, optional) – 要恢复的对象。它们需要一个(唯一的)名称:这用于将检查点中的参数连接到正确的可恢复项。该名称也用于对象的保存文件的文件名。这些对象也可以通过 add_recoverable 或 add_recoverables 添加,或者直接修改 checkpointer.recoverables。

  • custom_load_hooks (mapping, optional) – 从名称(与 recoverables 中的名称相同)到函数或方法的映射。为特定对象设置自定义加载钩子。函数/方法必须可以使用位置参数签名 (instance, path) 调用。例如:def loader(self, path) 满足此条件。

  • custom_save_hooks (mapping, optional) – 从名称(与 recoverables 中的名称相同)到函数或方法的映射。为特定对象设置自定义保存钩子。函数/方法必须可以使用位置参数签名 (instance, path) 调用。例如:def saver(self, path) 满足此条件。

  • allow_partial_load (bool, optional) – 如果为 True,则允许加载检查点时,即使并非所有已注册的可恢复项都找到保存文件。在这种情况下,只加载找到的保存文件。如果为 False,加载此类检查点将引发 RuntimeError。(default: False)

示例

>>> import torch
>>> #SETUP:
>>> tempdir = getfixture('tmpdir')
>>> class Recoverable(torch.nn.Module):
...     def __init__(self, param):
...         super().__init__()
...         self.param = torch.nn.Parameter(torch.tensor([param]))
...     def forward(self, x):
...         return x * self.param
>>> recoverable = Recoverable(1.)
>>> recoverables = {'recoverable': recoverable}
>>> # SETUP DONE.
>>> checkpointer = Checkpointer(tempdir, recoverables)
>>> first_ckpt = checkpointer.save_checkpoint()
>>> recoverable.param.data = torch.tensor([2.])
>>> loaded_ckpt = checkpointer.recover_if_possible()
>>> # Parameter has been loaded:
>>> assert recoverable.param.data == torch.tensor([1.])
>>> # With this call, by default, oldest checkpoints are deleted:
>>> checkpointer.save_and_keep_only()
>>> assert first_ckpt not in checkpointer.list_checkpoints()
add_recoverable(name, obj, custom_load_hook=None, custom_save_hook=None, optional_load=False)[源码]

注册一个可恢复项,并可能带有自定义钩子。

参数:
  • name (str) – 可恢复项的唯一名称。用于将保存文件映射到对象。

  • obj (instance) – 要恢复的对象。

  • custom_load_hook (callable, optional) – 用于加载对象保存文件时调用。函数/方法必须可以使用位置参数签名 (instance, path) 调用。例如:def load(self, path) 满足此条件。

  • custom_save_hook (callable, optional) – 用于保存对象参数时调用。函数/方法必须可以使用位置参数签名 (instance, path) 调用。例如:def saver(self, path) 满足此条件。

  • optional_load (bool, optional) – 如果为 True,则允许从检查点可选地加载对象。如果检查点缺少指定对象,则不引发错误。这在不同训练配置之间的过渡中特别有用,例如将精度从浮点 32 更改为 16。例如,假设你有一个训练检查点,其中不包含 scaler 对象。如果你打算继续使用浮点 16 进行预训练,其中需要 scaler 对象,将其标记为可选可防止加载错误。如果不将其标记为可选,尝试从在浮点 32 中训练的检查点加载 scaler 对象将失败,因为该检查点中不存在 scaler 对象。

add_recoverables(recoverables)[源码]

从给定映射更新 recoverables 字典。

参数:

recoverables (mapping) – 要恢复的对象。它们需要一个(唯一的)名称:这用于将检查点中的参数连接到正确的可恢复项。该名称也用于对象的参数保存文件的文件名。

save_checkpoint(meta={}, end_of_epoch=True, name=None, verbosity=20)[源码]

保存检查点。

整个检查点成为一个目录。将每个已注册对象的参数保存到单独的文件中。还会添加一个元文件。默认情况下,元文件只包含 unixtime(自 Unix 纪元以来的秒数),但你可以自己添加任何相关信息。元信息稍后用于选择要加载的检查点。

end_of_epoch 的值会保存到 meta 中。这会影响 epoch 计数器和数据集迭代器如何加载其状态。

对于多进程保存,在某些情况下我们可能需要在多个进程上运行保存代码(例如 FSDP,我们需要在保存前收集参数)。这通过在主进程上创建一个保存文件夹并将其通信给所有进程,然后让每个保存器/加载器方法控制是否在单个进程还是所有进程上保存来实现。

参数:
  • meta (mapping, optional) – 添加到检查点元文件中的映射。默认包含键 "unixtime"。

  • end_of_epoch (bool, optional) – 检查点是否在 epoch 结束时。默认为 True。可能会影响加载。

  • name (str, optional) – 为你的检查点指定自定义名称。名称仍会添加前缀。如果未给定名称,则从时间戳和随机唯一 ID 创建名称。

  • verbosity (logging level) – 设置此保存操作的日志级别。

返回:

namedtuple [参见上文],已保存的检查点,除非在非主进程上运行,在这种情况下返回 None。

返回类型:

Checkpoint

save_and_keep_only(meta={}, end_of_epoch=True, name=None, num_to_keep=1, keep_recent=True, importance_keys=[], max_keys=[], min_keys=[], ckpt_predicate=None, verbosity=20)[源码]

保存检查点,然后删除最不重要的检查点。

这实质上是将 save_checkpoint()delete_checkpoints() 合并到一个调用中,提供了简洁的语法。

参数:
  • meta (mapping, optional) – 添加到检查点元文件中的映射。默认包含键 "unixtime"。

  • end_of_epoch (bool, optional) – 检查点是否在 epoch 结束时。默认为 True。可能会影响加载。

  • name (str, optional) – 为你的检查点指定自定义名称。名称仍会添加前缀。如果未给定名称,则从时间戳和随机唯一 ID 创建名称。

  • num_to_keep (int, optional) – 要保留的检查点数量。默认为 1。这将删除过滤后剩余的所有检查点。必须 >=0。

  • keep_recent (bool, optional) – 是否保留最近的 num_to_keep 个检查点。

  • importance_keys (list, optional) – 用于排序的键函数列表(参见内置函数 sorted)。每个可调用对象定义一个排序顺序,并为每个可调用对象保留 num_to_keep 个检查点。将保留键值最高的检查点。这些函数会传入 Checkpoint 命名元组(参见上文)。

  • max_keys (list, optional) – 将保留该键值*最高*的检查点列表。

  • min_keys (list, optional) – 将保留该键值*最低*的检查点列表。

  • ckpt_predicate (callable, optional) – 用此参数排除某些检查点不被删除。在任何排序之前,检查点列表会用此谓词进行过滤。只有 ckpt_predicate 为 True 的检查点才可能被删除。该函数使用 Checkpoint 命名元组(参见上文)调用。

  • verbosity (int) – 日志级别,默认为 logging.INFO

注意

与 save_checkpoint 不同,此函数不返回任何内容,因为我们无法保证保存的检查点实际上不会被删除。

find_checkpoint(importance_key=None, max_key=None, min_key=None, ckpt_predicate=None)[源码]

从所有可用检查点中选择一个特定检查点。

如果 importance_key, max_keymin_key 都没有使用,则返回最新的检查点。这些参数中最多只能使用一个。

大多数功能实际上在 find_checkpoints() 中实现,但此函数作为有用的接口保留。

参数:
  • importance_key (callable, optional) – 用于排序的键函数。将选择返回值最高的检查点。该函数会使用 Checkpoint 命名元组调用。

  • max_key (str, optional) – 将返回该键值最高的检查点。只考虑包含此键的检查点!

  • min_key (str, optional) – 将返回该键值最低的检查点。只考虑包含此键的检查点!

  • ckpt_predicate (callable, optional) – 在排序之前,检查点列表会用此谓词进行过滤。参见内置函数 filter。该函数会使用 Checkpoint 命名元组(参见上文)调用。默认情况下,考虑所有检查点。

返回:

  • Checkpoint – 如果找到。

  • None – 如果过滤后没有检查点存在/剩余。

speechbrain.utils.checkpoints.find_checkpoints(importance_key=None, max_key=None, min_key=None, ckpt_predicate=None, max_num_checkpoints=None)[源码]

选择多个检查点。

如果 importance_key, max_keymin_key 都没有使用,则返回最新的检查点。这些参数中最多只能使用一个。

参数:
  • importance_key (callable, optional) – 用于排序的键函数。将选择返回值最高的检查点。该函数会使用 Checkpoint 命名元组调用。

  • max_key (str, optional) – 将返回该键值最高的检查点。只考虑包含此键的检查点!

  • min_key (str, optional) – 将返回该键值最低的检查点。只考虑包含此键的检查点!

  • ckpt_predicate (callable, optional) – 在排序之前,检查点列表会用此谓词进行过滤。参见内置函数 filter。该函数会使用 Checkpoint 命名元组(参见上文)调用。默认情况下,考虑所有检查点。

  • max_num_checkpoints (int, None) – 要返回的最大检查点数量,或 None 返回所有找到的检查点。

返回:

包含最多指定数量 Checkpoint 的列表。

返回类型:

list

recover_if_possible(importance_key=None, max_key=None, min_key=None, ckpt_predicate=None)[源码]

选择一个检查点并从该检查点恢复,如果找到了检查点。

如果未找到检查点,则不运行恢复。

如果 importance_key, max_keymin_key 都没有使用,则返回最新的检查点。这些参数中最多只能使用一个。

参数:
  • importance_key (callable, optional) – 用于排序的键函数。将加载返回值最高的检查点。该函数会使用 Checkpoint 命名元组调用。

  • max_key (str, optional) – 将加载该键值最高的检查点。只考虑包含此键的检查点!

  • min_key (str, optional) – 将加载该键值最低的检查点。只考虑包含此键的检查点!

  • ckpt_predicate (callable, optional) – 在排序之前,检查点列表会用此谓词进行过滤。参见内置函数 filter。该函数会使用 Checkpoint 命名元组(参见上文)调用。默认情况下,考虑所有检查点。

返回:

  • Checkpoint – 如果找到。

  • None – 如果过滤后没有检查点存在/剩余。

load_checkpoint(checkpoint)[源码]

加载指定的检查点。

参数:

checkpoint (Checkpoint) – 要加载的检查点。

list_checkpoints()[源码]

列出检查点目录中的所有检查点。

返回:

Checkpoint 命名元组列表(参见上文)。

返回类型:

list

delete_checkpoints(*, num_to_keep=1, min_keys=None, max_keys=None, importance_keys=[<function ckpt_recency>], ckpt_predicate=None, verbosity=20)[源码]

删除最不重要的检查点。

由于定义重要性有多种方式(例如最低 WER、最低损失),用户应提供一个排序键函数列表,每个函数定义一个特定的重要性顺序。本质上,每个重要性键函数提取一个重要性指标(越高越重要)。对于这些顺序中的每一个,都会保留 num_to_keep 个检查点。但是,如果不同顺序保留的检查点之间存在重叠,则不会保留额外的检查点,因此保留的总检查点数量可能小于

num_to_keep * len(importance_keys)
参数:
  • num_to_keep (int, optional) – 要保留的检查点数量。默认为 10。你可以选择保留 0 个。这将删除过滤后剩余的所有检查点。必须 >=0

  • min_keys (list, optional) – 表示 meta 中键的字符串列表。将保留这些键中值最低的检查点,最多 num_to_keep 个。

  • max_keys (list, optional) – 表示 meta 中键的字符串列表。将保留这些键中值最高的检查点,最多 num_to_keep 个。

  • importance_keys (list, optional) – 用于排序的键函数列表(参见内置函数 sorted)。每个可调用函数定义一个排序顺序,并为每个可调用函数保留 num_to_keep 个检查点。明确地说,将保留键值最高的检查点。这些函数使用 Checkpoint 命名元组(参见上文)调用。另请参阅默认值 (ckpt_recency,上文)。默认设置会删除除最新检查点外的所有检查点。

  • ckpt_predicate (callable, optional) – 用此参数排除某些检查点不被删除。在任何排序之前,检查点列表会用此谓词进行过滤。只有 ckpt_predicate 为 True 的检查点才可能被删除。该函数使用 Checkpoint 命名元组(参见上文)调用。

  • verbosity (logging level) – 设置此删除操作的日志级别。

注意

必须使用关键字参数调用,作为你知道自己在做什么的标志。删除是永久性的。

speechbrain.utils.checkpoints.average_state_dicts(state_dicts)[源码]

从 state_dicts 的迭代器生成平均 state_dict。

请注意,在任何时候,这都会在内存中保留两个 state_dicts,这是最低内存需求。

参数:

state_dicts (iterator, list) – 要平均的 state_dicts。

返回:

平均后的 state_dict。

返回类型:

state_dict

speechbrain.utils.checkpoints.average_checkpoints(checkpoint_list, recoverable_name, parameter_loader=<function load>, averager=<function average_state_dicts>)[源码]

平均多个检查点中的参数。

使用 Checkpointer.find_checkpoints() 获取要平均的检查点列表。研究表明,对训练中最后几个检查点的参数进行平均有时可以提高性能。

默认的加载器和平均器适用于标准的 PyTorch 模块。

参数:
  • checkpoint_list (列表) – 要平均的检查点列表。

  • recoverable_name (字符串) – 可恢复对象的名称,其参数将被加载并平均。

  • parameter_loader (函数) – 一个接受单个参数(参数文件的路径)并从该文件中加载参数的函数。默认情况下,使用 torch.load,它产生 state_dict 字典。

  • averager (函数) – 一个函数,它接受一个迭代器,该迭代器遍历 parameter_loader 加载的每个检查点的参数,并产生它们的平均值。请注意,该函数是使用迭代器调用的,因此长度最初未知;实现应该在参数集被生成时简单地计数其数量。请参见上面的 average_state_dicts 以获取示例。它是默认的平均器,并对 state_dicts 进行平均。

返回:

平均器函数的输出。

返回类型:

Any

示例

>>> # Consider this toy Module again:
>>> class Recoverable(torch.nn.Module):
...     def __init__(self, param):
...         super().__init__()
...         self.param = torch.nn.Parameter(torch.tensor([param]))
...     def forward(self, x):
...         return x * self.param
>>> # Now let's make some checkpoints:
>>> model = Recoverable(1.)
>>> tempdir = getfixture('tmpdir')
>>> checkpointer = Checkpointer(tempdir, {"model": model})
>>> for new_param in range(10):
...     model.param.data = torch.tensor([float(new_param)])
...     _ = checkpointer.save_checkpoint()  # Suppress output with assignment
>>> # Let's average the 3 latest checkpoints
>>> # (parameter values 7, 8, 9 -> avg=8)
>>> ckpt_list = checkpointer.find_checkpoints(max_num_checkpoints = 3)
>>> averaged_state = average_checkpoints(ckpt_list, "model")
>>> # Now load that state in the normal way:
>>> _ = model.load_state_dict(averaged_state)  # Suppress output
>>> model.param.data
tensor([8.])