Open In ColabGitHub 上执行、查看或下载此笔记本

检查点

通过检查点,我们指的是在特定的时间点保存模型以及所有其他必要的状态信息(例如优化器参数、哪个 epoch 和哪个迭代)。对于实验来说,这主要有两个动机

  • 恢复。从实验中途继续。计算集群作业可能超出时间或内存,或者出现一些简单错误,导致实验脚本在完成前停止。在这种情况下,所有未保存到磁盘的进度都将丢失。

  • 提前停止。在训练过程中,应该在一个单独的验证集上监控性能,这提供了泛化能力的估计。随着训练的进行,我们期望验证误差首先会下降。然而,如果训练时间过长,验证误差可能会再次开始上升(由于过拟合)。训练结束后,我们应该回到在验证集上表现最好的模型参数。

此外,保存训练好的模型参数也很重要,这样模型就可以在实验脚本之外使用。

SpeechBrain 检查点的作用

SpeechBrain 检查点只是协调检查点操作。它跟踪检查点中应该包含的所有内容,每项如何保存,检查点应该存放在哪里,并且它集中处理加载和保存。

检查点器本身并不实际将内容保存到磁盘。它要么通过类型(考虑类继承)找到合适的保存“钩子”,要么你可以提供自定义钩子。

安装依赖

%%capture
# Installing SpeechBrain via pip
BRANCH = 'develop'
!python -m pip install git+https://github.com/speechbrain/speechbrain.git@$BRANCH
import speechbrain as sb
import torch
from speechbrain.utils.checkpoints import Checkpointer

SpeechBrain 检查点概述

多次运行下面的代码块。每次运行该块,它都会训练一个 epoch,然后结束。再次运行该块类似于重新启动实验脚本。

# You have a model, an optimizer and an epoch counter:
model = torch.nn.Linear(1, 1, False)
optimizer = torch.optim.Adam(model.parameters(), lr=1.0)
epoch_counter = sb.utils.epoch_loop.EpochCounter(10)
# Create a checkpointer:
checkpoint_dir = "./nutshell_checkpoints"
checkpointer = Checkpointer(checkpoint_dir,
                            recoverables = {"mdl": model,
                                            "opt": optimizer,
                                            "epochs": epoch_counter})
# Now, before running the training epochs, you want to recover,
# if that is possible (if checkpoints have already been saved.)
# By default, the most recent checkpoint is loaded.
checkpointer.recover_if_possible()
# Then we run an epoch loop:
for epoch in epoch_counter:
    print(f"Starting epoch {epoch}.")
    # Training:
    optimizer.zero_grad()
    prediction = model(torch.tensor([1.]))
    loss = (prediction - torch.tensor([1.]))**2
    loss.backward()
    optimizer.step()
    print(f"Model prediction={prediction.item()}, loss={loss.item()}")
    # And finally at the end, save an end-of-epoch checkpoint:
    checkpointer.save_and_keep_only(meta={"loss":loss.item()})
    # Now, let's "crash" this code block:
    break
else:
    # After training (epoch loop is depleted),
    # we want to recover the best model:
    print("Epoch loop has finished.")
    checkpointer.recover_if_possible(min_key="loss")
    print(f"Best model parameter: {model.weight.data}")
    print(f"Achieved on epoch {epoch_counter.current}.")
# You can use this cell to reset, by deleting all checkpoints:
checkpointer.delete_checkpoints(num_to_keep=0)

检查点是什么样的?

检查点器被指定一个顶级目录,所有检查点都保存在这里

checkpoint_dir = "./full_example_checkpoints"
checkpointer = Checkpointer(checkpoint_dir)

每个检查点应包含许多内容,如模型参数和训练进度。

# You have a model, an optimizer and an epoch counter:
model = torch.nn.Linear(1, 1, True)
optimizer = torch.optim.Adam(model.parameters(), lr=1.0)
epoch_counter = sb.utils.epoch_loop.EpochCounter(10)

每个需要保存的实体都单独分配给检查点器,带有一个唯一的键,比如一个名称

checkpointer.add_recoverable("mdl", model)
checkpointer.add_recoverables({"opt": optimizer, "epoch": epoch_counter})

保存检查点时,检查点器会在顶级目录中创建一个子目录。该子目录表示这个已保存的检查点。在新创建的目录中,传递给检查点器的每个实体都会拥有自己的文件。

ckpt = checkpointer.save_checkpoint()
print("The checkpoint directory was:", ckpt.path)
for key, filepath in ckpt.paramfiles.items():
    print("The entity with key", key, "was saved to:", filepath)

每个文件里有什么?

这取决于实体。检查点器通过类型(考虑类继承)查找保存“钩子”,并调用该钩子,传入要保存的对象和文件路径。

Torch 实体 (Module, Optimizer) 已经有默认的保存和加载钩子

torch_hook = sb.utils.checkpoints.get_default_hook(torch.nn.Linear(1,1), sb.utils.checkpoints.DEFAULT_SAVE_HOOKS)
print(torch_hook.__doc__)

