"""
Stochastic modulation layers for symmetry-enforcing normalizing flows.

This module provides stochastic transformation layers that enforce various
symmetries in the output distribution of normalizing flows. These include:

- Z2Modulation: Discrete sign flip symmetry
- ZNModulation: Discrete N-fold rotation symmetry
- U1Modulation: Continuous U(1) rotation symmetry
- HubbardModulation: Combined symmetries for the Hubbard model

Each modulation layer samples from a stochastic transformation conditioned
on the input, enabling REINFORCE-style gradient estimation.
"""

from typing import Tuple
import torch
import torch.nn as nn
import math
import omegaconf
import nflows.transforms as transforms

import logging

logger = logging.getLogger("SESaMo")


class StochasticModulation(nn.Module):
    """
    Base class for stochastic modulation layers.
    
    Stochastic modulation layers apply random symmetry transformations
    to enforce symmetries in the output distribution. They return both
    the transformed samples and the log probability of the transformation,
    enabling REINFORCE gradient estimation.
    
    Attributes
    ----------
    name : str
        Identifier for the modulation type.
    """
    
    name = "base"
    
    def __init__(self, **kwargs):
        """Initialize the stochastic modulation layer."""
        super(StochasticModulation, self).__init__()
        
    def __call__(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply the stochastic modulation.
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor.
            
        Returns
        -------
        tuple
            (transformed_x, log_modprob) where log_modprob is the log
            probability of the applied transformation.
        """
        return self.transform(x)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass with sign-flipped log probability.
        
        The sign flip compensates for the log-determinant subtraction
        in the flow's log probability computation, so that:
        log_prob = prior_log_prob + log_modprob - log_det1 - ... - log_detN
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor.
            
        Returns
        -------
        tuple
            (transformed_x, -log_modprob).
        """
        x, self.log_modprob = self.transform(x)
        return x, -self.log_modprob

    def reverse(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Reverse transformation (inverse of forward).
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor.
            
        Returns
        -------
        tuple
            (inverse_transformed_x, -log_modprob).
        """
        x, self.log_modprob = self.inverse_transform(x)
        return x, -self.log_modprob
    
    def transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply the forward transformation.
        
        Must be implemented by subclasses.
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor.
            
        Returns
        -------
        tuple
            (transformed_x, log_modprob).
        """
        raise NotImplementedError
    
    def inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply the inverse transformation.
        
        Must be implemented by subclasses.
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor.
            
        Returns
        -------
        tuple
            (inverse_transformed_x, log_modprob).
        """
        raise NotImplementedError



# ==============================
# DISCRETE SYMMETRIES
# ==============================


class Z2Modulation(StochasticModulation):
    """
    Z2 symmetry modulation layer.
    
    Randomly flips the sign of the entire configuration with 50% probability,
    enforcing Z2 (parity) symmetry in the output distribution.
    
    The transformation is: x -> +/- x with equal probability.
    """
    
    name = "z2_stochmod"
    
    def __init__(self, **kwargs):
        """Initialize the Z2 modulation layer."""
        super(Z2Modulation, self).__init__()
        logger.info("Initialized Z2 Stochastic Modulation")

    def transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply random sign flip to enforce Z2 symmetry.
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (batch, ...).
            
        Returns
        -------
        tuple
            (transformed_x, log_modprob) where log_modprob = 0
            since both signs have equal probability.
        """
        x_shape = x.shape
        N = x.shape[0]

        # sample from bernoulli distribution to flip the sign randomly
        bernoulli = torch.distributions.Bernoulli(torch.tensor(0.5, device=x.device, dtype=x.dtype))
        u = bernoulli.sample((N,)).to(x.device) # either 0 or 1 with 50% probability
        random_sign = -2 * u + 1 # either 1 (u=0) or -1 (u=1)

        # apply signs
        x = (x.view(N, -1) * random_sign.unsqueeze(-1)).view(x_shape)

        # compute log of modulation probability
        log_modprob = math.log(0.5)

        return x, log_modprob
    


class BrokenZ2Modulation(StochasticModulation):
    """
    Broken Z2 symmetry modulation layer.
    
    Flips the sign with a learnable probability, allowing the model to
    learn asymmetric distributions when the Z2 symmetry is explicitly broken.
    
    The flip probability is p = exp(b) where b <= 0 is a learnable parameter.
    When b = log(0.5), the symmetry is exact.
    
    Parameters
    ----------
    init_breaking : float, optional
        Initial value for the breaking parameter (must be <= 0).
        Default is log(0.5) for exact symmetry.
    """
    
    name = "broken_z2_stochmod"
    
    def __init__(self, init_breaking: float = math.log(0.5), **kwargs):
        """
        Initialize the broken Z2 modulation layer.
        
        Parameters
        ----------
        init_breaking : float, optional
            Initial breaking parameter (must be <= 0). Default is log(0.5).
            
        Raises
        ------
        ValueError
            If init_breaking > 0.
        """
        if init_breaking > 0:
            raise ValueError(f"init_breaking should be smaller or equal to 0, got {init_breaking}") 

        super(BrokenZ2Modulation, self).__init__()
        self.breaking = nn.Parameter(torch.tensor(init_breaking, dtype=torch.float32), requires_grad=True)

        logger.info(f"Initialized Broken Z2 Stochastic Modulation with breaking parameter {init_breaking:.3f}")

    def transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply sign flip with learnable probability.
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (batch, ...).
            
        Returns
        -------
        tuple
            (transformed_x, log_modprob).
        """
        x_shape = x.shape
        N = x.shape[0]
        
        # the breaking parameter should be smaller or equal to 0
        self.breaking.data.clamp_max_(0)

        # sample from bernoulli distribution to randomly flip the sign
        dist = torch.distributions.Bernoulli(torch.exp(self.breaking))
        u = dist.sample((N,)).to(x.device) # either 0 or 1 with probability p = e^b
        random_sign = -2 * u + 1 # either 1 (u=0) or -1 (u=1)

        # apply signs
        x = (x.view(N, -1) * random_sign.unsqueeze(-1)).view(x_shape)

        # compute log of modulation probability
        log_modprob = dist.log_prob(u)

        return x, log_modprob



class ZNModulation(StochasticModulation):
    """
    Z_N symmetry modulation layer.
    
    Applies a random discrete rotation from the N-fold rotation group,
    enforcing Z_N symmetry in 2D. The rotation angle is 2*pi*k/N for
    k uniformly sampled from {0, 1, ..., N-1}.
    
    Parameters
    ----------
    n : int, optional
        Order of the symmetry group (number of rotations). Default is 8.
    """
    
    name = "zn_stochmod"
    
    def __init__(self, n: int = 8, **kwargs):
        """
        Initialize the Z_N modulation layer.
        
        Parameters
        ----------
        n : int, optional
            Order of the Z_N group. Default is 8.
        """
        super(ZNModulation, self).__init__()
        self.n = n
        logger.info(f"Initialized ZN Stochastic Modulation with n={n}")

    def transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply random discrete rotation.
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (batch, N_t, 2) representing 2D vectors.
            
        Returns
        -------
        tuple
            (rotated_x, log_modprob) where log_modprob = log(1/n).
            
        Raises
        ------
        ValueError
            If the input shape is not (batch, N_t, 2).
        """
        if x.dim() != 3 or x.shape[2] != 2:
            raise ValueError(f"Input tensor should have shape (batch, Nt, 2) but got {x.shape}")

        # sample random rotation
        random = torch.randint(0, self.n, (x.shape[0],1), device=x.device)
        angle = 2 * math.pi * random / self.n

        # apply rotation
        x = torch.stack([x[:,:,0] * torch.cos(angle) - x[:,:,1] * torch.sin(angle), 
                         x[:,:,0] * torch.sin(angle) + x[:,:,1] * torch.cos(angle)], dim
                         =2)

        # compute log of modulation probability
        log_modprob = math.log(1/self.n)

        return x, log_modprob
    

class BrokenZNModulation(StochasticModulation):
    """
    Broken Z_N symmetry modulation layer.
    
    Applies discrete rotations from Z_N with learnable probabilities,
    allowing the model to adapt when the symmetry is explicitly broken.
    
    Parameters
    ----------
    n : int, optional
        Order of the symmetry group. Default is 8.
    """
    
    name = "broken_zn_stochmod"
    
    def __init__(self, n: int = 8, **kwargs):
        """
        Initialize the broken Z_N modulation layer.
        
        Parameters
        ----------
        n : int, optional
            Order of the Z_N group. Default is 8.
        """
        super(BrokenZNModulation, self).__init__()
        self.n = n
        
        init_breaking = -math.log(n) * torch.ones(n, dtype=torch.float32)
        self.breaking = nn.Parameter(init_breaking, requires_grad=True)

        logger.info(f"Initialized Broken ZN Stochastic Modulation with n={n}")

    def transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply discrete rotation with learnable probabilities.
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (batch, N_t, 2).
            
        Returns
        -------
        tuple
            (rotated_x, log_modprob).
            
        Raises
        ------
        ValueError
            If the input shape is not (batch, N_t, 2).
        """
        if x.dim() != 3 or x.shape[2] != 2:
            raise ValueError(f"Input tensor should have shape (batch, N_t, 2) but got {x.shape}")

        # sample random rotation
        dist = torch.distributions.Categorical(logits=self.breaking)
        random = dist.sample((x.shape[0],1)).to(x.device)
        angle = 2 * math.pi * random / self.n

        # apply rotation
        x = torch.stack([x[:,:,0] * torch.cos(angle) - x[:,:,1] * torch.sin(angle), 
                         x[:,:,0] * torch.sin(angle) + x[:,:,1] * torch.cos(angle)], dim
                         =2)

        # compute log of modulation probability
        log_modprob = dist.log_prob(random).squeeze(1)

        return x, log_modprob
    
    

class BrokenZ2powNModulation(StochasticModulation):
    """
    Broken (Z_2)^N symmetry modulation layer.
    
    Implements independent learnable sign flips for each spatial dimension,
    resulting in a (Z_2)^N symmetry group with 2^N possible configurations.
    
    Parameters
    ----------
    n_dims : int, optional
        Number of spatial dimensions. Default is 2.
    """
    
    name = "brokenz2pown_stochmod"
    
    def __init__(self, n_dims: int = 2, **kwargs):
        """
        Initialize the broken (Z_2)^N modulation layer.
        
        Parameters
        ----------
        n_dims : int, optional
            Number of dimensions. Default is 2.
        """
        nn.Module.__init__(self)

        # create learnable tensor with 2**n_dims parameters
        log_probs_init = -n_dims*math.log(2) * torch.ones((2**n_dims), dtype=torch.float32, requires_grad=True) 
        self.flip_log_prob = nn.Parameter(log_probs_init, requires_grad=True)

        # for tensorboard logging
        self.breaking_list = None
        if 2**n_dims < 100:
            self.breaking_list = self.flip_log_prob.tolist()

        # get flip directions
        self.flip_directions = self.get_flip_directions(n_dims) # list of 2**n_dims entries
        
        num_groups = len(self.flip_directions)
        flip_directions_tensor = torch.zeros((num_groups, n_dims+1), dtype=torch.bool)

        for i, dims in enumerate(self.flip_directions):
            flip_directions_tensor[i, dims] = True

        # Store permanently:
        self.register_buffer("flip_directions_tensor", flip_directions_tensor)

        logger.info(f"Initialized BrokenFlipSignsZN with {2**n_dims-1} parameter(s)")
        
        
    def get_flip_directions(self, n_dims: int) -> list:
        """
        Generate all possible flip direction combinations.
        
        Parameters
        ----------
        n_dims : int
            Number of dimensions.
            
        Returns
        -------
        list
            List of 2^n_dims entries, each containing dimension indices to flip.
            Example for n_dims=2: [[], [0], [1], [0, 1]]
        """
        flip_directions = [[]]*(2**n_dims)
        for i in range(2**n_dims):
            dims = []
            for j in range(n_dims):
                if i & (1 << j):
                    dims.append(j)
            flip_directions[i] = dims

        if 2**n_dims < 10:
            logger.info(f"Initialized flip directions: {flip_directions}")

        return flip_directions


    def transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply sign flips with learnable probabilities.
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (batch, N_t, N_x).
            
        Returns
        -------
        tuple
            (transformed_x, log_modprob).
            
        Raises
        ------
        ValueError
            If x is not 3-dimensional.
        """
        if x.dim() != 3:
            raise ValueError(f"x should be of shape (Batch, N_t, N_x), got {x.shape}")

        N = x.shape[0]
        Nx = x.shape[2]

        # update brekeaking list
        if len(self.flip_log_prob) < 100:
            self.breaking_list = self.flip_log_prob.tolist()

        # sample from categorical distribution
        dist = torch.distributions.Categorical(logits=self.flip_log_prob)
        random_flip_directions = dist.sample((N,)).to(x.device)

        # Produce: (N, Nx) boolean mask of where to flip
        flip_mask = self.flip_directions_tensor[random_flip_directions]   # (N, Nx-1)

        # Create output
        random_sign = torch.ones((N, 1, Nx), device=x.device)

        # Apply flips where mask is true
        random_sign[:, 0].masked_fill_(flip_mask, -1)

        # apply flips
        x = x * random_sign

        # compute log_modprob
        log_modprob = dist.log_prob(random_flip_directions)

        return x, log_modprob
    

class HubbardModulation(StochasticModulation):
    """
    Stochastic modulation layer for the Hubbard model.
    
    Combines an exact Z_2 symmetry (global sign flip) with a broken
    (Z_2)^(N-1) symmetry for spatial dimensions, suited for the Hubbard
    model in the spin basis.
    
    Parameters
    ----------
    nx : int, optional
        Number of spatial sites. Default is 2.
    """
    
    name = "hubbard_stochmod"
    
    def __init__(self, nx: int = 2, **kwargs):
        """
        Initialize the Hubbard modulation layer.
        
        Parameters
        ----------
        nx : int, optional
            Number of spatial sites. Default is 2.
        """
        super().__init__(**kwargs)

        # init exact z2 flips
        self.z2modulation = Z2Modulation()

        # init broken zn flips, with n = 2^(nx-1)
        self.brokenzpownmodulation = BrokenZ2powNModulation(nx-1)

        # update brekeaking list
        self.breaking_list = self.brokenzpownmodulation.flip_log_prob.tolist()

    def transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply combined Z_2 and (Z_2)^N transformations.
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (batch, N_t, N_x).
            
        Returns
        -------
        tuple
            (transformed_x, log_modprob).
        """
        # update brekeaking list
        self.breaking_list = self.brokenzpownmodulation.flip_log_prob.tolist()

        # apply flips
        x, log_modprob2 = self.brokenzpownmodulation.transform(x)
        x, log_modprob1 = self.z2modulation.transform(x)
    
        return x, log_modprob1 + log_modprob2

    
    
    

# ==============================
# CONTINUOUS SYMMETRIES
# ==============================


class U1Modulation(StochasticModulation):
    """
    U(1) symmetry modulation layer.
    
    Applies a random continuous rotation uniformly sampled from [0, 2*pi),
    enforcing U(1) (continuous rotation) symmetry in 2D. The input is
    1-dimensional and expanded to 2D via the rotation.
    """
    
    name = "u1_stochmod"
    
    def __init__(self, **kwargs):
        """Initialize the U(1) modulation layer."""
        super(U1Modulation, self).__init__()
        logger.info("Initialized U1 Stochastic Modulation")

    def transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply random continuous rotation.
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (batch, 1, N_t, N_x) with 1D field.
            
        Returns
        -------
        tuple
            (rotated_x, log_modprob) where rotated_x has shape
            (batch, 2, N_t, N_x) and log_modprob = log(1/(2*pi)).
            
        Raises
        ------
        ValueError
            If the input shape is not (batch, 1, N_t, N_x).
        """
        if x.dim() != 4 or x.shape[1] != 1:
            raise ValueError(f"Input tensor should have shape (batch, 1, N_t, N_x) but got {x.shape}")

        # sample random rotation
        angle = torch.rand((x.shape[0],1,1), device=x.device) * 2 * math.pi

        # apply rotation
        x = torch.stack([x[:,0] * torch.cos(angle), 
                         x[:,0] * torch.sin(angle)], dim
                         =1)

        # compute log of modulation probability
        log_modprob = -math.log(2 * math.pi)

        return x, log_modprob
    


class BrokenU1Modulation(StochasticModulation):
    """
    Broken U(1) symmetry modulation layer.
    
    Applies continuous rotations with a learnable angle distribution,
    parameterized by a rational quadratic spline. This allows learning
    non-uniform angle distributions when U(1) symmetry is broken.
    """
    
    name = "broken_u1_stochmod"
    
    def __init__(self, **kwargs):
        """
        Initialize the broken U(1) modulation layer.
        
        Uses a rational quadratic spline with 8 bins for flexible
        angle distribution modeling.
        """
        super(BrokenU1Modulation, self).__init__()

        # Define Rational Quadratic Spline (RQS) transformation
        self.spline = transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
            features=1,
            hidden_features=5,
            num_bins=8,
        )

        logger.info("Initialized Broken U1 Stochastic Modulation with RQS")

    def transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply rotation with learned angle distribution.
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (batch, 1, N_t, N_x).
            
        Returns
        -------
        tuple
            (rotated_x, log_modprob) where rotated_x has shape
            (batch, 2, N_t, N_x).
            
        Raises
        ------
        ValueError
            If the input shape is not (batch, 1, N_t, N_x).
        """
        if x.dim() != 4 or x.shape[1] != 1:
            raise ValueError(f"Input tensor should have shape (batch, 1, N_t, N_x) but got {x.shape}")

        # sample random rotation
        uniform = torch.rand((x.shape[0],1), device=x.device)
        angle, log_det_spline = self.spline(uniform)
        angle = angle.reshape(-1,1,1) * 2 * math.pi

        # apply rotation
        x = torch.stack([x[:,0] * torch.cos(angle), 
                         x[:,0] * torch.sin(angle)], dim
                         =1)

        # compute log of modulation probability (see eq. 43 in paper)
        log_modprob = -log_det_spline - math.log(2 * math.pi)

        return x, log_modprob

