ISubGVQA.models.isubgvqa ======================== .. py:module:: ISubGVQA.models.isubgvqa Classes ------- .. autoapisummary:: ISubGVQA.models.isubgvqa.ISubGVQA Module Contents --------------- .. py:class:: ISubGVQA(args, use_imle=False, use_masking=True, use_instruction=True, use_mgat=False, mgat_masks=None, use_topk=False, interpretable_mode=True, concat_instr=False, embed_cat=True) Bases: :py:obj:`torch.nn.Module` ISubGVQA is a PyTorch neural network module designed for Visual Question Answering (VQA) tasks. It integrates various components such as scene graph encoding, question encoding, and multi-head graph attention networks to process and interpret visual and textual data. Args: args (Namespace): Configuration arguments. use_imle (bool, optional): Flag to use IMLE. Defaults to False. use_masking (bool, optional): Flag to use masking. Defaults to True. use_instruction (bool, optional): Flag to use instruction. Defaults to True. use_mgat (bool, optional): Flag to use multi-head graph attention. Defaults to False. mgat_masks (optional): Masks for MGAT. Defaults to None. use_topk (bool, optional): Flag to use top-k sampling. Defaults to False. interpretable_mode (bool, optional): Flag for interpretable mode. Defaults to True. concat_instr (bool, optional): Flag to concatenate instructions. Defaults to False. embed_cat (bool, optional): Flag to embed categories. Defaults to True. Attributes: args (Namespace): Configuration arguments. n_train_steps (int): Number of training steps. n_valid_steps (int): Number of validation steps. use_imle (bool): Flag to use IMLE. use_instruction (bool): Flag to use instruction. use_masking (bool): Flag to use masking. use_mgat (bool): Flag to use multi-head graph attention. interpretable_mode (bool): Flag for interpretable mode. concat_instr (bool): Flag to concatenate instructions. embed_cat (bool): Flag to embed categories. text_sampling (bool): Flag for text sampling. general_hidden_dim (int): General hidden dimension size. scene_graph_encoder (SceneGraphEncoder): Scene graph encoder module. text_emb_dim (int): Text embedding dimension size. text_vocab_embedding (torch.nn.Embedding): Text vocabulary embedding. question_hidden_dim (int): Question hidden dimension size. question_encoder (QuestionEncoder): Question encoder module. text_sampler (EdgeSIMPLEBatched, optional): Text sampler module. qsts_att_keys (torch.nn.Sequential, optional): Attention keys for questions. qsts_att_query (torch.nn.Sequential, optional): Attention query for questions. program_decoder (QuestionDecoder): Program decoder module. gat_seq (MGAT): Multi-head graph attention module. graph_global_attention_pooling (GlobalAttention): Global attention pooling layer. qsts_reduction (torch.nn.Sequential): Question reduction layer. instr_reduction (torch.nn.Sequential): Instruction reduction layer. embedding (torch.nn.Sequential): Embedding layer. logit_fc (torch.nn.Linear): Final classification layer. Methods: forward(node_embeddings, edge_index, edge_embeddings, batch, questions, qsts_att_mask, return_masks=False, explainer=False, explainer_stage=False, expl_bypass_x=False, scene_graphs=None): Forward pass of the model. Args: node_embeddings (torch.Tensor): Node embeddings. edge_index (torch.Tensor): Edge indices. edge_embeddings (torch.Tensor): Edge embeddings. batch (torch.Tensor): Batch indices. questions (torch.Tensor): Encoded questions. qsts_att_mask (torch.Tensor): Attention mask for questions. return_masks (bool, optional): Flag to return masks. Defaults to False. explainer (bool, optional): Flag for explainer mode. Defaults to False. explainer_stage (bool, optional): Stage for explainer. Defaults to False. expl_bypass_x (bool, optional): Bypass for explainer. Defaults to False. scene_graphs (optional): Scene graphs. Defaults to None. Returns: torch.Tensor: Model logits. torch.Tensor: IMLE mask. torch.Tensor: MGAT gate. torch.Tensor: Node logits layers. torch.Tensor: Mask text. .. py:attribute:: args .. py:attribute:: n_train_steps :value: 0 .. py:attribute:: n_valid_steps :value: 0 .. py:attribute:: use_imle :value: False .. py:attribute:: use_instruction :value: True .. py:attribute:: use_masking :value: True .. py:attribute:: use_mgat :value: False .. py:attribute:: interpretable_mode :value: True .. py:attribute:: concat_instr :value: False .. py:attribute:: embed_cat :value: True .. py:attribute:: text_sampling .. py:attribute:: general_hidden_dim .. py:attribute:: scene_graph_encoder .. py:attribute:: text_emb_dim :value: 512 .. py:attribute:: text_vocab_embedding .. py:attribute:: question_hidden_dim .. py:attribute:: question_encoder .. py:attribute:: program_decoder .. py:attribute:: gat_seq .. py:attribute:: graph_global_attention_pooling .. py:attribute:: qsts_reduction .. py:attribute:: instr_reduction .. py:attribute:: embedding .. py:attribute:: logit_fc .. py:method:: forward(node_embeddings, edge_index, edge_embeddings, batch, questions, qsts_att_mask, return_masks=False, explainer=False, explainer_stage=False, expl_bypass_x=False, scene_graphs=None)