类可以注册自己的默认保存和加载钩子

@sb.utils.checkpoints.register_checkpoint_hooks
class Duck:
    def __init__(self):
        self.quacks = 0

    def quack(self):
        print("Quack!")
        self.quacks += 1
        print(f"I have already quacked {self.quacks} times.")

    @sb.utils.checkpoints.mark_as_saver
    def save(self, path):
        with open(path, "w") as fo:
            fo.write(str(self.quacks))

    @sb.utils.checkpoints.mark_as_loader
    def load(self, path, end_of_epoch):
        # Irrelevant for ducks:
        del end_of_epoch
        del device
        with open(path) as fi:
            self.quacks = int(fi.read())

duck = Duck()
duckpointer = Checkpointer("./duckpoints", {"ducky": duck})
duckpointer.recover_if_possible()
duck.quack()
_ = duckpointer.save_checkpoint()

元信息

检查点还存储一个元信息字典。你可以在其中放入例如验证损失或一些其他指标。默认情况下,只保存 Unix 时间。

# Following from the cells of "What does a checkpoint look like?"
checkpointer.save_checkpoint(meta={"loss": 15.5, "validation-type": "fast", "num-examples": 3})
ckpt = checkpointer.save_checkpoint(meta={"loss": 14.4, "validation-type": "full"})
print(ckpt.meta)

这些元信息可以用来加载最佳检查点,而不仅仅是最近的那个

ckpt = checkpointer.recover_if_possible(min_key="loss")
print(ckpt.meta)

还有更多高级过滤器可用

checkpointer.save_checkpoint(meta={"loss": 12.1, "validation-type": "fast", "num-examples": 2})
ckpt =  checkpointer.recover_if_possible(importance_key=lambda ckpt: -ckpt.meta["loss"]/ckpt.meta["num-examples"],
                                 ckpt_predicate=lambda ckpt: ckpt.meta.get("validation-type") == "fast")
print(ckpt.meta)

保留有限数量的检查点

现在的神经网络模型可能非常大,我们不需要存储每个检查点。检查点可以被显式删除,并且可以使用与恢复时相同的过滤器类型

checkpointer.delete_checkpoints(num_to_keep=1, ckpt_predicate=lambda ckpt: "validation-type" not in ckpt.meta)

但为了方便,还有一个方法可以同时进行保存和删除

checkpointer.save_and_keep_only(meta={"loss": 13.1, "validation-type": "full"},
                                num_to_keep = 2,
                                ckpt_predicate=lambda ckpt: ckpt.meta.get("validation-type") == "full")

预训练 / 参数迁移

从预训练模型迁移参数与恢复不同,尽管它们有一些相似之处。

找到最佳检查点

参数迁移的第一步是找到理想的参数集。你可以使用检查点器来完成此操作:将一个空的检查点器指向实验的顶级检查点目录,然后根据你的标准找到一个检查点。

ckpt_finder = Checkpointer(checkpoint_dir)
best_ckpt = ckpt_finder.find_checkpoint(min_key="loss",
                                        ckpt_predicate=lambda ckpt: ckpt.meta.get("validation-type") == "full")
best_paramfile = best_ckpt.paramfiles["mdl"]
print("The best parameters were stored in:", best_paramfile)

迁移参数

参数迁移没有通用公式,在很多情况下,你可能需要编写一些自定义代码来连接传入的参数到新模型。

SpeechBrain 提供了一个几乎是显而易见的参数迁移到另一个 torch Module 的实现,它简单地加载匹配的层(按名称),并忽略没有找到匹配层的已保存参数。

finetune_mdl = torch.nn.Linear(1,1,False) #This one doesn't have bias!
with torch.no_grad():
    print("Before:", finetune_mdl(torch.tensor([1.])))
    sb.utils.checkpoints.torch_parameter_transfer(finetune_mdl, best_paramfile)
    print("And after:", finetune_mdl(torch.tensor([1.])))

协调迁移

SpeechBrain 有一个类似于 Checkpointer 的参数迁移协调器:speechbrain.utils.parameter_transfer.Pretrainer。其主要目的是为 speechbrain.pretrained.Pretrained 子类(如 EncoderDecoderASR)实现参数下载和加载,并帮助编写易于分享的 recipe。

类似于 Checkpointer,Pretrainer 处理参数文件到实例的映射,并调用迁移代码(实现为类似于检查点加载的钩子)。

引用 SpeechBrain

如果您在研究或商业中使用 SpeechBrain,请使用以下 BibTeX 条目引用它

@misc{speechbrainV1,
  title={Open-Source Conversational AI with {SpeechBrain} 1.0},
  author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},
  year={2024},
  eprint={2407.00463},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2407.00463},
}
@misc{speechbrain,
  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
  year={2021},
  eprint={2106.04624},
  archivePrefix={arXiv},
  primaryClass={eess.AS},
  note={arXiv:2106.04624}
}