ISubGVQA.models.masking
Classes
A PyTorch module for masking nodes in a graph based on various sampling methods. |
Functions
|
|
|
|
|
Module Contents
- class ISubGVQA.models.masking.MaskingModel(dim_nodes, dim_questions, masking_threshold=0.3, use_topk=False, sample_k=None, sampler_type=None, nb_samples=1, alpha=1.0, beta=10.0, tau=1.0)
Bases:
torch.nn.ModuleA PyTorch module for masking nodes in a graph based on various sampling methods. Args:
dim_nodes (int): Dimension of node features. dim_questions (int): Dimension of question features. masking_threshold (float, optional): Threshold for masking. Defaults to 0.3. use_topk (bool, optional): Whether to use top-k pooling. Defaults to False. sample_k (int, optional): Number of samples for the sampler. Defaults to None. sampler_type (str, optional): Type of sampler to use (‘imle’, ‘aimle’, ‘simple’, ‘gumbel’). Defaults to None. nb_samples (int, optional): Number of samples for I-MLE samplers. Defaults to 1. alpha (float, optional): Alpha parameter for I-MLE samplers. Defaults to 1.0. beta (float, optional): Beta parameter for I-MLE samplers. Defaults to 10.0. tau (float, optional): Tau parameter for I-MLE samplers. Defaults to 1.0.
- Methods:
- reset_parameters():
Resets the parameters of the neural networks.
- forward(x, u, batch, edge_index, size=None, use_all_instrs=True):
Forward pass of the model. Args:
x (Tensor): Node features. u (Tensor): Question features. batch (Tensor): Batch indices. edge_index (Tensor): Edge indices. size (int, optional): Size of the batch. Defaults to None. use_all_instrs (bool, optional): Whether to use all instructions. Defaults to True.
- Returns:
Tensor: Masked node features.
- use_topk = False
- sample_k = None
- sampler_type = None
- masking_threshold = 0
- dim_nodes
- dim_questions
- gate_nn
- node_nn
- ques_nn
- reset_parameters()
- forward(x, u, batch, edge_index, size=None, use_all_instrs=True)
- ISubGVQA.models.masking.topk_sampling(gate, masking_threshold, batch)
- ISubGVQA.models.masking.get_imle_samplers(sample_k, beta=10, alpha=1.0, tau=1.0, noise_scale=0.3, nb_samples=1, device=None)
- ISubGVQA.models.masking.get_aimle_samplers(sample_k, alpha=1.0, tau=1.0, noise_scale=0.3, nb_samples=1, device=None)