speechbrain.inference.interfaces 模块

定义用于使用预训练模型进行简单推断的接口

作者
  • Aku Rouhe 2021

  • Peter Plantinga 2021

  • Loren Lugosch 2020

  • Mirco Ravanelli 2020

  • Titouan Parcollet 2021

  • Abdel Heba 2021

  • Andreas Nautsch 2022, 2023

  • Pooneh Mousavi 2023

  • Sylvain de Langen 2023

  • Adel Moumen 2023

  • Pradnya Kandarkar 2023

摘要

EncodeDecodePipelineMixin

一个用于预训练模型的 Mixin,使得可以指定编码管道和解码管道

Pretrained

接收一个训练好的模型,并在新数据上进行预测。

函数

foreign_class

从外部源获取并加载接口

参考

speechbrain.inference.interfaces.foreign_class(source, hparams_file='hyperparams.yaml', pymodule_file='custom.py', classname='CustomInterface', overrides={}, overrides_must_match=True, savedir=None, use_auth_token=False, download_only=False, huggingface_cache_dir=None, local_strategy: LocalStrategy = LocalStrategy.SYMLINK, **kwargs)[source]

从外部源获取并加载接口

源可以是文件系统上的位置,也可以是在线/huggingface 的位置

pymodule 文件应包含一个具有给定类名的类。将返回该类的一个实例。目的是在该文件中有一个自定义的 Pretrained 子类。在加载 Hyperparams YAML 文件之前,pymodule 文件也会添加到 Python 路径中,因此它可以包含所需的任何自定义实现。

hyperparams 文件应包含一个“modules”键,它是一个用于计算的 torch 模块字典。

hyperparams 文件应包含一个“pretrainer”键,它是一个 speechbrain.utils.parameter_transfer.Pretrainer

参数:
  • source (strPathFetchSource) – 用于查找模型的位置。详见 speechbrain.utils.fetching.fetch

  • hparams_file (str) – 用于构建推断所需模块的超参数文件名称。必须包含两个键:“modules”和“pretrainer”,如描述所示。

  • pymodule_file (str) – 应该获取的 Python 文件名称。

  • classname (str) – 将创建并返回其实例的类名称

  • overrides (dict) – 加载时对 hparams 文件所做的任何更改。

  • overrides_must_match (bool) – 当覆盖项与 yaml_stream 中对应的键不匹配时,是否抛出错误。

  • savedir (strPath) – 存放预训练材料的位置。如果未指定,则使用缓存。

  • use_auth_token (bool (默认值: False)) – 如果为 True,则将使用 Huggingface 的 auth_token 从 HuggingFace Hub 加载私有模型,默认值为 False,因为大多数模型是公共的。

  • download_only (bool (默认值: False)) – 如果为 True,则跳过类和实例的创建。

  • huggingface_cache_dir (str) – HuggingFace 缓存路径;如果为 None -> “~/.cache/huggingface” (默认值: None)

  • local_strategy (speechbrain.utils.fetching.LocalStrategy) – 使用的获取策略,控制远程文件获取在符号链接和复制方面的行为。详见 speechbrain.utils.fetching.fetch()

  • **kwargs (dict) – 转发给类构造函数的参数。

返回值:

给定 pymodule 文件中具有给定类名的类实例。

返回类型:

object

class speechbrain.inference.interfaces.Pretrained(modules=None, hparams=None, run_opts=None, freeze_params=True)[source]

基类: Module

接收一个训练好的模型,并在新数据上进行预测。

这是一个处理一些通用样板的基础类。它有意地具有与 Brain 相似的接口 - 这些基础类处理类似的事情。

Pretrained 的子类应实现预训练系统如何运行的实际逻辑,并添加具有描述性名称的方法(例如 ASR 的 transcribe_file())。

Pretrained 是一个 torch.nn.Module,因此 .to() 或 .eval() 等方法可以正常工作。子类应提供合适的 forward() 实现:按照惯例,它应该是一个接受批量音频信号并运行完整模型(如果适用)的方法。

参数:
  • modules (dict of str:torch.nn.Module pairs) – 构成学习系统的 Torch 模块。可以特殊处理这些模块(放在正确的设备上,冻结等)。这些模块可以通过 self.mods 作为属性访问,例如 self.mods.model(x)

  • hparams (dict) – 每个 key:value 对都应该由一个字符串键和在覆盖方法中使用的超参数组成。这些可以通过 hparams 属性,使用“点”符号访问:例如,self.hparams.model(x)。

  • run_opts (dict) –

    从命令行解析的选项。详见 speechbrain.parse_arguments()。此处支持的列表

    • device

    • data_parallel_count

    • data_parallel_backend

    • distributed_launch

    • distributed_backend

    • jit

    • jit_module_keys

    • compule

    • compile_module_keys

    • compile_mode

    • compile_using_fullgraph

    • compile_using_dynamic_shape_tracing

  • freeze_params (bool) – 是否冻结(requires_grad=False)参数。通常在推断中需要冻结参数。也会对所有模块调用 .eval()。

