import torch
import torch.nn as nn
from torch import Tensor
import math
from .groups import PLATONIC_GROUPS


class APE(nn.Module):
    """
    Absolute Position Encoding using random sinusoidal features.

    This module maps spatial coordinates to a high-dimensional embedding vector
    using a fixed, randomly initialized set of sinusoidal basis functions. This is a
    common technique for representing positions in transformer-like models, often
    referred to as Random Fourier Features.

    The embedding is generated by computing `cos` and `sin` of the dot product
    between input positions and a set of random frequency vectors.

    Args:
        embed_dim (int): The total dimension of the output embedding. Must be an even number.
        freq_sigma (float): Standard deviation for sampling the initial random frequencies.
                           This controls the "wavelengths" of the positional signals.
        spatial_dims (int): The number of spatial dimensions of the input positions
                            (e.g., 3 for x, y, z coordinates).
        learned_freqs (bool): If True, the frequencies become learnable parameters instead
                              of a fixed buffer.
    """
    def __init__(
        self,
        embed_dim: int,
        freq_sigma: float,
        spatial_dims: int = 3,
        learned_freqs: bool = False,
    ):
        super().__init__()

        if embed_dim % 2 != 0:
            raise ValueError(f"embed_dim must be an even number, but got {embed_dim}.")

        self.embed_dim = embed_dim
        self.freq_sigma = freq_sigma
        self.spatial_dims = spatial_dims
        self.num_frequencies = self.embed_dim // 2

        # --- Frequency Initialization ---
        # Frequencies are drawn from a Gaussian distribution and scaled by sigma.
        # Shape: [spatial_dims, num_frequencies]
        freqs = torch.randn(self.spatial_dims, self.num_frequencies) * self.freq_sigma

        if learned_freqs:
            self.register_parameter("freqs", nn.Parameter(freqs))
        else:
            self.register_buffer("freqs", freqs)

    def forward(self, pos: Tensor) -> Tensor:
        """
        Compute the sinusoidal position embeddings for the given coordinates.

        Args:
            pos (Tensor): Position tensor of shape (..., spatial_dims). The leading
                          dimensions '...' can be anything (e.g., batch_size, num_atoms).

        Returns:
            Tensor: The calculated position embedding of shape (..., embed_dim).
        """
        # 1. --- Project positions onto frequency vectors ---
        # This computes the dot product between each position vector and each frequency vector.
        # pos shape: (..., d) | freqs shape: (d, f) -> angles shape: (..., f)
        # where d = spatial_dims and f = num_frequencies.
        angles = torch.einsum('...d,df->...f', pos, self.freqs)

        # 2. --- Compute sinusoidal features ---
        # The embedding is formed by concatenating the cosine and sine of the angles.
        # cos_angles shape: (..., f) | sin_angles shape: (..., f)
        cos_angles = torch.cos(angles)
        sin_angles = torch.sin(angles)

        # embedding shape: (..., 2*f) which is (..., embed_dim)
        embedding = torch.cat([cos_angles, sin_angles], dim=-1)

        return embedding


class PlatonicAPE(nn.Module):
    """
    Group-Equivariant Absolute Position Encoding (PlatonicAPE).

    This module extends Absolute Position Encoding (APE) using sinusoidal features
    to be equivariant to the discrete rotational symmetry groups of the Platonic
    solids (T, O, I).

    The principle is to generate G distinct position embeddings, where G is the size
    of the symmetry group. It starts with a single set of base random frequencies.
    Each of the G rotation matrices in the group is then applied to these base
    frequencies, creating G unique sets of "rotated" frequencies. A standard APE
    is computed for each set, and the resulting G embedding vectors are concatenated
    to form the final output.

    Args:
        embed_dim (int): The total dimension of the output embedding. Must be divisible
                         by the group size G, and the result (embed_dim/G) must be even.
        solid_name (str): The name of the Platonic solid ('tetrahedron', 'octahedron',
                          'icosahedron') to define the symmetry group.
        freq_sigma (float): Standard deviation for sampling the initial random frequencies.
        spatial_dims (int): The number of spatial dimensions of the input positions (e.g., 3).
        learned_freqs (bool): If True, the base frequencies become learnable parameters.
    """
    def __init__(
        self,
        embed_dim: int,
        solid_name: str,
        freq_sigma: float,
        spatial_dims: int = 3,
        learned_freqs: bool = False,
    ):
        super().__init__()

        # --- Group Setup ---
        try:
            self.group = PLATONIC_GROUPS[solid_name.lower()]
        except KeyError:
            raise ValueError(f"Unknown solid '{solid_name}'. Available options are {list(PLATONIC_GROUPS.keys())}")
        self.num_G = self.group.G
        self.register_buffer('group_elements', self.group.elements.to(torch.float32))

        # --- Dimension Setup ---
        self.embed_dim = embed_dim
        self.spatial_dims = spatial_dims

        if self.embed_dim % self.num_G != 0:
            raise ValueError(f"embed_dim ({self.embed_dim}) must be divisible by group size G ({self.num_G}).")
        self.embed_dim_g = self.embed_dim // self.num_G

        if self.embed_dim_g % 2 != 0:
            raise ValueError(f"embed_dim per group element ({self.embed_dim_g}) must be an even number.")
        self.num_frequencies_g = self.embed_dim_g // 2

        # --- Base Frequency Initialization ---
        # Frequencies are defined once and then rotated by the group elements.
        # Shape: [spatial_dims, num_frequencies_per_group]
        freqs = torch.randn(self.spatial_dims, self.num_frequencies_g) * freq_sigma
        if learned_freqs:
            self.register_parameter("freqs", nn.Parameter(freqs))
        else:
            self.register_buffer("freqs", freqs)

    def forward(self, pos: Tensor) -> Tensor:
        """
        Compute the group-equivariant sinusoidal position embeddings.

        Args:
            pos (Tensor): Position tensor of shape (..., spatial_dims).

        Returns:
            Tensor: The calculated position embedding of shape (..., embed_dim).
        """
        # 1. --- Rotate the base frequencies using the group elements ---
        # group_elements shape: (g, d, d) | freqs shape: (d, f_g)
        # -> freqs_rotated shape: (g, d, f_g)
        # where g=num_G, d=spatial_dims, f_g=num_frequencies_g
        freqs_rotated = torch.einsum('gij, jf -> gif', self.group_elements, self.freqs)

        # 2. --- Project positions onto all sets of rotated frequencies ---
        # This computes the dot product for each of the G frequency sets.
        # pos shape: (...d) | freqs_rotated shape: (g, d, f_g)
        # -> angles shape: (...g, f_g)
        angles = torch.einsum('...d, gdf -> ...gf', pos, freqs_rotated)

        # 3. --- Compute sinusoidal features for each group element ---
        cos_angles = torch.cos(angles)
        sin_angles = torch.sin(angles)

        # Concatenate sin and cos features for each group element's embedding
        # embedding_grouped shape: (...g, 2*f_g) which is (..., num_G, embed_dim_g)
        embedding_grouped = torch.cat([cos_angles, sin_angles], dim=-1)

        # 4. --- Flatten the group-wise embeddings into a single vector ---
        # Reshape from (..., num_G, embed_dim_g) to (..., num_G * embed_dim_g)
        # Final shape: (..., embed_dim)
        *leading_dims, _, _ = embedding_grouped.shape
        embedding = embedding_grouped.view(*leading_dims, self.embed_dim)

        return embedding