speechbrain.lobes.models.transformer.TransformerLM 模块
Transformer 语言模型的实现。
作者 * Jianyuan Zhong * Samuele Cornell
摘要
类
这是 Transformer 语言模型的实现。 |
参考
- class speechbrain.lobes.models.transformer.TransformerLM.TransformerLM(vocab, d_model=512, nhead=8, num_encoder_layers=12, num_decoder_layers=0, d_ffn=2048, dropout=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>, positional_encoding='fixed_abs_sine', normalize_before=False, d_embedding=None, max_length=2500, causal=True, attention_type='regularMHA', decoder_use_memory=False)[source]
-
这是 Transformer 语言模型的实现。
该架构基于论文 "Attention Is All You Need":https://arxiv.org/pdf/1706.03762.pdf
- 参数:
vocab (int) – Embedding 词汇表大小
d_model (int) – 编码器/解码器输入中预期的特征数量(默认为 512)。
nhead (int) – 多头注意力模型中的头数(默认为 8)。
num_encoder_layers (int) – 编码器中的子编码器层数(默认为 12)。
num_decoder_layers (int) – 解码器中的子解码器层数(默认为 0)。
d_ffn (int) – 前馈网络模型的维度(默认为 2048)。
dropout (float) – Dropout 值(默认为 0.1)。
activation (torch class) – 编码器/解码器中间层的激活函数,relu 或 gelu(默认为 relu)。
positional_encoding (str) – 位置编码类型,默认为 "fixed_abs_sine"
normalize_before (bool) – 是否在每层之前进行归一化。
d_embedding (int) – Embedding 大小,如果为 None 则使用 d_model。
max_length (int) – 最大序列长度,默认为 2500 个 token。
causal (bool) – 在解码中是否考虑未来信息,默认为 True。
attention_type (str) – 使用的注意力类型,可以是 "regularMHA" 或 "RelPosMHAXL"
decoder_use_memory (bool) – 是否在解码器中使用隐藏状态
示例
>>> src = torch.randint(0, 720, [8, 120]) >>> net = TransformerLM(720, 512, 8, 1, 0, 1024, activation=torch.nn.GELU) >>> enc_out = net.forward(src) >>> print(enc_out.shape) torch.Size([8, 120, 720])