Open In ColabGitHub 上执行或查看/下载此笔记本

神经网络适配器:实现更快、低内存的微调

本教程介绍了 SpeechBrain 对适配器(如 LoRA)的实现。这包括如何将 SpeechBrain 实现的适配器、自定义适配器以及来自 PEFT 等库的适配器集成到预训练模型中。

先决条件

介绍和背景

随着预训练模型越来越大、能力越来越强,人们对以内存高效的方式在合理时间内将其应用于特定任务的方法越来越感兴趣。其中一种技术是冻结原始参数,并在原始模型中插入少量额外的参数,这些参数被称为“适配器”。这些适配器通常可以在参数数量的一小部分情况下匹配完整微调的性能,这意味着更快、更内存高效的微调 [1]。一种流行的实现技术称为低秩适应 (LoRA) [2]。

在软件方面,HuggingFace 发布了一个流行的适配器库 PEFT [3]。我们的实现包含该库的一些功能,同时也包括将 PEFT 适配器集成到 SpeechBrain 模型中的能力。我们将从一个基本的 YAML 示例开始,如果您更喜欢通过实验学习,可以直接尝试。

相关参考文献

  1. N. Houlsby, A. Giurgiu, S. Jastrzebski, B. Morrone, Q. De Laroussilhe, A. Gesmundo, M. Attariyan, and S. Gelly, “Parameter-efficient transfer learning for NLP.” In International Conference on Machine Learning, 2019.

  2. E.J. Hu, P. Wallis, Z. Allen-Zhu, Y. Li, S. Wang, L. Wang, and W. Chen, “LoRA: Low-rank adaptation of large language models.” In International Conference on Learning Representations, 2021.

  3. S. Mangrulkar, S. Gugger, L. Debut, Y. Belkada, S. Paul, and B. Bossan, “PEFT: State-of-the-art parameter-efficient fine-tuning methods.” GitHub Repository, 2022.

太长;没看

本教程的 TL;DR 是您应该在 HyperPyYAML 文件中使用类似这样的部分来创建一个带有适配器的模型

adapted_model: !new:speechbrain.nnet.adapters.AdaptedModel
    model_to_adapt: !ref <model>
    adapter_class: !name:speechbrain.nnet.adapters.LoRA
    all_linear: True
    unfrozen_layers: ["conv_1d_*"]
    adapter_kwargs:
        rank: 8

在 YAML 中添加此部分会获取已经定义的键 model,并为每个线性层添加一个 LoRA 适配器,其中关键字参数为 rank=8all_linearall_conv 参数分别向所有线性层或所有卷积层添加适配器。默认情况下,此类会冻结所有未适配层的参数,但可以使用 unfrozen_layers 参数指定也应训练的层的名称,但这会增加参数数量。可以使用 target_layers 参数指定应适配的特定层。这两个参数都支持通过 python 的 fnmatch 库使用 unix 风格的通配符。

如果 TL;DR 不够,并且您需要更详细地研究一个示例,请继续阅读下一节。

详细教程

我们将演示如何在包含完整训练所需一切的模板 recipe 上使用适配器。第一步是预训练一个模型,以便稍后可以添加适配器。

!git clone --depth 1 --branch v1.0.2 https://github.com/speechbrain/speechbrain.git
!python -m pip install -e speechbrain
fatal: destination path 'speechbrain' already exists and is not an empty directory.
/home/pplantinga/Documents/Repositories/uvenv/bin/python: No module named pip
# In order to use speechbrain in this repo we have to add it to the path
import os, sys

