speechbrain.utils.epoch_loop 模块
实现了一个可保存检查点的 epoch 计数器(循环),可选地集成了提前停止功能。
- 作者
Aku Rouhe 2020
Davide Borra 2021
总结
类
一个可以保存和恢复其状态的 epoch 计数器。 |
|
一个可以保存和恢复其状态的 epoch 计数器,通过跟踪目标评估指标集成了提前停止功能。 |
参考
- class speechbrain.utils.epoch_loop.EpochCounter(limit)[source]
基类:
object
一个可以保存和恢复其状态的 epoch 计数器。
将其用作 epoch 的迭代器。请注意,此迭代器提供从 [1 … limit] 的数字,而不是像 range(limit) 那样从 [0 … limit-1]。
- 参数:
limit (int) – 最大 epoch 数
示例
>>> from speechbrain.utils.checkpoints import Checkpointer >>> tmpdir = getfixture('tmpdir') >>> epoch_counter = EpochCounter(10) >>> recoverer = Checkpointer(tmpdir, {"epoch": epoch_counter}) >>> recoverer.recover_if_possible() >>> # Now after recovery, >>> # the epoch starts from where it left off! >>> for epoch in epoch_counter: ... # Run training... ... ckpt = recoverer.save_checkpoint()
- class speechbrain.utils.epoch_loop.EpochCounterWithStopper(limit, limit_to_stop, limit_warmup, direction)[source]
基类:
EpochCounter
一个可以保存和恢复其状态的 epoch 计数器,通过跟踪目标评估指标集成了提前停止功能。
- 参数:
示例
>>> limit = 10 >>> limit_to_stop = 5 >>> limit_warmup = 2 >>> direction = "min" >>> epoch_counter = EpochCounterWithStopper(limit, limit_to_stop, limit_warmup, direction) >>> for epoch in epoch_counter: ... # Run training... ... # Track a validation metric, (insert calculation here) ... current_valid_metric = 0 ... # Update epoch counter so that we stop at the appropriate time ... epoch_counter.update_metric(current_valid_metric) ... print(epoch) 1 2 3 4 5 6 7 8