菜单
本页目录

五、Jbeil utils

1.utils

这个文件实现了几个核心功能:网络结构的定义、早停机制、负边采样器、邻居查找器和邻居采样器。

  1. 定义了一些神经网络层,用于处理输入特征并进行非线性变换。
  2. 实现了一个早停机制,用于防止过拟合。
  3. 定义了一个随机边采样器,用于从给定的节点列表中随机采样边。
  4. 构建了一个邻居查找器,用于查找给定节点的邻居并进行邻居采样。

MergeLayer 类

class MergeLayer(torch.nn.Module):
  def __init__(self, dim1, dim2, dim3, dim4):
    super().__init__()
    self.fc1 = torch.nn.Linear(dim1 + dim2, dim3)
    self.fc2 = torch.nn.Linear(dim3, dim4)
    self.act = torch.nn.ReLU()

    torch.nn.init.xavier_normal_(self.fc1.weight)
    torch.nn.init.xavier_normal_(self.fc2.weight)

  def forward(self, x1, x2):
    x = torch.cat([x1, x2], dim=1)
    h = self.act(self.fc1(x))
    return self.fc2(h)

这个类定义了一个简单的神经网络层,它将两个输入向量拼接起来,然后通过两层全连接层和一个ReLU激活函数。主要用于将两个向量合并并进行非线性变换。

MLP 类

class MLP(torch.nn.Module):
  def __init__(self, dim, drop=0.3):
    super().__init__()
    self.fc_1 = torch.nn.Linear(dim, 80)
    self.fc_2 = torch.nn.Linear(80, 10)
    self.fc_3 = torch.nn.Linear(10, 1)
    self.act = torch.nn.ReLU()
    self.dropout = torch.nn.Dropout(p=drop, inplace=False)

  def forward(self, x):
    x = self.act(self.fc_1(x))
    x = self.dropout(x)
    x = self.act(self.fc_2(x))
    x = self.dropout(x)
    return self.fc_3(x).squeeze(dim=1)

这个类定义了一个多层感知器(MLP),包括三层全连接层和ReLU激活函数,以及Dropout层用于防止过拟合。

EarlyStopMonitor 类

class EarlyStopMonitor(object):
  def __init__(self, max_round=3, higher_better=True, tolerance=1e-10):
    self.max_round = max_round
    self.num_round = 0
    self.epoch_count = 0
    self.best_epoch = 0
    self.last_best = None
    self.higher_better = higher_better
    self.tolerance = tolerance

  def early_stop_check(self, curr_val):
    if not self.higher_better:
      curr_val *= -1
    if self.last_best is None:
      self.last_best = curr_val
    elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance:
      self.last_best = curr_val
      self.num_round = 0
      self.best_epoch = self.epoch_count
    else:
      self.num_round += 1

    self.epoch_count += 1

    return self.num_round >= self.max_round

该类实现了一个早停机制,用于在训练过程中监控验证集性能,并在性能不再提升时提前停止训练,以防止过拟合。

RandEdgeSampler 类

class RandEdgeSampler(object):
  def __init__(self, src_list, dst_list, seed=None):
    self.seed = seed
    self.src_list = np.unique(src_list)
    self.dst_list = np.unique(dst_list)

    if seed is not None:
      self.random_state = np.random.RandomState(self.seed)

  def sample(self, size):
    if self.seed is None:
      src_index = np.random.randint(0, len(self.src_list), size)
      dst_index = np.random.randint(0, len(self.dst_list), size)
    else:
      src_index = self.random_state.randint(0, len(self.src_list), size)
      dst_index = self.random_state.randint(0, len(self.dst_list), size)
    return self.src_list[src_index], self.dst_list[dst_index]

  def reset_random_state(self):
    self.random_state = np.random.RandomState(self.seed)

该类用于从给定的源和目标节点列表中随机采样一组边。它可以使用一个固定的种子来保证采样的可重复性。

get_neighbor_finder 函数

