import torch
import torch.nn as nn
import math
import omegaconf
import nflows.transforms as transforms

import logging
logger = logging.getLogger("symmetry")

# set up dict with names and classes
STOCHMODS = {}
def register_stochmod(name):
    def wrap(clss):
        STOCHMODS[name] = clss
        return clss
    return wrap



class StochasticModulation(nn.Module):
    def __init__(self, **kwargs):
        super(StochasticModulation, self).__init__()

    def forward(self, x):
        # flip the sign of the log_modprob part
        # in the flow.sample_with_logpprob function the log_prob is computed as
        # log_prob = prior_log_prob - log_det1 - log_det2 - ... - log_detN
        # to compensate the negative sign in the log_det, we need to flip the sign of the log_modprob for the stochastic modulation
        # As with the log_modprob, the total log_prob is computed as
        # log_prob = prior_log_prob + log_modprob - log_det1 - log_det2 - ... - log_detN
        # see eq. (12) in the paper
        x, log_modprob = self.transform(x)
        return x, -log_modprob

    def reverse(self, x):
        x, log_modprob = self.inverse_transform(x)
        return x, -log_modprob
    
    def transform(self, x):
        raise NotImplementedError
    
    def inverse_transform(self, x):
        raise NotImplementedError



# ==============================
# DISCRETE SYMMETRIES
# ==============================

@register_stochmod("z2_stochmod")
class Z2Modulation(StochasticModulation):
    def __init__(self, **kwargs):
        """
        Description:
            This class implements a Z2 symmetry where the sign of the configuration is flipped with 50% probability.
        """
        super(Z2Modulation, self).__init__()

        logger.info(f"Initialized Z2 Stochastic Modulation")

    def transform(self, x):
        """
        Args:
            x (torch.Tensor): input tensor of shape (Batch, N_t, N_x)

        Returns:
            x (torch.Tensor): output tensor with flipped signs
            log_modprob (torch.Tensor): log of modulation probability p_S
        """
        assert x.dim() == 3, f"dim of x should be 2, got {x.dim()}"

        N = x.shape[0]

        # sample from bernoulli distribution to flip the sign randomly
        bernoulli = torch.distributions.Bernoulli(torch.tensor(0.5, device=x.device, dtype=x.dtype))
        u = bernoulli.sample((N, 1, 1)).to(x.device) # either 0 or 1 with 50% probability
        random_sign = -2 * u + 1 # either 1 (u=0) or -1 (u=1)

        # apply signs
        x = x * random_sign

        # compute log of modulation probability
        log_modprob = math.log(0.5)

        return x, log_modprob
    


