"""
Decoder modules for BSNP.
Maps from processed latent representation to predictive distribution parameters.
"""

import torch
import torch.nn as nn
from typing import Tuple, Optional
import math


class GaussianDecoder(nn.Module):
    """
    Decoder that outputs parameters of a Gaussian distribution.
    
    Outputs mean μ and log-variance log(σ²) for each spatial location.
    Used for standard regression tasks.
    """
    
    def __init__(
        self,
        input_channels: int,
        output_dim: int = 1,
        hidden_channels: Optional[int] = None,
        min_sigma: float = 1e-4
    ):
        """
        Args:
            input_channels: Number of input channels from backbone
            output_dim: Dimension of output field (e.g., 1 for scalar)
            hidden_channels: Optional hidden layer size
            min_sigma: Minimum allowed standard deviation
        """
        super().__init__()
        
        self.input_channels = input_channels
        self.output_dim = output_dim
        self.min_sigma = min_sigma
        
        if hidden_channels is None:
            hidden_channels = input_channels
        
        # Mean decoder
        self.mean_decoder = nn.Sequential(
            nn.Conv1d(input_channels, hidden_channels, kernel_size=1),
            nn.SiLU(),
            nn.Conv1d(hidden_channels, output_dim, kernel_size=1)
        )
        
        # Log-variance decoder
        self.logvar_decoder = nn.Sequential(
            nn.Conv1d(input_channels, hidden_channels, kernel_size=1),
            nn.SiLU(),
            nn.Conv1d(hidden_channels, output_dim, kernel_size=1)
        )
    
    def forward(self, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decode to Gaussian parameters.
        
        Args:
            h: Latent features, shape (batch, channels, grid_size)
        
        Returns:
            mean: Predicted mean, shape (batch, output_dim, grid_size)
            sigma: Predicted std dev, shape (batch, output_dim, grid_size)
        """
        # Decode mean
        mean = self.mean_decoder(h)
        
        # Decode log-variance and convert to std dev
        logvar = self.logvar_decoder(h)
        sigma = torch.sqrt(torch.exp(logvar) + self.min_sigma)
        
        return mean, sigma


class GaussianDecoder2D(nn.Module):
    """
    2D version of Gaussian decoder for 2D problems.
    """
    
    def __init__(
        self,
        input_channels: int,
        output_dim: int = 1,
        hidden_channels: Optional[int] = None,
        min_sigma: float = 1e-4
    ):
        super().__init__()
        
        self.input_channels = input_channels
        self.output_dim = output_dim
        self.min_sigma = min_sigma
        
        if hidden_channels is None:
            hidden_channels = input_channels
        
        # Mean decoder
        self.mean_decoder = nn.Sequential(
            nn.Conv2d(input_channels, hidden_channels, kernel_size=1),
            nn.SiLU(),
            nn.Conv2d(hidden_channels, output_dim, kernel_size=1)
        )
        
        # Log-variance decoder
        self.logvar_decoder = nn.Sequential(
            nn.Conv2d(input_channels, hidden_channels, kernel_size=1),
            nn.SiLU(),
            nn.Conv2d(hidden_channels, output_dim, kernel_size=1)
        )
    
    def forward(self, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            h: Latent features, shape (batch, channels, height, width)
        
        Returns:
            mean: shape (batch, output_dim, height, width)
            sigma: shape (batch, output_dim, height, width)
        """
        mean = self.mean_decoder(h)
        logvar = self.logvar_decoder(h)
        sigma = torch.sqrt(torch.exp(logvar) + self.min_sigma)
        
        return mean, sigma


class ParameterConditionedDecoder(nn.Module):
    """
    Decoder that is conditioned on PDE parameters λ.
    
    Uses FiLM (Feature-wise Linear Modulation) to condition the decoding
    on the parameter encoding.
    """
    
    def __init__(
        self,
        input_channels: int,
        parameter_dim: int,
        output_dim: int = 1,
        hidden_channels: Optional[int] = None,
        min_sigma: float = 1e-4
    ):
        """
        Args:
            input_channels: Number of input channels from backbone
            parameter_dim: Dimension of encoded parameters
            output_dim: Output field dimension
            hidden_channels: Hidden layer size
            min_sigma: Minimum std dev
        """
        super().__init__()
        
        self.input_channels = input_channels
        self.parameter_dim = parameter_dim
        self.output_dim = output_dim
        self.min_sigma = min_sigma
        
        if hidden_channels is None:
            hidden_channels = input_channels
        
        # FiLM parameter generators
        self.film_mean = nn.Linear(parameter_dim, input_channels * 2)
        self.film_var = nn.Linear(parameter_dim, input_channels * 2)
        
        # Decoders
        self.mean_decoder = nn.Sequential(
            nn.Conv1d(input_channels, hidden_channels, kernel_size=1),
            nn.SiLU(),
            nn.Conv1d(hidden_channels, output_dim, kernel_size=1)
        )
        
        self.logvar_decoder = nn.Sequential(
            nn.Conv1d(input_channels, hidden_channels, kernel_size=1),
            nn.SiLU(),
            nn.Conv1d(hidden_channels, output_dim, kernel_size=1)
        )
    
    def apply_film(
        self,
        h: torch.Tensor,
        gamma: torch.Tensor,
        beta: torch.Tensor
    ) -> torch.Tensor:
        """
        Apply FiLM modulation: γ * h + β
        
        Args:
            h: Features, shape (batch, channels, length)
            gamma: Scale, shape (batch, channels)
            beta: Shift, shape (batch, channels)
        
        Returns:
            Modulated features, same shape as h
        """
        # Reshape for broadcasting
        gamma = gamma.unsqueeze(-1)  # (batch, channels, 1)
        beta = beta.unsqueeze(-1)
        
        return gamma * h + beta
    
    def forward(
        self,
        h: torch.Tensor,
        lambda_encoded: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decode with parameter conditioning.
        
        Args:
            h: Latent features, shape (batch, channels, grid_size)
            lambda_encoded: Encoded parameters, shape (batch, parameter_dim)
        
        Returns:
            mean: shape (batch, output_dim, grid_size)
            sigma: shape (batch, output_dim, grid_size)
        """
        # Generate FiLM parameters
        film_params_mean = self.film_mean(lambda_encoded)
        gamma_mean, beta_mean = torch.chunk(film_params_mean, 2, dim=-1)
        
        film_params_var = self.film_var(lambda_encoded)
        gamma_var, beta_var = torch.chunk(film_params_var, 2, dim=-1)
        
        # Apply FiLM and decode
        h_mean = self.apply_film(h, gamma_mean, beta_mean)
        mean = self.mean_decoder(h_mean)
        
        h_var = self.apply_film(h, gamma_var, beta_var)
        logvar = self.logvar_decoder(h_var)
        sigma = torch.sqrt(torch.exp(logvar) + self.min_sigma)
        
        return mean, sigma


class InterpolatingDecoder(nn.Module):
    """
    Decoder that interpolates from grid to arbitrary query points.
    
    Uses bilinear/trilinear interpolation to evaluate the field at
    any spatial location, not just grid points.
    """
    
    def __init__(
        self,
        base_decoder: nn.Module,
        spatial_dim: int = 1
    ):
        """
        Args:
            base_decoder: Base decoder (e.g., GaussianDecoder)
            spatial_dim: Spatial dimension
        """
        super().__init__()
        
        self.base_decoder = base_decoder
        self.spatial_dim = spatial_dim
    
    def interpolate_1d(
        self,
        field: torch.Tensor,
        x_grid: torch.Tensor,
        x_query: torch.Tensor
    ) -> torch.Tensor:
        """
        1D linear interpolation.
        
        Args:
            field: Field values on grid, shape (batch, dim, grid_size)
            x_grid: Grid coordinates, shape (batch, grid_size, 1)
            x_query: Query coordinates, shape (batch, n_query, 1)
        
        Returns:
            Interpolated values, shape (batch, dim, n_query)
        """
        # For 1D, we need to add a fake height dimension to use grid_sample
        # grid_sample expects: input (N, C, W_in, H_in) and grid (N, W_out, H_out, 2)
        # For 1D: input (N, C, W_in, 1) and grid (N, W_out, 1, 2)
        
        # Normalize coordinates to [-1, 1]
        x_min = x_grid.min(dim=1, keepdim=True)[0]  # (batch, 1, 1)
        x_max = x_grid.max(dim=1, keepdim=True)[0]
        
        x_norm = 2 * (x_query - x_min) / (x_max - x_min + 1e-8) - 1
        
        # Add height dimension to field: (batch, dim, grid_size, 1)
        field_2d = field.unsqueeze(-1)
        
        # Create grid: (batch, n_query, 1, 2)
        # For 1D interpolation, y coordinate is always 0
        batch_size, n_query, _ = x_query.shape
        grid = torch.zeros(batch_size, n_query, 1, 2, device=field.device)
        grid[:, :, 0, 0] = x_norm.squeeze(-1)  # x coordinates: (batch, n_query)
        grid[:, :, 0, 1] = 0  # y coordinates (dummy)
        
        # Interpolate
        interpolated = torch.nn.functional.grid_sample(
            field_2d,
            grid,
            mode='bilinear',
            padding_mode='border',
            align_corners=True
        )
        
        # Remove height dimension: (batch, dim, n_query, 1) -> (batch, dim, n_query)
        return interpolated.squeeze(-1)
    
    def forward(
        self,
        h: torch.Tensor,
        x_grid: torch.Tensor,
        x_query: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decode and interpolate to query points.
        
        Args:
            h: Latent features on grid, shape (batch, channels, grid_size)
            x_grid: Grid coordinates, shape (batch, grid_size, 1)
            x_query: Query coordinates, shape (batch, n_query, 1)
        
        Returns:
            mean: Interpolated mean, shape (batch, output_dim, n_query)
            sigma: Interpolated sigma, shape (batch, output_dim, n_query)
        """
        # Decode on grid
        mean_grid, sigma_grid = self.base_decoder(h)
        
        # Interpolate to query points
        mean_query = self.interpolate_1d(mean_grid, x_grid, x_query)
        sigma_query = self.interpolate_1d(sigma_grid, x_grid, x_query)
        
        return mean_query, sigma_query


def build_decoder(
    spatial_dim: int,
    input_channels: int,
    output_dim: int = 1,
    parameter_dim: Optional[int] = None,
    min_sigma: float = 1e-4,
    use_parameter_conditioning: bool = False
) -> nn.Module:
    """
    Factory function to build appropriate decoder.
    
    Args:
        spatial_dim: Spatial dimension (1 or 2)
        input_channels: Input channels from backbone
        output_dim: Output field dimension
        parameter_dim: Parameter encoding dimension (if conditioning)
        min_sigma: Minimum std dev
        use_parameter_conditioning: Whether to use parameter conditioning
    
    Returns:
        Decoder module
    """
    if use_parameter_conditioning:
        if parameter_dim is None:
            raise ValueError("parameter_dim required for parameter conditioning")
        
        if spatial_dim == 1:
            return ParameterConditionedDecoder(
                input_channels=input_channels,
                parameter_dim=parameter_dim,
                output_dim=output_dim,
                min_sigma=min_sigma
            )
        else:
            raise NotImplementedError(
                "Parameter-conditioned decoder not implemented for 2D"
            )
    else:
        if spatial_dim == 1:
            return GaussianDecoder(
                input_channels=input_channels,
                output_dim=output_dim,
                min_sigma=min_sigma
            )
        elif spatial_dim == 2:
            return GaussianDecoder2D(
                input_channels=input_channels,
                output_dim=output_dim,
                min_sigma=min_sigma
            )
        else:
            raise ValueError(f"Unsupported spatial dimension: {spatial_dim}")