菜单
本页目录

三、Jbeil modules

1.memory.py

实现记忆模块,用于存储和处理节点的长期信息。TGN模型依赖这个模块来管理节点的记忆。

注:基本上是对记忆的管理,没有实现更新等功能,那个是在memory_updater.py (记忆更新模块)中。

  1. 类初始化 __init__

    • 功能:初始化 Memory 类,定义节点数量、记忆维度、输入维度、消息维度、设备和组合方法等属性。

    • 关键代码:

      class Memory(nn.Module):
          def __init__(self, n_nodes, memory_dimension, input_dimension, message_dimension=None,
                       device="cpu", combination_method='sum'):
              super(Memory, self).__init__()
              self.n_nodes = n_nodes
              self.memory_dimension = memory_dimension
              self.input_dimension = input_dimension
              self.message_dimension = message_dimension
              self.device = device
              self.combination_method = combination_method
              self.__init_memory__()
      
  2. 初始化记忆 __init_memory__

    • 功能:将记忆初始化为全零。在每个epoch开始时调用此函数。

    • 关键代码:

      def __init_memory__(self):
          self.memory = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device),
                                     requires_grad=False)
          self.last_update = nn.Parameter(torch.zeros(self.n_nodes).to(self.device),
                                          requires_grad=False)
          self.messages = defaultdict(list)
      
  3. 存储原始消息 store_raw_messages

    • 功能:存储原始消息,将消息添加到对应节点的消息列表中。

    • 关键代码:

      def store_raw_messages(self, nodes, node_id_to_messages):
          for node in nodes:
              self.messages[node].extend(node_id_to_messages[node])
      
  4. 获取记忆 get_memory

    • 功能:获取指定节点的记忆。

    • 关键代码:

      def get_memory(self, node_idxs):
          return self.memory[node_idxs, :]
      
  5. 设置记忆 set_memory

    • 功能:设置指定节点的记忆为给定值。

    • 关键代码:

      def set_memory(self, node_idxs, values):
          self.memory[node_idxs, :] = values
      
  6. 获取最后更新 get_last_update

    • 功能:获取指定节点的最后更新时间。

    • 关键代码:

      def get_last_update(self, node_idxs):
          return self.last_update[node_idxs]
      
  7. 备份记忆 backup_memory

    • 功能:备份当前记忆和消息。

    • 关键代码:

      def backup_memory(self):
          messages_clone = {}
          for k, v in self.messages.items():
              messages_clone[k] = [(x[0].clone(), x[1].clone()) for x in v]
          return self.memory.data.clone(), self.last_update.data.clone(), messages_clone
      
  8. 恢复记忆 restore_memory

    • 功能:从备份中恢复记忆和消息。

    • 关键代码:

      def restore_memory(self, memory_backup):
          self.memory.data, self.last_update.data = memory_backup[0].clone(), memory_backup[1].clone()
          self.messages = defaultdict(list)
          for k, v in memory_backup[2].items():
              self.messages[k] = [(x[0].clone(), x[1].clone()) for x in v]
      
  9. 分离记忆 detach_memory

    • 功能:分离当前记忆和消息,使其不参与反向传播。

    • 关键代码:

      def detach_memory(self):
          self.memory.detach_()
          for k, v in self.messages.items():
              new_node_messages = []
              for message in v:
                  new_node_messages.append((message[0].detach(), message[1]))
              self.messages[k] = new_node_messages
      
  10. 清除消息 clear_messages

    • 功能:清除指定节点的消息。

    • 关键代码:

      def clear_messages(self, nodes):
          for node in nodes:
              self.messages[node] = []
      

2.message_aggregator.py

实现消息聚合器模块,用于将相同节点ID的消息聚合在一起。该模块提供了两种聚合策略:保留最后一个消息和取平均值。消息聚合器在TGN模型中用于处理节点之间传递的消息,确保在同一个节点上所有消息可以被有效地聚合。

该文件包含一个抽象类 MessageAggregator 和两个具体的实现类 LastMessageAggregatorMeanMessageAggregator,以及一个工厂函数 get_message_aggregator

理论简述

在TGN模型中,每个节点在不同的时间点可以接收多个消息。这些消息可能来自不同的源节点,表示不同的交互事件。例如,一个节点在不同时间点与多个其他节点进行交互,或者同一个源节点在不同时间点向目标节点发送消息。因此,多个消息的聚合实际上是指在同一个节点上收集和处理这些来自不同时间点或不同源节点的消息。

聚合消息的情景解释

假设有以下场景:

  • 节点A在时间点1接收到来自节点B的消息。
  • 节点A在时间点2接收到来自节点C的消息。
  • 节点A在时间点3再次接收到来自节点B的消息。

在这种情况下,节点A有多个消息需要聚合。消息聚合器的作用就是在这些不同时间点接收到的消息中,根据一定的策略(例如保留最后一个消息或取平均值)来聚合消息。

聚合器的具体作用

  1. LastMessageAggregator:只保留每个节点的最后一个消息。例如,在上面的例子中,节点A将保留在时间点3从节点B接收到的消息,因为这是最后一个消息。
  2. MeanMessageAggregator:将每个节点接收到的所有消息取平均值。例如,在上面的例子中,节点A将计算从节点B和节点C接收到的所有消息的平均值。