def get_neighbor_finder(data, uniform, max_node_idx=None):
  max_node_idx = max(data.sources.max(), data.destinations.max()) if max_node_idx is None else max_node_idx
  adj_list = [[] for _ in range(max_node_idx + 1)]
  for source, destination, edge_idx, timestamp in zip(data.sources, data.destinations,
                                                      data.edge_idxs,
                                                      data.timestamps):
    adj_list[source].append((destination, edge_idx, timestamp))
    adj_list[destination].append((source, edge_idx, timestamp))

  return NeighborFinder(adj_list, uniform=uniform)

该函数创建一个邻居查找器对象(NeighborFinder),用于查找给定节点的邻居。它根据输入的数据构建一个邻接列表,其中每个节点都记录了与其相连的节点、边索引和时间戳。

NeighborFinder 类

class NeighborFinder:
  def __init__(self, adj_list, uniform=False, seed=None):
    self.node_to_neighbors = []
    self.node_to_edge_idxs = []
    self.node_to_edge_timestamps = []

    for neighbors in adj_list:
      sorted_neighhbors = sorted(neighbors, key=lambda x: x[2])
      self.node_to_neighbors.append(np.array([x[0] for x in sorted_neighhbors]))
      self.node_to_edge_idxs.append(np.array([x[1] for x in sorted_neighhbors]))
      self.node_to_edge_timestamps.append(np.array([x[2] for x in sorted_neighhbors]))

    self.uniform = uniform

    if seed is not None:
      self.seed = seed
      self.random_state = np.random.RandomState(self.seed)

  def find_before(self, src_idx, cut_time):
    i = np.searchsorted(self.node_to_edge_timestamps[src_idx], cut_time)
    return self.node_to_neighbors[src_idx][:i], self.node_to_edge_idxs[src_idx][:i], self.node_to_edge_timestamps[src_idx][:i]

  def get_temporal_neighbor(self, source_nodes, timestamps, n_neighbors=20):
    assert (len(source_nodes) == len(timestamps))

    tmp_n_neighbors = n_neighbors if n_neighbors > 0 else 1
    neighbors = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(np.int32)
    edge_times = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(np.float32)
    edge_idxs = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(np.int32)

    for i, (source_node, timestamp) in enumerate(zip(source_nodes, timestamps)):
      source_neighbors, source_edge_idxs, source_edge_times = self.find_before(source_node, timestamp)

      if len(source_neighbors) > 0 and n_neighbors > 0:
        if self.uniform:
          sampled_idx = np.random.randint(0, len(source_neighbors), n_neighbors)
          neighbors[i, :] = source_neighbors[sampled_idx]
          edge_times[i, :] = source_edge_times[sampled_idx]
          edge_idxs[i, :] = source_edge_idxs[sampled_idx]

          pos = edge_times[i, :].argsort()
          neighbors[i, :] = neighbors[i, :][pos]
          edge_times[i, :] = edge_times[i, :][pos]
          edge_idxs[i, :] = edge_idxs[i, :][pos]
        else:
          source_edge_times = source_edge_times[-n_neighbors:]
          source_neighbors = source_neighbors[-n_neighbors:]
          source_edge_idxs = source_edge_idxs[-n_neighbors:]

          neighbors[i, n_neighbors - len(source_neighbors):] = source_neighbors
          edge_times[i, n_neighbors - len(source_edge_times):] = source_edge_times
          edge_idxs[i, n_neighbors - len(source_edge_idxs):] = source_edge_idxs

    return neighbors, edge_idxs, edge_times

该类实现了一个邻居查找器,用于查找给定节点的邻居。它可以基于时间戳查找邻居,并根据不同的采样策略(如统一采样)进行采样。主要方法包括:

  • find_before: 查找给定节点在特定时间点之前的所有邻居。
  • get_temporal_neighbor: 获取一组节点在给定时间戳之前的临时邻居。

2.preprocess_data.py

这个preprocess_data文件的主要功能是预处理输入数据,并生成适合于时间图神经网络(TGN)使用的格式化数据。主要步骤包括数据的读取、特征提取、索引重排和保存。

  1. trunc(values, decs=0): 用于截断浮点数到指定的小数位数。
  2. preprocess(data_name): 读取数据文件并提取出用户、项目、时间戳、标签和特征等信息。
  3. reindex(df, bipartite=True): 重新排列数据框的索引,并处理二部图(bipartite)和非二部图的情况。
  4. run(data_name, bipartite=True): 调用上述函数来处理输入数据,并将处理后的数据保存到指定的路径中。

