ISubGVQA.models.mgat

Classes

MGAT

Module Contents

class ISubGVQA.models.mgat.MGAT(channels, num_ins, dropout=0.0, heads=4, use_instr=False, masking_thresholds=None, use_topk: bool = False, interpretable_mode: bool = True, concat_instr: bool = False, use_all_instrs: bool = False, use_global_mask: bool = False, node_classification: bool = False, sampler_type: str = None, sample_k: int = None, nb_samples: int = 1, alpha=1.0, beta=10.0, tau=1.0)

Bases: torch.nn.Module

masking_thresholds = None
use_global_mask = False
node_classification = False
heads = 4
use_instr = False
use_topk = False
interpretable_mode = True
use_all_instrs = False
convs
x_proj
bns
dropout = 0.0
node_logits
reset_parameters()
forward(x, edge_index, instr_vectors, global_language_feats, edge_attr, batch, return_masks=False, explainer=False, explainer_stage=False, expl_bypass_x=False)