ISubGVQA.models.att_pooling
Classes
GlobalAttention is a neural network module that applies attention mechanisms to graph data. |
Module Contents
- class ISubGVQA.models.att_pooling.GlobalAttention(num_node_features, num_out_features)
Bases:
torch.nn.ModuleGlobalAttention is a neural network module that applies attention mechanisms to graph data. Args:
num_node_features (int): The number of input features per node. num_out_features (int): The number of output features.
- Methods:
- reset_parameters():
Resets the parameters of the neural network layers.
- forward(x, u, batch, size=None, return_mask=False, node_mask=None):
Forward pass of the GlobalAttention module. Args:
x (Tensor): Node feature matrix with shape [num_nodes, num_node_features]. u (Tensor): Global feature matrix with shape [batch_size, num_out_features]. batch (Tensor): Batch vector which assigns each node to a specific example in the batch. size (int, optional): The number of examples in the batch. If None, it is inferred from the batch vector. return_mask (bool, optional): If True, returns the attention mask along with the output. node_mask (Tensor, optional): Mask to apply on the node features.
- Returns:
Tensor: The output feature matrix with shape [batch_size, num_out_features]. Tensor (optional): The attention mask with shape [num_nodes, 1] if return_mask is True.
- __repr__():
Returns a string representation of the GlobalAttention module.
- gate_nn
- node_nn
- ques_nn
- reset_parameters()
- forward(x, u, batch, size=None, return_mask=False, node_mask=None)
- __repr__()