speechbrain.nnet.unet 模块

用于扩散模型的 UNet 模型实现

改编自 OpenAI guided diffusion,并进行了一些修改和新增功能 https://github.com/openai/guided-diffusion

MIT 许可证

版权所有 (c) 2021 OpenAI

特此免费授予获得本软件及相关文档文件(“软件”)副本的任何人,不受限制地处理本软件的权利,包括但不限于使用、复制、修改、合并、出版、分发、再许可和/或销售软件副本的权利,并允许向其提供软件的人这样做,但须遵守以下条件

上述版权声明和本许可声明应包含在软件的所有副本或重要部分中。

本软件按“原样”提供,不提供任何明示或暗示的担保,包括但不限于对适销性、特定用途适用性和不侵权的担保。在任何情况下,作者或版权持有人均不对因本软件或本软件的使用或其他交易而产生或与之相关的任何索赔、损害或其他责任承担责任,无论是在合同行为、侵权行为或其他方面。

作者
  • Artem Ploujnikov 2022

摘要

AttentionBlock

一个允许空间位置相互注意的注意力块。

AttentionPool2d

二维注意力池化

DecoderUNetModel

带注意力和时间步 embedding 的半 UNet 模型。

Downsample

带有可选卷积的下采样层。

DownsamplingPadding

一个应用下采样因子所需 padding 的 wrapper 模块

EmbeddingProjection

一个计算 embedding 向量投影到指定维度的简单模块

EncoderUNetModel

带注意力和时间步 embedding 的半 UNet 模型。

QKVAttention

一个执行 QKV 注意力并以不同顺序分割的模块。

ResBlock

一个可以可选更改通道数量的残差块。

TimestepBlock

任何 forward() 方法将时间步 embedding 作为第二个参数的模块。

TimestepEmbedSequential

一个顺序模块,将时间步 embedding 作为额外输入传递给支持它的子模块。

UNetModel

带有注意力和时间步 embedding 的完整 UNet 模型。

UNetNormalizingAutoencoder

一个基于 UNet 的变分自编码器 (VAE) 便捷类 - 用于构建潜在扩散模型

Upsample

带有可选卷积的上采样层。

函数

avg_pool_nd

创建一个 1D, 2D 或 3D 平均池化模块。

build_emb_proj

构建一个用于 embedding 投影的 embedding 模块字典

conv_nd

创建一个 1D, 2D 或 3D 卷积模块。

fixup

将模块参数归零并返回。

timestep_embedding

创建正弦时间步 embedding。

参考

speechbrain.nnet.unet.fixup(module, use_fixup_init=True)[source]

将模块参数归零并返回。

参数:
  • module (torch.nn.Module) – 一个模块

  • use_fixup_init (bool) – 是否将参数归零。如果设置为 false,则该函数不做任何操作

返回类型:

固定后的模块

speechbrain.nnet.unet.conv_nd(dims, *args, **kwargs)[source]

创建一个 1D, 2D 或 3D 卷积模块。

参数:
  • dims (int) – 维度数量

  • *args (tuple)

  • **kwargs (dict) – 任何剩余参数都传递给构造函数

返回类型:

构建的 Conv 层

speechbrain.nnet.unet.avg_pool_nd(dims, *args, **kwargs)[source]

创建一个 1D, 2D 或 3D 平均池化模块。

speechbrain.nnet.unet.timestep_embedding(timesteps, dim, max_period=10000)[source]

创建正弦时间步 embedding。

参数:
  • timesteps (torch.Tensor) – 一个 1-D Tensor,包含 N 个索引,每个 batch 元素一个。这些索引可以是小数。

  • dim (int) – 输出的维度。

  • max_period (int) – 控制 embedding 的最小频率。

返回:

result – 一个 [N x dim] Tensor,包含位置 embedding。

返回类型:

torch.Tensor

class speechbrain.nnet.unet.AttentionPool2d(spatial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: int | None = None)[source]

基类: Module

二维注意力池化

改编自 CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py

参数:
  • spatial_dim (int) – 空间维度的大小

  • embed_dim (int) – embedding 维度

  • num_heads_channels (int) – 注意力头的数量

  • output_dim (int) – 输出维度

示例

