六、Jbeil evaluation
评估边缘预测和节点分类模型的性能。
评估包括计算平均精确度、AUC、召回率、精确度、混淆矩阵等指标。
1.eval_edge_prediction
函数
这个函数用于评估边缘预测模型的性能。
参数
model
: 要评估的模型。negative_edge_sampler
: 负样本采样器,用于生成负样本。data
: 包含评估数据的对象。n_neighbors
: 每个节点的邻居数量。batch_size
: 批处理大小,默认为200。
步骤
-
重置负样本采样器的随机状态:确保每次评估时使用相同的负样本。
negative_edge_sampler.reset_random_state()
-
计算批次数量:根据数据量和批处理大小计算需要处理的批次数量。
num_test_instance = len(data.sources) num_test_batch = math.ceil(num_test_instance / TEST_BATCH_SIZE)
-
逐批处理数据:对每个批次的数据进行处理,计算正负样本的边缘概率。
for k in range(num_test_batch): ... pos_prob, neg_prob = model.compute_edge_probabilities(sources_batch, destinations_batch, negative_samples, timestamps_batch, edge_idxs_batch, n_neighbors)
-
计算评估指标:使用
sklearn
计算AP、AUC、召回率、精确度、混淆矩阵等指标。val_ap.append(average_precision_score(true_label, pred_score)) val_auc.append(roc_auc_score(true_label, pred_score)) val_recall.append(recall_score(true_label, pred_score)) val_precision.append(precision_score(true_label, pred_score)) tn, fp, fn, tp = confusion_matrix(true_label, pred_score).ravel() val_fp.append(fp) val_fn.append(fn) val_tp.append(tp) val_tn.append(tn)
-
返回平均值:返回所有批次评估指标的平均值。
return np.mean(val_ap), np.mean(val_auc), np.mean(val_recall), np.mean(val_precision), np.mean(val_fp), np.mean(val_fn), np.mean(val_tp), np.mean(val_tn)
2.eval_node_classification
函数
这个函数用于评估节点分类模型的性能。
参数
tgn
: Temporal Graph Network (TGN) 模型。decoder
: 用于节点分类的解码器。data
: 包含评估数据的对象。edge_idxs
: 边缘索引。batch_size
: 批处理大小。n_neighbors
: 每个节点的邻居数量。
步骤
-
计算批次数量:根据数据量和批处理大小计算需要处理的批次数量。
num_instance = len(data.sources) num_batch = math.ceil(num_instance / batch_size)
-
逐批处理数据:对每个批次的数据进行处理,计算节点的嵌入向量,并使用解码器进行分类。
for k in range(num_batch): ... source_embedding, destination_embedding, _ = tgn.compute_temporal_embeddings(sources_batch, destinations_batch, destinations_batch, timestamps_batch, edge_idxs_batch, n_neighbors) pred_prob_batch = decoder(source_embedding).sigmoid() pred_prob[s_idx: e_idx] = pred_prob_batch.cpu().numpy()
-
计算AUC-ROC:使用
sklearn
计算AUC-ROC值。auc_roc = roc_auc_score(data.labels, pred_prob)
-
返回AUC-ROC值。
return auc_roc
评估方法总结
-
边缘预测评估:
eval_edge_prediction
函数通过计算正负样本的边缘概率,评估模型在边缘预测任务上的性能。评估指标包括AP、AUC、召回率、精确度和混淆矩阵。 -
节点分类评估:
eval_node_classification
函数通过计算节点的嵌入向量,并使用解码器进行分类,评估模型在节点分类任务上的性能。评估指标为AUC-ROC。
这两个评估方法都使用批处理来处理数据,并且在每个批次之间保持模型的记忆状态,以便后续批次能够访问前几批次的交互信息。这种方法提高了评估的准确性和效率。