speechbrain.utils.autocast 模块
此模块实现了与 torch.autocast
(即自动混合精度)一起使用的实用工具和抽象。
- 作者
Sylvain de Langen 2023
Adel Moumen 2025
总结
类
自动混合精度 (AMP) 配置。 |
|
一个上下文管理器,当在 GPU 上运行且指定的数据类型不是 float32 时,有条件地启用 |
函数
用于前向方法的装饰器,默认情况下会禁用 autocast 并将所有浮点张量参数转换为指定的 dtype(非常类似于 |
参考
- class speechbrain.utils.autocast.AMPConfig(dtype: dtype)[source]
基类:
object
自动混合精度 (AMP) 配置。
- 参数:
dtype (torch.dtype) – 用于 AMP 的 dtype。
- dtype: dtype
- class speechbrain.utils.autocast.TorchAutocast(*args, **kwargs)[source]
基类:
object
一个上下文管理器,当在 GPU 上运行且指定的数据类型不是 float32 时,有条件地启用
torch.autocast
。如果期望的数据类型是 float32,则绕过 autocasting,此上下文管理器表现为无操作。此管理器包装了
torch.autocast
,当在 GPU 上运行并且指定的数据类型不是 float32 时,它会自动启用 autocasting。如果期望的数据类型是 float32,则绕过 autocasting,此上下文管理器表现为无操作。- 参数:
*args (tuple) – 转发给
torch.autocast
的位置参数。请参阅 PyTorch 文档:https://pytorch.ac.cn/docs/stable/amp.html#torch.autocast**kwargs (dict) – 转发给
torch.autocast
的关键字参数。通常包括用于指定所需精度的dtype
参数。更多详细信息请参阅 PyTorch 文档。
- __enter__()[source]
进入 autocast 上下文。
- 返回:
进入底层 autocast 上下文管理器的结果。
- 返回类型:
context
- 抛出:
RuntimeError – 如果在进入 autocast 上下文时发生错误,并且上下文提供了 ‘device’ 和 ‘fast_dtype’ 属性,则会抛出包含额外诊断信息的 RuntimeError。
- speechbrain.utils.autocast.fwd_default_precision(fwd: Callable | None = None, cast_inputs: dtype | None = torch.float32)[source]
用于前向方法的装饰器,默认情况下会禁用 autocast 并将所有浮点张量参数转换为指定的 dtype(非常类似于
torch.cuda.amp.custom_fwd
)。被包装的 forward 方法将获得一个额外的
force_allow_autocast
关键字参数。当设置为True
时,函数将忽略cast_inputs
并且不会禁用 autocast,就像没有指定此装饰器一样。(因此,模块可以指定一个默认推荐的精度,用户可以在需要时覆盖该行为。)请注意,截至 PyTorch 2.1.1,这仅影响 CUDA AMP。非 CUDA AMP 不受影响,并且不会进行任何输入张量转换!未来此函数可能会支持此用例。
当 autocast 未激活时,此装饰器不会改变任何行为。
- 参数:
fwd (Optional[Callable]) –
要包装的函数。如果省略,则返回装饰器的部分应用,例如允许
new_decorator = fwd_default_precision(cast_inputs=torch.float32)
。提示:如果您直接装饰一个函数,此参数已经隐式指定了。
cast_inputs (Optional[torch.dtype]) –
如果不是
None
(默认值为torch.float32
),则包装函数的所有浮点输入都将被转换为指定的类型。注意:当 autocasting 启用时,与 autocast 兼容的操作的输出张量可能是 autocast 数据类型。在不转换输入的情况下禁用 autocast 不会改变这一事实,因此即使在禁用 autocast 的区域内也可能发生低精度操作,此参数有助于避免这种情况(如果需要)。
- 返回类型:
被包装的函数