ISubGVQA.models.att_pooling

Classes

GlobalAttention

GlobalAttention is a neural network module that applies attention mechanisms to graph data.

Module Contents

class ISubGVQA.models.att_pooling.GlobalAttention(num_node_features, num_out_features)

Bases: torch.nn.Module

GlobalAttention is a neural network module that applies attention mechanisms to graph data. Args:

num_node_features (int): The number of input features per node. num_out_features (int): The number of output features.

Methods:
reset_parameters():

Resets the parameters of the neural network layers.

forward(x, u, batch, size=None, return_mask=False, node_mask=None):

Forward pass of the GlobalAttention module. Args:

x (Tensor): Node feature matrix with shape [num_nodes, num_node_features]. u (Tensor): Global feature matrix with shape [batch_size, num_out_features]. batch (Tensor): Batch vector which assigns each node to a specific example in the batch. size (int, optional): The number of examples in the batch. If None, it is inferred from the batch vector. return_mask (bool, optional): If True, returns the attention mask along with the output. node_mask (Tensor, optional): Mask to apply on the node features.

Returns:

Tensor: The output feature matrix with shape [batch_size, num_out_features]. Tensor (optional): The attention mask with shape [num_nodes, 1] if return_mask is True.

__repr__():

Returns a string representation of the GlobalAttention module.

gate_nn
node_nn
ques_nn
reset_parameters()
forward(x, u, batch, size=None, return_mask=False, node_mask=None)
__repr__()