>>> attn_pool = AttentionPool2d(
...     spatial_dim=64,
...     embed_dim=16,
...     num_heads_channels=2,
...     output_dim=4
... )
>>> x = torch.randn(4, 1, 64, 64)
>>> x_pool = attn_pool(x)
>>> x_pool.shape
torch.Size([4, 4])
forward(x)[source]

计算注意力前向传播

参数:

x (torch.Tensor) – 要进行注意力的张量

返回:

result – 注意力输出

返回类型:

torch.Tensor

class speechbrain.nnet.unet.TimestepBlock(*args, **kwargs)[source]

基类: Module

任何 forward() 方法将时间步 embedding 作为第二个参数的模块。

abstractmethod forward(x, emb=None)[source]

将模块应用于 x,并给定 emb 时间步 embedding。

参数:
  • x (torch.Tensor) – 数据张量

  • emb (torch.Tensor) – embedding 张量

class speechbrain.nnet.unet.TimestepEmbedSequential(*args: Module)[source]
class speechbrain.nnet.unet.TimestepEmbedSequential(arg: OrderedDict[str, Module])

基类: Sequential, TimestepBlock

一个顺序模块,将时间步 embedding 作为额外输入传递给支持它的子模块。

示例

>>> from speechbrain.nnet.linear import Linear
>>> class MyBlock(TimestepBlock):
...     def __init__(self, input_size, output_size, emb_size):
...         super().__init__()
...         self.lin = Linear(
...             n_neurons=output_size,
...             input_size=input_size
...         )
...         self.emb_proj = Linear(
...             n_neurons=output_size,
...             input_size=emb_size,
...         )
...     def forward(self, x, emb):
...         return self.lin(x) + self.emb_proj(emb)
>>> tes = TimestepEmbedSequential(
...     MyBlock(128, 64, 16),
...     Linear(
...         n_neurons=32,
...         input_size=64
...     )
... )
>>> x = torch.randn(4, 10, 128)
>>> emb = torch.randn(4, 10, 16)
>>> out = tes(x, emb)
>>> out.shape
torch.Size([4, 10, 32])
forward(x, emb=None)[source]

计算适用的时间步 embedding 的顺序传播

参数:
  • x (torch.Tensor) – 数据张量

  • emb (torch.Tensor) – 时间步 embedding

返回类型:

处理后的输入

class speechbrain.nnet.unet.Upsample(channels, use_conv, dims=2, out_channels=None)[source]

基类: Module

带有可选卷积的上采样层。

参数:
  • channels (torch.Tensor) – 输入和输出中的通道数。

  • use_conv (bool) – 一个布尔值,指示是否应用卷积。

  • dims (int) – 指示信号是 1D, 2D 还是 3D。如果是 3D,则上采样发生在内部两个维度。

  • out_channels (int) – 输出通道数。如果为 None,则与输入通道数相同。

示例

>>> ups = Upsample(channels=4, use_conv=True, dims=2, out_channels=8)
>>> x = torch.randn(8, 4, 32, 32)
>>> x_up = ups(x)
>>> x_up.shape
torch.Size([8, 8, 64, 64])
forward(x)[source]

计算上采样传播

参数:

x (torch.Tensor) – 层输入

返回:

result – 上采样输出

返回类型:

torch.Tensor

class speechbrain.nnet.unet.Downsample(channels, use_conv, dims=2, out_channels=None)[source]

基类: Module

带有可选卷积的下采样层。

参数:
  • channels (int) – 输入和输出中的通道数。

  • use_conv (bool) – 一个布尔值,指示是否应用卷积。

  • dims (int) – 确定信号是 1D、2D 还是 3D。如果是 3D,则下采样发生在内部两个维度。

  • out_channels (int) – 输出通道数。如果为 None,则与输入通道数相同。

示例

>>> ups = Downsample(channels=4, use_conv=True, dims=2, out_channels=8)
>>> x = torch.randn(8, 4, 32, 32)
>>> x_up = ups(x)
>>> x_up.shape
torch.Size([8, 8, 16, 16])
forward(x)[source]

计算下采样过程

参数:

x (torch.Tensor) – 层输入

返回:

result – 下采样的输出

返回类型:

torch.Tensor

