speechbrain.utils.parameter_transfer 模块

用于最简单参数迁移场景的便捷函数。

使用 speechbrain.utils.checkpoints.Checkpointer 来查找检查点和参数文件的路径。

作者
  • Aku Rouhe 2020

  • Andreas Nautsch 2023

  • Adel Moumen 2023

摘要

Pretrainer

协调预训练过程

参考

class speechbrain.utils.parameter_transfer.Pretrainer(collect_in=None, loadables=None, paths=None, custom_hooks=None, conditions=None)[源代码]

基类:object

协调预训练过程

首先(可选地)从某些来源(本地目录、HuggingFace 仓库、基础 URL)收集文件到 collect_in 目录(如果指定)。

然后,为每个文件调用加载 hook。

参数
  • collect_in (str or Path, 可选) – 文件收集到的目录路径。如果为 None,则文件将从缓存或直接引用(如果是 URL 则会失败)。不会有一个集中存放所有文件的目标目录。

  • loadables (映射) – 从可加载键到对象的映射。这将键与实际对象实例连接起来。

  • paths (映射) – 从可加载键到文件路径的映射。路径的最后一部分被视为文件名,其余部分被视为“源”,可以是目录路径或像 Huggingface hub ID 这样的魔术源。例如:sb/asr-crdnn-libri/lm.ckpt -> source=sb/asr-crdnn-libri, file=lm.ckpt 注意,在收集时,您可以指定一个默认源,用于所有未指定路径的可加载项。

  • custom_hooks (映射) – 从可加载键到参数迁移 hook 函数的映射。如果您想使用自定义加载函数,请在此处指定。

  • conditions (映射) – 可选的从可加载键到条件值的映射,对于仅在特定标志打开时加载某些元素非常有用。

set_collect_in(path)[源代码]

更改收集路径

add_loadables(loadables)[源代码]

从给定的映射更新 loadables 字典。

参数

loadables (映射) – 从可加载键到对象的映射

add_paths(paths)[源代码]

更新不同可加载项的路径。

在收集参数时,优先使用这里的路径。注意,在收集时,您可以指定一个默认源,用于所有未指定路径的可加载项。

参数

paths (映射) – 从可加载键到文件路径的映射。路径的最后一部分被视为文件名,其余部分被视为“源”,可以是目录路径或像 Huggingface hub ID 这样的魔术源。例如:sb/asr-crdnn-libri/lm.ckpt -> source=sb/asr-crdnn-libri, file=lm.ckpt

add_custom_hooks(custom_hooks)[源代码]

更新自定义 hook。

在加载参数时,这里的 hook 优先于类默认值。

参数

custom_hooks (映射) – 从可加载键到参数迁移 hook 函数的映射。如果您想使用自定义加载函数,请在此处指定。

add_conditions(conditions)[源代码]

更新条件。

参数

conditions (映射) – 从可加载键到条件值的映射,对于仅在特定标志打开时加载某些元素非常有用。

static split_path(path)[源代码]

将路径分割为源和文件名

除了常规路径外,这也处理 URL 和 Huggingface hub 路径。

参数

path (str)

返回

  • str – 源

  • str – 文件名

collect_files(default_source=None, use_auth_token=False, local_strategy: LocalStrategy = LocalStrategy.SYMLINK)[源代码]

从已知路径获取参数,并使用 fallback default_source

实际的参数文件可能位于其他位置,但这确保了在 self.collect_in 目录中有一个符号链接。符号链接总是使用文件名中的 loadable key。这种标准化使得在例如分布式设置中协调预训练更加容易。

如果您将所有内容整齐地组织在一个位置,例如 Huggingface hub 仓库,请使用 default_source。

参数
  • default_source (str or Path or FetchSource) – 这用于所有尚未指定路径的可加载项。例如,如果可加载项的键为 "asr",则要查找的文件是 <default_source>/asr.ckpt

  • use_auth_token (bool (默认值: False)) – 如果为 True,则使用 Huggingface 的 auth_token 从 HuggingFace Hub 加载私有模型,默认为 False,因为大多数模型是公共的。

  • local_strategy (speechbrain.utils.fetching.LocalStrategy) – 要使用的获取策略,它控制远程文件获取在符号链接和复制方面的行为。如果未指定 collect_in 目录,则忽略。有关更多详细信息,请参阅 speechbrain.utils.fetching.fetch()

返回

从可加载键到本地路径的映射,可以从该路径加载可加载项的参数。此映射在此类中未使用,但可能有帮助。

返回类型

dict

is_loadable(name)[源代码]

如果没有定义条件或者指定的可加载项的条件为真,则返回 True

参数

name (str) – 可加载项的名称

返回

is_loadable – 该项目是否应该被加载

返回类型

bool

load_collected()[源代码]

加载已收集的文件。