speechbrain.utils.checkpoints 模块
此模块实现了检查点保存器和加载器。
实验中的检查点通常需要保存许多不同内容的状态:模型参数、优化器参数、当前 epoch 等。检查点的保存格式是一个目录,其中每个独立的可保存项都有自己的文件。此外,一个特殊文件保存了关于检查点的元信息(默认仅为创建时间,但你可以指定你想要的任何其他信息,例如验证损失)。
检查点系统的接口要求你指定要保存的内容。这种方法很灵活,并且与你的实验实际运行方式无关。
接口要求你为每个要保存的内容指定名称。该名称用于在恢复时将正确的参数文件分配给正确的对象。
默认的保存和加载方法仅针对 torch.nn.Modules(及其子类)和 torch.optim.Optimizers 添加。如果这些方法不适用于你的对象,你可以为特定实例或类指定自己的保存和/或加载方法。
示例
>>> # Toy example Module:
>>> class Recoverable(torch.nn.Module):
... def __init__(self, param):
... super().__init__()
... self.param = torch.nn.Parameter(torch.tensor([param]))
... def forward(self, x):
... return x * self.param
>>> model = Recoverable(1.)
>>> tempdir = getfixture('tmpdir')
>>> # In simple cases, the module aims to have a terse syntax,
>>> # consisting of three steps.
>>> # 1. Specifying where to save checkpoints and what is included in a
>>> # checkpoint:
>>> checkpointer = Checkpointer(tempdir, {"network": model})
>>> # 2. Recover from the latest checkpoint, if one is found:
>>> checkpointer.recover_if_possible()
>>> # Run your experiment:
>>> data = [(0.1, 0.9), (0.3, 0.8)]
>>> for example, target in data:
... loss = (model(example) - target)**2
... # 3. Save checkpoints, and keep by default just one, the newest:
... ckpt = checkpointer.save_and_keep_only()
- 作者
Aku Rouhe 2020
Adel Moumen 2024
摘要
类
描述一个已保存检查点的 NamedTuple |
|
保存检查点并从中恢复。 |
函数
平均多个检查点中的参数。 |
|
从 state_dicts 的迭代器生成平均 state_dict。 |
|
检查点重要性指标:新近度。 |
|
查找用于给定对象的默认保存/加载钩子。 |
|
加载 state_dict 检查点时调用的钩子。 |
|
根据提供的映射,映射旧状态字典中的键。 |
|
方法装饰器,标记给定方法为检查点加载钩子。 |
|
方法装饰器,标记给定方法为检查点保存钩子。 |
|
方法装饰器,标记给定方法为参数迁移钩子。 |
|
类装饰器,注册加载、保存和迁移钩子。 |
|
非严格的 Torch Module state_dict 加载。 |
|
使用 |
|
立即从给定路径加载 torch.nn.Module state_dict。 |
|
将对象的参数保存到路径。 |
参考
- speechbrain.utils.checkpoints.map_old_state_dict_weights(state_dict: Dict[str, Tensor], mapping: Dict[str, str]) Dict[str, Tensor] [源码]
根据提供的映射,映射旧状态字典中的键。
注意:此函数将重新映射包含旧键的所有 state_dict 键。例如,如果 state_dict 是 {‘model.encoder.layer.0.atn.self.query.weight’: …} 并且映射是 {‘.atn’: ‘.attn’},则结果 state_dict 将是 {‘model.encoder.layer.0.attn.self.query.weight’: …}。
由于这实际上是批量子串替换,因此部分键匹配(例如在某个层名称中间)也会生效,所以要小心避免误报。
- speechbrain.utils.checkpoints.hook_on_loading_state_dict_checkpoint(state_dict: Dict[str, Tensor]) Dict[str, Tensor] [源码]
加载 state_dict 检查点时调用的钩子。
此钩子在加载 state_dict 检查点时调用。可用于在加载到模型之前修改 state_dict。
默认情况下,此钩子会将旧的 state_dict 键映射到新的键。
- speechbrain.utils.checkpoints.torch_recovery(obj, path, end_of_epoch)[源码]
立即从给定路径加载 torch.nn.Module state_dict。
这可以设置为 torch.nn.Modules 的默认值:>>> DEFAULT_LOAD_HOOKS[torch.nn.Module] = torch_recovery
- 参数:
obj (torch.nn.Module) – 要加载参数的实例。
path (str, pathlib.Path) – 加载路径。
end_of_epoch (bool) – 恢复是否来自 epoch 结束检查点。
- speechbrain.utils.checkpoints.torch_patched_state_dict_load(path, device='cpu')[源码]
使用
从给定路径加载torch.load()
,并调用 SpeechBrainstate_dict
加载钩子,例如应用键名修补规则以实现兼容性。state_dict
不会进一步预处理,也不会应用到模型中,请参阅state_dict
torch_recovery()
或torch_parameter_transfer()
。- 参数:
path (str, pathlib.Path) – 加载路径。
device (str) – 加载的
张量应驻留的设备。这被转发到state_dict
;详情请参阅其文档。torch.load()
- 返回类型:
加载的状态字典。
- speechbrain.utils.checkpoints.torch_save(obj, path)[源码]
将对象的参数保存到路径。
torch.nn.Modules 的默认保存钩子。用于保存 torch.nn.Module state_dicts。
- 参数:
obj (torch.nn.Module) – 要保存的实例。
path (str, pathlib.Path) – 保存路径。
- speechbrain.utils.checkpoints.torch_parameter_transfer(obj, path)[源码]
非严格的 Torch Module state_dict 加载。
从 path 加载一组参数到 obj。如果 obj 中某些层的参数未找到,则仅记录警告。path 中某些参数在 obj 中未找到对应层,也同样记录警告。
- 参数:
obj (torch.nn.Module) – 要加载参数的实例。
path (str) – 加载路径。
- speechbrain.utils.checkpoints.mark_as_saver(method)[源码]
方法装饰器,标记给定方法为检查点保存钩子。
示例请参阅 register_checkpoint_hooks。
- 参数:
method (callable) – 要装饰的类方法。必须可以使用位置参数签名 (instance, path) 调用。例如:def saver(self, path) 满足此条件。
- 返回类型:
已装饰的方法,标记为检查点保存器。
注意
这不会添加钩子(通过方法装饰器不可能),你还必须使用 @register_checkpoint_hooks 装饰类。只能添加一个方法作为钩子。
- speechbrain.utils.checkpoints.mark_as_loader(method)[源码]
方法装饰器,标记给定方法为检查点加载钩子。
- 参数:
method (callable) – 要装饰的类方法。必须可以使用位置参数签名 (instance, path, end_of_epoch) 调用。例如:
满足此条件。def loader(self, path, end_of_epoch):
- 返回类型:
已装饰的方法,注册为检查点加载器。
注意
这不会添加钩子(通过方法装饰器不可能),你还必须使用 @register_checkpoint_hooks 装饰类。只能添加一个方法作为钩子。
- speechbrain.utils.checkpoints.mark_as_transfer(method)[源码]
方法装饰器,标记给定方法为参数迁移钩子。
- 参数:
method (callable) – 要装饰的类方法。必须可以使用位置参数签名 (instance, path) 调用。例如:
满足此条件。def loader(self, path):
- 返回类型:
已装饰的方法,注册为迁移方法。
注意
这不会添加钩子(通过方法装饰器不可能),你还必须使用 @register_checkpoint_hooks 装饰类。只能添加一个方法作为钩子。
注意
会优先使用迁移钩子而不是加载钩子。然而,如果没有注册迁移钩子,Pretrainer 将使用加载钩子。Pretrainer
- speechbrain.utils.checkpoints.register_checkpoint_hooks(cls, save_on_main_only=True)[源码]
类装饰器,注册加载、保存和迁移钩子。
钩子必须已用 mark_as_loader 和 mark_as_saver 标记,并可能已用 mark_as_transfer 标记。
- 参数:
cls (class) – 要装饰的类
save_on_main_only (bool) – 默认情况下,保存器仅在单个进程上运行。此参数提供了在所有进程上运行保存器的选项,这对于一些保存器是必需的,这些保存器需要在保存之前收集数据。
- 返回类型:
已注册钩子的装饰类
示例
>>> @register_checkpoint_hooks ... class CustomRecoverable: ... def __init__(self, param): ... self.param = int(param) ... ... @mark_as_saver ... def save(self, path): ... with open(path, "w", encoding="utf-8") as fo: ... fo.write(str(self.param)) ... ... @mark_as_loader ... def load(self, path, end_of_epoch): ... del end_of_epoch # Unused here ... with open(path, encoding="utf-8") as fi: ... self.param = int(fi.read())
- speechbrain.utils.checkpoints.get_default_hook(obj, default_hooks)[源码]
查找用于给定对象的默认保存/加载钩子。
遵循方法解析顺序(Method Resolution Order),即如果对象的类本身没有注册钩子,也会搜索对象继承的类。
- 参数:
obj (instance) – 类的实例。
default_hooks (dict) – 从类到(检查点钩子)函数的映射。
- 返回类型:
正确的方法,如果未注册任何方法则返回 None。
示例
>>> a = torch.nn.Module() >>> get_default_hook(a, DEFAULT_SAVE_HOOKS) == torch_save True
- class speechbrain.utils.checkpoints.Checkpoint(path, meta, paramfiles)
基类:
tuple
描述一个已保存检查点的 NamedTuple
要从众多检查点中选择一个加载,首先根据此命名元组过滤和排序检查点。Checkpointers 在 path 中放置 pathlib.Path,在 meta 中放置 dict。你可以在保存检查点时向 meta 添加任何你想要的任何信息,例如验证损失。meta 中唯一的默认键是 "unixtime"。Checkpoint.paramfiles 是从可恢复名称到参数文件路径的字典。
- meta
字段编号 1 的别名
- paramfiles
字段编号 2 的别名
- path
字段编号 0 的别名
- speechbrain.utils.checkpoints.ckpt_recency(ckpt)[源码]
检查点重要性指标:新近度。
此函数也可以作为一个示例,说明如何创建检查点重要性键函数。这是一个命名函数,但如你所见,在紧急情况下它也可以轻松地实现为 lambda 函数。
- class speechbrain.utils.checkpoints.Checkpointer(checkpoints_dir, recoverables=None, custom_load_hooks=None, custom_save_hooks=None, allow_partial_load=False)[源码]
基类:
object
保存检查点并从中恢复。
- 参数:
checkpoints_dir (str, pathlib.Path) – 检查点保存目录的路径。
recoverables (mapping, optional) – 要恢复的对象。它们需要一个(唯一的)名称:这用于将检查点中的参数连接到正确的可恢复项。该名称也用于对象的保存文件的文件名。这些对象也可以通过 add_recoverable 或 add_recoverables 添加,或者直接修改 checkpointer.recoverables。
custom_load_hooks (mapping, optional) – 从名称(与 recoverables 中的名称相同)到函数或方法的映射。为特定对象设置自定义加载钩子。函数/方法必须可以使用位置参数签名 (instance, path) 调用。例如:
满足此条件。def loader(self, path)
custom_save_hooks (mapping, optional) – 从名称(与 recoverables 中的名称相同)到函数或方法的映射。为特定对象设置自定义保存钩子。函数/方法必须可以使用位置参数签名 (instance, path) 调用。例如:def saver(self, path) 满足此条件。
allow_partial_load (bool, optional) – 如果为 True,则允许加载检查点时,即使并非所有已注册的可恢复项都找到保存文件。在这种情况下,只加载找到的保存文件。如果为 False,加载此类检查点将引发 RuntimeError。(default: False)
示例
>>> import torch >>> #SETUP: >>> tempdir = getfixture('tmpdir') >>> class Recoverable(torch.nn.Module): ... def __init__(self, param): ... super().__init__() ... self.param = torch.nn.Parameter(torch.tensor([param])) ... def forward(self, x): ... return x * self.param >>> recoverable = Recoverable(1.) >>> recoverables = {'recoverable': recoverable} >>> # SETUP DONE. >>> checkpointer = Checkpointer(tempdir, recoverables) >>> first_ckpt = checkpointer.save_checkpoint() >>> recoverable.param.data = torch.tensor([2.]) >>> loaded_ckpt = checkpointer.recover_if_possible() >>> # Parameter has been loaded: >>> assert recoverable.param.data == torch.tensor([1.]) >>> # With this call, by default, oldest checkpoints are deleted: >>> checkpointer.save_and_keep_only() >>> assert first_ckpt not in checkpointer.list_checkpoints()
- add_recoverable(name, obj, custom_load_hook=None, custom_save_hook=None, optional_load=False)[源码]
注册一个可恢复项,并可能带有自定义钩子。
- 参数:
name (str) – 可恢复项的唯一名称。用于将保存文件映射到对象。
obj (instance) – 要恢复的对象。
custom_load_hook (callable, optional) – 用于加载对象保存文件时调用。函数/方法必须可以使用位置参数签名 (instance, path) 调用。例如:def load(self, path) 满足此条件。
custom_save_hook (callable, optional) – 用于保存对象参数时调用。函数/方法必须可以使用位置参数签名 (instance, path) 调用。例如:def saver(self, path) 满足此条件。
optional_load (bool, optional) – 如果为 True,则允许从检查点可选地加载对象。如果检查点缺少指定对象,则不引发错误。这在不同训练配置之间的过渡中特别有用,例如将精度从浮点 32 更改为 16。例如,假设你有一个训练检查点,其中不包含
对象。如果你打算继续使用浮点 16 进行预训练,其中需要scaler
对象,将其标记为可选可防止加载错误。如果不将其标记为可选,尝试从在浮点 32 中训练的检查点加载scaler
对象将失败,因为该检查点中不存在scaler
对象。scaler
- add_recoverables(recoverables)[源码]
从给定映射更新 recoverables 字典。
- 参数:
recoverables (mapping) – 要恢复的对象。它们需要一个(唯一的)名称:这用于将检查点中的参数连接到正确的可恢复项。该名称也用于对象的参数保存文件的文件名。
- save_checkpoint(meta={}, end_of_epoch=True, name=None, verbosity=20)[源码]
保存检查点。
整个检查点成为一个目录。将每个已注册对象的参数保存到单独的文件中。还会添加一个元文件。默认情况下,元文件只包含 unixtime(自 Unix 纪元以来的秒数),但你可以自己添加任何相关信息。元信息稍后用于选择要加载的检查点。
end_of_epoch 的值会保存到 meta 中。这会影响 epoch 计数器和数据集迭代器如何加载其状态。
对于多进程保存,在某些情况下我们可能需要在多个进程上运行保存代码(例如 FSDP,我们需要在保存前收集参数)。这通过在主进程上创建一个保存文件夹并将其通信给所有进程,然后让每个保存器/加载器方法控制是否在单个进程还是所有进程上保存来实现。
- 参数:
- 返回:
namedtuple [参见上文],已保存的检查点,除非在非主进程上运行,在这种情况下返回 None。
- 返回类型:
- save_and_keep_only(meta={}, end_of_epoch=True, name=None, num_to_keep=1, keep_recent=True, importance_keys=[], max_keys=[], min_keys=[], ckpt_predicate=None, verbosity=20)[源码]
保存检查点,然后删除最不重要的检查点。
这实质上是将
和save_checkpoint()
合并到一个调用中,提供了简洁的语法。delete_checkpoints()
- 参数:
meta (mapping, optional) – 添加到检查点元文件中的映射。默认包含键 "unixtime"。
end_of_epoch (bool, optional) – 检查点是否在 epoch 结束时。默认为 True。可能会影响加载。
name (str, optional) – 为你的检查点指定自定义名称。名称仍会添加前缀。如果未给定名称,则从时间戳和随机唯一 ID 创建名称。
num_to_keep (int, optional) – 要保留的检查点数量。默认为 1。这将删除过滤后剩余的所有检查点。必须 >=0。
keep_recent (bool, optional) – 是否保留最近的
个检查点。num_to_keep
importance_keys (list, optional) – 用于排序的键函数列表(参见内置函数 sorted)。每个可调用对象定义一个排序顺序,并为每个可调用对象保留 num_to_keep 个检查点。将保留键值最高的检查点。这些函数会传入 Checkpoint 命名元组(参见上文)。
max_keys (list, optional) – 将保留该键值*最高*的检查点列表。
min_keys (list, optional) – 将保留该键值*最低*的检查点列表。
ckpt_predicate (callable, optional) – 用此参数排除某些检查点不被删除。在任何排序之前,检查点列表会用此谓词进行过滤。只有 ckpt_predicate 为 True 的检查点才可能被删除。该函数使用 Checkpoint 命名元组(参见上文)调用。
verbosity (int) – 日志级别,默认为 logging.INFO
注意
与 save_checkpoint 不同,此函数不返回任何内容,因为我们无法保证保存的检查点实际上不会被删除。
- find_checkpoint(importance_key=None, max_key=None, min_key=None, ckpt_predicate=None)[源码]
从所有可用检查点中选择一个特定检查点。
如果
,importance_key
和max_key
都没有使用,则返回最新的检查点。这些参数中最多只能使用一个。min_key
大多数功能实际上在
中实现,但此函数作为有用的接口保留。find_checkpoints()
- 参数:
importance_key (callable, optional) – 用于排序的键函数。将选择返回值最高的检查点。该函数会使用 Checkpoint 命名元组调用。
max_key (str, optional) – 将返回该键值最高的检查点。只考虑包含此键的检查点!
min_key (str, optional) – 将返回该键值最低的检查点。只考虑包含此键的检查点!
ckpt_predicate (callable, optional) – 在排序之前,检查点列表会用此谓词进行过滤。参见内置函数 filter。该函数会使用 Checkpoint 命名元组(参见上文)调用。默认情况下,考虑所有检查点。
- 返回:
Checkpoint – 如果找到。
None – 如果过滤后没有检查点存在/剩余。
- speechbrain.utils.checkpoints.find_checkpoints(importance_key=None, max_key=None, min_key=None, ckpt_predicate=None, max_num_checkpoints=None)[源码]
选择多个检查点。
如果
,importance_key
和max_key
都没有使用,则返回最新的检查点。这些参数中最多只能使用一个。min_key
- 参数:
importance_key (callable, optional) – 用于排序的键函数。将选择返回值最高的检查点。该函数会使用 Checkpoint 命名元组调用。
max_key (str, optional) – 将返回该键值最高的检查点。只考虑包含此键的检查点!
min_key (str, optional) – 将返回该键值最低的检查点。只考虑包含此键的检查点!
ckpt_predicate (callable, optional) – 在排序之前,检查点列表会用此谓词进行过滤。参见内置函数 filter。该函数会使用 Checkpoint 命名元组(参见上文)调用。默认情况下,考虑所有检查点。
max_num_checkpoints (int, None) – 要返回的最大检查点数量,或 None 返回所有找到的检查点。
- 返回:
包含最多指定数量 Checkpoint 的列表。
- 返回类型:
- recover_if_possible(importance_key=None, max_key=None, min_key=None, ckpt_predicate=None)[源码]
选择一个检查点并从该检查点恢复,如果找到了检查点。
如果未找到检查点,则不运行恢复。
如果
,importance_key
和max_key
都没有使用,则返回最新的检查点。这些参数中最多只能使用一个。min_key
- 参数:
importance_key (callable, optional) – 用于排序的键函数。将加载返回值最高的检查点。该函数会使用 Checkpoint 命名元组调用。
max_key (str, optional) – 将加载该键值最高的检查点。只考虑包含此键的检查点!
min_key (str, optional) – 将加载该键值最低的检查点。只考虑包含此键的检查点!
ckpt_predicate (callable, optional) – 在排序之前,检查点列表会用此谓词进行过滤。参见内置函数 filter。该函数会使用 Checkpoint 命名元组(参见上文)调用。默认情况下,考虑所有检查点。
- 返回:
Checkpoint – 如果找到。
None – 如果过滤后没有检查点存在/剩余。
- load_checkpoint(checkpoint)[源码]
加载指定的检查点。
- 参数:
checkpoint (Checkpoint) – 要加载的检查点。
- delete_checkpoints(*, num_to_keep=1, min_keys=None, max_keys=None, importance_keys=[<function ckpt_recency>], ckpt_predicate=None, verbosity=20)[源码]
删除最不重要的检查点。
由于定义重要性有多种方式(例如最低 WER、最低损失),用户应提供一个排序键函数列表,每个函数定义一个特定的重要性顺序。本质上,每个重要性键函数提取一个重要性指标(越高越重要)。对于这些顺序中的每一个,都会保留 num_to_keep 个检查点。但是,如果不同顺序保留的检查点之间存在重叠,则不会保留额外的检查点,因此保留的总检查点数量可能小于
num_to_keep * len(importance_keys)
- 参数:
num_to_keep (int, optional) – 要保留的检查点数量。默认为 10。你可以选择保留 0 个。这将删除过滤后剩余的所有检查点。必须 >=0
min_keys (list, optional) – 表示 meta 中键的字符串列表。将保留这些键中值最低的检查点,最多 num_to_keep 个。
max_keys (list, optional) – 表示 meta 中键的字符串列表。将保留这些键中值最高的检查点,最多 num_to_keep 个。
importance_keys (list, optional) – 用于排序的键函数列表(参见内置函数 sorted)。每个可调用函数定义一个排序顺序,并为每个可调用函数保留 num_to_keep 个检查点。明确地说,将保留键值最高的检查点。这些函数使用 Checkpoint 命名元组(参见上文)调用。另请参阅默认值 (ckpt_recency,上文)。默认设置会删除除最新检查点外的所有检查点。
ckpt_predicate (callable, optional) – 用此参数排除某些检查点不被删除。在任何排序之前,检查点列表会用此谓词进行过滤。只有 ckpt_predicate 为 True 的检查点才可能被删除。该函数使用 Checkpoint 命名元组(参见上文)调用。
verbosity (logging level) – 设置此删除操作的日志级别。
注意
必须使用关键字参数调用,作为你知道自己在做什么的标志。删除是永久性的。
- speechbrain.utils.checkpoints.average_state_dicts(state_dicts)[源码]
从 state_dicts 的迭代器生成平均 state_dict。
请注意,在任何时候,这都会在内存中保留两个 state_dicts,这是最低内存需求。
- 参数:
state_dicts (iterator, list) – 要平均的 state_dicts。
- 返回:
平均后的 state_dict。
- 返回类型:
state_dict
- speechbrain.utils.checkpoints.average_checkpoints(checkpoint_list, recoverable_name, parameter_loader=<function load>, averager=<function average_state_dicts>)[源码]
平均多个检查点中的参数。
使用 Checkpointer.find_checkpoints() 获取要平均的检查点列表。研究表明,对训练中最后几个检查点的参数进行平均有时可以提高性能。
默认的加载器和平均器适用于标准的 PyTorch 模块。
- 参数:
checkpoint_list (列表) – 要平均的检查点列表。
recoverable_name (字符串) – 可恢复对象的名称,其参数将被加载并平均。
parameter_loader (函数) – 一个接受单个参数(参数文件的路径)并从该文件中加载参数的函数。默认情况下,使用 torch.load,它产生 state_dict 字典。
averager (函数) – 一个函数,它接受一个迭代器,该迭代器遍历 parameter_loader 加载的每个检查点的参数,并产生它们的平均值。请注意,该函数是使用迭代器调用的,因此长度最初未知;实现应该在参数集被生成时简单地计数其数量。请参见上面的 average_state_dicts 以获取示例。它是默认的平均器,并对 state_dicts 进行平均。
- 返回:
平均器函数的输出。
- 返回类型:
Any
示例
>>> # Consider this toy Module again: >>> class Recoverable(torch.nn.Module): ... def __init__(self, param): ... super().__init__() ... self.param = torch.nn.Parameter(torch.tensor([param])) ... def forward(self, x): ... return x * self.param >>> # Now let's make some checkpoints: >>> model = Recoverable(1.) >>> tempdir = getfixture('tmpdir') >>> checkpointer = Checkpointer(tempdir, {"model": model}) >>> for new_param in range(10): ... model.param.data = torch.tensor([float(new_param)]) ... _ = checkpointer.save_checkpoint() # Suppress output with assignment >>> # Let's average the 3 latest checkpoints >>> # (parameter values 7, 8, 9 -> avg=8) >>> ckpt_list = checkpointer.find_checkpoints(max_num_checkpoints = 3) >>> averaged_state = average_checkpoints(ckpt_list, "model") >>> # Now load that state in the normal way: >>> _ = model.load_state_dict(averaged_state) # Suppress output >>> model.param.data tensor([8.])