speechbrain.utils.autocast 模块

此模块实现了与 torch.autocast(即自动混合精度)一起使用的实用工具和抽象。

作者
  • Sylvain de Langen 2023

  • Adel Moumen 2025

总结

AMPConfig

自动混合精度 (AMP) 配置。

TorchAutocast

一个上下文管理器,当在 GPU 上运行且指定的数据类型不是 float32 时,有条件地启用 torch.autocast。如果期望的数据类型是 float32,则绕过 autocasting,此上下文管理器表现为无操作。

函数

fwd_default_precision

用于前向方法的装饰器,默认情况下会禁用 autocast 并将所有浮点张量参数转换为指定的 dtype(非常类似于 torch.cuda.amp.custom_fwd)。

参考

class speechbrain.utils.autocast.AMPConfig(dtype: dtype)[source]

基类: object

自动混合精度 (AMP) 配置。

参数:

dtype (torch.dtype) – 用于 AMP 的 dtype。

dtype: dtype
classmethod from_name(name)[source]

从字符串名称创建 AMPConfig。

参数:

name (str) – 要创建的 AMPConfig 的名称。必须是 fp32fp16bf16 中的一个。

返回:

对应于名称的 AMPConfig。

返回类型:

AMPConfig

class speechbrain.utils.autocast.TorchAutocast(*args, **kwargs)[source]

基类: object

一个上下文管理器,当在 GPU 上运行且指定的数据类型不是 float32 时,有条件地启用 torch.autocast。如果期望的数据类型是 float32,则绕过 autocasting,此上下文管理器表现为无操作。

此管理器包装了 torch.autocast,当在 GPU 上运行并且指定的数据类型不是 float32 时,它会自动启用 autocasting。如果期望的数据类型是 float32,则绕过 autocasting,此上下文管理器表现为无操作。

参数:
__enter__()[source]

进入 autocast 上下文。

返回:

进入底层 autocast 上下文管理器的结果。

返回类型:

context

抛出:

RuntimeError – 如果在进入 autocast 上下文时发生错误,并且上下文提供了 ‘device’ 和 ‘fast_dtype’ 属性,则会抛出包含额外诊断信息的 RuntimeError。

__exit__(exc_type, exc_val, exc_tb)[source]

退出 autocast 上下文。

参数:
  • exc_type (type) – 如果发生异常,则为异常类型,否则为 None。

  • exc_val (Exception) – 如果发生异常,则为异常实例,否则为 None。

  • exc_tb (traceback) – 如果发生异常,则为 traceback 对象,否则为 None。

返回:

退出底层 autocast 上下文管理器的结果。

返回类型:

bool 或 None

speechbrain.utils.autocast.fwd_default_precision(fwd: Callable | None = None, cast_inputs: dtype | None = torch.float32)[source]

用于前向方法的装饰器,默认情况下会禁用 autocast 并将所有浮点张量参数转换为指定的 dtype(非常类似于 torch.cuda.amp.custom_fwd)。

被包装的 forward 方法将获得一个额外的 force_allow_autocast 关键字参数。当设置为 True 时,函数将忽略 cast_inputs 并且不会禁用 autocast,就像没有指定此装饰器一样。(因此,模块可以指定一个默认推荐的精度,用户可以在需要时覆盖该行为。)

请注意,截至 PyTorch 2.1.1,这影响 CUDA AMP。非 CUDA AMP 不受影响,并且不会进行任何输入张量转换!未来此函数可能会支持此用例。

当 autocast 激活时,此装饰器不会改变任何行为。

参数:
  • fwd (Optional[Callable]) –

    要包装的函数。如果省略,则返回装饰器的部分应用,例如允许 new_decorator = fwd_default_precision(cast_inputs=torch.float32)

    提示:如果您直接装饰一个函数,此参数已经隐式指定了。

  • cast_inputs (Optional[torch.dtype]) –

    如果不是 None(默认值为 torch.float32),则包装函数的所有浮点输入都将被转换为指定的类型。

    注意:当 autocasting 启用时,与 autocast 兼容的操作的输出张量可能是 autocast 数据类型。在不转换输入的情况下禁用 autocast 不会改变这一事实,因此即使在禁用 autocast 的区域内也可能发生低精度操作,此参数有助于避免这种情况(如果需要)。

返回类型:

被包装的函数