class speechbrain.nnet.unet.ResBlock(channels, emb_channels, dropout, out_channels=None, use_conv=False, dims=2, up=False, down=False, norm_num_groups=32, use_fixup_init=True)[source]

基类: TimestepBlock

一个可以可选更改通道数量的残差块。

参数:
  • channels (int) – 输入通道的数量。

  • emb_channels (int) – 时间步嵌入通道的数量。

  • dropout (float) – dropout 比率。

  • out_channels (int) – 如果指定,则为输出通道的数量。

  • use_conv (bool) – 如果为 True 且指定了 out_channels,则使用空间卷积而不是较小的 1x1 卷积来更改跳跃连接中的通道。

  • dims (int) – 确定信号是 1D、2D 还是 3D。

  • up (bool) – 如果为 True,则使用此块进行上采样。

  • down (bool) – 如果为 True,则使用此块进行下采样。

  • norm_num_groups (int) – 组归一化的分组数

  • use_fixup_init (bool) – 是否使用 FixUp 初始化

示例

>>> res = ResBlock(
...     channels=4,
...     emb_channels=8,
...     dropout=0.1,
...     norm_num_groups=2,
...     use_conv=True,
... )
>>> x = torch.randn(2, 4, 32, 32)
>>> emb = torch.randn(2, 8)
>>> res_out = res(x, emb)
>>> res_out.shape
torch.Size([2, 4, 32, 32])
forward(x, emb=None)[source]

将块应用于 torch.Tensor,以时间步嵌入为条件。

参数:
  • x (torch.Tensor) – 一个 [N x C x …] 特征张量。

  • emb (torch.Tensor) – 一个 [N x emb_channels] 时间步嵌入张量。

返回:

result – 一个 [N x C x …] 输出张量。

返回类型:

torch.Tensor

class speechbrain.nnet.unet.AttentionBlock(channels, num_heads=1, num_head_channels=-1, norm_num_groups=32, use_fixup_init=True)[source]

基类: Module

一个注意力块,允许空间位置相互注意。最初移植自此处,但已调整适用于 N-d 情况。 https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66

参数:
  • channels (int) – 通道数量

  • num_heads (int) – 注意力头数量

  • num_head_channels (int) – 每个注意力头中的通道数量

  • norm_num_groups (int) – 用于组归一化的分组数

  • use_fixup_init (bool) – 是否使用 FixUp 初始化

示例

>>> attn = AttentionBlock(
...     channels=8,
...     num_heads=4,
...     num_head_channels=4,
...     norm_num_groups=2
... )
>>> x = torch.randn(4, 8, 16, 16)
>>> out = attn(x)
>>> out.shape
torch.Size([4, 8, 16, 16])
forward(x)[source]

完成前向传播

参数:

x (torch.Tensor) – 要进行注意操作的数据

返回:

result – 已应用注意的数据

返回类型:

torch.Tensor

class speechbrain.nnet.unet.QKVAttention(n_heads)[source]

基类: Module

一个执行 QKV 注意力并以不同顺序分割的模块。

参数:

n_heads (int) – 注意力头数量。

示例

>>> attn = QKVAttention(4)
>>> n = 4
>>> c = 8
>>> h = 64
>>> w = 16
>>> qkv = torch.randn(4, (3 * h * c), w)
>>> out = attn(qkv)
>>> out.shape
torch.Size([4, 512, 16])
forward(qkv)[source]

应用 QKV 注意力。

参数:

qkv (torch.Tensor) – 一个包含 Q、K 和 V 的 [N x (3 * H * C) x T] 张量。

返回:

result – 注意力后的 [N x (H * C) x T] 张量。

返回类型:

torch.Tensor

speechbrain.nnet.unet.build_emb_proj(emb_config, proj_dim=None, use_emb=None)[source]

构建一个用于 embedding 投影的 embedding 模块字典

参数:
  • emb_config (dict) – 配置字典

  • proj_dim (int) – 目标投影维度

  • use_emb (dict) – 一个可选的字典,包含用于开启和关闭嵌入的“开关”

返回:

result – 一个 ModuleDict,其中包含每个嵌入的模块

返回类型:

torch.nn.ModuleDict

class speechbrain.nnet.unet.UNetModel(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, emb_dim=None, cond_emb=None, use_cond_emb=None, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, norm_num_groups=32, resblock_updown=False, use_fixup_init=True)[source]