代码说明

  1. 抽象类 MessageAggregator

    • 功能:定义了消息聚合器的基本结构和方法,包括 aggregate 方法和 group_by_id 方法。

    • 关键代码:

      class MessageAggregator(torch.nn.Module):
          """
          消息聚合器模块的抽象类,给定一批节点ID和对应的消息,聚合具有相同节点ID的消息。
          """
          def __init__(self, device):
              super(MessageAggregator, self).__init__()
              self.device = device
      
          def aggregate(self, node_ids, messages):
              """
              给定节点ID列表和相同长度的消息列表,使用一种可能的策略聚合相同ID的不同消息。
              :param node_ids: 长度为 batch_size 的节点ID列表
              :param messages: 形状为 [batch_size, message_length] 的张量
              :param timestamps: 形状为 [batch_size] 的张量
              :return: 形状为 [n_unique_node_ids, message_length] 的张量,包含聚合后的消息
              """
              pass
      
          def group_by_id(self, node_ids, messages, timestamps):
              """
              根据节点ID对消息进行分组,返回一个字典,字典的键是节点ID,值是包含消息和时间戳的列表。
              :param node_ids: 节点ID列表
              :param messages: 消息张量
              :param timestamps: 时间戳张量
              :return: 一个字典,键是节点ID,值是消息和时间戳的列表
              """
              node_id_to_messages = defaultdict(list)
              for i, node_id in enumerate(node_ids):
                  node_id_to_messages[node_id].append((messages[i], timestamps[i]))
              return node_id_to_messages
      
  2. 具体实现类 LastMessageAggregator

    • 功能:实现了 aggregate 方法,只保留每个节点的最后一个消息。

    • 关键代码:

      class LastMessageAggregator(MessageAggregator):
          """
          最后消息聚合器,只保留每个节点的最后一个消息。
          """
          def __init__(self, device):
              super(LastMessageAggregator, self).__init__(device)
      
          def aggregate(self, node_ids, messages):
              """
              只保留每个节点的最后一个消息
              :param node_ids: 节点ID列表
              :param messages: 消息张量
              :return: 待更新的节点ID列表,聚合后的消息张量,聚合后的时间戳张量
              """
              unique_node_ids = np.unique(node_ids)  # 获取唯一的节点ID
              unique_messages = []  # 用于存储聚合后的唯一消息
              unique_timestamps = []  # 用于存储聚合后的唯一时间戳
      
              to_update_node_ids = []  # 待更新的节点ID列表
      
              for node_id in unique_node_ids:
                  if len(messages[node_id]) > 0:
                      to_update_node_ids.append(node_id)  # 记录待更新的节点ID
                      unique_messages.append(messages[node_id][-1][0])  # 取最后一个消息
                      unique_timestamps.append(messages[node_id][-1][1])  # 取最后一个时间戳
      
              unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
              unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []
      
              return to_update_node_ids, unique_messages, unique_timestamps
      
      
  3. 具体实现类 MeanMessageAggregator

    • 功能:实现了 aggregate 方法,将每个节点的消息取平均值进行聚合。

    • 关键代码:

      class MeanMessageAggregator(MessageAggregator):
          """
          平均消息聚合器,将每个节点的消息取平均值进行聚合。
          """
          def __init__(self, device):
              super(MeanMessageAggregator, self).__init__(device)
      
          def aggregate(self, node_ids, messages):
              """
              通过取每个节点消息的平均值进行聚合
              :param node_ids: 节点ID列表
              :param messages: 消息张量
              :return: 待更新的节点ID列表,聚合后的消息张量,聚合后的时间戳张量
              """
              # 获取唯一的节点ID
              unique_node_ids = np.unique(node_ids)
              unique_messages = []
              unique_timestamps = []
      
              to_update_node_ids = []
              n_messages = 0
      
              for node_id in unique_node_ids:
                  if len(messages[node_id]) > 0:
                      n_messages += len(messages[node_id])  # 计算消息的数量
                      to_update_node_ids.append(node_id)
                      # 取消息的平均值
                      unique_messages.append(torch.mean(torch.stack([m[0] for m in messages[node_id]]), dim=0))
                      # 取最后一个时间戳,代表最新的消息时间
                      unique_timestamps.append(messages[node_id][-1][1])
      
              unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
              unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []
      
              return to_update_node_ids, unique_messages, unique_timestamps
      
      
  4. 工厂函数 get_message_aggregator

    • 功能:根据给定的类型返回相应的消息聚合器实例。

    • 关键代码:

      def get_message_aggregator(aggregator_type, device):
          """
          根据给定的类型返回相应的消息聚合器实例。
          :param aggregator_type: 聚合器类型(last或mean)
          :param device: 设备(cpu或gpu)
          :return: 消息聚合器实例
          """
          if aggregator_type == "last":
              return LastMessageAggregator(device=device)
          elif aggregator_type == "mean":
              return MeanMessageAggregator(device=device)
          else:
              raise ValueError("消息聚合器 {} 未实现".format(aggregator_type))
      

3.message_function.py

message_function.py 定义了用于计算节点之间交互事件的消息函数。这些函数在TGN模型中用于处理节点之间的消息传递和特征计算。

