import torch.nn as nn
from pyro.distributions.util import copy_docs_from
from pyro.distributions.torch_transform import TransformModule
from torch.distributions import Transform, constraints
import torch.distributions as tdist
import torch.nn.functional as F
from torch import nn
import torch
import math


class MaskedAutoregressiveTransform1d(nn.Module):
    """
    A masked autoregressive transform using MADE-style masking.
    
    This replaces the PyBlaze implementation.
    """
    
    def __init__(self, dim, *hidden_dims, constrain_scale=True, activation=nn.LeakyReLU()):
        super().__init__()
        self.dim = dim
        self.constrain_scale = constrain_scale
        self.activation = activation
        
        # Create masked linear layers
        layers = []
        in_dim = dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(in_dim, h_dim))
            layers.append(self.activation)
            in_dim = h_dim
        layers.append(nn.Linear(in_dim, 2 * dim))  # Output shift and scale
        
        self.net = nn.Sequential(*layers)
        self.register_buffer('mask', self._create_mask())
        
    def _create_mask(self):
        # Simple mask where each output dimension i depends only on input dimensions < i
        mask = torch.zeros(self.dim, self.dim)
        for i in range(self.dim):
            mask[i, :i] = 1.0
        return mask
        
    def forward(self, x):
        batch_size = x.shape[0]
        output = self.net(x)
        
        # Split into shift and scale
        shift, scale = torch.chunk(output, 2, dim=-1)
        
        # Apply mask to ensure autoregressive property
        shift = shift * self.mask
        scale = scale * self.mask
        
        # Ensure scale is positive and bounded
        if self.constrain_scale:
            scale = torch.sigmoid(scale) + 0.5
        else:
            scale = torch.exp(scale)
            
        # Apply the transformation
        transformed = x * scale + shift
        
        # Log determinant of Jacobian
        log_det = torch.sum(torch.log(scale), dim=-1)
        
        return transformed, log_det


class BatchNormTransform1d(nn.Module):
    """
    A batch normalization transform.
    
    This replaces the PyBlaze implementation.
    """
    
    def __init__(self, dim, momentum=0.1):
        super().__init__()
        self.bn = nn.BatchNorm1d(dim, momentum=momentum)
        
    def forward(self, x):
        y = self.bn(x)
        scale = torch.exp(self.bn.weight.expand_as(x))
        log_det = torch.sum(torch.log(scale), dim=-1)
        return y, log_det


class RadialTransform(nn.Module):
    """
    A radial flow transformation.
    
    This replaces the PyBlaze implementation.
    """
    
    def __init__(self, dim):
        super().__init__()
        self.alpha = nn.Parameter(torch.randn(1))
        self.beta = nn.Parameter(torch.randn(1))
        self.z_0 = nn.Parameter(torch.randn(dim))
        self.dim = dim
        
    def forward(self, x):
        # Compute the distance to z_0
        diff = x - self.z_0
        r = torch.norm(diff, p=2, dim=-1, keepdim=True)
        
        # Ensure invertibility
        beta = torch.exp(self.beta)
        alpha = -beta + torch.log(1 + torch.exp(self.alpha))
        
        # Apply transformation
        h = 1 / (alpha + r)
        output = x + beta * h * diff
        
        # Log determinant - ensure it's scalar per sample
        log_det = (self.dim - 1) * torch.log(1 + beta * h) + torch.log(1 + beta * h - beta * h**2 * r)
        log_det = log_det.reshape(-1)  # Ensure shape [batch_size]
        
        return output, log_det


