import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch


class GinkgoCut(nn.Module):
    """Ginkgo cut feature."""

    def __init__(self, pt_min_sqrt: float) -> None:
        """Initialize the Ginkgo cut feature."""
        super().__init__()
        self.cut_threshold = pt_min_sqrt**2

    @property
    def d(self) -> int:
        """Return dimension of Ginkgo cut feature.

        :return: Dimension of Ginkgo cut feature.
        """
        return 7

    def forward(
        self,
        data: Batch,
        fully_connected_index: Float[Tensor, "2 n_fully_connected_edges"],
    ) -> Float[Tensor, "n_edges 2"]:
        """Compute the Adjacency Matrix Features.

        :param data: PyG batch object.
        :return: Adjacency Matrix Features.
        """
        assert "x" in data, "Node features are required for Ginkgo cut feature."
        idx = fully_connected_index
        x_parent = data.x.unsqueeze(0) + data.x.unsqueeze(1)

        # parent invariant mass
        t_parent = x_parent[:, :, 0] ** 2 - torch.norm(x_parent[:, :, 1:], dim=2) ** 2
        t_parent = t_parent.clamp(min=0)
        t_parent_threshold = (t_parent > self.cut_threshold).float()

        t_child = data.x[:, 0] ** 2 - torch.norm(data.x[:, 1:], dim=1) ** 2
        t_child = t_child.clamp(min=0)
        t_left = t_child.view(-1, 1).expand(-1, t_child.size(0))
        t_right = t_child.view(1, -1).expand(t_child.size(0), -1)

        diff = torch.sqrt(t_parent) - torch.sqrt(t_left) - torch.sqrt(t_right)
        diff_left = (torch.sqrt(t_parent) - torch.sqrt(t_left)) ** 2
        diff_right = (torch.sqrt(t_parent) - torch.sqrt(t_right)) ** 2
        return torch.stack(
            [
                t_parent[idx[0], idx[1]],
                t_parent_threshold[idx[0], idx[1]],
                t_right[idx[0], idx[1]],
                t_left[idx[0], idx[1]],
                diff[idx[0], idx[1]],
                diff_left[idx[0], idx[1]],
                diff_right[idx[0], idx[1]],
            ]
        ).T
