"""
Canonicalization layers for symmetry projection.

This module provides deterministic transformations that project samples
to a canonical cell of the symmetry. Unlike stochastic modulation,
these layers apply a deterministic transformation based on the sample's
position, mapping it to a canonical representative.

This approach is an alternative to stochastic modulation that doesn't
require REINFORCE gradients but may have less flexibility.
"""

import torch
import torch.nn as nn
from torch import Tensor
from math import pi as pi
from typing import Tuple
import logging

logger = logging.getLogger("SESaMo")


class Canonicalization(nn.Module):
    """
    Base class for canonicalization layers.
    
    Canonicalization layers project samples to a canonical cell of the
    symmetry using a deterministic transformation. The inverse
    transformation restores samples to their original position.
    
    Attributes
    ----------
    name : str
        Identifier for the canonicalization type.
    forwarded : bool
        Whether forward has been called (for tracking inverse).
    """
    
    name = "base"
    
    def __init__(self, **kwargs):
        """Initialize the canonicalization layer."""
        super().__init__()
        self.forwarded = False
        logger.info(f"Initialized {self.__class__.__name__} symmetry")

    def __call__(self, x: Tensor) -> Tuple[Tensor, float]:
        """
        Apply forward or reverse transform based on state.
        
        Parameters
        ----------
        x : Tensor
            Input tensor.
            
        Returns
        -------
        tuple
            (transformed_x, log_det) where log_det is always 0.
        """
        return self.forward(x)

    def forward(self, x: Tensor) -> Tuple[Tensor, float]:
        """
        Apply the canonicalization transformation.
        
        Parameters
        ----------
        x : Tensor
            Input tensor.
            
        Returns
        -------
        tuple
            (canonicalized_x, 0.0).
        """
        
        if not self.forwarded:
            self.forwarded = True
            return self.transform(x), 0.
        else:
            self.forwarded = False
            return self.inverse_transform(x), 0.
    
    def transform(self, x: Tensor) -> Tensor:
        """
        Apply the forward canonicalization.
        
        Must be implemented by subclasses.
        
        Parameters
        ----------
        x : Tensor
            Input tensor.
            
        Returns
        -------
        Tensor
            Canonicalized tensor.
        """
        raise NotImplementedError("Transform method not implemented")

    def inverse_transform(self, x: Tensor) -> Tensor:
        """
        Apply the inverse canonicalization.
        
        Must be implemented by subclasses.
        
        Parameters
        ----------
        x : Tensor
            Input tensor.
            
        Returns
        -------
        Tensor
            De-canonicalized tensor.
        """
        raise NotImplementedError("Inverse transform method not implemented")


class Z2Canonicalization(Canonicalization):
    """
    Z2 symmetry canonicalization.
    
    Projects samples to have positive total sum by flipping the sign
    if the sum is negative. Stores the original signs for inversion.
    """
    
    name = "z2_canon"
    
    def __init__(self, **kwargs):
        """
        Initialize Z2 canonicalization.
        
        Stores signs from forward pass for use in inverse transformation.
        """
        super().__init__()
        self.signs = None

    def transform(self, x: Tensor) -> Tensor:
        """
        Canonicalize by ensuring positive total sum.
        
        Parameters
        ----------
        x : Tensor
            Input tensor of shape (batch, ...).
            
        Returns
        -------
        Tensor
            Canonicalized tensor with positive total sum.
        """

        # get signs from x
        self.signs = x.view(x.shape[0],-1).sum(-1).sign().view(-1, *[1]*(x.dim()-1))

        return x * self.signs
    
    def inverse_transform(self, x: Tensor) -> Tensor:
        """
        Restore original signs.
        
        Parameters
        ----------
        x : Tensor
            Canonicalized tensor.
            
        Returns
        -------
        Tensor
            Original tensor with restored signs.
        """
        return x * self.signs



class ZNCanonicalization(Canonicalization):
    """
    Z_N symmetry canonicalization.
    
    Projects 2D samples to a canonical pizza slice by rotating to the
    nearest canonical angle. The canonical region has angular width 2*pi/N.
    
    Parameters
    ----------
    n : int, optional
        Order of the Z_N group. Default is 8.
    """
    
    name = "zn_canon"
    
    def __init__(self, n: int = 8, **kwargs):
        """
        Initialize Z_N canonicalization.
        
        Parameters
        ----------
        n : int, optional
            Order of Z_N group. Default is 8.
        """
        super().__init__()
        self.n = n
        logger.info(f"Initialized Z{n} canonicalization")

    def transform(self, x: Tensor) -> Tensor:
        """
        Rotate samples to canonical pizza slice at angle 0.
        
        Parameters
        ----------
        x : Tensor
            Input tensor of shape (batch, N_t, 2).
            
        Returns
        -------
        Tensor
            Rotated tensor in canonical region.
        """
        if x.dim() != 3 or x.shape[2] != 2:
            raise ValueError(f"Input tensor should have shape (batch, Nt, 2) but got {x.shape}")
        x_sum = x.sum(axis=1)

        angle_x = torch.atan2(x_sum[:,1], x_sum[:,0]) # in [-pi, pi]
        self.angle = -(angle_x / (2*pi/self.n)).round() * 2*pi/self.n
        self.angle = self.angle.unsqueeze(1)

        x = torch.stack([x[:,:,0] * torch.cos(self.angle) - x[:,:,1] * torch.sin(self.angle), 
                         x[:,:,0] * torch.sin(self.angle) + x[:,:,1] * torch.cos(self.angle)], dim
                         =2)

        return x
    
    def inverse_transform(self, x: Tensor) -> Tensor:
        """
        Rotate samples back to original angle.
        
        Parameters
        ----------
        x : Tensor
            Canonicalized tensor.
            
        Returns
        -------
        Tensor
            Original tensor with restored angles.
        """
        x = torch.stack([x[:,:,0] * torch.cos(-self.angle) - x[:,:,1] * torch.sin(-self.angle),
                         x[:,:,0] * torch.sin(-self.angle) + x[:,:,1] * torch.cos(-self.angle)], dim=2)
        
        return x
    
    
    

class Z2powNCanonicalization(Canonicalization):
    """
    (Z_2)^N symmetry canonicalization.
    
    Projects samples to the first quadrant by flipping signs
    independently in each dimension to ensure positive sums.
    """
    
    name = "z2pown_canon"
    
    def __init__(self, **kwargs):
        """
        Initialize  canonicalization.
        
        Stores per-dimension signs for use in inverse transformation.
        """
        super().__init__()
        self.signs = None

    def transform(self, x: Tensor) -> Tensor:
        """
        Canonicalize by ensuring positive sum in each dimension.
        
        Parameters
        ----------
        x : Tensor
            Input tensor of shape (batch, N_t, N_x).
            
        Returns
        -------
        Tensor
            Canonicalized tensor in first quadrant.
            
        Raises
        ------
        ValueError
            If input shape is not (batch, N_t, N_x).
        """
        if x.dim() != 3:
            raise ValueError(f"Input tensor should have shape (batch, Nt, N_x) but got {x.shape}")

        # get signs from x
        self.signs = x.sum(axis=1).sign().unsqueeze(1)

        return x * self.signs
    
    def inverse_transform(self, x: Tensor) -> Tensor:
        """
        Restore original per-dimension signs.
        
        Parameters
        ----------
        x : Tensor
            Canonicalized tensor.
            
        Returns
        -------
        Tensor
            Original tensor with restored signs.
        """
        return x * self.signs