sys.path.append(os.path.join(os.getcwd(), 'speechbrain'))
%cd speechbrain/templates/speech_recognition/ASR
/home/pplantinga/Documents/Repositories/speechbrain/docs/tutorials/nn/speechbrain/templates/speech_recognition/ASR
/home/pplantinga/Documents/Repositories/uvenv/lib/python3.12/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]
!python train.py train.yaml --number_of_epochs=1 --batch_size=2 --test_scorer "!ref <valid_scorer>" --enable_add_reverb=False --enable_add_noise=False #To speed up
INFO:speechbrain.utils.seed:Setting seed to 2602
WARNING:speechbrain.utils.train_logger:torchvision is not available - cannot save figures
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results/CRDNN_BPE_960h_LM/2602
mini_librispeech_prepare - Preparation completed in previous run, skipping.
../data/noise/data.zip exists. Skipping download
../data/rir/data.zip exists. Skipping download
speechbrain.utils.fetching - Fetch lm.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch lm.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch tokenizer.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch tokenizer.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch asr.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch asr.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.parameter_transfer - Loading pretrained files for: lm, tokenizer, model
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used
speechbrain.core - Gradscaler enabled: False. Using precision: fp32.
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/core.py:793: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled)
speechbrain.core - ASR Model Statistics:
* Total Number of Trainable Parameters: 173.0M
* Total Number of Parameters: 173.0M
* Trainable Parameters represent 100.0000% of the total size.
speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.
speechbrain.utils.epoch_loop - Going into epoch 1
speechbrain.augment.augmenter - No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch.
100%|████████████████████████| 760/760 [08:28<00:00,  1.49it/s, train_loss=1.35]
100%|█████████████████████████████████████████| 545/545 [01:28<00:00,  6.18it/s]
speechbrain.utils.train_logger - epoch: 1, lr: 1.00e+00 - train loss: 1.35 - valid loss: 1.31, valid CER: 7.71, valid WER: 20.06
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/CRDNN_BPE_960h_LM/2602/save/CKPT+2024-10-08+11-08-06+00
speechbrain.utils.checkpoints - Loading a checkpoint from results/CRDNN_BPE_960h_LM/2602/save/CKPT+2024-10-08+11-08-06+00
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/nnet/schedulers.py:240: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  data = torch.load(path)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  stats = torch.load(path, map_location=device)
100%|███████████████████████████████████████| 1310/1310 [09:25<00:00,  2.32it/s]
speechbrain.utils.train_logger - Epoch loaded: 1 - test loss: 1.30, test CER: 5.75, test WER: 17.57
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/CRDNN_BPE_960h_LM/2602/save/CKPT+latest

推理

为了证明这是有效的,我们先对一个文件进行推理。这段代码取自 transcribe_file.py

from speechbrain.inference.ASR import EncoderDecoderASR
from speechbrain.utils.fetching import fetch, LocalStrategy

# Ensure all the needed files end up in the same place to load with the transcriber
save_dir = os.path.abspath("results/CRDNN_BPE_960h_LM/2602/save/CKPT+latest")
fetch("lm.ckpt", "speechbrain/asr-crdnn-rnnlm-librispeech", save_dir, local_strategy=LocalStrategy.SYMLINK)
fetch("tokenizer.ckpt", "speechbrain/asr-crdnn-rnnlm-librispeech", save_dir, local_strategy=LocalStrategy.SYMLINK)
fetch("inference.yaml", os.getcwd(), save_dir, local_strategy=LocalStrategy.SYMLINK)

transcriber = EncoderDecoderASR.from_hparams(source=save_dir, hparams_file="inference.yaml")
speech_file = "../data/LibriSpeech/dev-clean-2/1272/135031/1272-135031-0015.flac"
transcriber.transcribe_file(speech_file)
INFO:speechbrain.utils.fetching:Fetch lm.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
INFO:speechbrain.utils.fetching:Fetch tokenizer.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
INFO:speechbrain.utils.parameter_transfer:Loading pretrained files for: lm, tokenizer, model, normalizer
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  stats = torch.load(path, map_location=device)
'THE METAL FOREST IS IN THE GREAT DOMED CAVERN THE LARGEST IN ALL OUR DOMINIONS REPLIED CALICO ⁇ '

添加适配器

既然我们已经证明模型至少是可以工作的,现在就来添加适配器吧。我们基本上需要创建一个新的 yaml 文件,将适配器添加到模型中,然后使用这个新的 yaml 文件进行训练。为此,我们只需加载旧的 yaml 文件,然后更改所有必要的部件来训练适配后的模型。

%%writefile train_lora.patch
--- train.yaml	2024-10-07 19:23:49.839501714 -0400
+++ train_lora.yaml	2024-10-07 19:25:40.340933091 -0400
@@ -30,7 +30,7 @@
 NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
 RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
 
