"""
Regularization modules for symmetry enforcement.

This module provides regularization penalty terms that encourage samples
to lie within a canonical region of the symmetry orbit. These penalties
help stabilize training by preventing samples from drifting too far
from the canonical cell.

The penalty function uses a sigmoid-based soft constraint:
    penalty(lamb) = A * sigmoid(B * lamb) * (lamb > 0)
where A controls the penalty magnitude and B controls the gradient.
"""

from typing import Tuple
import torch
import torch.nn as nn
from math import pi as pi
import logging
import math

logger = logging.getLogger("SESaMo")


class Regularization(nn.Module):
    """
    Base class for regularization modules.
    
    Regularization modules compute penalty terms based on how far samples
    are from the canonical region, using a smooth sigmoid-based penalty.
    
    Parameters
    ----------
    A : float, optional
        Penalty magnitude. Default is 1000.
    B : float, optional
        Penalty gradient (steepness). Default is 100.
        
    Attributes
    ----------
    name : str
        Identifier for the regularization type.
    penalty_size : float
        Penalty magnitude parameter A.
    penalty_gradient : float
        Penalty gradient parameter B.
    """
    
    name = "base"
    
    def __init__(self, A: float = 1000, B: float = 100, **kwargs):
        """
        Initialize the regularization module.
        
        Parameters
        ----------
        A : float, optional
            Penalty magnitude. Default is 1000.
        B : float, optional
            Penalty gradient. Default is 100.
        """
        super().__init__()
        self.penalty_size = A
        self.penalty_gradient = B
        logger.info(f"Initialized {self.__class__.__name__} regularization with A={A}, B={B}")

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute the regularization penalty.
        
        Parameters
        ----------
        x : torch.Tensor
            Input samples to regularize.
            
        Returns
        -------
        torch.Tensor
            Penalty values for each sample. Shape: (batch,).
        """
        return self.regularization(x)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, float]:
        """
        Forward pass (identity transformation with zero log-det).
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor.
            
        Returns
        -------
        tuple
            (x, 0.0) - unchanged input and zero log-determinant.
        """
        return x, 0.
    
    def reverse(self, x: torch.Tensor) -> Tuple[torch.Tensor, float]:
        """
        Reverse pass (identity transformation with zero log-det).
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor.
            
        Returns
        -------
        tuple
            (x, 0.0) - unchanged input and zero log-determinant.
        """
        return x, 0.

    def penalty_term(self, lamb: torch.Tensor) -> torch.Tensor:
        """
        Compute the smooth penalty function.
        
        Parameters
        ----------
        lamb : torch.Tensor
            Distance from canonical region boundary (negative inside, positive outside).
            
        Returns
        -------
        torch.Tensor
            Penalty values (zero inside canonical region, positive outside).
        """
        return self.penalty_size * torch.sigmoid(self.penalty_gradient * lamb) * (lamb > 0).float()

    def regularization(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute the regularization penalty for given samples.
        
        Must be implemented by subclasses.
        
        Parameters
        ----------
        x : torch.Tensor
            Input samples.
            
        Returns
        -------
        torch.Tensor
            Penalty values. Shape: (batch,).
        """
        raise NotImplementedError("Regularization method not implemented")


class Z2Regularization(Regularization):
    """
    Z2 symmetry regularization.
    
    Encourages samples to have positive sum (canonical region for Z2),
    penalizing configurations with negative total sum.
    """
    
    name = "z2_reg"
    
    def regularization(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute Z2 regularization penalty.
        
        Parameters
        ----------
        x : torch.Tensor
            Input samples of shape (batch, *spatial_dims).
            
        Returns
        -------
        torch.Tensor
            Penalty values. Shape: (batch,).
        """
        return self.penalty_term(-x.reshape(x.shape[0], -1).sum(-1))



class ZNRegularization(Regularization):
    """
    Z_N symmetry regularization.
    
    Encourages samples to lie within a pizza-slice shaped canonical
    region with angular width 2*pi/N.
    
    Parameters
    ----------
    n : int
        Order of the Z_N group.
    offset : bool, optional
        Whether to offset the canonical region. Default is False.
    A : float, optional
        Penalty magnitude. Default is 1000.
    B : float, optional
        Penalty gradient. Default is 100.
    """
    
    name = "zn_reg"
    
    def __init__(self, n: int, offset: bool = False, A: float = 1000, B: float = 100, **kwargs):
        """
        Initialize Z_N regularization.
        
        Parameters
        ----------
        n : int
            Order of the Z_N group.
        offset : bool, optional
            Whether to use offset canonical region. Default is False.
        A : float, optional
            Penalty magnitude. Default is 1000.
        B : float, optional
            Penalty gradient. Default is 100.
        """
        super().__init__(A=A, B=B, **kwargs)
        self.n = n
        self.offset = offset
        logger.info(f"Initialized Z{n} regularization with offset={offset}")
    
    def regularization(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute Z_N regularization penalty.
        
        Parameters
        ----------
        x : torch.Tensor
            Input samples of shape (batch, N_t, 2).
            
        Returns
        -------
        torch.Tensor
            Penalty values. Shape: (batch,).
            
        Raises
        ------
        ValueError
            If input shape is not (batch, N_t, 2).
        """
        if x.dim() != 3 or x.shape[2] != 2:
            raise ValueError(f"ZNRegularization expects input of shape (batch, N_t, 2), but got {x.shape}")
        
        # The offset_angle rotates the penalty region
        angle = 2 * pi / self.n
        x_sum = x.sum(axis=1)
        
        if self.offset:
            d1 = x_sum[:,0] - x_sum[:,1]
            d2 = x_sum[:,1]
        else:
            d1 = math.tan(angle/2) * x_sum[:,0] - x_sum[:,1] / (1 + math.tan(angle/2)**2)
            d2 = math.tan(angle/2) * x_sum[:,0] + x_sum[:,1] / (1 + math.tan(angle/2)**2)
        
        return self.penalty_term(-d1) + self.penalty_term(-d2)
    

class Z2powNRegularization(Regularization):
    """
    (Z_2)^N symmetry regularization.
    
    Encourages samples to have positive sum in each spatial dimension,
    penalizing configurations where any spatial site has negative sum.
    """
    
    name = "z2pown_reg"
    
    def regularization(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute Z_2^N regularization penalty.
        
        Parameters
        ----------
        x : torch.Tensor
            Input samples of shape (batch, N_t, N_x).
            
        Returns
        -------
        torch.Tensor
            Penalty values. Shape: (batch,).
            
        Raises
        ------
        ValueError
            If input shape is not (batch, N_t, N_x).
        """
        if x.dim() != 3:
            raise ValueError(f"Z2powNRegularization expects input of shape (batch, N_t, N_x), but got {x.shape}")
        
        z = x.sum(axis=1)
        return self.penalty_term(-z).sum(-1)