class NormalizingFlow(nn.Module):
    """
    A normalizing flow consisting of a given number of predefined transform layer types.
    
    This replaces the PyBlaze implementation.
    """

    def __init__(self, transforms):
        """
        Initializes a new normalizing flow with a list of transforms.

        Parameters
        ----------
        transforms: list
            List of transform modules to apply in sequence.
        """
        super().__init__()
        self.transforms = nn.ModuleList(transforms)

    def forward(self, x):
        """
        Applies all transforms in sequence and returns the result along with
        the log determinant of the Jacobian.
        
        Parameters
        ----------
        x: torch.Tensor
            Input tensor.
            
        Returns
        -------
        tuple(torch.Tensor, torch.Tensor)
            The transformed input and the log determinant.
        """
        log_det_sum = torch.zeros(x.shape[0], device=x.device)
        
        for transform in self.transforms:
            x, log_det = transform(x)
            # Ensure log_det has shape [batch_size] for proper addition
            if log_det.dim() > 1:
                log_det = log_det.sum(dim=1)
            log_det_sum += log_det
            
        return x, log_det_sum
        
    @staticmethod
    def create(dim, maf_layers=0, radial_layers=0, flow_size=0.5, n_hidden=2, 
               batch_norm=False, activation=nn.LeakyReLU()):
        """
        Factory method to create a normalizing flow with specified parameters.
        
        Parameters
        ----------
        dim: int
            The dimension of the input.
        maf_layers: int, default: 0
            The number of MAF layers.
        radial_layers: int, default: 0
            The number of radial transform layers.
        flow_size: float, default: 0.5
            A multiplier for the hidden dimensions of the MADE model based on the hidden dimension.
        n_hidden: int, default: 2
            The number of hidden layers for MAF transforms.
        batch_norm: bool, default: False
            Whether to apply batch normalization after every MAF layer.
        activation: torch.nn.Module, default: nn.LeakyReLU()
            The activation function to use for MAF layers.
        """
        flow_size = int(dim * flow_size)
        transforms = []
        
        # Add MAF layers
        for _ in range(maf_layers):
            transforms.append(MaskedAutoregressiveTransform1d(
                dim, *([flow_size] * n_hidden), constrain_scale=not batch_norm,
                activation=activation
            ))
            if batch_norm:
                transforms.append(BatchNormTransform1d(dim, momentum=0.5))
                
        # Add radial layers
        for _ in range(radial_layers):
            transforms.append(RadialTransform(dim))
            
        return NormalizingFlow(transforms)


class Density(nn.Module):
    """Class for representing data of different classes as a collection of normalizing flows."""

    def __init__(self, dim_latent, num_mixture_elements, radial_layers=10, maf_layers=0, gaussian_layers=0, use_batched_flow=False):
        """Initializes a new flow-based density estimator

        Args:
            dim_latent (int): dimension of latent space
            num_mixture_elements (int): number of mixture elements (e.g. classes)
            radial_layers (int, optional): number of radial transformation in flow. Defaults to 10.
            maf_layers (int, optional): number of masked autoregressive flows. Defaults to 0.
            gaussian_layers (int, optional): number of simple Gaussian layers. Defaults to 0.
            use_batched_flow (bool, optional): whether to use batched flow. Defaults to False.
        """
        super().__init__()

        self.dim_latent = dim_latent
        self.num_mixture_elements = num_mixture_elements

        # create a separate flow for every mixture element with dim_latent inputs
        flows = []
        for i in range(self.num_mixture_elements):
            flow = NormalizingFlow.create(
                dim_latent, maf_layers=maf_layers, radial_layers=radial_layers,
                flow_size=1.0, batch_norm=False
            )
            flows.append(flow)

        self.flows = nn.ModuleList(flows)

        # Define parameters for standard normal base distribution that will be properly moved to the right device
        self.register_buffer('loc', torch.zeros(dim_latent))
        self.register_buffer('scale', torch.ones(dim_latent))

    def evidence(self, a, b, epsilon=1e-10):
        """Compute evidence (log marginal likelihood) for a multivariate normal

        Args:
            a (Tensor): first parameter tensor
            b (Tensor): second parameter tensor
            epsilon (float, optional): small constant for numerical stability. Defaults to 1e-10.

        Returns:
            Tensor: evidence
        """

        # compute log p(z | c) + log p(c)
        log_probs = a + b

        # compute log p(z) = log \sum_c p(z, c)
        # note: we're summing class-wise
        max_log_probs = torch.max(log_probs, -1, keepdim=True)[0]
        stabled_log_probs = log_probs - max_log_probs
        evidence = torch.log(torch.sum(torch.exp(stabled_log_probs), -1) + epsilon) + max_log_probs.squeeze()
        return evidence

    def forward(self, inputs, labels=None):
        # inputs: [bs, D]
        batch_size = inputs.shape[0]
        device = inputs.device

        # Instantiate the base distribution on the correct device
        base_dist = tdist.Normal(self.loc, self.scale)

        # compute by class in a batch-wise manner
        # outputs = [batch-size, num_mixture_elements]
        log_prob = torch.zeros((batch_size, self.num_mixture_elements), device=device)

        for c in range(self.num_mixture_elements):
            # compute log p(z | c)
            recon, log_det = self.flows[c](inputs)
            log_pz = torch.sum(base_dist.log_prob(recon), dim=1)

            # store to log_probs
            log_prob[:, c] = log_pz + log_det

        # optionally compute cross-entropy / negative log-likelihood
        if labels is not None:
            ce = -torch.gather(log_prob, 1, labels.unsqueeze(1)).squeeze()
            nll = -self.evidence(log_prob, torch.zeros_like(log_prob))
            return log_prob, (ce, nll)

        return log_prob