-output_folder: !ref results/CRDNN_BPE_960h_LM/<seed>
+output_folder: !ref results/crdnn_lora/<seed>
 test_wer_file: !ref <output_folder>/wer_test.txt
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
@@ -41,7 +41,7 @@
 # speechbrain HuggingFace repository. However, a local path pointing to a
 # directory containing the lm.ckpt and tokenizer.ckpt may also be specified
 # instead. E.g if you want to use your own LM / tokenizer.
-pretrained_path: speechbrain/asr-crdnn-rnnlm-librispeech
+pretrained_path: results/CRDNN_BPE_960h_LM/2602/save/CKPT+latest
 
 
 # Path where data manifest files will be stored. The data manifest files are created by the
@@ -481,10 +481,9 @@
     ctc_lin: !ref <ctc_lin>
     seq_lin: !ref <seq_lin>
     normalize: !ref <normalize>
-    lm_model: !ref <lm_model>
 
 # Gathering all the submodels in a single model object.
-model: !new:torch.nn.ModuleList
+model_pretrained: !new:torch.nn.ModuleList
     - - !ref <encoder>
       - !ref <embedding>
       - !ref <decoder>
@@ -629,8 +628,31 @@
     loadables:
         lm: !ref <lm_model>
         tokenizer: !ref <tokenizer>
-        model: !ref <model>
+        model: !ref <model_pretrained>
     paths:
         lm: !ref <pretrained_path>/lm.ckpt
         tokenizer: !ref <pretrained_path>/tokenizer.ckpt
-        model: !ref <pretrained_path>/asr.ckpt
+        model: !ref <pretrained_path>/model.ckpt
+
+new_encoder: !new:speechbrain.nnet.adapters.AdaptedModel
+    model_to_adapt: !ref <encoder>
+    adapter_class: !name:speechbrain.nnet.adapters.LoRA
+    all_linear: True
+    manual_adapter_insertion: True
+    adapter_kwargs:
+        rank: 8
+
+new_decoder: !new:speechbrain.nnet.adapters.AdaptedModel
+    model_to_adapt: !ref <decoder>
+    adapter_class: !name:speechbrain.nnet.adapters.LoRA
+    all_linear: True
+    manual_adapter_insertion: True
+    adapter_kwargs:
+        rank: 8
+
+model: !new:torch.nn.ModuleList
+    - - !ref <new_encoder>
+      - !ref <embedding>
+      - !ref <new_decoder>
+      - !ref <ctc_lin>
+      - !ref <seq_lin>
Overwriting train_lora.patch
!patch train.yaml -i train_lora.patch -o train_lora.yaml
patching file train_lora.yaml (read from train.yaml)

由于我们使用预训练器加载预训练参数,因此必须在此代码中插入代码,以便在预训练参数加载后插入适配器。

这就是 yaml 文件中出现 manual_adapter_insertion: True 以及随后的训练代码进行简短更改的原因

%%writefile train_lora_py.patch
--- train.py	2024-10-07 14:57:21.534381751 -0400
+++ train_lora.py	2024-10-07 19:33:12.839895913 -0400
@@ -473,6 +473,8 @@
     # the path given in the YAML file). The tokenizer is loaded at the same time.
     hparams["pretrainer"].collect_files()
     hparams["pretrainer"].load_collected()
