from typing import Tuple

import torch
from jaxtyping import Float
from torch import Tensor, sparse
from torch_geometric.data import Batch

from src.utils.sparse_utils import make_sparse_matrix


def get_fully_connected_index(data: Batch) -> Float[Tensor, "2 n_full"]:
    """Compute the fully connected index. Makes use of the fact that the product of the adjacency
    matrix with its transpose gives the fully connected index.

    :param data: PyG batch object.
    :return: Fully connected index.
    """
    edge_index, edge_attr = add_root_self_loop(data)
    sparse_matrix = make_sparse_matrix(edge_index, edge_attr)
    return sparse.mm(sparse_matrix, sparse_matrix.T).indices()


def get_gnn_index(data: Batch) -> Float[Tensor, "2 n_full"]:
    """Compute the gnn index consisting of the edges and their reverse.

    :param data: PyG batch object.
    :return: Fully connected index.
    """
    sparse_matrix = make_sparse_matrix(data.edge_index, data.edge_attr)
    return (sparse_matrix + sparse_matrix.T).coalesce().indices()


def add_root_self_loop(data: Batch) -> Tuple[Float[Tensor, "2 n_edges"], Float[Tensor, "n_edges"]]:
    """Extend the edge index and edge attribute with self loops added to the root nodes.

    :param data: PyG batch object.
    :return: Edge index and edge attribute with self loops added to the root nodes.
    """
    root = data.ptr[1:] - 1
    n_roots = len(root)

    root_self_loop_index = torch.stack([root, root])
    root_self_loop_value = torch.zeros(n_roots, device=root.device)

    edge_index = torch.cat([data.edge_index, root_self_loop_index], dim=1)
    edge_attr = torch.cat([data.edge_attr, root_self_loop_value])
    return edge_index, edge_attr
