ISubGVQA.models.masking

Classes

MaskingModel

A PyTorch module for masking nodes in a graph based on various sampling methods.

Functions

topk_sampling(gate, masking_threshold, batch)

get_imle_samplers(sample_k[, beta, alpha, tau, ...])

get_aimle_samplers(sample_k[, alpha, tau, ...])

Module Contents

class ISubGVQA.models.masking.MaskingModel(dim_nodes, dim_questions, masking_threshold=0.3, use_topk=False, sample_k=None, sampler_type=None, nb_samples=1, alpha=1.0, beta=10.0, tau=1.0)

Bases: torch.nn.Module

A PyTorch module for masking nodes in a graph based on various sampling methods. Args:

dim_nodes (int): Dimension of node features. dim_questions (int): Dimension of question features. masking_threshold (float, optional): Threshold for masking. Defaults to 0.3. use_topk (bool, optional): Whether to use top-k pooling. Defaults to False. sample_k (int, optional): Number of samples for the sampler. Defaults to None. sampler_type (str, optional): Type of sampler to use (‘imle’, ‘aimle’, ‘simple’, ‘gumbel’). Defaults to None. nb_samples (int, optional): Number of samples for I-MLE samplers. Defaults to 1. alpha (float, optional): Alpha parameter for I-MLE samplers. Defaults to 1.0. beta (float, optional): Beta parameter for I-MLE samplers. Defaults to 10.0. tau (float, optional): Tau parameter for I-MLE samplers. Defaults to 1.0.

Methods:
reset_parameters():

Resets the parameters of the neural networks.

forward(x, u, batch, edge_index, size=None, use_all_instrs=True):

Forward pass of the model. Args:

x (Tensor): Node features. u (Tensor): Question features. batch (Tensor): Batch indices. edge_index (Tensor): Edge indices. size (int, optional): Size of the batch. Defaults to None. use_all_instrs (bool, optional): Whether to use all instructions. Defaults to True.

Returns:

Tensor: Masked node features.

use_topk = False
sample_k = None
sampler_type = None
masking_threshold = 0
dim_nodes
dim_questions
gate_nn
node_nn
ques_nn
reset_parameters()
forward(x, u, batch, edge_index, size=None, use_all_instrs=True)
ISubGVQA.models.masking.topk_sampling(gate, masking_threshold, batch)
ISubGVQA.models.masking.get_imle_samplers(sample_k, beta=10, alpha=1.0, tau=1.0, noise_scale=0.3, nb_samples=1, device=None)
ISubGVQA.models.masking.get_aimle_samplers(sample_k, alpha=1.0, tau=1.0, noise_scale=0.3, nb_samples=1, device=None)