speechbrain.nnet.unet 模块
用于扩散模型的 UNet 模型实现
改编自 OpenAI guided diffusion,并进行了一些修改和新增功能 https://github.com/openai/guided-diffusion
MIT 许可证
版权所有 (c) 2021 OpenAI
特此免费授予获得本软件及相关文档文件(“软件”)副本的任何人,不受限制地处理本软件的权利,包括但不限于使用、复制、修改、合并、出版、分发、再许可和/或销售软件副本的权利,并允许向其提供软件的人这样做,但须遵守以下条件
上述版权声明和本许可声明应包含在软件的所有副本或重要部分中。
本软件按“原样”提供,不提供任何明示或暗示的担保,包括但不限于对适销性、特定用途适用性和不侵权的担保。在任何情况下,作者或版权持有人均不对因本软件或本软件的使用或其他交易而产生或与之相关的任何索赔、损害或其他责任承担责任,无论是在合同行为、侵权行为或其他方面。
- 作者
Artem Ploujnikov 2022
摘要
类
一个允许空间位置相互注意的注意力块。 |
|
二维注意力池化 |
|
带注意力和时间步 embedding 的半 UNet 模型。 |
|
带有可选卷积的下采样层。 |
|
一个应用下采样因子所需 padding 的 wrapper 模块 |
|
一个计算 embedding 向量投影到指定维度的简单模块 |
|
带注意力和时间步 embedding 的半 UNet 模型。 |
|
一个执行 QKV 注意力并以不同顺序分割的模块。 |
|
一个可以可选更改通道数量的残差块。 |
|
任何 forward() 方法将时间步 embedding 作为第二个参数的模块。 |
|
一个顺序模块,将时间步 embedding 作为额外输入传递给支持它的子模块。 |
|
带有注意力和时间步 embedding 的完整 UNet 模型。 |
|
一个基于 UNet 的变分自编码器 (VAE) 便捷类 - 用于构建潜在扩散模型 |
|
带有可选卷积的上采样层。 |
函数
创建一个 1D, 2D 或 3D 平均池化模块。 |
|
构建一个用于 embedding 投影的 embedding 模块字典 |
|
创建一个 1D, 2D 或 3D 卷积模块。 |
|
将模块参数归零并返回。 |
|
创建正弦时间步 embedding。 |
参考
- speechbrain.nnet.unet.fixup(module, use_fixup_init=True)[source]
将模块参数归零并返回。
- 参数:
module (torch.nn.Module) – 一个模块
use_fixup_init (bool) – 是否将参数归零。如果设置为 false,则该函数不做任何操作
- 返回类型:
固定后的模块
- speechbrain.nnet.unet.timestep_embedding(timesteps, dim, max_period=10000)[source]
创建正弦时间步 embedding。
- class speechbrain.nnet.unet.AttentionPool2d(spatial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: int | None = None)[source]
基类:
Module
二维注意力池化
改编自 CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
- 参数:
示例
>>> attn_pool = AttentionPool2d( ... spatial_dim=64, ... embed_dim=16, ... num_heads_channels=2, ... output_dim=4 ... ) >>> x = torch.randn(4, 1, 64, 64) >>> x_pool = attn_pool(x) >>> x_pool.shape torch.Size([4, 4])
- class speechbrain.nnet.unet.TimestepBlock(*args, **kwargs)[source]
基类:
Module
任何 forward() 方法将时间步 embedding 作为第二个参数的模块。
- class speechbrain.nnet.unet.TimestepEmbedSequential(*args: Module)[source]
- class speechbrain.nnet.unet.TimestepEmbedSequential(arg: OrderedDict[str, Module])
基类:
Sequential
,TimestepBlock
一个顺序模块,将时间步 embedding 作为额外输入传递给支持它的子模块。
示例
>>> from speechbrain.nnet.linear import Linear >>> class MyBlock(TimestepBlock): ... def __init__(self, input_size, output_size, emb_size): ... super().__init__() ... self.lin = Linear( ... n_neurons=output_size, ... input_size=input_size ... ) ... self.emb_proj = Linear( ... n_neurons=output_size, ... input_size=emb_size, ... ) ... def forward(self, x, emb): ... return self.lin(x) + self.emb_proj(emb) >>> tes = TimestepEmbedSequential( ... MyBlock(128, 64, 16), ... Linear( ... n_neurons=32, ... input_size=64 ... ) ... ) >>> x = torch.randn(4, 10, 128) >>> emb = torch.randn(4, 10, 16) >>> out = tes(x, emb) >>> out.shape torch.Size([4, 10, 32])
- class speechbrain.nnet.unet.Upsample(channels, use_conv, dims=2, out_channels=None)[source]
基类:
Module
带有可选卷积的上采样层。
- 参数:
示例
>>> ups = Upsample(channels=4, use_conv=True, dims=2, out_channels=8) >>> x = torch.randn(8, 4, 32, 32) >>> x_up = ups(x) >>> x_up.shape torch.Size([8, 8, 64, 64])
- class speechbrain.nnet.unet.Downsample(channels, use_conv, dims=2, out_channels=None)[source]
基类:
Module
带有可选卷积的下采样层。
- 参数:
示例
>>> ups = Downsample(channels=4, use_conv=True, dims=2, out_channels=8) >>> x = torch.randn(8, 4, 32, 32) >>> x_up = ups(x) >>> x_up.shape torch.Size([8, 8, 16, 16])
- class speechbrain.nnet.unet.ResBlock(channels, emb_channels, dropout, out_channels=None, use_conv=False, dims=2, up=False, down=False, norm_num_groups=32, use_fixup_init=True)[source]
基类:
TimestepBlock
一个可以可选更改通道数量的残差块。
- 参数:
channels (int) – 输入通道的数量。
emb_channels (int) – 时间步嵌入通道的数量。
dropout (float) – dropout 比率。
out_channels (int) – 如果指定,则为输出通道的数量。
use_conv (bool) – 如果为 True 且指定了 out_channels,则使用空间卷积而不是较小的 1x1 卷积来更改跳跃连接中的通道。
dims (int) – 确定信号是 1D、2D 还是 3D。
up (bool) – 如果为 True,则使用此块进行上采样。
down (bool) – 如果为 True,则使用此块进行下采样。
norm_num_groups (int) – 组归一化的分组数
use_fixup_init (bool) – 是否使用 FixUp 初始化
示例
>>> res = ResBlock( ... channels=4, ... emb_channels=8, ... dropout=0.1, ... norm_num_groups=2, ... use_conv=True, ... ) >>> x = torch.randn(2, 4, 32, 32) >>> emb = torch.randn(2, 8) >>> res_out = res(x, emb) >>> res_out.shape torch.Size([2, 4, 32, 32])
- class speechbrain.nnet.unet.AttentionBlock(channels, num_heads=1, num_head_channels=-1, norm_num_groups=32, use_fixup_init=True)[source]
基类:
Module
一个注意力块,允许空间位置相互注意。最初移植自此处,但已调整适用于 N-d 情况。 https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66。
- 参数:
示例
>>> attn = AttentionBlock( ... channels=8, ... num_heads=4, ... num_head_channels=4, ... norm_num_groups=2 ... ) >>> x = torch.randn(4, 8, 16, 16) >>> out = attn(x) >>> out.shape torch.Size([4, 8, 16, 16])
- class speechbrain.nnet.unet.QKVAttention(n_heads)[source]
基类:
Module
一个执行 QKV 注意力并以不同顺序分割的模块。
- 参数:
n_heads (int) – 注意力头数量。
示例
>>> attn = QKVAttention(4) >>> n = 4 >>> c = 8 >>> h = 64 >>> w = 16 >>> qkv = torch.randn(4, (3 * h * c), w) >>> out = attn(qkv) >>> out.shape torch.Size([4, 512, 16])
- speechbrain.nnet.unet.build_emb_proj(emb_config, proj_dim=None, use_emb=None)[source]
构建一个用于 embedding 投影的 embedding 模块字典
- class speechbrain.nnet.unet.UNetModel(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, emb_dim=None, cond_emb=None, use_cond_emb=None, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, norm_num_groups=32, resblock_updown=False, use_fixup_init=True)[source]
基类:
Module
带有注意力和时间步 embedding 的完整 UNet 模型。
- 参数:
in_channels (int) – 输入 torch.Tensor 中的通道数。
model_channels (int) – 模型的基通道数。
out_channels (int) – 输出 torch.Tensor 中的通道数。
num_res_blocks (int) – 每次下采样的残差块数量。
attention_resolutions (int) – 发生注意的下采样率集合。可以是集合、列表或元组。例如,如果包含 4,则在 4 倍下采样时将使用注意力。
dropout (float) – dropout 概率。
channel_mult (int) – UNet 各层通道乘数。
conv_resample (bool) – 如果为 True,则使用学习到的卷积进行上采样和下采样
dims (int) – 确定信号是 1D、2D 还是 3D。
emb_dim (int) – 时间嵌入维度(默认为 model_channels * 4)
cond_emb (dict) –
模型将以此为条件的嵌入
Example: {
- ”speaker”: {
“emb_dim”: 256
}, “label”: {
”emb_dim”: 12
}
}
use_cond_emb (dict) –
一个字典,其键对应于 cond_emb 中的键,值对应于用于开启和关闭嵌入的布尔值。这与 hparams 文件结合使用非常有用,可以通过简单的开关来开启和关闭嵌入
Example: {“speaker”: False, “label”: True}
num_heads (int) – 每个注意力层中的注意力头数量。
num_head_channels (int) – 如果指定,则忽略 num_heads,转而使用每个注意力头的固定通道宽度。
num_heads_upsample (int) – 与 num_heads 一起使用,为上采样设置不同的头数量。已弃用。
norm_num_groups (int) – 归一化中的分组数,默认为 32
resblock_updown (bool) – 是否使用残差块进行上/下采样。
use_fixup_init (bool) – 是否使用 FixUp 初始化
示例
>>> model = UNetModel( ... in_channels=3, ... model_channels=32, ... out_channels=1, ... num_res_blocks=1, ... attention_resolutions=[1] ... ) >>> x = torch.randn(4, 3, 16, 32) >>> ts = torch.tensor([10, 100, 50, 25]) >>> out = model(x, ts) >>> out.shape torch.Size([4, 1, 16, 32])
- class speechbrain.nnet.unet.EncoderUNetModel(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, norm_num_groups=32, resblock_updown=False, pool=None, attention_pool_dim=None, out_kernel_size=3, use_fixup_init=True)[source]
基类:
Module
带有注意力和时间步嵌入的半 UNet 模型。用法详见 UNetModel。
- 参数:
in_channels (int) – 输入 torch.Tensor 中的通道数。
model_channels (int) – 模型的基通道数。
out_channels (int) – 输出 torch.Tensor 中的通道数。
num_res_blocks (int) – 每次下采样的残差块数量。
attention_resolutions (int) – 发生注意的下采样率集合。可以是集合、列表或元组。例如,如果包含 4,则在 4 倍下采样时将使用注意力。
dropout (float) – dropout 概率。
channel_mult (int) – UNet 各层通道乘数。
conv_resample (bool) – 如果为 True,则使用学习到的卷积进行上采样和下采样
dims (int) – 确定信号是 1D、2D 还是 3D。
num_heads (int) – 每个注意力层中的注意力头数量。
num_head_channels (int) – 如果指定,则忽略 num_heads,转而使用每个注意力头的固定通道宽度。
num_heads_upsample (int) – 与 num_heads 一起使用,为上采样设置不同的头数量。已弃用。
norm_num_groups (int) – 归一化中的分组数,默认为 32。
resblock_updown (bool) – 是否使用残差块进行上/下采样。
pool (str) – 要使用的池化类型,取值包括: [“adaptive”, “attention”, “spatial”, “spatial_v2”].
attention_pool_dim (int) – 应用注意力池化的维度。
out_kernel_size (int) – 输出卷积的核大小
use_fixup_init (bool) – 是否使用 FixUp 初始化
示例
>>> model = EncoderUNetModel( ... in_channels=3, ... model_channels=32, ... out_channels=1, ... num_res_blocks=1, ... attention_resolutions=[1] ... ) >>> x = torch.randn(4, 3, 16, 32) >>> ts = torch.tensor([10, 100, 50, 25]) >>> out = model(x, ts) >>> out.shape torch.Size([4, 1, 2, 4])
- class speechbrain.nnet.unet.EmbeddingProjection(emb_dim, proj_dim)[source]
基类:
Module
一个计算 embedding 向量投影到指定维度的简单模块
示例
>>> mod_emb_proj = EmbeddingProjection(emb_dim=16, proj_dim=64) >>> emb = torch.randn(4, 16) >>> emb_proj = mod_emb_proj(emb) >>> emb_proj.shape torch.Size([4, 64])
- class speechbrain.nnet.unet.DecoderUNetModel(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, resblock_updown=False, norm_num_groups=32, out_kernel_size=3, use_fixup_init=True)[source]
基类:
Module
带有注意力和时间步嵌入的半 UNet 模型。用法详见 UNet。
- 参数:
in_channels (int) – 输入 torch.Tensor 中的通道数。
model_channels (int) – 模型的基通道数。
out_channels (int) – 输出 torch.Tensor 中的通道数。
num_res_blocks (int) – 每次下采样的残差块数量。
attention_resolutions (int) – 发生注意的下采样率集合。可以是集合、列表或元组。例如,如果包含 4,则在 4 倍下采样时将使用注意力。
dropout (float) – dropout 概率。
channel_mult (int) – UNet 各层通道乘数。
conv_resample (bool) – 如果为 True,则使用学习到的卷积进行上采样和下采样
dims (int) – 确定信号是 1D、2D 还是 3D。
num_heads (int) – 每个注意力层中的注意力头数量。
num_head_channels (int) –
- 如果指定,则忽略 num_heads 并改用
每个注意力头的固定通道宽度。
num_heads_upsample (int) –
- 与 num_heads 一起使用,为上采样设置不同数量
的头。已弃用。
resblock_updown (bool) – 是否使用残差块进行上/下采样。
norm_num_groups (int) – 归一化中使用的分组数,默认为 32
out_kernel_size (int) – 输出核大小,默认为 3
use_fixup_init (bool) – 是否使用 FixUp 初始化
示例
>>> model = DecoderUNetModel( ... in_channels=1, ... model_channels=32, ... out_channels=3, ... num_res_blocks=1, ... attention_resolutions=[1] ... ) >>> x = torch.randn(4, 1, 2, 4) >>> ts = torch.tensor([10, 100, 50, 25]) >>> out = model(x, ts) >>> out.shape torch.Size([4, 3, 16, 32])
- class speechbrain.nnet.unet.DownsamplingPadding(factor, len_dim=2, dims=None)[source]
基类:
Module
一个应用下采样因子所需 padding 的 wrapper 模块
示例
>>> padding = DownsamplingPadding(factor=4, dims=[1, 2], len_dim=1) >>> x = torch.randn(4, 7, 14) >>> length = torch.tensor([1., 0.8, 1., 0.7]) >>> x, length_new = padding(x, length) >>> x.shape torch.Size([4, 8, 16]) >>> length_new tensor([0.8750, 0.7000, 0.8750, 0.6125])
- class speechbrain.nnet.unet.UNetNormalizingAutoencoder(in_channels, model_channels, encoder_out_channels, latent_channels, encoder_num_res_blocks, encoder_attention_resolutions, decoder_num_res_blocks, decoder_attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), dims=2, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, norm_num_groups=32, resblock_updown=False, out_kernel_size=3, len_dim=2, out_mask_value=0.0, latent_mask_value=0.0, use_fixup_norm=False, downsampling_padding=None)[source]
-
一个基于 UNet 的变分自编码器 (VAE) 便捷类 - 用于构建潜在扩散模型
- 参数:
in_channels (int) – 输入通道数量
model_channels (int) – UNet 编码器和解码器的卷积层中的通道数量
encoder_out_channels (int) – 编码器将输出的通道数量
latent_channels (int) – 潜在空间中的通道数量
encoder_num_res_blocks (int) – 编码器中的残差块数量
encoder_attention_resolutions (list) – 在编码器中应用注意力层的分辨率
decoder_num_res_blocks (int) – 解码器中的残差块数量
decoder_attention_resolutions (list) – 在编码器中应用注意力层的分辨率
dropout (float) – dropout 概率
channel_mult (tuple) – 每层的通道乘数
dims (int) – 要使用的卷积维度(1、2 或 3)
num_heads (int) – 注意力头数量
num_head_channels (int) – 注意力头中的通道数量
num_heads_upsample (int) – 上采样头的数量
norm_num_groups (int) – 归一化组数量,默认为 32
resblock_updown (bool) – 是否使用残差块进行上采样和下采样
out_kernel_size (int) – 输出卷积层的核大小(如果适用)
len_dim (int) – 输出的大小。
out_mask_value (float) – 掩蔽输出时填充的值。
latent_mask_value (float) – 掩蔽潜在变量时填充的值。
use_fixup_norm (bool) – 是否使用 FixUp 归一化
downsampling_padding (int) – 下采样中应用的填充量,默认为 2 ** len(channel_mult)
示例
>>> unet_ae = UNetNormalizingAutoencoder( ... in_channels=1, ... model_channels=4, ... encoder_out_channels=16, ... latent_channels=3, ... encoder_num_res_blocks=1, ... encoder_attention_resolutions=[], ... decoder_num_res_blocks=1, ... decoder_attention_resolutions=[], ... norm_num_groups=2, ... ) >>> x = torch.randn(4, 1, 32, 32) >>> x_enc = unet_ae.encode(x) >>> x_enc.shape torch.Size([4, 3, 4, 4]) >>> x_dec = unet_ae.decode(x_enc) >>> x_dec.shape torch.Size([4, 1, 32, 32])