import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch


class GinkgoAngle(nn.Module):
    """Ginkgo angle feature."""

    def __init__(self) -> None:
        """Initialize the Ginkgo angle feature."""
        super().__init__()

    @property
    def d(self) -> int:
        """Return dimension of Ginkgo angle feature.

        :return: Dimension of Ginkgo angle feature.
        """
        return 8

    def forward(self, data: Batch) -> Float[Tensor, "n_nodes 1"]:
        """Compute the Ginkgo angle feature.

        :param data: PyG batch object.
        :return: Ginkgo cut feature.
        """
        E = data.x[:, 0]
        p = data.x[:, 1:]
        p_x = p[:, 0]
        p_y = p[:, 1]
        p_z = p[:, 2]

        # see ginkgo implementation
        p_norm = torch.norm(p, p=2, dim=1, keepdim=True)
        p_normalized = p / p_norm
        p_norm = torch.norm(p, p=2, dim=1)

        # according to https://s3.cern.ch/inspire-prod-files-6/6904a3576c84c5d1f05a1f171cac3695
        p_T = torch.norm(p[:, :2], p=2, dim=1)
        phi = torch.atan(p_x / p_y)
        y = 0.5 * torch.log(((E + p_z) / (E - p_z)).clamp(min=1e-3))
        theta = torch.acos(p_z / p_norm)
        return torch.stack(
            [
                p_norm,
                p_normalized[:, 0],
                p_normalized[:, 1],
                p_normalized[:, 2],
                p_T,
                phi,
                y,
                theta,
            ]
        ).T