基类: Module

带有注意力和时间步 embedding 的完整 UNet 模型。

参数:
  • in_channels (int) – 输入 torch.Tensor 中的通道数。

  • model_channels (int) – 模型的基通道数。

  • out_channels (int) – 输出 torch.Tensor 中的通道数。

  • num_res_blocks (int) – 每次下采样的残差块数量。

  • attention_resolutions (int) – 发生注意的下采样率集合。可以是集合、列表或元组。例如,如果包含 4,则在 4 倍下采样时将使用注意力。

  • dropout (float) – dropout 概率。

  • channel_mult (int) – UNet 各层通道乘数。

  • conv_resample (bool) – 如果为 True,则使用学习到的卷积进行上采样和下采样

  • dims (int) – 确定信号是 1D、2D 还是 3D。

  • emb_dim (int) – 时间嵌入维度(默认为 model_channels * 4)

  • cond_emb (dict) –

    模型将以此为条件的嵌入

    Example: {

    ”speaker”: {

    “emb_dim”: 256

    }, “label”: {

    ”emb_dim”: 12

    }

    }

  • use_cond_emb (dict) –

    一个字典,其键对应于 cond_emb 中的键,值对应于用于开启和关闭嵌入的布尔值。这与 hparams 文件结合使用非常有用,可以通过简单的开关来开启和关闭嵌入

    Example: {“speaker”: False, “label”: True}

  • num_heads (int) – 每个注意力层中的注意力头数量。

  • num_head_channels (int) – 如果指定,则忽略 num_heads,转而使用每个注意力头的固定通道宽度。

  • num_heads_upsample (int) – 与 num_heads 一起使用,为上采样设置不同的头数量。已弃用。

  • norm_num_groups (int) – 归一化中的分组数,默认为 32

  • resblock_updown (bool) – 是否使用残差块进行上/下采样。

  • use_fixup_init (bool) – 是否使用 FixUp 初始化

示例

>>> model = UNetModel(
...    in_channels=3,
...    model_channels=32,
...    out_channels=1,
...    num_res_blocks=1,
...    attention_resolutions=[1]
... )
>>> x = torch.randn(4, 3, 16, 32)
>>> ts = torch.tensor([10, 100, 50, 25])
>>> out = model(x, ts)
>>> out.shape
torch.Size([4, 1, 16, 32])
forward(x, timesteps, cond_emb=None)[source]

将模型应用于输入批次。

参数:
  • x (torch.Tensor) – 一个 [N x C x …] 输入张量。

  • timesteps (torch.Tensor) – 1 维的时间步批次。

  • cond_emb (dict) – 一个字符串到张量的条件嵌入字典(支持多个嵌入)

返回:

result – 一个 [N x C x …] 输出张量。

返回类型:

torch.Tensor

diffusion_forward(x, timesteps, cond_emb=None, length=None, out_mask_value=None, latent_mask_value=None)[source]

适合由 diffusion 包装的前向函数。对于此模型,length/out_mask_value/latent_mask_value 未使用且被丢弃。详见 forward()

class speechbrain.nnet.unet.EncoderUNetModel(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, norm_num_groups=32, resblock_updown=False, pool=None, attention_pool_dim=None, out_kernel_size=3, use_fixup_init=True)[source]

基类: Module

带有注意力和时间步嵌入的半 UNet 模型。用法详见 UNetModel。

参数:
  • in_channels (int) – 输入 torch.Tensor 中的通道数。

  • model_channels (int) – 模型的基通道数。

  • out_channels (int) – 输出 torch.Tensor 中的通道数。

  • num_res_blocks (int) – 每次下采样的残差块数量。

  • attention_resolutions (int) – 发生注意的下采样率集合。可以是集合、列表或元组。例如,如果包含 4,则在 4 倍下采样时将使用注意力。

  • dropout (float) – dropout 概率。

  • channel_mult (int) – UNet 各层通道乘数。

  • conv_resample (bool) – 如果为 True,则使用学习到的卷积进行上采样和下采样

  • dims (int) – 确定信号是 1D、2D 还是 3D。

  • num_heads (int) – 每个注意力层中的注意力头数量。

  • num_head_channels (int) – 如果指定,则忽略 num_heads,转而使用每个注意力头的固定通道宽度。

  • num_heads_upsample (int) – 与 num_heads 一起使用,为上采样设置不同的头数量。已弃用。

  • norm_num_groups (int) – 归一化中的分组数,默认为 32。

  • resblock_updown (bool) – 是否使用残差块进行上/下采样。

  • pool (str) – 要使用的池化类型,取值包括: [“adaptive”, “attention”, “spatial”, “spatial_v2”].

  • attention_pool_dim (int) – 应用注意力池化的维度。

  • out_kernel_size (int) – 输出卷积的核大小

  • use_fixup_init (bool) – 是否使用 FixUp 初始化

