import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.sparse_utils import make_sparse_matrix, pad_values_of_sparse_matrix
from src.utils.adjacency_utils import get_adj, get_adj_row_indices


class IndexSelection(nn.Module):
    """Index Selection."""

    def __init__(self, n, k) -> None:
        """Initialize Adjacency Matrix Features."""
        super().__init__()
        self.n = n
        self.k = k + 1

    @property
    def d(self) -> int:
        """Return dimension of Adjacency Matrix Features.

        :return: Dimension of Ginkgo Features.
        """
        return 2

    def forward(
        self, data: Batch, fully_connected_index: Float[Tensor, "2 n_fully_connected_edges"]
    ) -> Float[Tensor, "n_edges 2"]:
        """Compute the Adjacency Matrix Features.

        :param data: PyG batch object.
        :return: Adjacency Matrix Features.
        """
        # pick top k edges according to edge_attr
        adj = get_adj(data, self.n)
        batch_index, node_index = get_adj_row_indices(data)
        adj_rows = adj[batch_index, node_index, :]
        top_k_edge_attr, _ = torch.topk(adj_rows, self.k, dim=1)
        kth_highest_edge_attr = top_k_edge_attr[:, -1]
        kth_highest_edge_attr = kth_highest_edge_attr[fully_connected_index[0]]
        adj = make_sparse_matrix(data.edge_index, data.edge_attr)
        adj_feature = pad_values_of_sparse_matrix(adj, fully_connected_index)
        adj_selector = (adj_feature > kth_highest_edge_attr)

        # pick top k edges according to transpose edge_attr
        adj = get_adj(data, self.n).transpose(1, 2)
        batch_index, node_index = get_adj_row_indices(data)
        adj_rows = adj[batch_index, node_index, :]
        top_k_edge_attr, _ = torch.topk(adj_rows, self.k, dim=1)
        kth_highest_edge_attr = top_k_edge_attr[:, -1]
        kth_highest_edge_attr = kth_highest_edge_attr[fully_connected_index[0]]
        adj = make_sparse_matrix(data.edge_index, data.edge_attr)
        adj_feature = pad_values_of_sparse_matrix(adj, fully_connected_index)
        adj_transpose_selector = (adj_feature > kth_highest_edge_attr)

        # adj dot_prod based selection
        adj = make_sparse_matrix(data.edge_index, data.edge_attr)
        val = pad_values_of_sparse_matrix(adj, fully_connected_index)
        adj = make_sparse_matrix(fully_connected_index, val)
        adj_dot_prod = torch.sparse.mm(adj, adj.T)
        adj_dot_prod_feature = adj_dot_prod.coalesce().values()
        adj_dot_prod = adj_dot_prod.coalesce().to_dense()
        kth_highest_adj_dot_prod, _ = torch.topk(adj_dot_prod, self.k, dim=1)
        kth_highest_adj_dot_prod = kth_highest_adj_dot_prod[:, -1]
        kth_highest_adj_dot_prod = kth_highest_adj_dot_prod[fully_connected_index[0]]
        adj_dot_prod_selector = (adj_dot_prod_feature > kth_highest_adj_dot_prod)

        # Ancestor dot_prod based selection
        p_anc = make_sparse_matrix(data.edge_index, data.p_anc)
        val = pad_values_of_sparse_matrix(p_anc, fully_connected_index)
        p_anc = make_sparse_matrix(fully_connected_index, val)
        p_anc_dot_prod = torch.sparse.mm(p_anc, p_anc.T)
        p_anc_dot_prod_feature = p_anc_dot_prod.coalesce().values()
        p_anc_dot_prod = p_anc_dot_prod.coalesce().to_dense()
        kth_highest_p_anc_dot_prod, _ = torch.topk(p_anc_dot_prod, self.k, dim=1)
        kth_highest_p_anc_dot_prod = kth_highest_p_anc_dot_prod[:, -1]
        kth_highest_p_anc_dot_prod = kth_highest_p_anc_dot_prod[fully_connected_index[0]]
        p_anc_selector = (p_anc_dot_prod_feature > kth_highest_p_anc_dot_prod)

        # Angle based selection
        p = data.x[:, 1:]
        p_norm = torch.norm(p, p=2, dim=1, keepdim=True)
        p_normalized = p / p_norm
        p_cos = torch.matmul(p_normalized, p_normalized.T)
        p_angle = 1 - torch.acos(torch.clamp(p_cos, min=-1+1e-8, max=1-1e-8))
        p_angle = p_angle[fully_connected_index[0], fully_connected_index[1]]
        p_angle_matrix = make_sparse_matrix(fully_connected_index, p_angle).to_dense()
        kth_highest_p_angle, _ = torch.topk(p_angle_matrix, self.k, dim=1)
        kth_highest_p_angle = kth_highest_p_angle[:, -1]
        kth_highest_p_angle = kth_highest_p_angle[fully_connected_index[0]]
        p_angle_selector = (p_angle > kth_highest_p_angle)

        selector = torch.stack(
            [
                adj_selector,
                adj_transpose_selector,
                adj_dot_prod_selector,
                p_anc_selector,
                p_angle_selector,
            ],
        ).any(dim=0)
        return selector