+    hparams["new_encoder"].insert_adapters()
+    hparams["new_decoder"].insert_adapters()
 
     # Trainer initialization
     asr_brain = ASR(
Overwriting train_lora_py.patch
!patch train.py -i train_lora_py.patch -o train_lora.py
patching file train_lora.py (read from train.py)

训练适配后的模型

训练过程与之前完全相同,使用更新后的 lora 文件。适配后的模型设计为原地替换。请注意,可训练参数的数量减少到接近原始参数的 1%。

!python train_lora.py train_lora.yaml --number_of_epochs=1 --batch_size=2 --test_scorer "!ref <valid_scorer>" --enable_add_reverb=False --enable_add_noise=False #To speed up
INFO:speechbrain.utils.seed:Setting seed to 2602
WARNING:speechbrain.utils.train_logger:torchvision is not available - cannot save figures
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results/crdnn_lora/2602
mini_librispeech_prepare - Preparation completed in previous run, skipping.
../data/noise/data.zip exists. Skipping download
../data/rir/data.zip exists. Skipping download
speechbrain.utils.parameter_transfer - Loading pretrained files for: lm, tokenizer, model
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used
speechbrain.core - Gradscaler enabled: False. Using precision: fp32.
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/core.py:793: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled)
speechbrain.core - ASR Model Statistics:
* Total Number of Trainable Parameters: 1.8M
* Total Number of Parameters: 120.0M
* Trainable Parameters represent 1.4807% of the total size.
speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.
speechbrain.utils.epoch_loop - Going into epoch 1
speechbrain.augment.augmenter - No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch.
100%|███████████████████████████| 760/760 [04:09<00:00,  3.04it/s, train_loss=1]
100%|█████████████████████████████████████████| 545/545 [01:40<00:00,  5.42it/s]
speechbrain.utils.train_logger - epoch: 1, lr: 1.00e+00 - train loss: 1.00 - valid loss: 1.26, valid CER: 7.29, valid WER: 19.16
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_lora/2602/save/CKPT+2024-10-08+11-23-53+00
speechbrain.utils.checkpoints - Loading a checkpoint from results/crdnn_lora/2602/save/CKPT+2024-10-08+11-23-53+00
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/nnet/schedulers.py:240: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  data = torch.load(path)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  stats = torch.load(path, map_location=device)
100%|███████████████████████████████████████| 1310/1310 [12:55<00:00,  1.69it/s]
speechbrain.utils.train_logger - Epoch loaded: 1 - test loss: 1.26, test CER: 5.62, test WER: 17.05
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_lora/2602/save/CKPT+latest

自定义适配器

我们这样设计是为了您可以将 SpeechBrain 适配器替换为 peft 适配器

new_encoder: !new:speechbrain.nnet.adapters.AdaptedModel
    model_to_adapt: !ref <encoder>
-   adapter_class: !name:speechbrain.nnet.adapters.LoRA
+   adapter_class: !name:peft.tuners.lora.layer.Linear
    manual_adapter_insertion: True
    adapter_kwargs:
-       rank: 16
+       r: 16
+       adapter_name: lora

但这训练的结果与之前完全相同,所以我们无需再重复一遍。也许更有趣的是设计一个自定义适配器

%%file pool_lora.py

import torch

class PoolLoRA(torch.nn.Module):
    def __init__(self, target_module, stride=2, rank=16, alpha=1.0):
        super().__init__()

        input_size = target_module.weight.data.shape[1]
        output_size = target_module.weight.data.shape[0]
        
        # Disable gradient for pretrained module
        self.pretrained_module = target_module
        for param in self.pretrained_module.parameters():
            param.requires_grad = False
        device = target_module.weight.device

        self.adapter_down_scale = torch.nn.AvgPool1d(kernel_size=stride)
        self.adapter_down_proj = torch.nn.Linear(
            input_size // stride, rank, bias=False, device=device
        )   
        self.adapter_up_proj = torch.nn.Linear(
            rank, output_size, bias=False, device=device
        )   
        self.adapter_up_proj.weight.data.fill_(0.0)

        self.scaling = alpha / rank

    def forward(self, x: torch.Tensor):
        """Applies the LoRA Adapter.

        Arguments
        ---------
        x: torch.Tensor
            Input tensor to the adapter module.

        Returns
        -------
        The linear outputs
        """
        x_pretrained = self.pretrained_module(x)

        x_downsample = self.adapter_down_proj(self.adapter_down_scale(x))
        x_pool_lora = self.adapter_up_proj(x_downsample)
        
        return x_pretrained + x_pool_lora * self.scaling
Overwriting pool_lora.py
%%writefile train_pool_lora.patch
--- train_lora.yaml	2024-10-07 22:44:02.767830301 -0400
+++ train_pool_lora.yaml	2024-10-07 22:41:30.602641301 -0400
@@ -30,7 +30,7 @@
 NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
 RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
 
