ISubGVQA.sampling.methods.tensor_utils

Functions

batched_edge_index_to_batched_adj(data)

Args:

self_defined_softmax(scores, mask)

A specific function

weighted_cross_entropy(pred, true)

Weighted cross-entropy for unbalanced classes.

non_merge_coalesce(edge_index, edge_attr, edge_weight, ...)

batch_repeat_edge_index(edge_index, num_nodes, repeats)

Module Contents

ISubGVQA.sampling.methods.tensor_utils.batched_edge_index_to_batched_adj(data: torch_geometric.data.Data)
Args:

data: should be the original batch, i.e. without ensembles

Returns:
batched edge_index ([0, 0, 0, 1, 1, 1, …],

[0, 0, 1, 0, 1, 1, …], [1, 2, 1, 0, 0, 2, …]) with self loops

ISubGVQA.sampling.methods.tensor_utils.self_defined_softmax(scores, mask)

A specific function

Args:

scores: B, N, N, E mask: same shape as scores

Returns:

ISubGVQA.sampling.methods.tensor_utils.weighted_cross_entropy(pred, true)

Weighted cross-entropy for unbalanced classes. https://github.com/rampasek/GraphGPS/blob/main/graphgps/loss/weighted_cross_entropy.py

ISubGVQA.sampling.methods.tensor_utils.non_merge_coalesce(edge_index, edge_attr, edge_weight, num_nodes, is_sorted: bool = False, sort_by_row: bool = True)
ISubGVQA.sampling.methods.tensor_utils.batch_repeat_edge_index(edge_index: torch.Tensor, num_nodes: int, repeats: int)