speechbrain.utils.Accuracy 模块

计算准确率。

作者 * Jianyuan Zhong 2020

概述

AccuracyStats

用于计算总体一步前向预测准确率的模块。

函数

Accuracy

计算批量预测对数概率和目标的准确率。

参考

speechbrain.utils.Accuracy.Accuracy(log_probabilities, targets, length=None)[source]

计算批量预测对数概率和目标的准确率。

参数:
  • log_probabilities (torch.Tensor) – 预测对数概率 (batch_size, time, feature)。

  • targets (torch.Tensor) – 目标 (batch_size, time)。

  • length (torch.Tensor) – 目标的长度 (batch_size,)。

返回:

  • numerator (float) – 正确样本的数量

  • denominator (float) – 样本总数

示例

>>> probs = torch.tensor([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2]]).unsqueeze(0)
>>> acc = Accuracy(torch.log(probs), torch.tensor([1, 1, 0]).unsqueeze(0), torch.tensor([2/3]))
>>> print(acc)
(1.0, 2.0)
class speechbrain.utils.Accuracy.AccuracyStats[source]

基类: object

用于计算总体一步前向预测准确率的模块。

示例

>>> probs = torch.tensor([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2]]).unsqueeze(0)
>>> stats = AccuracyStats()
>>> stats.append(torch.log(probs), torch.tensor([1, 1, 0]).unsqueeze(0), torch.tensor([2/3]))
>>> acc = stats.summarize()
>>> print(acc)
0.5
append(log_probabilities, targets, length=None)[source]

此函数用于根据当前批次的预测和目标更新统计信息。

参数:
  • log_probabilities (torch.Tensor) – 预测对数概率 (batch_size, time, feature)。

  • targets (torch.Tensor) – 目标 (batch_size, time)。

  • length (torch.Tensor) – 目标的长度 (batch_size,)。

summarize()[source]

计算准确率指标。