speechbrain.utils.kmeans 模块

用于训练 kmeans 模型的实用程序。

作者
  • Pooneh Mousavi 2023

摘要

函数

accumulate_and_extract_features

提取特征(SSL 模型的输出)并在 CPU 上累积,用于聚类。

fetch_kmeans_model

返回一个具有指定参数的 k-means 聚类模型。

process_chunks

以指定大小的块处理数据。

save_model

保存 Kmeans 模型。

train

训练 Kmeans 模型。

参考

speechbrain.utils.kmeans.accumulate_and_extract_features(batch, features_list, ssl_model, ssl_layer_num, device)[源代码]

提取特征(SSL 模型的输出)并在 CPU 上累积,用于聚类。

参数:
  • batch (tensor) – 单个数据批次。

  • features_list (list) – 累积特征列表。

  • ssl_model (torch.nn.Module) – 用于提取聚类特征的 SSL 模型。

  • ssl_layer_num (int) – 指定应使用 ssl_model 的哪个层的输出。

  • device (str) – cpucuda 设备。

speechbrain.utils.kmeans.fetch_kmeans_model(n_clusters, init, max_iter, batch_size, tol, max_no_improvement, n_init, reassignment_ratio, random_state, checkpoint_path)[源代码]

返回一个具有指定参数的 k-means 聚类模型。

参数:
  • n_clusters (MiniBatchKMeans) – 要形成的簇的数量以及要生成的质心的数量。

  • init (int) – 初始化方法:{‘k-means++’’, ‘’random’’}

  • max_iter (int) – 在停止之前对完整数据集的最大迭代次数,与任何早期停止准则启发式方法无关。

  • batch_size (int) – 迷你批次的尺寸。

  • tol (float) – 控制基于平滑的、方差归一化的平均中心平方位置变化的相对中心变化的早期停止。

  • max_no_improvement (int) – 控制基于未在平滑惯性上产生改进的连续迷你批次数的早期停止。

  • n_init (int) – 尝试的随机初始化次数

  • reassignment_ratio (float) – 控制中心重新分配的最大计数比例。

  • random_state (int) – 确定质心初始化和随机重新分配的随机数生成。

  • checkpoint_path (str) – 保存模型的路径。

返回:

一个具有指定参数的 k-means 聚类模型。

返回类型:

MiniBatchKMeans

speechbrain.utils.kmeans.process_chunks(data, chunk_size, model)[源代码]

以指定大小的块处理数据。

参数:
  • data (list) – 要处理的整数列表。

  • chunk_size (int) – 每个块的大小。

  • model (MiniBatchKMeans) – 用于训练的初始 kmeans 模型。

speechbrain.utils.kmeans.train(model, train_set, ssl_model, save_path, ssl_layer_num, kmeans_batch_size=1000, device='cpu', checkpoint_interval=10)[源代码]

训练 Kmeans 模型。

参数:
  • model (MiniBatchKMeans) – 用于训练的初始 kmeans 模型。

  • train_set (Dataloader) – 训练数据的批次。

  • ssl_model (torch.nn.Module) – 用于提取聚类特征的 SSL 模型。

  • save_path (string) – 保存内部检查点和 dataloader 的路径。

  • ssl_layer_num (int) – 指定应使用 ssl_model 的哪个层的输出。

  • kmeans_batch_size (int) – 迷你批次的尺寸。

  • device (str) – cpucuda 设备。

  • checkpoint_interval (int) – 确定在哪些迭代时保存检查点。

speechbrain.utils.kmeans.save_model(model, checkpoint_path)[源代码]

保存 Kmeans 模型。

参数:
  • model (MiniBatchKMeans) – 要保存的 kmeans 模型。

  • checkpoint_path (str) – 保存模型的路径。