示例

>>> model = EncoderUNetModel(
...    in_channels=3,
...    model_channels=32,
...    out_channels=1,
...    num_res_blocks=1,
...    attention_resolutions=[1]
... )
>>> x = torch.randn(4, 3, 16, 32)
>>> ts = torch.tensor([10, 100, 50, 25])
>>> out = model(x, ts)
>>> out.shape
torch.Size([4, 1, 2, 4])
forward(x, timesteps=None)[source]

将模型应用于输入批次。

参数:
  • x (torch.Tensor) – 一个 [N x C x …] 输入张量。

  • timesteps (torch.Tensor) – 1 维的时间步批次。

返回:

result – 一个 [N x K] 输出张量。

返回类型:

torch.Tensor

class speechbrain.nnet.unet.EmbeddingProjection(emb_dim, proj_dim)[source]

基类: Module

一个计算 embedding 向量投影到指定维度的简单模块

参数:
  • emb_dim (int) – 原始嵌入维度

  • proj_dim (int) – 目标投影空间的维度

示例

>>> mod_emb_proj = EmbeddingProjection(emb_dim=16, proj_dim=64)
>>> emb = torch.randn(4, 16)
>>> emb_proj = mod_emb_proj(emb)
>>> emb_proj.shape
torch.Size([4, 64])
forward(emb)[source]

计算前向传播

参数:

emb (torch.Tensor) – 原始嵌入张量

返回:

result – 目标嵌入空间

返回类型:

torch.Tensor

class speechbrain.nnet.unet.DecoderUNetModel(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, resblock_updown=False, norm_num_groups=32, out_kernel_size=3, use_fixup_init=True)[source]

基类: Module

带有注意力和时间步嵌入的半 UNet 模型。用法详见 UNet。

参数:
  • in_channels (int) – 输入 torch.Tensor 中的通道数。

  • model_channels (int) – 模型的基通道数。

  • out_channels (int) – 输出 torch.Tensor 中的通道数。

  • num_res_blocks (int) – 每次下采样的残差块数量。

  • attention_resolutions (int) – 发生注意的下采样率集合。可以是集合、列表或元组。例如,如果包含 4,则在 4 倍下采样时将使用注意力。

  • dropout (float) – dropout 概率。

  • channel_mult (int) – UNet 各层通道乘数。

  • conv_resample (bool) – 如果为 True,则使用学习到的卷积进行上采样和下采样

  • dims (int) – 确定信号是 1D、2D 还是 3D。

  • num_heads (int) – 每个注意力层中的注意力头数量。

  • num_head_channels (int) –

    如果指定,则忽略 num_heads 并改用

    每个注意力头的固定通道宽度。

  • num_heads_upsample (int) –

    与 num_heads 一起使用,为上采样设置不同数量

    的头。已弃用。

  • resblock_updown (bool) – 是否使用残差块进行上/下采样。

  • norm_num_groups (int) – 归一化中使用的分组数,默认为 32

  • out_kernel_size (int) – 输出核大小,默认为 3

  • use_fixup_init (bool) – 是否使用 FixUp 初始化

示例

>>> model = DecoderUNetModel(
...    in_channels=1,
...    model_channels=32,
...    out_channels=3,
...    num_res_blocks=1,
...    attention_resolutions=[1]
... )
>>> x = torch.randn(4, 1, 2, 4)
>>> ts = torch.tensor([10, 100, 50, 25])
>>> out = model(x, ts)
>>> out.shape
torch.Size([4, 3, 16, 32])
forward(x, timesteps=None)[source]

