ISubGVQA.sampling.methods.simple_scheme

Attributes

LARGE_NUMBER

Classes

EdgeSIMPLEBatched

Functions

logsigmoid(x)

Module Contents

ISubGVQA.sampling.methods.simple_scheme.LARGE_NUMBER = 10000000000.0
ISubGVQA.sampling.methods.simple_scheme.logsigmoid(x)
class ISubGVQA.sampling.methods.simple_scheme.EdgeSIMPLEBatched(k, device, policy, val_ensemble=1, train_ensemble=1, logits_activation=None)

Bases: torch.nn.Module

k
device
policy
layer_configs
adj = None
val_ensemble = 1
train_ensemble = 1
logits_activation = None
forward(scores, train=True)
validation(scores)

during the inference we need to margin-out the stochasticity thus we do top-k once or sample multiple times

Args:

scores: shape B x N x N x E

Returns:

mask: shape B x N x N x (E x VE)