import torch
import torch.nn as nn
from omegaconf import ListConfig


class MLP(nn.Module):
    """
    A multi-layer perceptron (MLP) with ReLU activation function and optional layer normalization and dropout.
    """
    def __init__(self, input_size, hidden_layer_size, output_size, layer_norm=True, drop_out=True, drop_out_p=0.3, bias=True):
        super(MLP, self).__init__()
        layers = []
        in_features = input_size

        if isinstance(hidden_layer_size, (list, ListConfig)):
            for out_features in hidden_layer_size:
                layers.append(nn.Linear(in_features, out_features, bias=bias))
                if layer_norm:
                    layers.append(nn.LayerNorm(out_features))
                layers.append(nn.ReLU())
                if drop_out:
                    layers.append(nn.Dropout(drop_out_p))
                in_features = out_features

        layers.append(nn.Linear(in_features, output_size, bias=bias))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        if len(x.shape) > 2:
            x = torch.flatten(x, start_dim=1)
        return self.model(x)


class mu_X_Given_Z_Estimator(nn.Module):
    """Estimates conditional mean E[X|Z]."""
    def __init__(self, input_dim=19, hidden_size=128, output_size=1,
                 layer_norm=True, drop_out=True, drop_out_p=0.3):
        super().__init__()
        if isinstance(hidden_size, int):
            hidden_sizes = [hidden_size]
        else:
            hidden_sizes = list(hidden_size)

        layers = []
        prev_size = input_dim
        for h_size in hidden_sizes:
            layers.append(nn.Linear(prev_size, h_size))
            if layer_norm:
                layers.append(nn.LayerNorm(h_size))
            layers.append(nn.ReLU())
            if drop_out:
                layers.append(nn.Dropout(drop_out_p))
            prev_size = h_size

        self.shared = nn.Sequential(*layers)
        self.output = nn.Linear(prev_size, output_size)

    def forward(self, z):
        h = self.shared(z)
        mu = self.output(h)
        return mu


class GMMN_Estimator(nn.Module):
    """
    Generative Moment Matching Network for estimating P(X|Z).
    Maps noise eta ~ N(0, I) and conditioning variable z to samples from P(X|Z=z).
    """
    def __init__(self, input_dim=19, hidden_size=128, output_size=1,
                 drop_out=False, drop_out_p=0.3, noise_dim=16, layer_norm=True):
        super().__init__()
        self.input_dim = input_dim
        self.output_size = output_size
        self.noise_dim = noise_dim

        if isinstance(hidden_size, int):
            hidden_sizes = [hidden_size]
        else:
            hidden_sizes = list(hidden_size)

        layers = []
        prev_size = noise_dim + input_dim
        for h_size in hidden_sizes:
            layers.append(nn.Linear(prev_size, h_size))
            if layer_norm:
                layers.append(nn.LayerNorm(h_size))
            layers.append(nn.ReLU())
            prev_size = h_size

        layers.append(nn.Linear(prev_size, output_size))
        self.generator = nn.Sequential(*layers)

    def forward(self, eta, z):
        return self.generator(torch.cat([eta, z], dim=-1))

    def sample(self, z, n_samples=1):
        batch_size = z.shape[0]
        device = z.device

        if n_samples == 1:
            eta = torch.randn(batch_size, self.noise_dim, device=device)
            return self.forward(eta, z)
        else:
            return self.sample_multiple(z, n_samples)

    def sample_multiple(self, z, M):
        batch_size = z.shape[0]
        device = z.device
        z_repeated = z.repeat_interleave(M, dim=0)
        eta = torch.randn(batch_size * M, self.noise_dim, device=device)
        samples = self.forward(eta, z_repeated)
        return samples.view(batch_size, M, self.output_size)