-output_folder: !ref results/crdnn_lora/<seed>
+output_folder: !ref results/crdnn_pool_lora/<seed>
 test_wer_file: !ref <output_folder>/wer_test.txt
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
@@ -636,19 +636,21 @@
 
 new_encoder: !new:speechbrain.nnet.adapters.AdaptedModel
     model_to_adapt: !ref <encoder>
-    adapter_class: !name:speechbrain.nnet.adapters.LoRA
+    adapter_class: !name:pool_lora.PoolLoRA
     all_linear: True
     manual_adapter_insertion: True
     adapter_kwargs:
-        rank: 8
+        stride: 2
+        rank: 16
 
 new_decoder: !new:speechbrain.nnet.adapters.AdaptedModel
     model_to_adapt: !ref <decoder>
-    adapter_class: !name:speechbrain.nnet.adapters.LoRA
+    adapter_class: !name:pool_lora.PoolLoRA
     all_linear: True
     manual_adapter_insertion: True
     adapter_kwargs:
-        rank: 8
+        stride: 2
+        rank: 16
 
 model: !new:torch.nn.ModuleList
     - - !ref <new_encoder>
Overwriting train_pool_lora.patch
!patch train_lora.yaml -i train_pool_lora.patch -o train_pool_lora.yaml
patching file train_pool_lora.yaml (read from train_lora.yaml)
!python train_lora.py train_pool_lora.yaml --number_of_epochs=1 --batch_size=2 --test_scorer "!ref <valid_scorer>" --enable_add_reverb=False --enable_add_noise=False #To speed up
INFO:speechbrain.utils.seed:Setting seed to 2602
WARNING:speechbrain.utils.train_logger:torchvision is not available - cannot save figures
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results/crdnn_pool_lora/2602
mini_librispeech_prepare - Preparation completed in previous run, skipping.
../data/noise/data.zip exists. Skipping download
../data/rir/data.zip exists. Skipping download
speechbrain.utils.parameter_transfer - Loading pretrained files for: lm, tokenizer, model
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used
speechbrain.core - Gradscaler enabled: False. Using precision: fp32.
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/core.py:793: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled)
speechbrain.core - ASR Model Statistics:
* Total Number of Trainable Parameters: 1.8M
* Total Number of Parameters: 120.0M
* Trainable Parameters represent 1.5210% of the total size.
speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.
speechbrain.utils.epoch_loop - Going into epoch 1
speechbrain.augment.augmenter - No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch.
100%|████████████████████████| 760/760 [04:19<00:00,  2.93it/s, train_loss=0.98]
100%|█████████████████████████████████████████| 545/545 [01:44<00:00,  5.24it/s]
speechbrain.utils.train_logger - epoch: 1, lr: 1.00e+00 - train loss: 9.80e-01 - valid loss: 1.26, valid CER: 7.18, valid WER: 18.92
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_pool_lora/2602/save/CKPT+2024-10-08+11-43-00+00
speechbrain.utils.checkpoints - Loading a checkpoint from results/crdnn_pool_lora/2602/save/CKPT+2024-10-08+11-43-00+00
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/nnet/schedulers.py:240: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  data = torch.load(path)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  stats = torch.load(path, map_location=device)
100%|███████████████████████████████████████| 1310/1310 [14:07<00:00,  1.55it/s]
speechbrain.utils.train_logger - Epoch loaded: 1 - test loss: 1.25, test CER: 5.61, test WER: 16.99
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_pool_lora/2602/save/CKPT+latest

结论

就是这样!感谢您的跟随!请继续创造酷炫的适配器吧。

引用 SpeechBrain

如果您在研究或商业中使用 SpeechBrain,请使用以下 BibTeX 条目引用它

@misc{speechbrainV1,
  title={Open-Source Conversational AI with {SpeechBrain} 1.0},
  author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},
  year={2024},
  eprint={2407.00463},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2407.00463},
}
@misc{speechbrain,
  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
  year={2021},
  eprint={2106.04624},
  archivePrefix={arXiv},
  primaryClass={eess.AS},
  note={arXiv:2106.04624}
}