ISubGVQA.models.isubgvqa

Classes

ISubGVQA

ISubGVQA is a PyTorch neural network module designed for Visual Question Answering (VQA) tasks.

Module Contents

class ISubGVQA.models.isubgvqa.ISubGVQA(args, use_imle=False, use_masking=True, use_instruction=True, use_mgat=False, mgat_masks=None, use_topk=False, interpretable_mode=True, concat_instr=False, embed_cat=True)

Bases: torch.nn.Module

ISubGVQA is a PyTorch neural network module designed for Visual Question Answering (VQA) tasks. It integrates various components such as scene graph encoding, question encoding, and multi-head graph attention networks to process and interpret visual and textual data. Args:

args (Namespace): Configuration arguments. use_imle (bool, optional): Flag to use IMLE. Defaults to False. use_masking (bool, optional): Flag to use masking. Defaults to True. use_instruction (bool, optional): Flag to use instruction. Defaults to True. use_mgat (bool, optional): Flag to use multi-head graph attention. Defaults to False. mgat_masks (optional): Masks for MGAT. Defaults to None. use_topk (bool, optional): Flag to use top-k sampling. Defaults to False. interpretable_mode (bool, optional): Flag for interpretable mode. Defaults to True. concat_instr (bool, optional): Flag to concatenate instructions. Defaults to False. embed_cat (bool, optional): Flag to embed categories. Defaults to True.

Attributes:

args (Namespace): Configuration arguments. n_train_steps (int): Number of training steps. n_valid_steps (int): Number of validation steps. use_imle (bool): Flag to use IMLE. use_instruction (bool): Flag to use instruction. use_masking (bool): Flag to use masking. use_mgat (bool): Flag to use multi-head graph attention. interpretable_mode (bool): Flag for interpretable mode. concat_instr (bool): Flag to concatenate instructions. embed_cat (bool): Flag to embed categories. text_sampling (bool): Flag for text sampling. general_hidden_dim (int): General hidden dimension size. scene_graph_encoder (SceneGraphEncoder): Scene graph encoder module. text_emb_dim (int): Text embedding dimension size. text_vocab_embedding (torch.nn.Embedding): Text vocabulary embedding. question_hidden_dim (int): Question hidden dimension size. question_encoder (QuestionEncoder): Question encoder module. text_sampler (EdgeSIMPLEBatched, optional): Text sampler module. qsts_att_keys (torch.nn.Sequential, optional): Attention keys for questions. qsts_att_query (torch.nn.Sequential, optional): Attention query for questions. program_decoder (QuestionDecoder): Program decoder module. gat_seq (MGAT): Multi-head graph attention module. graph_global_attention_pooling (GlobalAttention): Global attention pooling layer. qsts_reduction (torch.nn.Sequential): Question reduction layer. instr_reduction (torch.nn.Sequential): Instruction reduction layer. embedding (torch.nn.Sequential): Embedding layer. logit_fc (torch.nn.Linear): Final classification layer.

Methods:
forward(node_embeddings, edge_index, edge_embeddings, batch, questions, qsts_att_mask,

return_masks=False, explainer=False, explainer_stage=False, expl_bypass_x=False, scene_graphs=None):

Forward pass of the model. Args:

node_embeddings (torch.Tensor): Node embeddings. edge_index (torch.Tensor): Edge indices. edge_embeddings (torch.Tensor): Edge embeddings. batch (torch.Tensor): Batch indices. questions (torch.Tensor): Encoded questions. qsts_att_mask (torch.Tensor): Attention mask for questions. return_masks (bool, optional): Flag to return masks. Defaults to False. explainer (bool, optional): Flag for explainer mode. Defaults to False. explainer_stage (bool, optional): Stage for explainer. Defaults to False. expl_bypass_x (bool, optional): Bypass for explainer. Defaults to False. scene_graphs (optional): Scene graphs. Defaults to None.

Returns:

torch.Tensor: Model logits. torch.Tensor: IMLE mask. torch.Tensor: MGAT gate. torch.Tensor: Node logits layers. torch.Tensor: Mask text.

args
n_train_steps = 0
n_valid_steps = 0
use_imle = False
use_instruction = True
use_masking = True
use_mgat = False
interpretable_mode = True
concat_instr = False
embed_cat = True
text_sampling
general_hidden_dim
scene_graph_encoder
text_emb_dim = 512
text_vocab_embedding
question_hidden_dim
question_encoder
program_decoder
gat_seq
graph_global_attention_pooling
qsts_reduction
instr_reduction
embedding
logit_fc
forward(node_embeddings, edge_index, edge_embeddings, batch, questions, qsts_att_mask, return_masks=False, explainer=False, explainer_stage=False, expl_bypass_x=False, scene_graphs=None)