将模型应用于输入批次。

参数:
  • x (torch.Tensor) – 一个 [N x C x …] 输入张量。

  • timesteps (torch.Tensor) – 1 维的时间步批次。

返回:

result – 一个 [N x K] 输出张量。

返回类型:

torch.Tensor

class speechbrain.nnet.unet.DownsamplingPadding(factor, len_dim=2, dims=None)[source]

基类: Module

一个应用下采样因子所需 padding 的 wrapper 模块

参数:
  • factor (int) – 下采样/可整除因子

  • len_dim (int) – 长度变化的维度的索引

  • dims (list) – 要包含在填充中的维度列表

示例

>>> padding = DownsamplingPadding(factor=4, dims=[1, 2], len_dim=1)
>>> x = torch.randn(4, 7, 14)
>>> length = torch.tensor([1., 0.8, 1., 0.7])
>>> x, length_new = padding(x, length)
>>> x.shape
torch.Size([4, 8, 16])
>>> length_new
tensor([0.8750, 0.7000, 0.8750, 0.6125])
forward(x, length=None)[source]

应用填充

参数:
  • x (torch.Tensor) – 样本

  • length (torch.Tensor) – 长度张量

返回:

  • x_pad (torch.Tensor) – 填充后的张量

  • lens (torch.Tensor) – 新的调整后长度(如果适用)

class speechbrain.nnet.unet.UNetNormalizingAutoencoder(in_channels, model_channels, encoder_out_channels, latent_channels, encoder_num_res_blocks, encoder_attention_resolutions, decoder_num_res_blocks, decoder_attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), dims=2, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, norm_num_groups=32, resblock_updown=False, out_kernel_size=3, len_dim=2, out_mask_value=0.0, latent_mask_value=0.0, use_fixup_norm=False, downsampling_padding=None)[source]

基类: NormalizingAutoencoder

一个基于 UNet 的变分自编码器 (VAE) 便捷类 - 用于构建潜在扩散模型

参数:
  • in_channels (int) – 输入通道数量

  • model_channels (int) – UNet 编码器和解码器的卷积层中的通道数量

  • encoder_out_channels (int) – 编码器将输出的通道数量

  • latent_channels (int) – 潜在空间中的通道数量

  • encoder_num_res_blocks (int) – 编码器中的残差块数量

  • encoder_attention_resolutions (list) – 在编码器中应用注意力层的分辨率

  • decoder_num_res_blocks (int) – 解码器中的残差块数量

  • decoder_attention_resolutions (list) – 在编码器中应用注意力层的分辨率

  • dropout (float) – dropout 概率

  • channel_mult (tuple) – 每层的通道乘数

  • dims (int) – 要使用的卷积维度(1、2 或 3)

  • num_heads (int) – 注意力头数量

  • num_head_channels (int) – 注意力头中的通道数量

  • num_heads_upsample (int) – 上采样头的数量

  • norm_num_groups (int) – 归一化组数量,默认为 32

  • resblock_updown (bool) – 是否使用残差块进行上采样和下采样

  • out_kernel_size (int) – 输出卷积层的核大小(如果适用)

  • len_dim (int) – 输出的大小。

  • out_mask_value (float) – 掩蔽输出时填充的值。

  • latent_mask_value (float) – 掩蔽潜在变量时填充的值。

  • use_fixup_norm (bool) – 是否使用 FixUp 归一化

  • downsampling_padding (int) – 下采样中应用的填充量,默认为 2 ** len(channel_mult)

示例

>>> unet_ae = UNetNormalizingAutoencoder(
...     in_channels=1,
...     model_channels=4,
...     encoder_out_channels=16,
...     latent_channels=3,
...     encoder_num_res_blocks=1,
...     encoder_attention_resolutions=[],
...     decoder_num_res_blocks=1,
...     decoder_attention_resolutions=[],
...     norm_num_groups=2,
... )
>>> x = torch.randn(4, 1, 32, 32)
>>> x_enc = unet_ae.encode(x)
>>> x_enc.shape
torch.Size([4, 3, 4, 4])
>>> x_dec = unet_ae.decode(x_enc)
>>> x_dec.shape
torch.Size([4, 1, 32, 32])