speechbrain.utils.semdist 模块
提供用于 SemDist 指标的度量类。
作者 * Sylvain de Langen 2024
摘要
类
实现 SemDist 指标的基类,用于估计每对目标文本和预测文本之间的单一余弦相似度的变体。 |
|
使用提供的 HuggingFace Transformers 文本编码器计算 SemDist 指标。 |
参考
- class speechbrain.utils.semdist.BaseSemDistStats(embed_function: Callable[[List[str]], Tensor], scale: float = 1000.0, batch_size: int = 64)[source]
基类:
MetricStats
实现 SemDist 指标的基类,用于估计每对目标文本和预测文本之间的单一余弦相似度的变体。SemDist 指标在论文 Evaluating User Perception of Speech Recognition System Quality with Semantic Distance Metric 中有描述。
- 参数:
- summarize(field=None)[source]
总结 SemDist 指标得分。执行实际的嵌入函数调用和 SemDist 计算。
所有字段的完整集合:-
semdist
:所有 utterance 的平均 SemDist,乘以可选在初始化时指定的缩放因子。
此外,此函数会为每对句子填充一个
scores
列表。该列表的每个条目都是一个字典,包含以下字段:-key
:utterance 的 ID。-semdist
:该 utterance 的 SemDist,乘以缩放因子。- 参数:
field (str, optional) – 如果只对其中一个字段感兴趣,则返回该字段。如果指定,将返回一个单独的
float
值,否则返回一个字典。- 返回值:
dict from str to float, if
field is None
– 上述文档中描述的字段字典。float, if
field is not None
– 由field
选择的单个字段。
- class speechbrain.utils.semdist.SemDistStats(lm, : Literal['meanpool', 'cls'] = 'meanpool', *args, **kwargs)[source]
基类:
BaseSemDistStats
使用提供的 HuggingFace Transformers 文本编码器计算 SemDist 指标。
- 参数:
lm (speechbrain.integrations.huggingface.TextEncoder) – 用作 LM 的 HF Transformers 分词器和文本编码器包装器。
method ("meanpool" or "cls") –
"meanpool"
(默认):计算所有上下文化的嵌入(不包括填充标记)的均值。"cls"
:仅使用第一个上下文化的嵌入,在使用类似 BERT 的分词器时,它是[CLS]
标记,通常旨在捕获分类信息。
*args – 传递给基类构造函数的额外位置参数。
**kwargs – 传递给基类构造函数的额外关键字参数。