HPARAMS_NEEDED = []
MODULES_NEEDED = []
load_audio(path, savedir=None)[source]

使用此模型的输入规范加载音频文件

使用语音模型时,使用与训练模型时相同类型的数据非常重要。例如,这意味着使用相同的采样率和通道数。但是,可以将采样率较高的文件转换为较低的文件(下采样)。同样,将立体声文件混音为单声道也很简单。路径可以是本地路径、网络 URL 或指向 huggingface 仓库的链接。

classmethod from_hparams(source, hparams_file='hyperparams.yaml', pymodule_file='custom.py', overrides={}, savedir=None, use_auth_token=False, revision=None, download_only=False, huggingface_cache_dir=None, overrides_must_match=True, local_strategy: LocalStrategy = LocalStrategy.SYMLINK, **kwargs)[source]

根据 HyperPyYAML 文件从外部源获取和加载

源可以是文件系统上的位置,也可以是在线/huggingface 的位置

您可以使用 pymodule_file 来包含所需的任何自定义实现:如果该文件存在,其位置会在加载 Hyperparams YAML 之前添加到 sys.path 中,因此可以在 YAML 中引用它。

hyperparams 文件应包含一个“modules”键,它是一个用于计算的 torch 模块字典。

hyperparams 文件应包含一个“pretrainer”键,它是一个 speechbrain.utils.parameter_transfer.Pretrainer

参数:
  • source (str) – 用于查找模型的位置。详见 speechbrain.utils.fetching.fetch

  • hparams_file (str) – 用于构建推断所需模块的超参数文件名称。必须包含两个键:“modules”和“pretrainer”,如描述所示。

  • pymodule_file (str) – 可以获取一个 Python 文件。这允许包含任何自定义实现。文件的位置会在加载 hyperparams YAML 文件之前添加到 sys.path 中,因此可以在 YAML 中引用它。这是可选的,但有一个默认值:“custom.py”。如果找不到默认文件,则会忽略;但如果您指定不同的文件名,而文件未找到,则会引发错误。

  • overrides (dict) – 加载时对 hparams 文件所做的任何更改。

  • savedir (strPath) – 存放预训练材料的位置。如果未指定,则使用缓存。

  • use_auth_token (bool (默认值: False)) – 如果为 True,则将使用 Huggingface 的 auth_token 从 HuggingFace Hub 加载私有模型,默认值为 False,因为大多数模型是公共的。

  • revision (str) – 对应于 HuggingFace Hub 模型版本的模型修订版本。如果您希望将代码固定到 HuggingFace 上托管的特定模型版本,这特别有用。

  • download_only (bool (默认值: False)) – 如果为 True,则跳过类和实例的创建。

  • huggingface_cache_dir (str) – HuggingFace 缓存路径;如果为 None -> “~/.cache/huggingface” (默认值: None)

  • overrides_must_match (bool) – 覆盖项是否必须与文件中已有的参数匹配。

  • local_strategy (LocalStrategy, 可选) – 处理本地文件的策略。(默认值: LocalStrategy.SYMLINK

  • **kwargs (dict) – 转发给类构造函数的参数。

返回类型:

cls 的实例

class speechbrain.inference.interfaces.EncodeDecodePipelineMixin[source]

基类: object

一个用于预训练模型的 Mixin,使得可以指定编码管道和解码管道

create_pipelines()[source]

初始化编码和解码管道

to_dict(data)[source]

将填充批次转换为字典,其他数据类型保持不变

参数:

data (object) – 字典或填充批次

返回值:

results – 字典

返回类型:

dict

property batch_inputs

确定输入管道是按批次还是按单个示例操作(true 表示按批次)

返回值:

batch_inputs

返回类型:

bool

property input_use_padded_data

如果启用,原始 PaddedData 实例将传递给模型。如果禁用,将仅使用 .data

返回值:

result – 是否按原样使用填充数据

返回类型:

bool

property batch_outputs

确定输出管道是按批次还是按单个示例操作(true 表示按批次)

返回值:

batch_outputs

返回类型:

bool

encode_input(input)[source]

使用管道对输入进行编码

参数:

input (dict) – 原始输入

返回值:

结果

返回类型:

object

decode_output(output)[source]

解码原始模型输出

参数:

output (tuple) – 原始模型输出

返回值:

result – 管道的输出

返回类型:

dictlist