函数详细解释

trunc(values, decs=0)

def trunc(values, decs=0):
    return np.trunc(values*10**decs)/(10**decs)

这个函数用于截断浮点数到指定的小数位数。参数values是需要截断的浮点数,decs是小数位数。

preprocess(data_name)

def preprocess(data_name):
  u_list, i_list, ts_list, label_list = [], [], [], []
  feat_l = []
  idx_list = []

  with open(data_name) as f:
    s = next(f)
    for idx, line in enumerate(f):
      e = line.strip().split(',')
      u = int(e[0])
      i = int(e[1])
      ts = float(e[2])
      label = float(e[3])
      feat = np.array([float(x) for x in e[4:]])

      u_list.append(u)
      i_list.append(i)
      ts_list.append(ts)
      label_list.append(label)
      idx_list.append(idx)
      feat_l.append(feat)

  return pd.DataFrame({'u': u_list,
                       'i': i_list,
                       'ts': ts_list,
                       'label': label_list,
                       'idx': idx_list}), np.array(feat_l)

这个函数读取数据文件,并将数据解析成用户(u)、项目(i)、时间戳(ts)、标签(label)和特征(feat)的列表。然后将这些列表转化为一个DataFrame和一个特征矩阵。

reindex(df, bipartite=True)

def reindex(df, bipartite=True):
  new_df = df.copy()
  if bipartite:
    assert (df.u.max() - df.u.min() + 1 == len(df.u.unique()))
    assert (df.i.max() - df.i.min() + 1 == len(df.i.unique()))

    upper_u = df.u.max() + 1
    new_i = df.i + upper_u

    new_df.i = new_i
    new_df.u += 1
    new_df.i += 1
    new_df.idx += 1
  else:
    new_df.u += 1
    new_df.i += 1
    new_df.idx += 1

  return new_df

这个函数用于重新排列数据框的索引。如果是二部图(bipartite=True),它会将项目索引偏移,使用户和项目的索引不重叠。对于非二部图,只是简单地将索引加1。

run(data_name, bipartite=True)(不重要)

def run(data_name, bipartite=True):
  Path("data/").mkdir(parents=True, exist_ok=True)
  PATH = './data/{}.csv'.format(data_name)
  OUT_DF = './data/ml_{}.csv'.format(data_name)
  OUT_FEAT = './data/ml_{}.npy'.format(data_name)
  OUT_NODE_FEAT = './data/ml_{}_node.npy'.format(data_name)

  df, feats = preprocess(PATH)
  new_df = reindex(df, bipartite)
  empty = np.zeros(feats.shape[1])[np.newaxis, :]
  feat = np.vstack([empty, feats])
  max_idx = max(new_df.u.max(), new_df.i.max())
  rand_feat = np.zeros((max_idx + 1, 10))

  new_df.to_csv(OUT_DF)
  np.save(OUT_FEAT, feat)
  np.save(OUT_NODE_FEAT, rand_feat)

这个函数是预处理数据的主要入口。它创建必要的文件夹,调用preprocess函数读取和处理数据,然后调用reindex函数重新排列索引,最后将处理后的数据保存为CSV和Numpy数组文件。

脚本入口

parser = argparse.ArgumentParser('Interface for TGN data preprocessing')
parser.add_argument('--data', type=str, help='Dataset name (eg. auth or pivoting)')
parser.add_argument('--bipartite', action='store_true', help='Whether the graph is bipartite')

args = parser.parse_args()

run(args.data, bipartite=args.bipartite)

这个部分是脚本的入口,使用argparse解析命令行参数,并调用run函数进行数据预处理。

函数和类之间的依赖关系

  1. run函数调用preprocess函数来读取和初步处理数据。
  2. run函数调用reindex函数来重新排列索引。
  3. run函数保存处理后的数据到文件中。