ISubGVQA.models.mgat
Classes
Module Contents
- class ISubGVQA.models.mgat.MGAT(channels, num_ins, dropout=0.0, heads=4, use_instr=False, masking_thresholds=None, use_topk: bool = False, interpretable_mode: bool = True, concat_instr: bool = False, use_all_instrs: bool = False, use_global_mask: bool = False, node_classification: bool = False, sampler_type: str = None, sample_k: int = None, nb_samples: int = 1, alpha=1.0, beta=10.0, tau=1.0)
Bases:
torch.nn.Module- masking_thresholds = None
- use_global_mask = False
- node_classification = False
- heads = 4
- use_instr = False
- use_topk = False
- interpretable_mode = True
- use_all_instrs = False
- convs
- x_proj
- bns
- dropout = 0.0
- node_logits
- reset_parameters()
- forward(x, edge_index, instr_vectors, global_language_feats, edge_attr, batch, return_masks=False, explainer=False, explainer_stage=False, expl_bypass_x=False)