import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch


class GinkgoCut(nn.Module):
    """Ginkgo cut feature."""

    def __init__(self, pt_min_sqrt: float) -> None:
        """Initialize the Ginkgo cut feature."""
        super().__init__()
        self.cut_threshold = pt_min_sqrt**2

    @property
    def d(self) -> int:
        """Return dimension of Ginkgo cut feature.

        :return: Dimension of Ginkgo cut feature.
        """
        return 5

    def forward(self, data: Batch) -> Float[Tensor, "n_nodes 1"]:
        """Compute the Ginkgo cut feature. The cut property decides if a cut is plausible given the
        parameters of the simulator.

        :param data: PyG batch object.
        :return: Ginkgo cut feature.
        """
        assert "x" in data, "Node features are required for Ginkgo cut feature."
        cut = data.x[:, 0] ** 2 - torch.norm(data.x[:, 1:], dim=1) ** 2
        cut_higher_than_threshold = (cut > self.cut_threshold) | ~data.parent_mask

        cut_sqrt = torch.where(cut >= 0, torch.sqrt(cut), torch.zeros_like(cut))
        cut_sqrt_diff = cut_sqrt.clone()
        p_parent = data.edge_attr
        index = data.edge_index[1, :]
        source = -cut_sqrt[data.edge_index[0, :]] * p_parent
        cut_sqrt_diff.index_add_(0, index, source)

        cut_diff_larger_than_zero = cut_sqrt_diff > 0
        return torch.stack(
            [
                cut,
                cut_higher_than_threshold,
                cut_sqrt,
                cut_sqrt_diff,
                cut_diff_larger_than_zero,
            ]
        ).T