@register_stochmod("brokenz2_stochmod")
class BrokenZ2Modulation(StochasticModulation):
    def __init__(self, flip_direction=None, init_breaking=math.log(0.5), **kwargs):
        """
        Args:
            dim (list): list of dimensions where to flip the sign, use [] or None for all dimensions
            init_breaking (float): initial value of the breaking parameter, should be smaller or equal to 0

        Description:
            Flips the sign of the input tensor with a probability depending on the breaking parameter.
            The dimensions that should be flipped are specified in the flip_direction list.
            E.g., if flip_direction = [0], the tensor x = [1, 2, 3] -> [-1, 2, 3] with a probability of e^b
            The breaking parameter is a learnable parameter that is initialized to init_breaking.
            The probability of flipping the sign is given by p = e^b where b is the breaking parameter.
        """
        assert init_breaking <= 0, f"init_breaking should be smaller or equal to 0, got {init_breaking}"
        assert type(flip_direction) == list or type(flip_direction) == omegaconf.listconfig.ListConfig or flip_direction == None, f"flip_direction should be a list or None, got {flip_direction}"

        super(BrokenZ2Modulation, self).__init__()
        self.breaking = nn.Parameter(torch.tensor(init_breaking, dtype=torch.float32), requires_grad=True)
        self.flip_direction = flip_direction

        logger.info(f"Initialized Broken Z2 Stochastic Modulation with breaking parameter {init_breaking:.3f} and flip direction {flip_direction}")

    def transform(self, x):
        """
        Args:
            x (torch.Tensor): input tensor of shape (Batch, N_t, N_x)

        Returns:
            x (torch.Tensor): output tensor with flipped signs
            log_modprob (torch.Tensor): log of modulation probability p_S
        """
        if self.flip_direction != [] and self.flip_direction != None:
            assert max(self.flip_direction) < x.shape[2], f"dim {self.flip_direction} is out of bounds for the N_x = {x.shape[2]}"
        assert x.dim() == 3, f"dim of input tensor x should be 3, got {x.dim()}"

        # the breaking parameter should be smaller or equal to 0
        self.breaking.data.clamp_max_(0)

        # sample from bernoulli distribution to randomly flip the sign
        dist = torch.distributions.Bernoulli(torch.exp(self.breaking))
        u = dist.sample((x.shape[0],)).to(x.device) # either 0 or 1 with probability p = e^b
        random_sign = -2 * u + 1 # either 1 (u=0) or -1 (u=1)

        # create random sign tensor
        if self.flip_direction == None or self.flip_direction == []:
            random_sign_full = random_sign.unsqueeze(1).unsqueeze(2)
        else:
            random_sign_full = torch.ones((x.shape[0], 1, x.shape[2]), device=x.device)
            for dim in self.flip_direction:
                random_sign_full[:,0,dim] = random_sign

        # apply signs
        x = x * random_sign_full

        # compute log of modulation probability
        self.log_modprob = dist.log_prob(u)

        return x, self.log_modprob



@register_stochmod("zn_stochmod")
class ZNModulation(StochasticModulation):
    def __init__(self, n=8, **kwargs):
        """
        Args:
            n (int): refers to the Z_n symmetry, e.g. Z_8 has 8 possible rotations

        Description:
            This class implements a Z_n symmetry that is equivalant to a discrete rotation symmetry in a 2D plane with n possible rotations.
        """
        super(ZNModulation, self).__init__()
        self.n = n

        logger.info(f"Initialized ZN Stochastic Modulation with n={n}")


    def transform(self, x):
        """
        Args:
            x (torch.Tensor): input tensor of shape (Batch, N_t, N_x)
        Returns:
            x (torch.Tensor): output tensor with rotated signs
            log_modprob (torch.Tensor): log of modulation probability p_S
        """
        assert x.shape[2] == 2, f"3rd dim of x should be 2, i.e., (batch, nt, 2), but got {x.shape}"

        # sample random rotation
        random = torch.randint(0, self.n, (x.shape[0],1), device=x.device)
        angle = 2 * math.pi * random / self.n

        # apply rotation
        # x = x @ rotation_matrix
        x = torch.stack([x[:,:,0] * torch.cos(angle) - x[:,:,1] * torch.sin(angle), 
                         x[:,:,0] * torch.sin(angle) + x[:,:,1] * torch.cos(angle)], dim
                         =2)

        # compute log of modulation probability
        log_modprob = math.log(1/self.n)

        return x, log_modprob
    
    
