ISubGVQA.sampling.methods.tensor_utils
Functions
Args: |
|
|
A specific function |
|
Weighted cross-entropy for unbalanced classes. |
|
|
|
Module Contents
- ISubGVQA.sampling.methods.tensor_utils.batched_edge_index_to_batched_adj(data: torch_geometric.data.Data)
- Args:
data: should be the original batch, i.e. without ensembles
- Returns:
- batched edge_index ([0, 0, 0, 1, 1, 1, …],
[0, 0, 1, 0, 1, 1, …], [1, 2, 1, 0, 0, 2, …]) with self loops
- ISubGVQA.sampling.methods.tensor_utils.self_defined_softmax(scores, mask)
A specific function
- Args:
scores: B, N, N, E mask: same shape as scores
Returns:
- ISubGVQA.sampling.methods.tensor_utils.weighted_cross_entropy(pred, true)
Weighted cross-entropy for unbalanced classes. https://github.com/rampasek/GraphGPS/blob/main/graphgps/loss/weighted_cross_entropy.py
- ISubGVQA.sampling.methods.tensor_utils.non_merge_coalesce(edge_index, edge_attr, edge_weight, num_nodes, is_sorted: bool = False, sort_by_row: bool = True)
- ISubGVQA.sampling.methods.tensor_utils.batch_repeat_edge_index(edge_index: torch.Tensor, num_nodes: int, repeats: int)