speechbrain.utils.hpopt 模块
用于超参数优化的工具。此包装器对 Oríon 有可选依赖。
https://orion.readthedocs.io/en/stable/ https://github.com/Epistimio/orion
- 作者
Artem Ploujnikov 2021
摘要
类
一个通用的超参数拟合报告器,将结果以 JSON 格式输出到任意数据流,可供第三方工具读取 |
|
一个方便的上下文管理器,可以有条件地为 recipe 启用超参数优化。 |
|
超参数拟合报告器的基类 |
|
基于 Orion 的结果报告器实现 |
函数
尝试获取指定 mode 的报告器,如果不可用则回退到通用报告器 |
|
返回当前超参数优化 trial 的 ID,主要用于实验文件夹的命名。 |
|
用于注册超参数优化 mode 报告器实现的装饰器 |
|
初始化超参数优化上下文 |
|
如果当前报告器可用,则使用它报告结果。 |
参考
- speechbrain.utils.hpopt.hpopt_mode(mode)[source]
用于注册超参数优化 mode 报告器实现的装饰器
- 参数:
mode (str) – 要注册的 mode
- 返回:
f – 注册并返回报告器类的可调用函数
- 返回类型:
callable
示例
>>> @hpopt_mode("raw") ... class RawHyperparameterOptimizationReporter(HyperparameterOptimizationReporter): ... def __init__(self, *args, **kwargs): ... super().__init__( *args, **kwargs) ... def report_objective(self, result): ... objective = result[self.objective_key] ... print(f"Objective: {objective}")
>>> reporter = get_reporter("raw", objective_key="error") >>> result = {"error": 1.2, "train_loss": 7.2} >>> reporter.report_objective(result) Objective: 1.2
- class speechbrain.utils.hpopt.HyperparameterOptimizationReporter(objective_key)[source]
基类:
object
超参数拟合报告器的基类
- 参数:
objective_key (str) – 用于作为 objective 的 result 字典中的键
- property is_available
确定此报告器是否可用
- property trial_id
此 trial 的唯一 ID(用于文件夹命名)
- class speechbrain.utils.hpopt.GenericHyperparameterOptimizationReporter(reference_date=None, output=None, *args, **kwargs)[source]
基类:
HyperparameterOptimizationReporter
一个通用的超参数拟合报告器,将结果以 JSON 格式输出到任意数据流,可供第三方工具读取
- 参数:
reference_date (datetime.datetime) – 用于创建 trial ID 的日期
output (stream) – 用于报告结果的流
*args (tuple) – 转发给父类的参数
**kwargs (dict) – 转发给父类的关键字参数
- report_objective(result)[source]
报告超参数优化的 objective。
- 参数:
result (dict) – 包含运行结果的字典。
示例
>>> reporter = GenericHyperparameterOptimizationReporter( ... objective_key="error" ... ) >>> result = {"error": 1.2, "train_loss": 7.2} >>> reporter.report_objective(result) {"error": 1.2, "train_loss": 7.2, "objective": 1.2}
- property trial_id
此 trial 的唯一 ID(主要用于文件夹命名)
示例
>>> import datetime >>> reporter = GenericHyperparameterOptimizationReporter( ... objective_key="error", ... reference_date=datetime.datetime(2021, 1, 3) ... ) >>> print(reporter.trial_id) 20210103000000000000
- class speechbrain.utils.hpopt.OrionHyperparameterOptimizationReporter(objective_key)[source]
基类:
HyperparameterOptimizationReporter
基于 Orion 的结果报告器实现
- 参数:
objective_key (str) – 用于作为 objective 的 result 字典中的键
- property trial_id
此 trial 的唯一 ID(主要用于文件夹命名)
- property is_available
确定 Orion 是否可用。要使其可用,需要安装该库,并且至少需要设置 ORION_EXPERIMENT_NAME, ORION_EXPERIMENT_VERSION 或 ORION_TRIAL_ID 中的一个
- speechbrain.utils.hpopt.get_reporter(mode, *args, **kwargs)[source]
尝试获取指定 mode 的报告器,如果不可用则回退到通用报告器
- 参数:
- 返回:
reporter – 报告器实例
- 返回类型:
示例
>>> reporter = get_reporter("generic", objective_key="error") >>> result = {"error": 3.4, "train_loss": 1.2} >>> reporter.report_objective(result) {"error": 3.4, "train_loss": 1.2, "objective": 3.4}
- class speechbrain.utils.hpopt.HyperparameterOptimizationContext(reporter_args=None, reporter_kwargs=None)[source]
基类:
object
一个方便的上下文管理器,可以有条件地为 recipe 启用超参数优化。
示例
>>> ctx = HyperparameterOptimizationContext( ... reporter_args=[], ... reporter_kwargs={"objective_key": "error"} ... )
- parse_arguments(arg_list, pass_hpopt_args=None, pass_trial_id=True)[source]
为超参数优化增强的 speechbrain.parse_arguments 版本。
如果提供了名为 ‘hpopt’ 的参数,将启用超参数优化和报告。
如果参数值对应文件名,它将被读取为 hyperpyyaml 文件,并且内容将添加到“overrides”中。这对于超参数优化期间和完整训练期间某些超参数值不同的情况(例如 epoch 数、保存文件等)非常有用。
- 参数:
- 返回:
param_file (str) – 参数文件的位置。
run_opts (dict) – 运行选项,例如 distributed, device 等。
overrides (dict) – 要传递给
load_hyperpyyaml
的 overrides。
示例
>>> ctx = HyperparameterOptimizationContext() >>> arg_list = ["hparams.yaml", "--x", "1", "--y", "2"] >>> hparams_file, run_opts, overrides = ctx.parse_arguments(arg_list) >>> print(f"File: {hparams_file}, Overrides: {overrides}") File: hparams.yaml, Overrides: {'x': 1, 'y': 2}
- speechbrain.utils.hpopt.hyperparameter_optimization(*args, **kwargs)[source]
初始化超参数优化上下文
- 参数:
- 返回类型:
示例
>>> import sys >>> with hyperparameter_optimization(objective_key="error", output=sys.stdout) as hp_ctx: ... result = {"error": 3.5, "train_loss": 2.1} ... report_result(result) ... {"error": 3.5, "train_loss": 2.1, "objective": 3.5}