该文件包含一个抽象类 MessageFunction 和两个具体的实现类 MLPMessageFunctionIdentityMessageFunction,以及一个工厂函数 get_message_function

这部分可以扩充MLP多层感知机的知识,能写跟多理论公式

理论简述

在TGN模型中,节点之间的交互事件会生成消息,这些消息可以包含节点的特征、时间戳、边的特征等。为了将这些原始消息转换为用于模型训练和推断的最终消息,我们需要使用消息函数。消息函数负责接收原始消息并进行必要的计算,生成最终的消息表示。

消息函数的具体作用

  1. MLPMessageFunction:使用多层感知器 (MLP) 处理原始消息,将其转换为目标消息维度。
  2. IdentityMessageFunction:直接返回原始消息,不进行任何处理。

代码说明

  1. 抽象类 MessageFunction

    • 功能:定义了消息函数的基本结构和方法,包括 compute_message 方法。

    • 关键代码:

      from torch import nn
      
      class MessageFunction(nn.Module):
          """
          计算给定交互事件消息的模块。
          """
          def compute_message(self, raw_messages):
              """
              计算给定原始消息的最终消息。
              :param raw_messages: 原始消息
              :return: 计算后的消息
              """
              return None
      
  2. 具体实现类 MLPMessageFunction

    • 功能:使用多层感知器 (MLP) 处理原始消息,将其转换为目标消息维度。

    • 关键代码:

      class MLPMessageFunction(MessageFunction):
          def __init__(self, raw_message_dimension, message_dimension):
              super(MLPMessageFunction, self).__init__()
              # 定义一个多层感知器(MLP)用于消息计算
              self.mlp = nn.Sequential(
                  nn.Linear(raw_message_dimension, raw_message_dimension // 2),  # 全连接层,将输入维度减半
                  nn.ReLU(),  # ReLU激活函数
                  nn.Linear(raw_message_dimension // 2, message_dimension),  # 全连接层,将维度转换为目标消息维度
              )
      
          def compute_message(self, raw_messages):
              """
              计算给定原始消息的最终消息。
              :param raw_messages: 原始消息
              :return: 通过MLP计算后的消息
              """
              messages = self.mlp(raw_messages)
              return messages
      
  3. 具体实现类 IdentityMessageFunction

    • 功能:直接返回原始消息,不进行任何处理。

    • 关键代码:

      class IdentityMessageFunction(MessageFunction):
          def compute_message(self, raw_messages):
              """
              直接返回原始消息,不进行任何处理。
              :param raw_messages: 原始消息
              :return: 原始消息
              """
              return raw_messages
      
  4. 工厂函数 get_message_function

    • 功能:根据给定的类型返回相应的消息函数实例。

    • 关键代码:

      def get_message_function(module_type, raw_message_dimension, message_dimension):
          """
          根据给定的模块类型返回相应的消息函数实例。
          :param module_type: 模块类型 ("mlp" 或 "identity")
          :param raw_message_dimension: 原始消息的维度
          :param message_dimension: 目标消息的维度
          :return: 消息函数实例
          """
          if module_type == "mlp":
              return MLPMessageFunction(raw_message_dimension, message_dimension)
          elif module_type == "identity":
              return IdentityMessageFunction()
          else:
              raise ValueError("消息函数 {} 未实现".format(module_type))
      

4.memory_updater.py

实现记忆更新模块,用于更新节点的记忆状态。该模块提供了不同的记忆更新策略,包括使用 GRU 和 RNN 进行记忆更新。记忆更新模块在 TGN 模型中用于处理节点之间的交互消息,并更新节点的记忆状态。

该文件包含一个抽象类 MemoryUpdater 和两个具体的实现类 GRUMemoryUpdaterRNNMemoryUpdater,以及一个工厂函数 get_memory_updater

理论简述

在TGN模型中,节点的记忆状态会随着时间和节点之间的交互事件不断更新。记忆更新模块的作用是接收这些交互消息,并更新节点的记忆状态。记忆更新可以使用不同的策略,例如 GRU 或 RNN,这些策略会影响记忆的更新方式和效果。

GRU 单元能够有效捕捉序列数据中的长期依赖关系,适用于处理较长序列。RNN 单元相对简单,但在处理长序列时容易出现梯度消失问题。

记忆更新的具体作用

  1. GRUMemoryUpdater:使用 GRU 单元更新节点的记忆状态,能够有效地捕捉长期依赖关系。
  2. RNNMemoryUpdater:使用 RNN 单元更新节点的记忆状态,相对简单但对长期依赖的捕捉效果较差。

代码说明

  1. 抽象类 MemoryUpdater

    • 功能:定义了记忆更新的基本结构和方法,包括 update_memory 方法。

    • 关键代码:

      from torch import nn
      
      class MemoryUpdater(nn.Module):
          def update_memory(self, unique_node_ids, unique_messages, timestamps):
              """
              更新记忆状态的方法。
              :param unique_node_ids: 唯一的节点ID列表
              :param unique_messages: 对应的消息列表
              :param timestamps: 时间戳列表
              """
              pass
      
  2. 具体实现类基类 SequenceMemoryUpdater

    • 功能

      • 作为具体记忆更新类的基类,实现了基于序列的记忆更新方法。
      • memory_updater.py 文件中,具体实现类 GRUMemoryUpdaterRNNMemoryUpdater 通过继承基类 SequenceMemoryUpdater 来实现记忆更新的功能。基类 SequenceMemoryUpdater 提供了记忆更新的基本逻辑和方法,而具体实现类通过定义特定的记忆更新单元(GRU 或 RNN)来实现实际的记忆更新操作。
    • 关键代码:

      class SequenceMemoryUpdater(MemoryUpdater):
          def __init__(self, memory, message_dimension, memory_dimension, device):
              super(SequenceMemoryUpdater, self).__init__()
              self.memory = memory  # 记忆模块实例
              self.layer_norm = torch.nn.LayerNorm(memory_dimension)  # 层归一化,用于规范化记忆状态
              self.message_dimension = message_dimension  # 消息的维度
              self.device = device  # 设备(如 'cpu' 或 'gpu')
      
          def update_memory(self, unique_node_ids, unique_messages, timestamps):
              """
              更新节点的记忆状态。
              :param unique_node_ids: 唯一的节点ID列表
              :param unique_messages: 对应的消息列表
              :param timestamps: 时间戳列表
              """
              if len(unique_node_ids) <= 0:
                  return  # 如果没有节点需要更新,则返回
      
              # 确保更新的时间戳不早于当前记忆的最后更新时间
              assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to update memory to time in the past"
      
              memory = self.memory.get_memory(unique_node_ids)  # 获取当前节点的记忆状态
              self.memory.last_update[unique_node_ids] = timestamps  # 更新最后更新时间戳
      
              updated_memory = self.memory_updater(unique_messages, memory)  # 通过记忆更新单元计算新的记忆状态
      
              self.memory.set_memory(unique_node_ids, updated_memory)  # 设置新的记忆状态
      
          def get_updated_memory(self, unique_node_ids, unique_messages, timestamps):
              """
              获取更新后的记忆状态。
              :param unique_node_ids: 唯一的节点ID列表
              :param unique_messages: 对应的消息列表
              :param timestamps: 时间戳列表
              :return: 更新后的记忆状态和最后更新时间戳
              """
              if len(unique_node_ids) <= 0:
                  return self.memory.memory.data.clone(), self.memory.last_update.data.clone()  # 如果没有节点需要更新,则返回当前记忆状态的副本
      
              # 确保更新的时间戳不早于当前记忆的最后更新时间
              assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to update memory to time in the past"
      
              updated_memory = self.memory.memory.data.clone()  # 克隆当前记忆状态
              updated_memory[unique_node_ids] = self.memory_updater(unique_messages, updated_memory[unique_node_ids])  # 更新指定节点的记忆状态
      
              updated_last_update = self.memory.last_update.data.clone()  # 克隆最后更新时间戳
              updated_last_update[unique_node_ids] = timestamps  # 更新指定节点的最后更新时间戳
      
              return updated_memory, updated_last_update  # 返回更新后的记忆状态和最后更新时间戳
      
      
  3. 具体实现类 GRUMemoryUpdater

    • 功能:使用 GRU 单元更新节点的记忆状态。

    • 关键代码:

      class GRUMemoryUpdater(SequenceMemoryUpdater):
          def __init__(self, memory, message_dimension, memory_dimension, device):
              super(GRUMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)
      
              # 使用 GRUCell 作为记忆更新单元
              self.memory_updater = nn.GRUCell(input_size=message_dimension,
                                               hidden_size=memory_dimension)
      
      
  4. 具体实现类 RNNMemoryUpdater

    • 功能:使用 RNN 单元更新节点的记忆状态。

    • 关键代码:

      class RNNMemoryUpdater(SequenceMemoryUpdater):
          def __init__(self, memory, message_dimension, memory_dimension, device):
              super(RNNMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)
      
              # 使用 RNNCell 作为记忆更新单元
              self.memory_updater = nn.RNNCell(input_size=message_dimension,
                                               hidden_size=memory_dimension)
      
  5. 工厂函数 get_memory_updater

    • 功能:根据给定的类型返回相应的记忆更新实例。

    • 关键代码:

      def get_memory_updater(module_type, memory, message_dimension, memory_dimension, device):
          """
          根据给定的模块类型返回相应的记忆更新实例。
          :param module_type: 模块类型 ("gru" 或 "rnn")
          :param memory: Memory 实例
          :param message_dimension: 消息的维度
          :param memory_dimension: 记忆的维度
          :param device: 设备
          :return: 记忆更新实例
          """
          if module_type == "gru":
              return GRUMemoryUpdater(memory, message_dimension, memory_dimension, device)
          elif module_type == "rnn":
              return RNNMemoryUpdater(memory, message_dimension, memory_dimension, device)
          else:
              raise ValueError("记忆更新模块 {} 未实现".format(module_type))
      

5.embedding_module.py

实现嵌入模块,用于计算节点和边的嵌入表示。该模块提供了多种嵌入策略,包括图注意力嵌入、图求和嵌入、时间嵌入和身份嵌入。嵌入模块在TGN模型中用于生成节点和边的动态表示,以捕捉随时间变化的网络结构和节点属性。

该文件包含一个抽象类 EmbeddingModule 和四个具体的实现类 IdentityEmbeddingTimeEmbeddingGraphSumEmbeddingGraphAttentionEmbedding,以及一个工厂函数 get_embedding_module

理论简述

在TGN模型中,嵌入模块的作用是生成节点和边的动态嵌入表示。这些嵌入表示能够捕捉随时间变化的网络结构和节点属性,从而支持模型在时间上的推理和预测。嵌入模块提供了多种策略,以满足不同的需求:

  1. IdentityEmbedding:直接返回节点的记忆状态,不进行任何处理。
  2. TimeEmbedding:结合时间信息生成节点的嵌入表示。
  3. GraphSumEmbedding:通过求和邻居节点和边的特征生成节点的嵌入表示。
  4. GraphAttentionEmbedding:使用图注意力机制生成节点的嵌入表示。

代码说明

  1. 抽象类 EmbeddingModule

    • 功能:定义了嵌入模块的基本结构和方法,包括 compute_embedding 方法。

    • 关键代码:

      import torch
      from torch import nn
      import numpy as np
      import math
      
      from model.temporal_attention import TemporalAttentionLayer
      
      class EmbeddingModule(nn.Module):
          def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
                       n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
                       dropout):
              """
              初始化嵌入模块
              :param node_features: 节点特征
              :param edge_features: 边特征
              :param memory: 记忆模块
              :param neighbor_finder: 邻居查找模块
              :param time_encoder: 时间编码器
              :param n_layers: 层数
              :param n_node_features: 节点特征维度
              :param n_edge_features: 边特征维度
              :param n_time_features: 时间特征维度
              :param embedding_dimension: 嵌入维度
              :param device: 设备(如 'cpu' 或 'gpu')
              :param dropout: dropout率
              """
              super(EmbeddingModule, self).__init__()
              self.node_features = node_features  # 节点特征
              self.edge_features = edge_features  # 边特征
              self.neighbor_finder = neighbor_finder  # 邻居查找模块
              self.time_encoder = time_encoder  # 时间编码器
              self.n_layers = n_layers  # 层数
              self.n_node_features = n_node_features  # 节点特征维度
              self.n_edge_features = n_edge_features  # 边特征维度
              self.n_time_features = n_time_features  # 时间特征维度
              self.dropout = dropout  # dropout率
              self.embedding_dimension = embedding_dimension  # 嵌入维度
              self.device = device  # 设备(如 'cpu' 或 'gpu')
      
          def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
                                use_time_proj=True):
              """
              计算嵌入表示
              :param memory: 记忆模块
              :param source_nodes: 源节点
              :param timestamps: 时间戳
              :param n_layers: 层数
              :param n_neighbors: 每层考虑的邻居数量
              :param time_diffs: 时间差异
              :param use_time_proj: 是否使用时间投影
              """
              pass
      
      
  2. 具体实现类 IdentityEmbedding

    • 功能:直接返回节点的记忆状态,不进行任何处理。

    • 关键代码:

      class IdentityEmbedding(EmbeddingModule):
          def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
                                use_time_proj=True):
              """
              计算节点的嵌入表示,直接返回节点的记忆状态。
              :param memory: 记忆模块
              :param source_nodes: 源节点
              :param timestamps: 时间戳
              :param n_layers: 层数
              :param n_neighbors: 邻居数量
              :param time_diffs: 时间差异
              :param use_time_proj: 是否使用时间投影
              :return: 节点的记忆状态
              """
              return memory[source_nodes, :]
      
      
  3. 具体实现类 TimeEmbedding

    • 功能:结合时间信息生成节点的嵌入表示。

    • 关键代码:

      class TimeEmbedding(EmbeddingModule):
          def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
                       n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
                       n_heads=2, dropout=0.1, use_memory=True, n_neighbors=1):
              """
              初始化时间嵌入模块
              :param node_features: 节点特征
              :param edge_features: 边特征
              :param memory: 记忆模块
              :param neighbor_finder: 邻居查找模块
              :param time_encoder: 时间编码器
              :param n_layers: 层数
              :param n_node_features: 节点特征维度
              :param n_edge_features: 边特征维度
              :param n_time_features: 时间特征维度
              :param embedding_dimension: 嵌入维度
              :param device: 设备(如 'cpu' 或 'gpu')
              :param n_heads: 注意力头数量
              :param dropout: dropout率
              :param use_memory: 是否使用记忆模块
              :param n_neighbors: 每层考虑的邻居数量
              """
              super(TimeEmbedding, self).__init__(node_features, edge_features, memory,
                                                  neighbor_finder, time_encoder, n_layers,
                                                  n_node_features, n_edge_features, n_time_features,
                                                  embedding_dimension, device, dropout)
      
              class NormalLinear(nn.Linear):
                  # 从Jodie代码中借用
                  def reset_parameters(self):
                      stdv = 1. / math.sqrt(self.weight.size(1))
                      self.weight.data.normal_(0, stdv)
                      if self.bias is not None:
                          self.bias.data.normal_(0, stdv)
      
              # 定义用于时间差异嵌入的线性层
              self.embedding_layer = NormalLinear(1, self.n_node_features)
      
          def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
                                use_time_proj=True):
              """
              计算节点的时间嵌入表示
              :param memory: 记忆模块,用于存储和更新节点的记忆状态
              :param source_nodes: 源节点,表示需要计算嵌入的节点
              :param timestamps: 时间戳,表示节点的时间信息
              :param n_layers: 层数,表示模型的层数
              :param n_neighbors: 每层考虑的邻居数量,默认值为20
              :param time_diffs: 时间差异,表示节点之间的时间差异
              :param use_time_proj: 是否使用时间投影,默认值为True
              :return: 源节点的嵌入表示
              """
              # 计算时间差异的嵌入表示,并将其与记忆状态相乘
              source_embeddings = memory[source_nodes, :] * (1 + self.embedding_layer(time_diffs.unsqueeze(1)))
              return source_embeddings
      
      
  4. 具体实现类 GraphEmbedding

    • 功能:作为图嵌入类的基类,实现了递归计算节点嵌入的基本逻辑。

    • 关键代码:

      class GraphEmbedding(EmbeddingModule):
          def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
                       n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
                       n_heads=2, dropout=0.1, use_memory=True):
              """
              初始化图嵌入模块
              :param node_features: 节点特征
              :param edge_features: 边特征
              :param memory: 记忆模块
              :param neighbor_finder: 邻居查找模块
              :param time_encoder: 时间编码器
              :param n_layers: 层数
              :param n_node_features: 节点特征维度
              :param n_edge_features: 边特征维度
              :param n_time_features: 时间特征维度
              :param embedding_dimension: 嵌入维度
              :param device: 设备(如 'cpu' 或 'gpu')
              :param n_heads: 注意力头数量
              :param dropout: dropout率
              :param use_memory: 是否使用记忆模块
              """
              super(GraphEmbedding, self).__init__(node_features, edge_features, memory,
                                                   neighbor_finder, time_encoder, n_layers,
                                                   n_node_features, n_edge_features, n_time_features,
                                                   embedding_dimension, device, dropout)
      
              self.use_memory = use_memory  # 是否使用记忆模块
              self.device = device  # 设备(如 'cpu' 或 'gpu')
      
          def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
                                use_time_proj=True):
              """
              递归实现当前层的时序图注意力层
              :param memory: 记忆模块
              :param source_nodes: 节点ID列表
              :param timestamps: 时间戳列表
              :param n_layers: 时序卷积层的数量
              :param n_neighbors: 每层时序卷积层考虑的邻居数
              :param time_diffs: 时间差异
              :param use_time_proj: 是否使用时间投影
              :return: 源节点的嵌入表示
              """
              assert (n_layers >= 0)
      
              source_nodes_torch = torch.from_numpy(source_nodes).long().to(self.device)  # 将源节点转换为tensor
              timestamps_torch = torch.unsqueeze(torch.from_numpy(timestamps).float().to(self.device), dim=1)  # 将时间戳转换为tensor并扩展维度
      
              # 获取源节点的时间嵌入表示
              source_nodes_time_embedding = self.time_encoder(torch.zeros_like(timestamps_torch))
      
              # 获取源节点的特征
              source_node_features = self.node_features[source_nodes_torch, :]
      
              if self.use_memory:
                  # 如果使用记忆模块,将记忆状态与节点特征相加
                  source_node_features = memory[source_nodes, :] + source_node_features
      
              if n_layers == 0:
                  # 如果层数为0,直接返回节点特征
                  return source_node_features
              else:
                  # 获取邻居节点、边索引和边时间
                  neighbors, edge_idxs, edge_times = self.neighbor_finder.get_temporal_neighbor(
                      source_nodes,
                      timestamps,
                      n_neighbors=n_neighbors)
      
                  neighbors_torch = torch.from_numpy(neighbors).long().to(self.device)  # 将邻居节点转换为tensor
                  edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device)  # 将边索引转换为tensor
                  edge_deltas = timestamps[:, np.newaxis] - edge_times  # 计算时间差异
                  edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device)  # 将时间差异转换为tensor
      
                  neighbors = neighbors.flatten()  # 展平邻居节点
                  neighbor_embeddings = self.compute_embedding(memory,
                                                               neighbors,
                                                               np.repeat(timestamps, n_neighbors),
                                                               n_layers=n_layers - 1,
                                                               n_neighbors=n_neighbors)  # 递归计算邻居节点的嵌入表示
      
                  effective_n_neighbors = n_neighbors if n_neighbors > 0 else 1  # 有效的邻居数
                  neighbor_embeddings = neighbor_embeddings.view(len(source_nodes), effective_n_neighbors, -1)  # 重塑邻居嵌入表示的形状
                  edge_time_embeddings = self.time_encoder(edge_deltas_torch)  # 获取边的时间嵌入表示
      
                  edge_features = self.edge_features[edge_idxs, :]  # 获取边的特征
      
                  mask = neighbors_torch == 0  # 创建掩码,用于忽略没有邻居的情况
      
                  # 聚合源节点和邻居节点的嵌入表示
                  source_embedding = self.aggregate(n_layers, source_node_features,
                                                    source_nodes_time_embedding,
                                                    neighbor_embeddings,
                                                    edge_time_embeddings,
                                                    edge_features,
                                                    mask)
      
                  return source_embedding
      
          def aggregate(self, n_layers, source_node_features, source_nodes_time_embedding,
                        neighbor_embeddings,
                        edge_time_embeddings, edge_features, mask):
              """
              聚合源节点和邻居节点的嵌入表示
              :param n_layers: 时序卷积层的数量
              :param source_node_features: 源节点特征
              :param source_nodes_time_embedding: 源节点时间嵌入表示
              :param neighbor_embeddings: 邻居节点嵌入表示
              :param edge_time_embeddings: 边的时间嵌入表示
              :param edge_features: 边特征
              :param mask: 掩码,用于忽略没有邻居的情况
              :return: 聚合后的嵌入表示
              """
              return None  # 聚合逻辑在具体的实现类中定义
      
      
  5. 具体实现类 GraphSumEmbedding

    • 功能:通过求和邻居节点和边的特征生成节点的嵌入表示。

    • 关键代码:

      class GraphSumEmbedding(GraphEmbedding):
          def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
                       n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
                       n_heads=2, dropout=0.1, use_memory=True):
              """
              初始化图嵌入模块(使用加和的方式聚合邻居节点嵌入)
              :param node_features: 节点特征
              :param edge_features: 边特征
              :param memory: 记忆模块
              :param neighbor_finder: 邻居查找模块
              :param time_encoder: 时间编码器
              :param n_layers: 层数
              :param n_node_features: 节点特征维度
              :param n_edge_features: 边特征维度
              :param n_time_features: 时间特征维度
              :param embedding_dimension: 嵌入维度
              :param device: 设备(如 'cpu' 或 'gpu')
              :param n_heads: 注意力头数量
              :param dropout: dropout率
              :param use_memory: 是否使用记忆模块
              """
              super(GraphSumEmbedding, self).__init__(node_features=node_features,
                                                      edge_features=edge_features,
                                                      memory=memory,
                                                      neighbor_finder=neighbor_finder,
                                                      time_encoder=time_encoder, n_layers=n_layers,
                                                      n_node_features=n_node_features,
                                                      n_edge_features=n_edge_features,
                                                      n_time_features=n_time_features,
                                                      embedding_dimension=embedding_dimension,
                                                      device=device,
                                                      n_heads=n_heads, dropout=dropout,
                                                      use_memory=use_memory)
              # 定义多层线性变换模块
              self.linear_1 = torch.nn.ModuleList([torch.nn.Linear(embedding_dimension + n_time_features +
                                                                   n_edge_features, embedding_dimension)
                                                   for _ in range(n_layers)])
              self.linear_2 = torch.nn.ModuleList(
                  [torch.nn.Linear(embedding_dimension + n_node_features + n_time_features,
                                   embedding_dimension) for _ in range(n_layers)])
      
          def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
                        neighbor_embeddings, edge_time_embeddings, edge_features, mask):
              """
              聚合源节点和邻居节点的嵌入表示
              :param n_layer: 当前层数
              :param source_node_features: 源节点特征
              :param source_nodes_time_embedding: 源节点时间嵌入表示
              :param neighbor_embeddings: 邻居节点嵌入表示
              :param edge_time_embeddings: 边的时间嵌入表示
              :param edge_features: 边特征
              :param mask: 掩码,用于忽略没有邻居的情况
              :return: 聚合后的嵌入表示
              """
              # 将邻居嵌入、时间嵌入和边特征拼接
              neighbors_features = torch.cat([neighbor_embeddings, edge_time_embeddings, edge_features], dim=2)
              # 对拼接后的特征进行线性变换
              neighbor_embeddings = self.linear_1[n_layer - 1](neighbors_features)
              # 通过加和并使用ReLU激活函数进行聚合
              neighbors_sum = torch.nn.functional.relu(torch.sum(neighbor_embeddings, dim=1))
      
              # 将源节点特征和时间嵌入拼接
              source_features = torch.cat([source_node_features, source_nodes_time_embedding.squeeze()], dim=1)
              # 将聚合后的邻居特征与源节点特征拼接
              source_embedding = torch.cat([neighbors_sum, source_features], dim=1)
              # 进行第二次线性变换
              source_embedding = self.linear_2[n_layer - 1](source_embedding)
      
              return source_embedding
      
      
  6. 具体实现类 GraphAttentionEmbedding

    • 功能:使用图注意力机制生成节点的嵌入表示。

    • 关键代码:

      class GraphAttentionEmbedding(GraphEmbedding):
          def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
                       n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
                       n_heads=2, dropout=0.1, use_memory=True):
              """
              初始化图注意力嵌入模块
              :param node_features: 节点特征
              :param edge_features: 边特征
              :param memory: 记忆模块
              :param neighbor_finder: 邻居查找模块
              :param time_encoder: 时间编码器
              :param n_layers: 层数
              :param n_node_features: 节点特征维度
              :param n_edge_features: 边特征维度
              :param n_time_features: 时间特征维度
              :param embedding_dimension: 嵌入维度
              :param device: 设备(如 'cpu' 或 'gpu')
              :param n_heads: 注意力头数量
              :param dropout: dropout率
              :param use_memory: 是否使用记忆模块
              """
              super(GraphAttentionEmbedding, self).__init__(node_features, edge_features, memory,
                                                            neighbor_finder, time_encoder, n_layers,
                                                            n_node_features, n_edge_features,
                                                            n_time_features,
                                                            embedding_dimension, device,
                                                            n_heads, dropout,
                                                            use_memory)
      
              # 初始化多层注意力模型
              self.attention_models = torch.nn.ModuleList([TemporalAttentionLayer(
                n_node_features=n_node_features,
                n_neighbors_features=n_node_features,
                n_edge_features=n_edge_features,
                time_dim=n_time_features,
                n_head=n_heads,
                dropout=dropout,
                output_dimension=n_node_features)
                for _ in range(n_layers)])
      
          def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
                        neighbor_embeddings,
                        edge_time_embeddings, edge_features, mask):
              """
              聚合源节点和邻居节点的嵌入表示
              :param n_layer: 当前层数
              :param source_node_features: 源节点特征
              :param source_nodes_time_embedding: 源节点时间嵌入表示
              :param neighbor_embeddings: 邻居节点嵌入表示
              :param edge_time_embeddings: 边的时间嵌入表示
              :param edge_features: 边特征
              :param mask: 掩码,用于忽略没有邻居的情况
              :return: 聚合后的嵌入表示
              """
              # 获取当前层的注意力模型
              attention_model = self.attention_models[n_layer - 1]
      
              # 使用注意力模型聚合特征
              source_embedding, _ = attention_model(source_node_features,
                                                    source_nodes_time_embedding,
                                                    neighbor_embeddings,
                                                    edge_time_embeddings,
                                                    edge_features,
                                                    mask)
      
              return source_embedding
      
      
  7. 工厂函数 get_embedding_module

    • 功能:根据给定的类型返回相应的嵌入模块实例。

    • 关键代码:

      def get_embedding_module(module_type, node_features, edge_features, memory, neighbor_finder,
                               time_encoder, n_layers, n_node_features, n_edge_features, n_time_features,
                               embedding_dimension, device,
                               n_heads=2, dropout=0.1, n_neighbors=None,
                               use_memory=True):
          """
          根据给定的模块类型返回相应的嵌入模块实例。
          :param module_type: 模块类型(如 "graph_attention", "graph_sum", "identity", "time")
          :param node_features: 节点特征
          :param edge_features: 边特征
          :param memory: 记忆模块
          :param neighbor_finder: 邻居查找模块
          :param time_encoder: 时间编码器
          :param n_layers: 时序卷积层的数量
          :param n_node_features: 节点特征的维度
          :param n_edge_features: 边特征的维度
          :param n_time_features: 时间特征的维度
          :param embedding_dimension: 嵌入维度
          :param device: 设备(如 'cpu' 或 'gpu')
          :param n_heads: 注意力头的数量(仅适用于注意力嵌入模块)
          :param dropout: dropout率
          :param n_neighbors: 每层时序卷积层考虑的邻居数
          :param use_memory: 是否使用记忆模块
          :return: 嵌入模块实例
          """
          if module_type == "graph_attention":
              return GraphAttentionEmbedding(node_features=node_features,
                                             edge_features=edge_features,
                                             memory=memory,
                                             neighbor_finder=neighbor_finder,
                                             time_encoder=time_encoder,
                                             n_layers=n_layers,
                                             n_node_features=n_node_features,
                                             n_edge_features=n_edge_features,
                                             n_time_features=n_time_features,
                                             embedding_dimension=embedding_dimension,
                                             device=device,
                                             n_heads=n_heads, dropout=dropout, use_memory=use_memory)
          elif module_type == "graph_sum":
              return GraphSumEmbedding(node_features=node_features,
                                       edge_features=edge_features,
                                       memory=memory,
                                       neighbor_finder=neighbor_finder,
                                       time_encoder=time_encoder,
                                       n_layers=n_layers,
                                       n_node_features=n_node_features,
                                       n_edge_features=n_edge_features,
                                       n_time_features=n_time_features,
                                       embedding_dimension=embedding_dimension,
                                       device=device,
                                       n_heads=n_heads, dropout=dropout, use_memory=use_memory)
      
          elif module_type == "identity":
              return IdentityEmbedding(node_features=node_features,
                                       edge_features=edge_features,
                                       memory=memory,
                                       neighbor_finder=neighbor_finder,
                                       time_encoder=time_encoder,
                                       n_layers=n_layers,
                                       n_node_features=n_node_features,
                                       n_edge_features=n_edge_features,
                                       n_time_features=n_time_features,
                                       embedding_dimension=embedding_dimension,
                                       device=device,
                                       dropout=dropout)
          elif module_type == "time":
              return TimeEmbedding(node_features=node_features,
                                   edge_features=edge_features,
                                   memory=memory,
                                   neighbor_finder=neighbor_finder,
                                   time_encoder=time_encoder,
                                   n_layers=n_layers,
                                   n_node_features=n_node_features,
                                   n_edge_features=n_edge_features,
                                   n_time_features=n_time_features,
                                   embedding_dimension=embedding_dimension,
                                   device=device,
                                   dropout=dropout,
                                   n_neighbors=n_neighbors)
          else:
              raise ValueError("Embedding Module {} not supported".format(module_type))