以在 GitHub 上执行或查看/下载此 Notebook
Brain 类
深度学习的一个基本方面是多次迭代数据集并更新模型参数,这通常被称为“训练循环”。为了简化和组织这个过程,SpeechBrain 提供了一个通用的框架,即 “Brain” 类,它实现在 speechbrain/core.py
中。在每个 recipe 中,该类都被子类化,并且其方法被覆盖以根据该 recipe 的特定要求定制实现。
Brain 类的核心方法是 fit()
,负责迭代数据集、更新模型以及管理训练循环。要利用 fit()
,子类中必须至少定义两个方法:compute_forward()
和 compute_objectives()
。这些方法处理用于生成预测的模型计算以及梯度计算所需的损失项计算。
让我们看一个最小示例来说明这一点
%%capture
# Installing SpeechBrain via pip
BRANCH = 'develop'
!python -m pip install git+https://github.com/speechbrain/speechbrain.git@$BRANCH
# Clone SpeechBrain repository
!git clone https://github.com/speechbrain/speechbrain/
import torch
import speechbrain as sb
class SimpleBrain(sb.Brain):
def compute_forward(self, batch, stage):
return self.modules.model(batch["input"])
def compute_objectives(self, predictions, batch, stage):
return torch.nn.functional.l1_loss(predictions, batch["target"])
model = torch.nn.Linear(in_features=10, out_features=10)
brain = SimpleBrain({"model": model}, opt_class=lambda x: torch.optim.SGD(x, 0.1))
data = [{"input": torch.rand(10, 10), "target": torch.rand(10, 10)}]
brain.fit(range(10), data)
只需大约 10 行代码,我们就可以成功训练一个神经网络模型。这种效率之所以能实现,是因为 Brain 类处理了训练中的复杂细节,例如管理 train()
和 eval()
状态或计算和应用梯度。此外,该类的灵活性允许通过向子类添加方法来覆盖流程的每个步骤。这意味着即使是复杂的训练过程,例如生成对抗网络 (GAN) 中涉及的过程,也可以无缝地集成到 Brain 类中。
在本教程中,我们将首先阐述 Brain 类的参数。随后,我们将深入探讨 fit()
方法,对其进行逐步分解,并重点介绍必要时可以覆盖的部分。对该类参数和 fit()
方法的这些理解构成了理解 Brain 类功能和灵活性的基础。
Brain 类的参数
Brain 类只接受 5 个参数,但每个参数都可能有点复杂,因此我们在此详细解释它们。相关代码仅是 __init__
定义
def __init__(
self,
modules=None,
opt_class=None,
hparams=None,
run_opts=None,
checkpointer=None,
):
modules 参数
第一个参数接受一个 torch 模块字典。Brain 类接收此字典并将其转换为 Torch ModuleDict。这提供了一种方便的方式,可以将所有参数移动到正确设备,调用 train()
和 eval()
,并在必要时将模块包装在适当的分布式包装器中。
opt_class 参数
Brain 类接受 pytorch 优化器的函数定义。选择它作为输入而不是预构建的 pytorch 优化器的原因是,如果需要,Brain 类会自动处理将模块参数包装在分布式包装器中。这需要在参数传递给优化器构造函数之前发生。
要传递 pytorch 优化器构造函数,可以使用 lambda,如本教程开头的示例所示。然而,更方便的选择是 SpeechBrain 中大多数 recipe 使用的选项:使用 HyperPyYAML 定义构造函数。!name:
标签的作用类似于 lambda,创建一个可用于生成优化器的新构造函数。
optimizer: !name:torch.optim.Adam
lr: 0.1
当然,有时需要零个或多个优化器。在需要多个优化器的情况下,可以覆盖 init_optimizers
方法来单独初始化每个优化器。
hparams 参数
Brain 类算法可能依赖于一组应易于从外部控制的超参数,此参数接受一个字典,所有内部方法都可以使用“点表示法”访问该字典。示例如下
class SimpleBrain(sb.Brain):
def compute_forward(self, batch, stage):
return self.modules.model(batch["input"])
def compute_objectives(self, predictions, batch, stage):
term1 = torch.nn.functional.l1_loss(predictions, batch["target1"])
term2 = torch.nn.functional.mse_loss(predictions, batch["target2"])
return self.hparams.weight1 * term1 + self.hparams.weight2 * term2
hparams = {"weight1": 0.7, "weight2": 0.3}
model = torch.nn.Linear(in_features=10, out_features=10)
brain = SimpleBrain(
modules={"model": model},
opt_class=lambda x: torch.optim.SGD(x, 0.1),
hparams=hparams,
)
data = [{
"input": torch.rand(10, 10),
"target1": torch.rand(10, 10),
"target2": torch.rand(10, 10),
}]
brain.fit(range(10), data)
run_opts 参数
有大量选项可以控制 fit()
方法的执行细节,所有这些选项都可以通过此参数传递。一些示例包括启用调试模式、执行设备和分布式执行选项。完整列表请参阅 [添加文档链接]。
checkpointer 参数
最后,如果你将 SpeechBrain checkpointer 传递给 Brain 类,则会自动调用几个操作
优化器参数会添加到 checkpointer 中。
在训练开始时,会加载最新的检查点并从该点恢复训练。如果训练已完成,则此操作将直接结束训练步骤并转到评估。
在训练期间,默认每 15 分钟保存一次检查点(这可以通过
run_opts
中的选项更改或禁用)。在评估开始时,会加载“最佳”检查点,最佳检查点由检查点中记录的指标的最低或最高分数确定。
fit() 方法
此方法功能很多,但实际上只占用大约 ~100 行代码,因此通过阅读代码本身即可理解。我们按章节对其进行分解,解释每个部分的功能。首先,让我们简要介绍一下参数
def fit(
self,
epoch_counter,
train_set,
valid_set=None,
progressbar=None,
train_loader_kwargs={},
valid_loader_kwargs={},
):
epoch_counter 参数接受一个迭代器,因此当
fit()
被调用时,外层循环会迭代此变量。此参数是与EpochCounter
类共同设计的,该类支持存储 epoch 循环状态。使用此参数,我们可以从实验停止的地方重新开始。train_set 和 valid_set 参数接受一个 Torch Dataset 或 DataLoader,它将加载训练所需的张量。如果未传递 DataLoader,则会自动构建一个(参见下一节)。
progressbar 参数控制是否显示一个
tqdm
进度条,以显示每个 epoch 的数据集处理进度。train_loader_kwargs 和 valid_loader_kwargs 会传递给
make_dataloader
方法以创建 DataLoader(参见下一节)。
Fit 结构
了解了参数之后,我们可以开始研究此方法的结构。这里有一个简单的图表,展示了 fit()
中所有可覆盖的调用。在本教程的其余部分,我们将逐一介绍这些调用。
make_dataloader
fit() 方法的第一步是确保数据处于适合迭代的格式。train_set 和 valid_set 以及它们各自的关键字参数都会被传递。以下是实际代码
if not isinstance(train_set, DataLoader):
train_set = self.make_dataloader(
train_set, stage=sb.Stage.TRAIN, **train_loader_kwargs
)
if valid_set is not None and not isinstance(valid_set, DataLoader):
valid_set = self.make_dataloader(
valid_set,
stage=sb.Stage.VALID,
ckpt_prefix=None,
**valid_loader_kwargs,
)
默认情况下,此方法处理 DataLoader 创建的潜在复杂性,例如为分布式执行创建 DistributedSampler。与 fit()
调用中的所有其他方法一样,这可以通过在 Brain 的子类定义中创建 make_dataloader
方法来覆盖。
on_fit_start
除了 dataloader 之外,训练开始前还需要进行一些设置。以下是相关代码
self.on_fit_start()
if progressbar is None:
progressbar = self.progressbar
on_fit_start 方法负责一些重要的事情,最容易通过共享代码来解释
def on_fit_start(self):
self._compile_jit()
self._wrap_distributed()
self.init_optimizers()
if self.checkpointer is not None:
self.checkpointer.recover_if_possible(
device=torch.device(self.device)
)
基本上,此方法确保 torch 模块得到适当准备,包括 jit 编译、分布式包装以及使用所有相关参数初始化优化器。如果存在 checkpointer,优化器初始化还会将优化器参数添加到 checkpointer 中。最后,如果训练被中断,此方法会加载最新的检查点以恢复训练。
on_stage_start
下一部分开始 epoch 迭代,并准备迭代训练数据。要调整准备工作,可以覆盖 on_stage_start
方法,这将允许进行诸如创建容器来存储训练统计信息之类的操作。
for epoch in epoch_counter:
self.on_stage_start(Stage.TRAIN, epoch)
self.modules.train()
self.nonfinite_count = 0
if self.train_sampler is not None and hasattr(
self.train_sampler, "set_epoch"
):
self.train_sampler.set_epoch(epoch)
last_ckpt_time = time.time()
训练循环
本教程中最长的代码块专门用于训练和验证数据循环。但是,它们实际上只做三件重要的事情
在 DataLoader 中的每个批次上调用
fit_batch()
。跟踪平均损失并报告。
选择性地定期保存检查点,以便可以恢复训练。
以下是代码
enable = progressbar and sb.utils.distributed.if_main_process()
with tqdm(
train_set, initial=self.step, dynamic_ncols=True, disable=not enable,
) as t:
for batch in t:
self.step += 1
loss = self.fit_batch(batch)
self.avg_train_loss = self.update_average(
loss, self.avg_train_loss
)
t.set_postfix(train_loss=self.avg_train_loss)
if self.debug and self.step == self.debug_batches:
break
if (
self.checkpointer is not None
and self.ckpt_interval_minutes > 0
and time.time() - last_ckpt_time
>= self.ckpt_interval_minutes * 60.0
):
run_on_main(self._save_intra_epoch_ckpt)
last_ckpt_time = time.time()
也许最重要的一步是调用 fit_batch(batch)
,此处显示其精简版本
def fit_batch(self, batch):
outputs = self.compute_forward(batch, Stage.TRAIN)
loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
loss.backward()
if self.check_gradients(loss):
self.optimizer.step()
self.optimizer.zero_grad()
return loss.detach().cpu()
此方法调用了拟合最重要的两个方法,compute_forward
和 compute_objectives
,要使用 Brain 类,这两个方法都必须被覆盖。然后进行损失反向传播,并在应用更新之前检查梯度是否存在非有限值和过大的范数(默认情况下会自动剪裁过大的范数)。
on_stage_end
在训练循环结束时,会调用 on_stage_end
方法以进行潜在的清理操作,例如报告训练统计信息。
self.on_stage_end(Stage.TRAIN, self.avg_train_loss, epoch)
self.avg_train_loss = 0.0
self.step = 0
验证循环
与训练循环非常相似,验证循环迭代处理 dataloader,一次处理一批数据。但是,此循环调用的是 evaluate_batch
而不是 fit_batch
,它不会进行梯度反向传播或应用任何更新。
if valid_set is not None:
self.on_stage_start(Stage.VALID, epoch)
self.modules.eval()
avg_valid_loss = 0.0
with torch.no_grad():
for batch in tqdm(
valid_set, dynamic_ncols=True, disable=not enable
):
self.step += 1
loss = self.evaluate_batch(batch, stage=Stage.VALID)
avg_valid_loss = self.update_average(
loss, avg_valid_loss
)
if self.debug and self.step == self.debug_batches:
break
on_stage_end
此方法与训练阶段的方法相同,但这次仅在单个进程上执行,因为该过程通常涉及写入文件。常见用途包括:更新学习率、保存检查点和记录 epoch 的统计信息。
self.step = 0
run_on_main(
self.on_stage_end,
args=[Stage.VALID, avg_valid_loss, epoch],
)
最后一件事情是简单检查调试模式,以便只运行几个 epoch。
if self.debug and epoch == self.debug_epochs:
break
恭喜,你现在了解了 fit()
方法的工作原理,以及它为何是运行实验的有用工具。训练模型的所有部分都被分解,恼人的细节都已处理,同时通过覆盖 Brain 类的任何部分仍然可以获得完全的灵活性。
evaluate() 方法
此方法以与 fit()
方法的验证数据非常相似的方式迭代处理测试数据,包括调用 on_stage_start
和 on_stage_end
。另外一个被调用的方法是 on_evaluate_start()
方法,它默认会加载用于评估的最佳检查点。
结论
Brain 类,特别是 fit()
方法,受到了其他流行的 Python 统计和机器学习库的启发,特别是 numpy、scipy、keras 和 PyTorch Lightning。
当我们添加关于 Brain 类更高级用法的教程时,我们会在这里添加链接。一些计划中的教程示例
使用 Brain 类编写 GAN
使用 Brain 类进行分布式训练
Brain 类的非基于梯度的用法
引用 SpeechBrain
如果你在研究或商业中使用了 SpeechBrain,请使用以下 BibTeX 条目引用它
@misc{speechbrainV1,
title={Open-Source Conversational AI with {SpeechBrain} 1.0},
author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},
year={2024},
eprint={2407.00463},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.00463},
}
@misc{speechbrain,
title={{SpeechBrain}: A General-Purpose Speech Toolkit},
author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
year={2021},
eprint={2106.04624},
archivePrefix={arXiv},
primaryClass={eess.AS},
note={arXiv:2106.04624}
}