from jaxtyping import Float
import torch
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.sparse_utils import sparse_mode, make_sparse_matrix, pad_values_of_sparse_matrix


class Mode(nn.Module):
    """Probabilistic Mode Features."""

    def __init__(self) -> None:
        """Initialize Mode Features."""
        super().__init__()

    @property
    def d(self) -> int:
        """Return dimension of Mode Features.

        :return: Dimension of Mode Features.
        """
        return 7

    def forward(
        self, data: Batch, fully_connected_index: Float[Tensor, "2 n_fully_connected_edges"]
    ) -> Float[Tensor, "n_edges 7"]:
        """Compute the Mode Features.

        :param data: PyG batch object.
        :param fully_connected_index: Index of fully connected graph.
        :return: Mode Features.
        """
        mode = sparse_mode(data.edge_index, data.edge_attr)
        node = torch.arange(data.num_nodes, device=mode.device)
        mask = (mode != 0)
        index = torch.stack([node, mode], dim=0)[:, mask]
        values = torch.ones_like(mode[mask], dtype=torch.float)

        mode_matrix = make_sparse_matrix(index, values)
        mode_feature = pad_values_of_sparse_matrix(mode_matrix, fully_connected_index)
        mode_transposed_feature = pad_values_of_sparse_matrix(mode_matrix.T, fully_connected_index)

        mode_matrix_dot_prod = torch.sparse.mm(mode_matrix, mode_matrix.T).coalesce()
        mode_dot_prod_feature = pad_values_of_sparse_matrix(mode_matrix_dot_prod, fully_connected_index)

        mode_counts = torch.bincount(mode, minlength=data.num_nodes)
        mode_counts[0] = 0
        mode_counts_start = mode_counts[fully_connected_index[0]]
        mode_counts_end = mode_counts[fully_connected_index[1]]

        mode_counts_of_mode = mode_counts[mode]
        mode_counts_of_mode_start = mode_counts_of_mode[fully_connected_index[0]]
        mode_counts_of_mode_end = mode_counts_of_mode[fully_connected_index[1]]

        return torch.stack([mode_feature, mode_transposed_feature,
                            mode_dot_prod_feature, mode_counts_start,
                            mode_counts_end, mode_counts_of_mode_start,
                            mode_counts_of_mode_end]).T