@register_stochmod("broken_zn_stochmod")
class BrokenZNModulation(StochasticModulation):
    def __init__(self, n=8, **kwargs):
        """
        Args:
            n (int): refers to the Z_n symmetry, e.g. Z_8 has 8 possible rotations

        Description:
            This class implements a broken Z_n symmetry that is equivalant to a discrete rotation symmetry in a 2D plane with n possible rotations.
        """
        super(BrokenZNModulation, self).__init__()
        self.n = n
        
        init_breaking = -math.log(n) * torch.ones(n, dtype=torch.float32)
        self.breaking = nn.Parameter(init_breaking, requires_grad=True)

        logger.info(f"Initialized Broken ZN Stochastic Modulation with n={n}")


    def transform(self, x):
        """
        Args:
            x (torch.Tensor): input tensor of shape (Batch, N_t, N_x)
        Returns:
            x (torch.Tensor): output tensor with rotated signs
            log_modprob (torch.Tensor): log of modulation probability p_S
        """
        assert x.shape[2] == 2, f"3rd dim of x should be 2, i.e., (batch, nt, 2), but got {x.shape}"

        # sample random rotation
        dist = torch.distributions.Categorical(logits=self.breaking)
        random = dist.sample((x.shape[0],1)).to(x.device)
        angle = 2 * math.pi * random / self.n

        # apply rotation
        # x = x @ rotation_matrix
        x = torch.stack([x[:,:,0] * torch.cos(angle) - x[:,:,1] * torch.sin(angle), 
                         x[:,:,0] * torch.sin(angle) + x[:,:,1] * torch.cos(angle)], dim
                         =2)

        # compute log of modulation probability
        self.log_modprob = dist.log_prob(random).squeeze(1)

        return x, self.log_modprob
    
    

# ==============================
# CONTINUOUS SYMMETRIES
# ==============================

@register_stochmod("u1_stochmod")
class U1Modulation(StochasticModulation):
    def __init__(self, **kwargs):
        """
        Description:
            This class implements a U1 symmetry that is equivalant to a continuous rotation symmetry in a 2D plane.
        """
        super(U1Modulation, self).__init__()

        logger.info(f"Initialized U1 Stochastic Modulation")


    def transform(self, x):
        """
        Args:
            x (torch.Tensor): input tensor of shape (Batch, 1, N_t, N_x)
        Returns:
            x (torch.Tensor): output tensor with rotated signs of shape (Batch, 2, N_t, N_x)
            log_modprob (torch.Tensor): log of modulation probability p_S
        """
        assert len(x.shape) == 4 and x.shape[1] == 1, f"x should be of shape (Batch, 1, N_t, N_x), got {x.shape}"

        # sample random rotation
        angle = torch.rand((x.shape[0],1,1), device=x.device) * 2 * math.pi

        # apply rotation
        x = torch.stack([x[:,0] * torch.cos(angle), 
                         x[:,0] * torch.sin(angle)], dim
                         =1)

        # compute log of modulation probability
        log_modprob = -math.log(2 * math.pi)

        return x, log_modprob
    


@register_stochmod("brokenu1_stochmod")
class BrokenU1Modulation(StochasticModulation):
    def __init__(self, lat_shape, **kwargs):
        """
        Description:
            This class implements a U1 symmetry that is equivalant to a continuous rotation symmetry in a 2D plane.
            The symmetry is broken which means that some angles are preferred over others.
            This is done by using a rational quadratic spline to sample the angles.
        """
        super(BrokenU1Modulation, self).__init__()

        # Define Rational Quadratic Spline (RQS) transformation
        self.spline = transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
            features=1,
            hidden_features=5,
            num_bins=8,
        )

        logger.info(f"Initialized Broken U1 Stochastic Modulation with RQS")


    def transform(self, x):
        """
        Args:
            x (torch.Tensor): input tensor of shape (Batch, 1, N_t, N_x)
        Returns:
            x (torch.Tensor): output tensor with rotated signs of shape (Batch, 2, N_t, N_x)
            log_modprob (torch.Tensor): log of modulation probability p_S
        """
        assert len(x.shape) == 4 and x.shape[1] == 1, f"x should be of shape (Batch, 1, N_t, N_x), got {x.shape}"

        # sample random rotation
        uniform = torch.rand((x.shape[0],1), device=x.device)
        angle, log_det_spline = self.spline(uniform)
        angle = angle.reshape(-1,1,1) * 2 * math.pi

        # apply rotation
        x = torch.stack([x[:,0] * torch.cos(angle), 
                         x[:,0] * torch.sin(angle)], dim
                         =1)

        # compute log of modulation probability
        # see eq. (43)
        self.log_modprob = -log_det_spline - math.log(2 * math.pi)

        return x, self.log_modprob

