speechbrain.utils.edit_distance 模块
编辑距离和 WER 计算。
- 作者
Aku Rouhe 2020
Salima Mdhaffar 2021
摘要
函数
计算一个批次的词错误率和相关计数。 |
|
从编辑操作表中获取编辑距离对齐。 |
|
计算编辑操作表中最短编辑路径中的编辑操作次数。 |
|
a 和 b 之间的编辑操作表。 |
|
查找词错误率最高的 K 位说话人。 |
|
查找词错误率最高的 K 个话语。 |
|
按说话人分组计算词错误率和另一关键信息。 |
|
计算每个话语的丰富 WER 信息。 |
|
|
|
根据 details_by_utterance 的输出计算汇总统计信息 |
参考
- speechbrain.utils.edit_distance.accumulatable_wer_stats(refs, hyps, stats={}, equality_comparator: ~typing.Callable[[str, str], bool] = <function _str_equals>)[source]
计算一个批次的词错误率和相关计数。
也可以通过将输出传递回函数调用以便处理下一个批次,从而累积多个批次的计数。
- 参数:
refs (iterable) – 参考序列的批次。
hyps (iterable) – 假设序列的批次。
stats (collections.Counter) – 正在运行的统计信息。将此函数的输出作为此参数传递回来以累积计数。最好是自己初始化统计信息;此时应使用空的 collections.Counter()。
equality_comparator (Callable[[str, str], bool]) – 用于检查两个词是否相等的函数。
- 返回值:
更新后的运行统计信息,包含以下键:
“WER” - 词错误率
“insertions” - 插入次数
“deletions” - 删除次数
“substitutions” - 替换次数
“num_ref_tokens” - 参考 Token 数量
- 返回类型:
示例
>>> import collections >>> batches = [[[[1,2,3],[4,5,6]], [[1,2,4],[5,6]]], ... [[[7,8], [9]], [[7,8], [10]]]] >>> stats = collections.Counter() >>> for batch in batches: ... refs, hyps = batch ... stats = accumulatable_wer_stats(refs, hyps, stats) >>> print("%WER {WER:.2f}, {num_ref_tokens} ref tokens".format(**stats)) %WER 33.33, 9 ref tokens
- speechbrain.utils.edit_distance.op_table(a, b, equality_comparator: ~typing.Callable[[str, str], bool] = <function _str_equals>)[source]
a 和 b 之间的编辑操作表。
计算编辑操作表,主要用于计算词错误率。该表的大小为
[|a|+1, |b|+1]
,表中每个点(i, j)
都有一个编辑操作。可以通过确定性地回溯这些编辑操作,找到从a[:i-1]
到b[:j-1]
的最短编辑路径。索引为零 (i=0
或j=0
) 对应于空序列。该算法本身是众所周知的,参见
请注意,在某些情况下,存在多条有效的编辑操作路径,它们会产生相同的最小编辑距离。
- 参数:
- 返回值:
列表的列表,矩阵,编辑操作表。
- 返回类型:
示例
>>> ref = [1,2,3] >>> hyp = [1,2,4] >>> for row in op_table(ref, hyp): ... print(row) ['=', 'I', 'I', 'I'] ['D', '=', 'I', 'I'] ['D', 'D', '=', 'I'] ['D', 'D', 'D', 'S']
- speechbrain.utils.edit_distance.alignment(table)[source]
从编辑操作表中获取编辑距离对齐。
回溯通过调用
table(a, b)
生成的编辑操作表,并收集 a 到 b 的编辑距离对齐。对齐显示 a 中的哪个 Token 对应 b 中的哪个 Token。请注意,对齐是单调的,一对零或一对一。- 参数:
table (list) – 来自
op_table(a, b)
的编辑操作表。- 返回值:
模式:
[(str <edit-op>, int-or-None <i>, int-or-None <j>),]
编辑操作列表,以及对应到 a 和 b 的索引。有关编辑操作,请参阅 EDIT_SYMBOLS 字典。i 索引 a,j 索引 b,索引可以是 None,表示与任何内容都不对齐。- 返回类型:
示例
>>> # table for a=[1,2,3], b=[1,2,4]: >>> table = [['I', 'I', 'I', 'I'], ... ['D', '=', 'I', 'I'], ... ['D', 'D', '=', 'I'], ... ['D', 'D', 'D', 'S']] >>> print(alignment(table)) [('=', 0, 0), ('=', 1, 1), ('S', 2, 2)]
- speechbrain.utils.edit_distance.count_ops(table)[source]
计算编辑操作表中最短编辑路径中的编辑操作次数。
回溯由 table(a, b) 生成的编辑操作表,并计算最短编辑路径中的插入、删除和替换次数。此信息通常用于语音识别中,以分别报告不同错误类型的数量。
- 参数:
table (list) – 来自
op_table(a, b)
的编辑操作表。- 返回值:
编辑操作的计数,包含以下键:
“insertions”
“deletions”
“substitutions”
注意:并非所有键都可能在输出中显式出现,但对于缺少的键,collections.Counter 将返回 0。
- 返回类型:
示例
>>> table = [['I', 'I', 'I', 'I'], ... ['D', '=', 'I', 'I'], ... ['D', 'D', '=', 'I'], ... ['D', 'D', 'D', 'S']] >>> print(count_ops(table)) Counter({'substitutions': 1})
- speechbrain.utils.edit_distance.wer_details_for_batch(ids, refs, hyps, compute_alignments=False, equality_comparator: ~typing.Callable[[str, str], bool] = <function _str_equals>)[source]
wer_details_by_utterance
的便捷批处理接口。wer_details_by_utterance
可以处理缺失的假设,但有时(例如使用贪婪解码进行 CTC 训练)它们不是必需的,在这种情况下,这是一个便捷的接口。- 参数:
- 返回值:
参见
wer_details_by_utterance
- 返回类型:
示例
>>> ids = [['utt1'], ['utt2']] >>> refs = [[['a','b','c']], [['d','e']]] >>> hyps = [[['a','b','d']], [['d','e']]] >>> wer_details = [] >>> for ids_batch, refs_batch, hyps_batch in zip(ids, refs, hyps): ... details = wer_details_for_batch(ids_batch, refs_batch, hyps_batch) ... wer_details.extend(details) >>> print(wer_details[0]['key'], ":", ... "{:.2f}".format(wer_details[0]['WER'])) utt1 : 33.33
- speechbrain.utils.edit_distance.wer_details_by_utterance(ref_dict, hyp_dict, compute_alignments=False, scoring_mode='strict', equality_comparator: ~typing.Callable[[str, str], bool] = <function _str_equals>)[source]
计算每个话语的丰富 WER 信息。
此信息随后可用于计算汇总详情 (WER, SER)。
- 参数:
ref_dict (dict) – 应可通过话语 ID 索引,并为每个话语 ID 返回参考 Token 的可迭代对象。
hyp_dict (dict) – 应可通过话语 ID 索引,并为每个话语 ID 返回假设 Token 的可迭代对象。
compute_alignments (bool) – 是否也应保存对齐。这也会保存 Token 本身,因为打印对齐可能需要它们。
scoring_mode ({'strict', 'all', 'present'}) –
如何处理缺失的假设(hyp_dict 中找不到参考话语 ID)。
‘strict’:对缺失的假设引发错误。
‘all’:将缺失的假设记为空。
‘present’:仅对现有假设进行评分。
equality_comparator (Callable[[str, str], bool]) – 用于检查两个词是否相等的函数。
- 返回值:
一个列表,每个参考话语有一个条目。每个条目都是一个字典,包含以下键:
“key”:话语 ID
“scored”:(bool) 该话语是否被评分。
“hyp_absent”:(bool) 如果未找到假设,则为 True。
“hyp_empty”:(bool) 如果假设被视为空(无论是由于它为空,还是未找到且 mode 为 ‘all’),则为 True。
“num_edits”:(int) 总编辑次数。
“num_ref_tokens”:(int) 参考 Token 数量。
“WER”:(float) 该话语的词错误率。
“insertions”:(int) 插入次数。
“deletions”:(int) 删除次数。
“substitutions”:(int) 替换次数。
“alignment”:如果 compute_alignments 为 True,则为对齐列表,参见
speechbrain.utils.edit_distance.alignment
。如果 compute_alignments 为 False,则为 None。“ref_tokens”:(iterable) 参考 Token,仅在计算对齐时保存,否则为 None。
“hyp_tokens”:(iterable) 假设 Token,仅在计算对齐时保存,否则为 None。
- 返回类型:
- 抛出:
KeyError – 如果评分模式为 ‘strict’ 且未找到假设。
- speechbrain.utils.edit_distance.wer_summary(details_by_utterance)[source]
根据 details_by_utterance 的输出计算汇总统计信息
摘要统计信息,例如 WER
- 参数:
details_by_utterance (list) – 参见 wer_details_by_utterance 的输出
- 返回值:
字典,包含以下键:
“WER”:(float) 词错误率。
“SER”:(float) 句子错误率(至少包含一个错误的话语百分比)。
“num_edits”:(int) 总编辑次数。
“num_scored_tokens”:(int) 评分的参考话语中的总 Token 数量(缺失的假设在使用 ‘all’ 评分模式时可能仍被评分)。
“num_erroneous_sents”:(int) 至少包含一个错误的话语总数。
“num_scored_sents”:(int) 已评分的话语总数。
“num_absent_sents”:(int) 未找到的假设数量。
“num_ref_sents”:(int) 所有参考话语的数量。
“insertions”:(int) 总插入次数。
“deletions”:(int) 总删除次数。
“substitutions”:(int) 总替换次数。
注意:在某些情况下,插入、删除和替换的数量存在歧义。我们的目标是复现 Kaldi compute_wer 的结果。
- 返回类型:
- speechbrain.utils.edit_distance.wer_details_by_speaker(details_by_utterance, utt2spk)[source]
按说话人分组计算词错误率和另一关键信息。
- 参数:
- 返回值:
将说话人 ID 映射到统计信息的字典,包含以下键:
“speaker”:说话人 ID,
“num_edits”:(int) 该说话人的总编辑次数。
“insertions”:(int) 该说话人的插入次数。
“dels”:(int) 该说话人的删除次数。
“subs”:(int) 该说话人的替换次数。
“num_scored_tokens”:(int) 该说话人评分的参考 Token 数量(缺失的假设在使用 ‘all’ 评分模式时可能仍被评分)。
“num_scored_sents”:(int) 该说话人已评分的话语数量。
“num_erroneous_sents”:(int) 该说话人至少包含一个错误的话语数量。
“num_absent_sents”:(int) 该说话人未找到假设的话语数量。
“num_ref_sents”:(int) 该说话人的总话语数量。
- 返回类型: