import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.sparse_utils import make_sparse_matrix, pad_values_of_sparse_matrix


class Ginkgo(nn.Module):
    """Probabilistic Ancestors Features."""

    def __init__(self) -> None:
        """Initialize Probabilistic Ancestors Features."""
        super().__init__()

    @property
    def d(self) -> int:
        """Return dimension of Probabilistic Ancestors Features.

        :return: Dimension of Probabilistic Ancestors Features.
        """
        return 6

    def forward(
        self, data: Batch, fully_connected_index: Float[Tensor, "2 n_fully_connected_edges"]
    ) -> Float[Tensor, "n_edges 2"]:
        """Compute the Adjacency Matrix Features.

        :param data: PyG batch object.
        :return: Adjacency Matrix Features.
        """
        E = data.x[:, 0]
        p = data.x[:, 1:]
        p_x =p[:, 0]
        p_y = p[:, 1]
        p_z = p[:, 2]

        # see ginkgo implementation
        p_norm = torch.norm(p, p=2, dim=1, keepdim=True)
        p_normalized = p / p_norm
        p_cos = torch.matmul(p_normalized, p_normalized.T)
        p_angle = torch.acos(torch.clamp(p_cos, min=-1+1e-8, max=1-1e-8))

        # according to https://s3.cern.ch/inspire-prod-files-6/6904a3576c84c5d1f05a1f171cac3695
        phi = torch.atan(p_x / p_y)
        y = 0.5 * torch.log(((E + p_z) / (E - p_z)).clamp(min=1e-3))

        delta_y = y.unsqueeze(1) - y.unsqueeze(0)
        delta_y_squared = delta_y**2

        delta_phi = phi.unsqueeze(1) - phi.unsqueeze(0)
        delta_phi_squared = delta_phi**2

        p_cos = p_cos[fully_connected_index[0], fully_connected_index[1]]
        p_angle = p_angle[fully_connected_index[0], fully_connected_index[1]]
        delta_y = delta_y[fully_connected_index[0], fully_connected_index[1]]
        delta_y_squared = delta_y_squared[fully_connected_index[0], fully_connected_index[1]]
        delta_phi = delta_phi[fully_connected_index[0], fully_connected_index[1]]
        delta_phi_squared = delta_phi_squared[fully_connected_index[0], fully_connected_index[1]]
        sum_ = torch.stack([p_cos,
                            p_angle,
                            delta_y,
                            delta_y_squared,
                            delta_phi,
                            delta_phi_squared], dim=1).isnan().sum()
        if sum_ > 0:
            print("nan in gingko")
        return torch.stack([p_cos,
                            p_angle,
                            delta_y,
                            delta_y_squared,
                            delta_phi,
                            delta_phi_squared], dim=1)
