ISubGVQA.models.masking ======================= .. py:module:: ISubGVQA.models.masking Classes ------- .. autoapisummary:: ISubGVQA.models.masking.MaskingModel Functions --------- .. autoapisummary:: ISubGVQA.models.masking.topk_sampling ISubGVQA.models.masking.get_imle_samplers ISubGVQA.models.masking.get_aimle_samplers Module Contents --------------- .. py:class:: MaskingModel(dim_nodes, dim_questions, masking_threshold=0.3, use_topk=False, sample_k=None, sampler_type=None, nb_samples=1, alpha=1.0, beta=10.0, tau=1.0) Bases: :py:obj:`torch.nn.Module` A PyTorch module for masking nodes in a graph based on various sampling methods. Args: dim_nodes (int): Dimension of node features. dim_questions (int): Dimension of question features. masking_threshold (float, optional): Threshold for masking. Defaults to 0.3. use_topk (bool, optional): Whether to use top-k pooling. Defaults to False. sample_k (int, optional): Number of samples for the sampler. Defaults to None. sampler_type (str, optional): Type of sampler to use ('imle', 'aimle', 'simple', 'gumbel'). Defaults to None. nb_samples (int, optional): Number of samples for I-MLE samplers. Defaults to 1. alpha (float, optional): Alpha parameter for I-MLE samplers. Defaults to 1.0. beta (float, optional): Beta parameter for I-MLE samplers. Defaults to 10.0. tau (float, optional): Tau parameter for I-MLE samplers. Defaults to 1.0. Methods: reset_parameters(): Resets the parameters of the neural networks. forward(x, u, batch, edge_index, size=None, use_all_instrs=True): Forward pass of the model. Args: x (Tensor): Node features. u (Tensor): Question features. batch (Tensor): Batch indices. edge_index (Tensor): Edge indices. size (int, optional): Size of the batch. Defaults to None. use_all_instrs (bool, optional): Whether to use all instructions. Defaults to True. Returns: Tensor: Masked node features. .. py:attribute:: use_topk :value: False .. py:attribute:: sample_k :value: None .. py:attribute:: sampler_type :value: None .. py:attribute:: masking_threshold :value: 0 .. py:attribute:: dim_nodes .. py:attribute:: dim_questions .. py:attribute:: gate_nn .. py:attribute:: node_nn .. py:attribute:: ques_nn .. py:method:: reset_parameters() .. py:method:: forward(x, u, batch, edge_index, size=None, use_all_instrs=True) .. py:function:: topk_sampling(gate, masking_threshold, batch) .. py:function:: get_imle_samplers(sample_k, beta=10, alpha=1.0, tau=1.0, noise_scale=0.3, nb_samples=1, device=None) .. py:function:: get_aimle_samplers(sample_k, alpha=1.0, tau=1.0, noise_scale=0.3, nb_samples=1, device=None)