ISubGVQA.sampling.methods.simple

Attributes

DISABLE

MODE

Classes

Layer

Functions

levelwiseSL(levels, idx2primesub, data, theta)

levelwiseMars(levels, idx2primesub, data, theta, parents)

log1mexp(x)

levelOrder(beta)

gumbel_keys(w, time_sampled)

sample_subset(w, k, time_sampled)

Args:

Module Contents

ISubGVQA.sampling.methods.simple.DISABLE = False
ISubGVQA.sampling.methods.simple.MODE = 'default'
ISubGVQA.sampling.methods.simple.levelwiseSL(levels: List[torch.Tensor], idx2primesub: torch.Tensor, data: torch.Tensor, theta: torch.Tensor)
ISubGVQA.sampling.methods.simple.levelwiseMars(levels: List[torch.Tensor], idx2primesub: torch.Tensor, data: torch.Tensor, theta: torch.Tensor, parents: torch.Tensor)
ISubGVQA.sampling.methods.simple.log1mexp(x)
ISubGVQA.sampling.methods.simple.levelOrder(beta)
Return type:

List[List[int]]

ISubGVQA.sampling.methods.simple.gumbel_keys(w, time_sampled)
ISubGVQA.sampling.methods.simple.sample_subset(w, k, time_sampled)
Args:
w (Tensor): Float Tensor of weights for each element. In gumbel mode

these are interpreted as log probabilities

k (int): number of elements in the subset sample

class ISubGVQA.sampling.methods.simple.Layer(n, k, device, root='./simple_configs')
id = 0
parents
levels = []
true_indices
literal_indices
literal_mask
pos_literals
idx2primesub
__call__(log_probs, k)
log_pr(log_probs)
sample(lit_weights, k, time_sampled=1)