speechbrain.utils.Accuracy 模块
计算准确率。
作者 * Jianyuan Zhong 2020
概述
类
用于计算总体一步前向预测准确率的模块。 |
函数
计算批量预测对数概率和目标的准确率。 |
参考
- speechbrain.utils.Accuracy.Accuracy(log_probabilities, targets, length=None)[source]
计算批量预测对数概率和目标的准确率。
- 参数:
log_probabilities (torch.Tensor) – 预测对数概率 (batch_size, time, feature)。
targets (torch.Tensor) – 目标 (batch_size, time)。
length (torch.Tensor) – 目标的长度 (batch_size,)。
- 返回:
numerator (float) – 正确样本的数量
denominator (float) – 样本总数
示例
>>> probs = torch.tensor([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2]]).unsqueeze(0) >>> acc = Accuracy(torch.log(probs), torch.tensor([1, 1, 0]).unsqueeze(0), torch.tensor([2/3])) >>> print(acc) (1.0, 2.0)
- class speechbrain.utils.Accuracy.AccuracyStats[source]
基类:
object
用于计算总体一步前向预测准确率的模块。
示例
>>> probs = torch.tensor([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2]]).unsqueeze(0) >>> stats = AccuracyStats() >>> stats.append(torch.log(probs), torch.tensor([1, 1, 0]).unsqueeze(0), torch.tensor([2/3])) >>> acc = stats.summarize() >>> print(acc) 0.5