ISubGVQA.sampling.methods.simple_scheme
Attributes
Classes
Functions
|
Module Contents
- ISubGVQA.sampling.methods.simple_scheme.LARGE_NUMBER = 10000000000.0
- ISubGVQA.sampling.methods.simple_scheme.logsigmoid(x)
- class ISubGVQA.sampling.methods.simple_scheme.EdgeSIMPLEBatched(k, device, policy, val_ensemble=1, train_ensemble=1, logits_activation=None)
Bases:
torch.nn.Module- k
- device
- policy
- layer_configs
- adj = None
- val_ensemble = 1
- train_ensemble = 1
- logits_activation = None
- forward(scores, train=True)
- validation(scores)
during the inference we need to margin-out the stochasticity thus we do top-k once or sample multiple times
- Args:
scores: shape B x N x N x E
- Returns:
mask: shape B x N x N x (E x VE)