speechbrain.inference.interfaces 模块
定义用于使用预训练模型进行简单推断的接口
- 作者
Aku Rouhe 2021
Peter Plantinga 2021
Loren Lugosch 2020
Mirco Ravanelli 2020
Titouan Parcollet 2021
Abdel Heba 2021
Andreas Nautsch 2022, 2023
Pooneh Mousavi 2023
Sylvain de Langen 2023
Adel Moumen 2023
Pradnya Kandarkar 2023
摘要
类
一个用于预训练模型的 Mixin,使得可以指定编码管道和解码管道 |
|
接收一个训练好的模型,并在新数据上进行预测。 |
函数
从外部源获取并加载接口 |
参考
- speechbrain.inference.interfaces.foreign_class(source, hparams_file='hyperparams.yaml', pymodule_file='custom.py', classname='CustomInterface', overrides={}, overrides_must_match=True, savedir=None, use_auth_token=False, download_only=False, huggingface_cache_dir=None, local_strategy: LocalStrategy = LocalStrategy.SYMLINK, **kwargs)[source]
从外部源获取并加载接口
源可以是文件系统上的位置,也可以是在线/huggingface 的位置
pymodule 文件应包含一个具有给定类名的类。将返回该类的一个实例。目的是在该文件中有一个自定义的 Pretrained 子类。在加载 Hyperparams YAML 文件之前,pymodule 文件也会添加到 Python 路径中,因此它可以包含所需的任何自定义实现。
hyperparams 文件应包含一个“modules”键,它是一个用于计算的 torch 模块字典。
hyperparams 文件应包含一个“pretrainer”键,它是一个 speechbrain.utils.parameter_transfer.Pretrainer
- 参数:
source (str 或 Path 或 FetchSource) – 用于查找模型的位置。详见
speechbrain.utils.fetching.fetch
。hparams_file (str) – 用于构建推断所需模块的超参数文件名称。必须包含两个键:“modules”和“pretrainer”,如描述所示。
pymodule_file (str) – 应该获取的 Python 文件名称。
classname (str) – 将创建并返回其实例的类名称
overrides (dict) – 加载时对 hparams 文件所做的任何更改。
overrides_must_match (bool) – 当覆盖项与 yaml_stream 中对应的键不匹配时,是否抛出错误。
savedir (str 或 Path) – 存放预训练材料的位置。如果未指定,则使用缓存。
use_auth_token (bool (默认值: False)) – 如果为 True,则将使用 Huggingface 的 auth_token 从 HuggingFace Hub 加载私有模型,默认值为 False,因为大多数模型是公共的。
download_only (bool (默认值: False)) – 如果为 True,则跳过类和实例的创建。
huggingface_cache_dir (str) – HuggingFace 缓存路径;如果为 None -> “~/.cache/huggingface” (默认值: None)
local_strategy (speechbrain.utils.fetching.LocalStrategy) – 使用的获取策略,控制远程文件获取在符号链接和复制方面的行为。详见
speechbrain.utils.fetching.fetch()
。**kwargs (dict) – 转发给类构造函数的参数。
- 返回值:
给定 pymodule 文件中具有给定类名的类实例。
- 返回类型:
- class speechbrain.inference.interfaces.Pretrained(modules=None, hparams=None, run_opts=None, freeze_params=True)[source]
基类:
Module
接收一个训练好的模型,并在新数据上进行预测。
这是一个处理一些通用样板的基础类。它有意地具有与
Brain
相似的接口 - 这些基础类处理类似的事情。Pretrained 的子类应实现预训练系统如何运行的实际逻辑,并添加具有描述性名称的方法(例如 ASR 的 transcribe_file())。
Pretrained 是一个 torch.nn.Module,因此 .to() 或 .eval() 等方法可以正常工作。子类应提供合适的 forward() 实现:按照惯例,它应该是一个接受批量音频信号并运行完整模型(如果适用)的方法。
- 参数:
modules (dict of str:torch.nn.Module pairs) – 构成学习系统的 Torch 模块。可以特殊处理这些模块(放在正确的设备上,冻结等)。这些模块可以通过
self.mods
作为属性访问,例如 self.mods.model(x)hparams (dict) – 每个 key:value 对都应该由一个字符串键和在覆盖方法中使用的超参数组成。这些可以通过
hparams
属性,使用“点”符号访问:例如,self.hparams.model(x)。run_opts (dict) –
从命令行解析的选项。详见
speechbrain.parse_arguments()
。此处支持的列表device
data_parallel_count
data_parallel_backend
distributed_launch
distributed_backend
jit
jit_module_keys
compule
compile_module_keys
compile_mode
compile_using_fullgraph
compile_using_dynamic_shape_tracing
freeze_params (bool) – 是否冻结(requires_grad=False)参数。通常在推断中需要冻结参数。也会对所有模块调用 .eval()。
- HPARAMS_NEEDED = []
- MODULES_NEEDED = []
- load_audio(path, savedir=None)[source]
使用此模型的输入规范加载音频文件
使用语音模型时,使用与训练模型时相同类型的数据非常重要。例如,这意味着使用相同的采样率和通道数。但是,可以将采样率较高的文件转换为较低的文件(下采样)。同样,将立体声文件混音为单声道也很简单。路径可以是本地路径、网络 URL 或指向 huggingface 仓库的链接。
- classmethod from_hparams(source, hparams_file='hyperparams.yaml', pymodule_file='custom.py', overrides={}, savedir=None, use_auth_token=False, revision=None, download_only=False, huggingface_cache_dir=None, overrides_must_match=True, local_strategy: LocalStrategy = LocalStrategy.SYMLINK, **kwargs)[source]
根据 HyperPyYAML 文件从外部源获取和加载
源可以是文件系统上的位置,也可以是在线/huggingface 的位置
您可以使用 pymodule_file 来包含所需的任何自定义实现:如果该文件存在,其位置会在加载 Hyperparams YAML 之前添加到 sys.path 中,因此可以在 YAML 中引用它。
hyperparams 文件应包含一个“modules”键,它是一个用于计算的 torch 模块字典。
hyperparams 文件应包含一个“pretrainer”键,它是一个 speechbrain.utils.parameter_transfer.Pretrainer
- 参数:
source (str) – 用于查找模型的位置。详见
speechbrain.utils.fetching.fetch
。hparams_file (str) – 用于构建推断所需模块的超参数文件名称。必须包含两个键:“modules”和“pretrainer”,如描述所示。
pymodule_file (str) – 可以获取一个 Python 文件。这允许包含任何自定义实现。文件的位置会在加载 hyperparams YAML 文件之前添加到 sys.path 中,因此可以在 YAML 中引用它。这是可选的,但有一个默认值:“custom.py”。如果找不到默认文件,则会忽略;但如果您指定不同的文件名,而文件未找到,则会引发错误。
overrides (dict) – 加载时对 hparams 文件所做的任何更改。
savedir (str 或 Path) – 存放预训练材料的位置。如果未指定,则使用缓存。
use_auth_token (bool (默认值: False)) – 如果为 True,则将使用 Huggingface 的 auth_token 从 HuggingFace Hub 加载私有模型,默认值为 False,因为大多数模型是公共的。
revision (str) – 对应于 HuggingFace Hub 模型版本的模型修订版本。如果您希望将代码固定到 HuggingFace 上托管的特定模型版本,这特别有用。
download_only (bool (默认值: False)) – 如果为 True,则跳过类和实例的创建。
huggingface_cache_dir (str) – HuggingFace 缓存路径;如果为 None -> “~/.cache/huggingface” (默认值: None)
overrides_must_match (bool) – 覆盖项是否必须与文件中已有的参数匹配。
local_strategy (LocalStrategy, 可选) – 处理本地文件的策略。(默认值:
LocalStrategy.SYMLINK
)**kwargs (dict) – 转发给类构造函数的参数。
- 返回类型:
cls 的实例