菜单
本页目录

六、Jbeil evaluation

评估边缘预测和节点分类模型的性能。

评估包括计算平均精确度、AUC、召回率、精确度、混淆矩阵等指标。

1.eval_edge_prediction 函数

这个函数用于评估边缘预测模型的性能。

参数

  • model: 要评估的模型。
  • negative_edge_sampler: 负样本采样器,用于生成负样本。
  • data: 包含评估数据的对象。
  • n_neighbors: 每个节点的邻居数量。
  • batch_size: 批处理大小,默认为200。

步骤

  1. 重置负样本采样器的随机状态:确保每次评估时使用相同的负样本。

    negative_edge_sampler.reset_random_state()
    
  2. 计算批次数量:根据数据量和批处理大小计算需要处理的批次数量。

    num_test_instance = len(data.sources)
    num_test_batch = math.ceil(num_test_instance / TEST_BATCH_SIZE)
    
  3. 逐批处理数据:对每个批次的数据进行处理,计算正负样本的边缘概率。

    for k in range(num_test_batch):
        ...
        pos_prob, neg_prob = model.compute_edge_probabilities(sources_batch, destinations_batch,
                                                              negative_samples, timestamps_batch,
                                                              edge_idxs_batch, n_neighbors)
    
  4. 计算评估指标:使用 sklearn 计算AP、AUC、召回率、精确度、混淆矩阵等指标。

    val_ap.append(average_precision_score(true_label, pred_score))
    val_auc.append(roc_auc_score(true_label, pred_score))
    val_recall.append(recall_score(true_label, pred_score))
    val_precision.append(precision_score(true_label, pred_score))
    tn, fp, fn, tp = confusion_matrix(true_label, pred_score).ravel()
    val_fp.append(fp)
    val_fn.append(fn)
    val_tp.append(tp)
    val_tn.append(tn)
    
  5. 返回平均值:返回所有批次评估指标的平均值。

    return np.mean(val_ap), np.mean(val_auc), np.mean(val_recall), np.mean(val_precision), np.mean(val_fp), np.mean(val_fn), np.mean(val_tp), np.mean(val_tn)
    

2.eval_node_classification 函数

这个函数用于评估节点分类模型的性能。

参数

  • tgn: Temporal Graph Network (TGN) 模型。
  • decoder: 用于节点分类的解码器。
  • data: 包含评估数据的对象。
  • edge_idxs: 边缘索引。
  • batch_size: 批处理大小。
  • n_neighbors: 每个节点的邻居数量。

步骤

  1. 计算批次数量:根据数据量和批处理大小计算需要处理的批次数量。

    num_instance = len(data.sources)
    num_batch = math.ceil(num_instance / batch_size)
    
  2. 逐批处理数据:对每个批次的数据进行处理,计算节点的嵌入向量,并使用解码器进行分类。

    for k in range(num_batch):
        ...
        source_embedding, destination_embedding, _ = tgn.compute_temporal_embeddings(sources_batch,
                                                                                    destinations_batch,
                                                                                    destinations_batch,
                                                                                    timestamps_batch,
                                                                                    edge_idxs_batch,
                                                                                    n_neighbors)
        pred_prob_batch = decoder(source_embedding).sigmoid()
        pred_prob[s_idx: e_idx] = pred_prob_batch.cpu().numpy()
    
  3. 计算AUC-ROC:使用 sklearn 计算AUC-ROC值。

    auc_roc = roc_auc_score(data.labels, pred_prob)
    
  4. 返回AUC-ROC值

    return auc_roc
    

评估方法总结

  1. 边缘预测评估eval_edge_prediction 函数通过计算正负样本的边缘概率,评估模型在边缘预测任务上的性能。评估指标包括AP、AUC、召回率、精确度和混淆矩阵。

  2. 节点分类评估eval_node_classification 函数通过计算节点的嵌入向量,并使用解码器进行分类,评估模型在节点分类任务上的性能。评估指标为AUC-ROC。

这两个评估方法都使用批处理来处理数据,并且在每个批次之间保持模型的记忆状态,以便后续批次能够访问前几批次的交互信息。这种方法提高了评估的准确性和效率。