class NormalizedGMMN_Estimator(nn.Module):
    """
    Wrapper around GMMN_Estimator that applies z-score normalization.

    This class handles normalization of inputs (Z) and outputs (X) automatically:
    - During training: learns on normalized data (mean=0, std=1)
    - During sampling: normalizes inputs and denormalizes outputs

    Usage:
        model = NormalizedGMMN_Estimator(input_dim=19, output_size=1)
        model.fit_normalization(Z_data, X_data)  # Compute normalization stats

        # Training (use normalized data internally)
        x_fake = model.sample_multiple(z, M=5)

        # Sampling (automatically normalizes/denormalizes)
        x_sample = model.sample(z, n_samples=1)
    """
    def __init__(self, input_dim=19, hidden_size=128, output_size=1,
                 drop_out=False, drop_out_p=0.3, noise_dim=16, layer_norm=True):
        super().__init__()

        # Core GMMN model
        self.gmmn = GMMN_Estimator(
            input_dim=input_dim,
            hidden_size=hidden_size,
            output_size=output_size,
            drop_out=drop_out,
            drop_out_p=drop_out_p,
            noise_dim=noise_dim,
            layer_norm=layer_norm
        )

        # Normalization statistics (will be computed from data)
        self.register_buffer('x_mean', None)
        self.register_buffer('x_std', None)
        self.register_buffer('z_mean', None)
        self.register_buffer('z_std', None)
        self._normalization_fitted = False

        # Expose these for compatibility
        self.input_dim = input_dim
        self.output_size = output_size
        self.noise_dim = noise_dim

    def fit_normalization(self, z_data, x_data, eps=1e-8):
        """
        Compute normalization statistics from data.

        Args:
            z_data: Conditioning variables (n, input_dim)
            x_data: Target variables (n, output_size)
            eps: Small constant to avoid division by zero
        """
        device = next(self.parameters()).device
        z_data = z_data.to(device)
        x_data = x_data.to(device)

        # Compute mean and std
        self.x_mean = x_data.mean(dim=0, keepdim=True)
        self.x_std = x_data.std(dim=0, keepdim=True) + eps
        self.z_mean = z_data.mean(dim=0, keepdim=True)
        self.z_std = z_data.std(dim=0, keepdim=True) + eps

        self._normalization_fitted = True

    def normalize_x(self, x):
        """Normalize X using z-score."""
        if not self._normalization_fitted:
            return x
        return (x - self.x_mean) / self.x_std

    def denormalize_x(self, x_normalized):
        """Denormalize X from z-score."""
        if not self._normalization_fitted:
            return x_normalized
        return x_normalized * self.x_std + self.x_mean

    def normalize_z(self, z):
        """Normalize Z using z-score."""
        if not self._normalization_fitted:
            return z
        return (z - self.z_mean) / self.z_std

    def denormalize_z(self, z_normalized):
        """Denormalize Z from z-score."""
        if not self._normalization_fitted:
            return z_normalized
        return z_normalized * self.z_std + self.z_mean

    def forward(self, eta, z):
        """
        Forward pass through GMMN.
        Assumes z is already normalized if normalization is fitted.
        """
        return self.gmmn.forward(eta, z)

    def sample(self, z, n_samples=1):
        """
        Sample from P(X|Z).
        Automatically normalizes input z and denormalizes output x.

        Args:
            z: Conditioning variables (batch_size, input_dim)
            n_samples: Number of samples per z

        Returns:
            x_samples: Samples in original (denormalized) space
        """
        z_normalized = self.normalize_z(z)
        x_normalized = self.gmmn.sample(z_normalized, n_samples)
        return self.denormalize_x(x_normalized)

    def sample_multiple(self, z, M):
        """
        Sample M samples from P(X|Z) for each z.
        Automatically normalizes input z and denormalizes output x.

        Args:
            z: Conditioning variables (batch_size, input_dim)
            M: Number of samples per z

        Returns:
            x_samples: Samples in original space (batch_size, M, output_size)
        """
        z_normalized = self.normalize_z(z)
        x_normalized = self.gmmn.sample_multiple(z_normalized, M)
        return self.denormalize_x(x_normalized)

    def sample_normalized(self, z_normalized, n_samples=1):
        """
        Sample from normalized space (for training).
        Input and output are both in normalized space.

        Args:
            z_normalized: Pre-normalized conditioning variables
            n_samples: Number of samples

        Returns:
            x_normalized: Samples in normalized space
        """
        return self.gmmn.sample(z_normalized, n_samples)

    def sample_multiple_normalized(self, z_normalized, M):
        """
        Sample M samples in normalized space (for training).

        Args:
            z_normalized: Pre-normalized conditioning variables
            M: Number of samples per z

        Returns:
            x_normalized: Samples in normalized space (batch_size, M, output_size)
        """
        return self.gmmn.sample_multiple(z_normalized, M)


class MMDEMLP(MLP):
    """
    MMDEMLP is an extension of the base MLP (Multi-Layer Perceptron) for DAVT.

    This class implements a custom forward operation that compares two inputs
    and computes log(1 + tanh(g(x) - g(y))) for the DAVT e-value calculation.
    """

    def __init__(self, input_size, hidden_layer_size, output_size, layer_norm=True,
                 drop_out=True, drop_out_p=0.3, bias=True, flatten=True):
        """
        Initialize the MMDEMLP model.

        Args:
        - input_size (int): Size of input layer
        - hidden_layer_size (int or list): Size(s) of hidden layer(s)
        - output_size (int): Size of output layer
        - layer_norm (bool): Whether to apply layer normalization
        - drop_out (bool): Whether to apply dropout
        - drop_out_p (float): Dropout probability
        - bias (bool): Whether to use bias in linear layers
        - flatten (bool): Whether to flatten input tensors
        """
        super(MMDEMLP, self).__init__(
            input_size, hidden_layer_size, output_size,
            layer_norm, drop_out, drop_out_p, bias
        )
        self.sigma = torch.nn.Tanh()
        self.flatten = flatten

    def forward(self, x, y) -> torch.Tensor:
        """
        Forward pass for the MMDEMLP model.

        Args:
        - x (torch.Tensor): First input tensor (z = [a, b, c])
        - y (torch.Tensor): Second input tensor (tau_z = [tilde_a, b, c])

        Returns:
        - torch.Tensor: log(1 + tanh(g(x) - g(y)))
        """
        if len(x.shape) > 2 or len(y.shape) > 2:
            if self.flatten:
                x = torch.flatten(x, start_dim=1)
                y = torch.flatten(y, start_dim=1)
                g_x = self.model(x)
                g_y = self.model(y)
            else:
                num_samples = x.shape[-1]
                g_x, g_y = 0, 0
                for i in range(num_samples):
                    g_x += self.model(torch.flatten(x[..., i], start_dim=1)) / num_samples
                    g_y += self.model(torch.flatten(y[..., i], start_dim=1)) / num_samples
        else:
            g_x = self.model(x)
            g_y = self.model(y)

        output = torch.log(1 + self.sigma(g_x - g_y))
        return output
