四、Jbeil model
1.time_encoding
这个代码定义了一个名为TimeEncode
的PyTorch模块,该模块实现了时间编码功能,具体参考了TGAT(Temporal Graph Attention Network)的时间编码方法。以下是代码的详细解释:
类定义和初始化
class TimeEncode(torch.nn.Module):
# Time Encoding proposed by TGAT
def __init__(self, dimension):
super(TimeEncode, self).__init__()
self.dimension = dimension
self.w = torch.nn.Linear(1, dimension)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
.float().reshape(dimension, -1))
self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float())
前向传播
def forward(self, t):
# t has shape [batch_size, seq_len]
# Add dimension at the end to apply linear layer --> [batch_size, seq_len, 1]
t = t.unsqueeze(dim=2)
# output has shape [batch_size, seq_len, dimension]
output = torch.cos(self.w(t))
return output
forward
方法
-
forward(self, t)
: 前向传播方法,接受一个时间张量t
作为输入。t
: 输入时间张量,形状为[batch_size, seq_len]
,其中batch_size
是批次大小,seq_len
是序列长度。
-
t = t.unsqueeze(dim=2)
: 在时间张量的最后增加一个维度,使其形状变为[batch_size, seq_len, 1]
,这样可以将其输入到线性层中。 -
output = torch.cos(self.w(t))
:self.w(t)
: 将时间张量输入到线性层,得到形状为[batch_size, seq_len, dimension]
的输出。torch.cos(...)
: 对线性层的输出应用cos
函数,得到时间编码。
-
return output
: 返回时间编码,形状为[batch_size, seq_len, dimension]
。
这个TimeEncode
模块通过一个线性变换和余弦函数将时间信息编码到一个高维空间中,以便用于时间序列或图网络中的时间依赖任务。这种时间编码方法在TGAT模型中被提出,用于捕捉时间信息对节点或边特征的影响。
2.temporal_attention
这段代码定义了一个TemporalAttentionLayer
类,该类实现了时间注意力机制,用于给定一个节点及其邻居的特征和边的时间戳,返回该节点的时间嵌入。
前向传播
forward
方法
-
forward(self, src_node_features, src_time_features, neighbors_features, neighbors_time_features, edge_features, neighbors_padding_mask)
: 前向传播方法,接受多个输入参数并返回注意力输出和权重。src_node_features
: 形状为[batch_size, n_node_features]
的节点特征张量。src_time_features
: 形状为[batch_size, 1, time_dim]
的时间特征张量。neighbors_features
: 形状为[batch_size, n_neighbors, n_node_features]
的邻居节点特征张量。neighbors_time_features
: 形状为[batch_size, n_neighbors, time_dim]
的邻居时间特征张量。edge_features
: 形状为[batch_size, n_neighbors, n_edge_features]
的边特征张量。neighbors_padding_mask
: 形状为[batch_size, n_neighbors]
的邻居填充掩码张量。
-
src_node_features_unrolled = torch.unsqueeze(src_node_features, dim=1)
: 在节点特征张量的维度1增加一个维度,使其形状变为[batch_size, 1, n_node_features]
。 -
query = torch.cat([src_node_features_unrolled, src_time_features], dim=2)
: 将节点特征和时间特征在维度2上连接,形成查询张量query
,形状为[batch_size, 1, query_dim]
。 -
key = torch.cat([neighbors_features, edge_features, neighbors_time_features], dim=2)
: 将邻居特征、边特征和时间特征在维度2上连接,形成键张量key
,形状为[batch_size, n_neighbors, key_dim]
。 -
query = query.permute([1, 0, 2])
: 将查询张量的维度顺序调整为[1, batch_size, query_dim]
,以适应多头注意力机制的输入要求。 -
key = key.permute([1, 0, 2])
: 将键张量的维度顺序调整为[n_neighbors, batch_size, key_dim]
。 -
invalid_neighborhood_mask = neighbors_padding_mask.all(dim=1, keepdim=True)
: 计算无效邻居掩码,即那些没有有效邻居的源节点。 -
neighbors_padding_mask[invalid_neighborhood_mask.squeeze(), 0] = False
: 对于没有有效邻居的源节点,将其第一个邻居设置为有效邻居。 -
attn_output, attn_output_weights = self.multi_head_target(query=query, key=key, value=key, key_padding_mask=neighbors_padding_mask)
: 使用多头注意力机制计算注意力输出和权重。 -
attn_output = attn_output.squeeze()
: 移除多余的维度。 -
attn_output_weights = attn_output_weights.squeeze()
: 移除多余的维度。 -
attn_output = attn_output.masked_fill(invalid_neighborhood_mask, 0)
: 将没有有效邻居的源节点的注意力输出填充为0。 -
attn_output_weights = attn_output_weights.masked_fill(invalid_neighborhood_mask, 0)
: 将没有有效邻居的源节点的注意力权重填充为0。 -
attn_output = self.merger(attn_output, src_node_features)
: 使用MergeLayer
将注意力输出和源节点特征进行融合。 -
return attn_output, attn_output_weights
: 返回注意力输出和权重。
TemporalAttentionLayer
类实现了时间注意力机制,通过考虑节点特征、邻居特征、边特征和时间特征,生成节点的时间嵌入。该类的核心在于使用多头注意力机制来计算注意力权重和输出,并将结果与原始节点特征进行融合,以得到最终的节点嵌入。
3.TGN
类和组件
TGN
类
TGN类是整个模型的核心,实现了Temporal Graph Network的主要功能。
初始化方法 __init__
- 初始化TGN的各种组件和参数。
- 主要参数包括邻居查找器(neighbor_finder)、节点和边的特征(node_features, edge_features)、设备(device)、层数(n_layers)、头数(n_heads)等。
- 设置时间编码器(TimeEncode)、内存(Memory)以及用于处理消息和内存更新的各种模块(message_aggregator, message_function, memory_updater)。
compute_temporal_embeddings
方法
- 计算给定源节点、目标节点和负采样节点的时间嵌入。
- 使用嵌入模块计算节点的时间嵌入,考虑到内存和时间差异。
- 如果使用内存,则在计算嵌入前更新内存,并在计算嵌入后更新内存中的消息。
compute_edge_probabilities
方法
- 计算给定源节点、目标节点和负采样节点之间边的概率。
- 首先调用
compute_temporal_embeddings
计算节点的时间嵌入,然后将嵌入输入到MLP(全连接层)中计算边的概率。
update_memory
方法
- 更新给定节点的内存。
- 先聚合消息,再计算消息,然后使用聚合的消息更新内存。
get_updated_memory
方法
- 获取更新后的内存状态。
- 与
update_memory
类似,但返回更新后的内存和最后更新的时间戳。
get_raw_messages
方法
- 获取原始消息,准备更新内存。
- 包含源节点和目标节点的嵌入、边特征和时间编码。
set_neighbor_finder
方法
- 设置新的邻居查找器,并更新嵌入模块中的邻居查找器。
逻辑流程
-
初始化:
__init__
方法设置模型的参数和组件,包括时间编码器、内存和消息处理模块。
-
计算嵌入:
compute_temporal_embeddings
方法根据时间信息和邻居信息计算节点的嵌入。- 如果使用内存,则在计算嵌入前更新内存,并在计算后更新内存。
-
计算边的概率:
compute_edge_probabilities
方法使用计算得到的节点嵌入,通过全连接层计算边的概率。
-
更新内存:
update_memory
方法根据新的消息更新节点的内存。get_updated_memory
方法返回更新后的内存状态。
-
获取原始消息:
get_raw_messages
方法准备内存更新所需的原始消息,包括节点嵌入和时间编码。
-
设置邻居查找器:
set_neighbor_finder
方法允许动态更新邻居查找器,以便在不同的时间点使用不同的邻居信息。
这个代码结构使得TGN能够处理动态图数据,通过时间嵌入和内存机制来捕捉节点和边的动态特性。