"""
MLPGaussian: Variance-aware MLP for uncertainty propagation.

Supports two propagation modes:
1. SDP (Stable Distribution Propagation) - fast, optimal for small variance [DEFAULT]
2. Moment Matching - accurate, better for large variance

Variance Propagation:
- Linear: Exact via quadratic form Var[Wx+b] = W·Var[x]·W^T (Kingma et al. 2015)
- ReLU (SDP): Optimal piecewise rule (fast)
- ReLU (Moment): Exact truncated Gaussian moments (Hernández-Lobato & Adams 2015)
- GELU/SiLU (SDP): Local linearization via derivatives
- GELU/SiLU (Moment): Approximate moment matching

References:
- Hernández-Lobato & Adams (2015): "Probabilistic Backpropagation for Scalable Learning of BNNs"
- Kingma et al. (2015): "Variational Dropout and the Local Reparameterization Trick"
- SDP approach: Local linearization (standard in uncertainty propagation)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, List
import math


class MLPGaussian(nn.Module):
    """
    MLP that propagates Gaussian uncertainty through layers.

    Supports two propagation modes:
    - 'sdp': Stable Distribution Propagation (fast, optimal for small variance) [DEFAULT]
    - 'moment_matching': Analytical moments (accurate for large variance)

    API:
        forward(x_mean, x_var=None) -> (out_mean, out_var)

    If x_var=None, behaves like standard MLP (backward compatible).
    """

    def __init__(
        self,
        dim_in: int,
        dim_out: int,
        hidden_units: List[int] = None,
        activation: str = 'gelu',
        dropout: float = 0.0,
        bias: bool = True,
        propagation_mode: str = 'sdp'
    ):
        """
        Initialize variance-aware MLP.

        Args:
            dim_in: Input dimension
            dim_out: Output dimension
            hidden_units: List of hidden layer sizes (empty list = direct projection)
            activation: Activation function ('relu', 'gelu', 'silu')
            dropout: Dropout rate (applied to mean only, variance scaled accordingly)
            bias: Use bias in linear layers
            propagation_mode: 'sdp' (fast, default) or 'moment_matching' (accurate)
        """
        super().__init__()

        self.dim_in = dim_in
        self.dim_out = dim_out
        self.hidden_units = hidden_units if hidden_units is not None else []
        self.activation = activation.lower()
        self.dropout_rate = dropout
        self.bias = bias
        self.propagation_mode = propagation_mode.lower()

        # Validate propagation mode
        if self.propagation_mode not in ['sdp', 'moment_matching']:
            raise ValueError(f"propagation_mode must be 'sdp' or 'moment_matching', got '{propagation_mode}'")

        # Build layer architecture
        layers = []
        layer_dims = [dim_in] + self.hidden_units + [dim_out]

        for i in range(len(layer_dims) - 1):
            layers.append(nn.Linear(layer_dims[i], layer_dims[i+1], bias=bias))

        self.layers = nn.ModuleList(layers)

        # Dropout (applied to mean, variance scaled by (1-p)²)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

    def _propagate_linear(
        self,
        mean: torch.Tensor,
        var: Optional[torch.Tensor],
        weight: torch.Tensor,
        bias: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Propagate Gaussian through linear layer: y = Wx + b

        Mean: E[y] = W·E[x] + b
        Variance: Var[y] = W·Var[x]·W^T

        Args:
            mean: Input mean (B, L, d_in)
            var: Input variance (B, L, d_in) [diagonal covariance]
            weight: Linear weight (d_out, d_in)
            bias: Linear bias (d_out,) or None

        Returns:
            (output_mean, output_variance)
        """
        # Mean propagation (standard linear)
        out_mean = F.linear(mean, weight, bias)

        # Variance propagation (diagonal covariance assumption)
        if var is not None:
            # Var[Wx] = W · diag(var) · W^T
            # For diagonal covariance, this simplifies to:
            # Var[y_i] = sum_j (W_ij^2 · var_j)
            out_var = F.linear(var, weight.pow(2), None)  # Element-wise W^2, no bias
        else:
            out_var = None

        return out_mean, out_var

    # ==================================================================
    # SDP: Activation derivatives for local linearization
    # ==================================================================
    @staticmethod
    def _relu_derivative(x: torch.Tensor) -> torch.Tensor:
        """d/dx ReLU(x) = 1 for x > 0, 0 otherwise"""
        return (x > 0).to(x.dtype)

    @staticmethod
    def _silu_derivative(x: torch.Tensor) -> torch.Tensor:
        """d/dx SiLU(x) = sigmoid(x) + x·sigmoid(x)·(1 - sigmoid(x))"""
        sig = torch.sigmoid(x)
        return sig + x * sig * (1.0 - sig)

    @staticmethod
    def _gelu_derivative(x: torch.Tensor) -> torch.Tensor:
        """d/dx GELU(x) using approximate derivative"""
        sqrt_2 = math.sqrt(2.0)
        sqrt_2_over_pi = math.sqrt(2.0 / math.pi)

        erf_term = torch.erf(x / sqrt_2)
        exp_term = torch.exp(-0.5 * x.pow(2))

        return 0.5 * (1.0 + erf_term) + 0.5 * x * sqrt_2_over_pi * exp_term

    # ==================================================================
    # SDP: Propagation methods (fast, local linearization)
    # ==================================================================
    def _propagate_relu_sdp(
        self,
        mean: torch.Tensor,
        var: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        ReLU SDP rule (optimal for TV distance):
        - if μ ≥ 0: (μ, var) stays (ReLU is identity)
        - if μ < 0: (0, 0)      (point mass at zero)
        """
        out_mean = F.relu(mean)

        if var is None:
            return out_mean, None

        mask = (mean >= 0).to(mean.dtype)
        out_var = var * mask

        return out_mean, out_var

    def _propagate_gelu_sdp(
        self,
        mean: torch.Tensor,
        var: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        GELU SDP rule (local linearization):
        mean_out = GELU(μ)
        var_out  = (GELU'(μ))² · var
        """
        out_mean = F.gelu(mean)

        if var is None:
            return out_mean, None

        d = self._gelu_derivative(mean)
        out_var = (d.pow(2) * var).clamp(min=0.0)

        return out_mean, out_var

    def _propagate_silu_sdp(
        self,
        mean: torch.Tensor,
        var: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        SiLU SDP rule (local linearization):
        mean_out = SiLU(μ)
        var_out  = (SiLU'(μ))² · var
        """
        out_mean = F.silu(mean)

        if var is None:
            return out_mean, None

        d = self._silu_derivative(mean)
        out_var = (d.pow(2) * var).clamp(min=0.0)

        return out_mean, out_var

    # ==================================================================
    # Moment Matching: Propagation methods (accurate, analytical)
    # ==================================================================
    def _propagate_relu(
        self,
        mean: torch.Tensor,
        var: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Propagate Gaussian through ReLU using analytical moment matching.

        For x ~ N(μ, σ²), compute moments of ReLU(x) using truncated Gaussian formulas:
        - E[ReLU(x)] = σ·φ(α) + μ·Φ(α)
        - Var[ReLU(x)] = E[ReLU(x)²] - E[ReLU(x)]²

        where α = μ/σ, φ = standard normal PDF, Φ = standard normal CDF

        Args:
            mean: Input mean (B, L, D)
            var: Input variance (B, L, D)

        Returns:
            (output_mean, output_variance)
        """
        # Mean-only fallback
        if var is None:
            return F.relu(mean), None

        # Numerical stability
        eps = 1e-8
        std = (var + eps).sqrt()

        # Standardized variable: α = μ/σ
        alpha = mean / (std + eps)

        # Standard normal PDF and CDF
        phi = torch.exp(-0.5 * alpha.pow(2)) / math.sqrt(2 * math.pi)
        Phi = 0.5 * (1 + torch.erf(alpha / math.sqrt(2)))

        # First moment: E[ReLU(x)] = σ·φ(α) + μ·Φ(α)
        out_mean = std * phi + mean * Phi

        # Second moment: E[ReLU(x)²] = (μ² + σ²)·Φ(α) + μ·σ·φ(α)
        second_moment = (mean.pow(2) + var) * Phi + mean * std * phi

        # Variance: Var[ReLU(x)] = E[ReLU(x)²] - E[ReLU(x)]²
        out_var = second_moment - out_mean.pow(2)
        out_var = out_var.clamp(min=0)  # Ensure non-negative

        return out_mean, out_var

    def _propagate_gelu(
        self,
        mean: torch.Tensor,
        var: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Propagate Gaussian through GELU using moment matching approximation.

        GELU(x) = x·Φ(x) where Φ is standard normal CDF.
        For x ~ N(μ, σ²), we approximate moments similar to ReLU.

        This is an approximation - exact GELU moments are intractable.

        Args:
            mean: Input mean (B, L, D)
            var: Input variance (B, L, D)

        Returns:
            (output_mean, output_variance)
        """
        # Mean-only fallback
        if var is None:
            return F.gelu(mean), None

        # Numerical stability
        eps = 1e-8
        std = (var + eps).sqrt()

        # Standardized variable
        alpha = mean / (std + eps)

        # Standard normal PDF and CDF
        phi = torch.exp(-0.5 * alpha.pow(2)) / math.sqrt(2 * math.pi)
        Phi = 0.5 * (1 + torch.erf(alpha / math.sqrt(2)))

        # Approximate first moment using GELU ≈ x·Φ(x)
        # E[GELU(x)] ≈ E[x·Φ(x)]
        # For small σ: E[GELU(x)] ≈ μ·Φ(μ) + σ·φ(μ)
        out_mean = mean * Phi + std * phi

        # Approximate second moment
        # E[GELU(x)²] ≈ E[x²·Φ(x)²]
        # For small σ: E[x²·Φ(x)²] ≈ (μ² + σ²)·Φ(μ)²
        second_moment = (mean.pow(2) + var) * Phi.pow(2) + 0.5 * var * phi

        # Variance
        out_var = second_moment - out_mean.pow(2)
        out_var = out_var.clamp(min=0)

        return out_mean, out_var

    def _propagate_silu(
        self,
        mean: torch.Tensor,
        var: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Propagate Gaussian through SiLU (Swish) using moment matching.

        SiLU(x) = x·σ(x) where σ is sigmoid.
        Similar approximation strategy as GELU.

        Args:
            mean: Input mean (B, L, D)
            var: Input variance (B, L, D)

        Returns:
            (output_mean, output_variance)
        """
        # Mean-only fallback
        if var is None:
            return F.silu(mean), None

        # Numerical stability
        eps = 1e-8
        std = (var + eps).sqrt()

        # Sigmoid approximation using tanh: σ(x) ≈ 0.5·(1 + tanh(x/2))
        # For Gaussian input, approximate moments
        sigmoid_mean = torch.sigmoid(mean)

        # Approximate E[SiLU(x)] = E[x·σ(x)]
        out_mean = mean * sigmoid_mean + 0.5 * var * sigmoid_mean * (1 - sigmoid_mean)

        # Approximate variance (crude approximation)
        second_moment = (mean.pow(2) + var) * sigmoid_mean.pow(2)
        out_var = (second_moment - out_mean.pow(2)).clamp(min=0)

        return out_mean, out_var

    def _apply_activation(
        self,
        mean: torch.Tensor,
        var: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Apply activation with variance propagation.
        Switches between SDP and moment matching based on propagation_mode.

        Args:
            mean: Input mean
            var: Input variance (or None for mean-only)

        Returns:
            (output_mean, output_variance)
        """
        # Select propagation method based on mode
        if self.propagation_mode == 'sdp':
            # SDP: Fast local linearization
            if self.activation == 'relu':
                return self._propagate_relu_sdp(mean, var)
            elif self.activation == 'gelu':
                return self._propagate_gelu_sdp(mean, var)
            elif self.activation == 'silu':
                return self._propagate_silu_sdp(mean, var)
            else:
                return mean, var
        else:  # moment_matching
            # Moment Matching: Analytical (ReLU) or approximate (GELU/SiLU)
            if self.activation == 'relu':
                return self._propagate_relu(mean, var)
            elif self.activation == 'gelu':
                return self._propagate_gelu(mean, var)
            elif self.activation == 'silu':
                return self._propagate_silu(mean, var)
            else:
                return mean, var

    def forward(
        self,
        x_mean: torch.Tensor,
        x_var: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Forward pass with variance propagation.

        Args:
            x_mean: Input mean of shape (B, L, dim_in)
            x_var: Input variance of shape (B, L, dim_in) or None

        Returns:
            (output_mean, output_variance)
            - If x_var=None, output_variance=None (mean-only mode)
            - Otherwise, both mean and variance propagated
        """
        mean = x_mean
        var = x_var

        # Process through all layers
        for i, layer in enumerate(self.layers):
            # Linear propagation
            mean, var = self._propagate_linear(mean, var, layer.weight, layer.bias)

            # Activation (skip on last layer)
            if i < len(self.layers) - 1:
                mean, var = self._apply_activation(mean, var)

                # Dropout (applied to mean, variance scaled)
                if self.dropout is not None and self.training:
                    if var is not None:
                        # Dropout mask applied to mean
                        dropout_mask = (torch.rand_like(mean) > self.dropout_rate).float()
                        mean = mean * dropout_mask / (1 - self.dropout_rate)
                        # Variance scaled by dropout probability
                        # Var[dropout(x)] = Var[x] / (1-p) when active
                        var = var / (1 - self.dropout_rate)
                    else:
                        mean = self.dropout(mean)

        return mean, var
