speechbrain.utils.epoch_loop 模块

实现了一个可保存检查点的 epoch 计数器(循环),可选地集成了提前停止功能。

作者
  • Aku Rouhe 2020

  • Davide Borra 2021

总结

EpochCounter

一个可以保存和恢复其状态的 epoch 计数器。

EpochCounterWithStopper

一个可以保存和恢复其状态的 epoch 计数器,通过跟踪目标评估指标集成了提前停止功能。

参考

class speechbrain.utils.epoch_loop.EpochCounter(limit)[source]

基类: object

一个可以保存和恢复其状态的 epoch 计数器。

将其用作 epoch 的迭代器。请注意,此迭代器提供从 [1 … limit] 的数字,而不是像 range(limit) 那样从 [0 … limit-1]。

参数:

limit (int) – 最大 epoch 数

示例

>>> from speechbrain.utils.checkpoints import Checkpointer
>>> tmpdir = getfixture('tmpdir')
>>> epoch_counter = EpochCounter(10)
>>> recoverer = Checkpointer(tmpdir, {"epoch": epoch_counter})
>>> recoverer.recover_if_possible()
>>> # Now after recovery,
>>> # the epoch starts from where it left off!
>>> for epoch in epoch_counter:
...     # Run training...
...     ckpt = recoverer.save_checkpoint()
class speechbrain.utils.epoch_loop.EpochCounterWithStopper(limit, limit_to_stop, limit_warmup, direction)[source]

基类: EpochCounter

一个可以保存和恢复其状态的 epoch 计数器,通过跟踪目标评估指标集成了提前停止功能。

参数:
  • limit (int) – 最大 epoch 数

  • limit_to_stop (int) – 性能没有改进的最大连续 epoch 数

  • limit_warmup (int) – 在开始检查提前停止之前等待的 epoch 数

  • direction ("max" or "min") – 优化目标评估指标的方向

示例

>>> limit = 10
>>> limit_to_stop = 5
>>> limit_warmup = 2
>>> direction = "min"
>>> epoch_counter = EpochCounterWithStopper(limit, limit_to_stop, limit_warmup, direction)
>>> for epoch in epoch_counter:
...     # Run training...
...     # Track a validation metric, (insert calculation here)
...     current_valid_metric = 0
...     # Update epoch counter so that we stop at the appropriate time
...     epoch_counter.update_metric(current_valid_metric)
...     print(epoch)
1
2
3
4
5
6
7
8
__next__()[source]

如果达到条件,则停止迭代。

update_metric(current_metric)[source]

更新状态以反映相关评估指标的最新值。

注意:应在每次验证循环中仅调用一次。

参数:

current_metric (float) – 用于做出停止决定的评估指标。