speechbrain.core 模块
用于运行实验的核心 SpeechBrain 代码。
- 作者
Peter Plantinga 2020, 2023
Abdel Heba 2020
Mirco Ravanelli 2020
Aku Rouhe 2021
Andreas Nautsch 2022
Sylvain de Langen 2023
Adel Moumen 2023, 2024
摘要
类
Brain 类抽象了数据循环的细节。 |
|
用于跟踪实验阶段的简单枚举。 |
函数
创建输出文件夹和相关的实验文件。 |
|
解析实验的命令行参数。 |
参考
- speechbrain.core.create_experiment_directory(experiment_directory, hyperparams_to_save=None, overrides={}, log_config='/home/docs/checkouts/readthedocs.org/user_builds/speechbrain/checkouts/latest/speechbrain/log-config.yaml', save_env_desc=True)[source]
创建输出文件夹和相关的实验文件。
- 参数:
- speechbrain.core.parse_arguments(arg_list=None)[source]
解析实验的命令行参数。
- 参数:
arg_list (list, None) – 要解析的参数列表。如果未给出,则从
sys.argv[1:]
读取。- 返回:
param_file (str) – 参数文件的位置。
run_opts (dict) – 运行选项,例如分布式、设备等。
overrides (dict) – 传递给
load_hyperpyyaml
的覆盖项。
示例
>>> argv = ['hyperparams.yaml', '--device', 'cuda:1', '--seed', '10'] >>> filename, run_opts, overrides = parse_arguments(argv) >>> filename 'hyperparams.yaml' >>> run_opts["device"] 'cuda:1' >>> overrides 'seed: 10'
- class speechbrain.core.Stage(*values)[source]
基类:
Enum
用于跟踪实验阶段的简单枚举。
- TRAIN = 1
- VALID = 2
- TEST = 3
- class speechbrain.core.Brain(modules=None, opt_class=None, hparams=None, run_opts=None, checkpointer=None)[source]
基类:
object
Brain 类抽象了数据循环的细节。
Brain
类的主要目的是实现fit()
方法,该方法迭代 epoch 和数据集,以将一组模块“拟合”到一组数据。为了使用
fit()
方法,应该继承Brain
类并覆盖任何默认行为不符合用例的方法。对于简单的用例(例如,使用单个数据集训练单个模型),只需要覆盖以下方法:compute_forward()
compute_objectives()
下面的示例说明了如何覆盖这两个方法。
对于更复杂的用例,例如需要更新多个模块,可以覆盖以下方法:
fit_batch()
evaluate_batch()
- 参数:
modules (dict of str:torch.nn.Module pairs) – 如果这些模块具有可训练参数,则默认情况下会将其传递给优化器,并且会在其上调用
train()
/eval()
。opt_class (torch.optim class) – 一个 torch 优化器构造函数,仅接受参数列表(例如 lambda 或偏函数定义)。默认情况下,在
fit()
方法开始时,会将modules
中的所有模块传递给它。可以通过覆盖configure_optimizers()
方法来更改此行为。hparams (dict) – 每个键值对应包含一个字符串键和在覆盖方法中使用的超参数。这些超参数可通过
hparams
属性以“点”表示法访问:例如 self.hparams.model(x)。run_opts –
用于更改运行时环境的一组选项,包括
- debug (bool)
如果为
True
,将仅对所有数据集迭代少量批次,以确保代码运行而不崩溃。- debug_batches (int)
调试模式下运行的批次数,默认为
2
。- debug_epochs (int)
调试模式下运行的 epoch 数,默认为
2
。如果传入非正数,则运行所有 epoch。- debug_persistently (bool)
在调试模式期间保留存储的数据(不使用 /tmp),默认为
False
。- jit (bool)
启用使用 jit 编译所有模块,默认为
False
。- jit_module_keys (str 列表)
modules
中应使用 jit 编译的键列表。- compile (bool)
启用使用 torch.compile 编译所有模块,默认为
False
。- compile_module_keys (str 列表)
modules
中应使用torch.compile
编译的键列表。如果torch.compile
不可用,则会引发错误。- compile_mode (str)
可以是
default
,reduce-overhead
,max-autotune
之一,默认为reduce-overhead
。- compile_using_fullgraph (bool)
是否允许将模型分解为多个子图,默认为
False
。- compile_using_dynamic_shape_tracing (bool)
使用动态形状跟踪进行编译,默认为
False
。- distributed_backend (str)
可以是
nccl
,gloo
,mpi
之一。- device (str)
执行计算的位置。
- precision (str)
可以是
fp32
,fp16
,bf16
之一。- eval_precision (str)
可以是
fp32
,fp16
,bf16
之一。- auto_mix_prec (bool)
如果为
True
,则使用自动混合精度 (fp16)。仅在使用 cuda 时激活。注意:这是一个已弃用的功能,将来将被移除。- bfloat16_mix_prec (bool)
如果为
True
,则使用自动混合精度 (bf16)。仅在使用 cuda 时激活。注意:这是一个已弃用的功能,将来将被移除。- max_grad_norm (float)
fit_batch()
的默认实现使用此值调用clip_grad_norm_
。默认值:5
。- skip_nonfinite_grads (bool)
如果为
True
,则在梯度非有限时(例如 NaN, Inf)将其设置为零。默认值:False
。- nonfinite_patience (int)
在停止之前忽略非有限损失的次数。默认值:
3
。- noprogressbar (bool)
训练时是否关闭进度条。默认值:
False
。- ckpt_interval_minutes (float)
保存 epoch 内检查点的时间间隔,以分钟为单位,默认值:
15.0
。如果非正,则不保存。- ckpt_interval_steps (int)
保存 epoch 内检查点的步数间隔。如果非正,则不保存。默认值:
0
。
- checkpointerspeechbrain.Checkpointer
默认情况下,将使用此对象加载检查点,并将优化器添加到其中,以便在中断时继续训练。
示例
>>> from torch.optim import SGD >>> class SimpleBrain(Brain): ... def compute_forward(self, batch, stage): ... return self.modules.model(batch[0]) ... def compute_objectives(self, predictions, batch, stage): ... return torch.nn.functional.l1_loss(predictions, batch[0]) >>> model = torch.nn.Linear(in_features=10, out_features=10) >>> brain = SimpleBrain({"model": model}, opt_class=lambda x: SGD(x, 0.1)) >>> brain.fit(range(1), ([torch.rand(10, 10), torch.rand(10, 10)],))
- compute_forward(batch, stage)[source]
前向传递,由子类覆盖。
- 参数:
batch (torch.Tensor or tensors) – 来自 dataloader 的元素,包含用于处理的输入。
stage (Stage) – 实验阶段:Stage.TRAIN, Stage.VALID, Stage.TEST
- 返回:
所有处理完成后的输出。直接传递给
compute_objectives()
。- 返回类型:
torch.Tensor or torch.Tensors
- compute_objectives(predictions, batch, stage)[source]
计算损失,由子类覆盖。
- 参数:
predictions (torch.Tensor or torch.Tensors) – 要评估的输出张量或张量集合。直接来自
compute_forward()
。batch (torch.Tensor or tensors) – 来自 dataloader 的元素,包含用于比较的目标。
stage (Stage) – 实验阶段:Stage.TRAIN, Stage.VALID, Stage.TEST
- 返回:
loss – 计算出的损失张量。
- 返回类型:
torch.Tensor
- make_dataloader(dataset, stage, ckpt_prefix='dataloader-', **loader_kwargs)[source]
为数据集创建 DataLoaders。
如果
fit()
和evaluate()
只接收数据集,则使用此方法。或者,也可以在 Brain 子类外部调用此方法。在这种情况下,应将 DataLoader 而不是数据集传递给
fit()
。Stage.TRAIN DataLoader 得到特殊处理。它有用于 shuffle 和 drop_last 的额外参数。在 DDP 中,会创建一个 DistributedSampler(除非数据集是 IterableDataset)。
注意
一些重要的 DataLoader 参数通过 **loader_kwargs 传递,例如 batch_size, num_workers, pin_memory。
注意
默认情况下,
evaluate()
指定 ckpt_prefix=None 以阻止将测试 DataLoader 添加到 checkpointer。如果你需要在保存检查点后(例如在测试时,训练检查点之后)添加一个可恢复项,并且仍然能够合理地恢复,你应该指定allow_partial_load=True
。- 参数:
dataset (Dataset) – 用于创建数据加载器的数据集。如果数据集是 DynamicItemDataset,则默认使用 PaddedBatch 作为 collate_fn,除非在 loader_kwargs 中指定。
stage (Stage) – 实验阶段:Stage.TRAIN, Stage.VALID, Stage.TEST
ckpt_prefix (str, None) – 用于 SaveableDataLoader 检查点名称的前缀。阶段名称将添加到此前缀以创建完整键。设置为 None 则不保存 DataLoader。
**loader_kwargs (dict) – DataLoader 的附加关键字参数。例如 batch_size, num_workers, pin_memory。
- 返回类型:
输入数据集的 DataLoader
- on_fit_start()[source]
在
fit()
开始时调用,如果distributed_count > 0
且后端是 ddp,则在多个进程上调用。默认实现编译 jit 模块,初始化优化器,并加载最新检查点以继续训练。
- init_optimizers()[source]
在
on_fit_start()
期间调用,在参数完全配置后(例如 DDP, jit)初始化优化器。此方法的默认实现依赖于在初始化时传入一个优化器类,该类仅接受参数列表(例如 lambda 或偏函数定义)。这将创建一个优化所有可训练参数的单个优化器。
如果存在多个优化器,请覆盖此方法。
- zero_grad(set_to_none=False)[source]
将所有优化过的
torch.Tensor``s 的梯度设为零,如果 ``set_to_none=False
(默认),否则设为 None。将梯度设为 None 可以节省内存,例如在
evaluate()
期间,从而可能可以使用更大的批次。
- on_evaluate_start(max_key=None, min_key=None)[source]
在
evaluate()
开始时调用。默认实现根据存储的指标加载性能最佳的检查点进行评估。
- fit_batch(batch)[source]
拟合一个批次,重写此方法以执行多次更新。
默认实现依赖于一些具有特定行为的方法的定义:
compute_forward()
compute_objectives()
optimizers_step()
也依赖于在初始化时传递了优化器。
- 参数:
batch (list of torch.Tensors) – 用于训练的数据批次。默认实现假定此批次包含两个元素:输入和目标。
- 返回类型:
分离的 loss
- check_loss_isfinite(loss)[source]
检查损失是否有限。
如果损失不是有限的,则记录一条有用的消息并增加
nonfinite_count
。如果nonfinite_count
超过--nonfinite_patience
阈值,则停止训练并引发错误。当损失变为 NaN 或 inf,而参数和梯度仍是有限的时,此检查特别有用。它有助于防止训练期间陷入无限循环。
- 参数:
loss (tensor) – 在调用
backward()
之后、优化器调用step()
之前的损失 tensor。
- on_fit_batch_start(batch, should_step)[source]
在
fit_batch()
开始时调用。在 AMP 上下文管理器下不调用此方法。不要假定输入批次会自动转换为较低精度(例如 fp16)。
- 参数:
batch (list of torch.Tensors) – 用于训练的数据批次。默认实现假定此批次包含两个元素:输入和目标。
should_step (boolean) – 是否调用了 optimizer.step()。
- evaluate_batch(batch, stage)[source]
评估一个批次,重写此方法以使用与训练不同的过程。
默认实现依赖于一些具有特定行为的方法的定义:
compute_forward()
compute_objectives()
- fit(epoch_counter, train_set, valid_set=None, progressbar=None, train_loader_kwargs={}, valid_loader_kwargs={})[source]
迭代 epoch 和数据集以优化目标。
依赖于可以(或应该)被重写的多个方法的存在。以下方法被使用并期望具有特定行为:
fit_batch()
evaluate_batch()
update_average()
如果初始化时 `distributed_count > 0` 并且 `distributed_backend` 是 ddp,则通常会处理多进程逻辑,例如将训练数据分割为每个设备的子集,并且仅在主进程上保存检查点。
- 参数:
epoch_counter (`iterable`) – 每次调用都应返回一个表示 epoch 计数的整数。
train_set (`Dataset`, `DataLoader`) – 用于训练的数据集。如果给定的是 Dataset,则会自动创建一个 DataLoader。如果给定的是 DataLoader,则直接使用它。
valid_set (`Dataset`, `DataLoader`) – 用于验证的数据集。如果给定的是 Dataset,则会自动创建一个 DataLoader。如果给定的是 DataLoader,则直接使用它。
progressbar (bool) – 是否在进度条中显示每个 epoch 的进度。
train_loader_kwargs (dict) – 传递给
make_dataloader()
以创建 train_loader 的关键字参数(如果 train_set 是 Dataset 而非 DataLoader)。例如,batch_size, num_workers。所有 DataLoader 的关键字参数都有效。valid_loader_kwargs (dict) – 传递给
make_dataloader()
以创建 valid_loader 的关键字参数(如果 valid_set 是 Dataset 而非 DataLoader)。例如,batch_size, num_workers。所有 DataLoader 的关键字参数都有效。
- 返回类型:
无返回值
- evaluate(test_set, max_key=None, min_key=None, progressbar=None, test_loader_kwargs={})[source]
迭代 test_set 并评估 brain 的性能。默认情况下,加载性能最佳的检查点(根据 checkpointer 记录)。
- 参数:
test_set (`Dataset`, `DataLoader`) – 如果给定的是 DataLoader,则直接迭代。否则传递给
self.make_dataloader()
。max_key (str) – 用于查找最佳检查点的键,传递给
on_evaluate_start()
。min_key (str) – 用于查找最佳检查点的键,传递给
on_evaluate_start()
。progressbar (bool) – 是否在进度条中显示进度。
test_loader_kwargs (dict) – 如果
test_set
不是 DataLoader,则传递给make_dataloader()
的关键字参数。注意:loader_kwargs["ckpt_prefix"]
会自动覆盖为None
(以便测试 DataLoader 不会添加到 checkpointer 中)。
- 返回类型:
平均测试 loss