ISubGVQA.sampling.methods.tensor_utils ====================================== .. py:module:: ISubGVQA.sampling.methods.tensor_utils Functions --------- .. autoapisummary:: ISubGVQA.sampling.methods.tensor_utils.batched_edge_index_to_batched_adj ISubGVQA.sampling.methods.tensor_utils.self_defined_softmax ISubGVQA.sampling.methods.tensor_utils.weighted_cross_entropy ISubGVQA.sampling.methods.tensor_utils.non_merge_coalesce ISubGVQA.sampling.methods.tensor_utils.batch_repeat_edge_index Module Contents --------------- .. py:function:: batched_edge_index_to_batched_adj(data: torch_geometric.data.Data) Args: data: should be the original batch, i.e. without ensembles Returns: batched edge_index ([0, 0, 0, 1, 1, 1, ...], [0, 0, 1, 0, 1, 1, ...], [1, 2, 1, 0, 0, 2, ...]) with self loops .. py:function:: self_defined_softmax(scores, mask) A specific function Args: scores: B, N, N, E mask: same shape as scores Returns: .. py:function:: weighted_cross_entropy(pred, true) Weighted cross-entropy for unbalanced classes. https://github.com/rampasek/GraphGPS/blob/main/graphgps/loss/weighted_cross_entropy.py .. py:function:: non_merge_coalesce(edge_index, edge_attr, edge_weight, num_nodes, is_sorted: bool = False, sort_by_row: bool = True) .. py:function:: batch_repeat_edge_index(edge_index: torch.Tensor, num_nodes: int, repeats: int)