"""
Encoder modules for BSNP.
Implements observation encoder φ_y, parameter encoder φ_λ, and latent encoder φ_z.
"""

import torch
import torch.nn as nn
import math
from typing import Tuple, Optional


class ObservationEncoder(nn.Module):
    """
    Observation encoder φ_y: R^{d_u} → R^{d_y}
    
    Encodes observation values y into feature representations.
    Uses a simple MLP as described in the paper.
    """
    
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_dims: Tuple[int, ...] = (64, 64),
        activation: str = 'swish'
    ):
        """
        Args:
            input_dim: Dimension of observations (d_u)
            output_dim: Dimension of encoded features (d_y)
            hidden_dims: Hidden layer dimensions
            activation: Activation function name
        """
        super().__init__()
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        # Build MLP
        layers = []
        dims = [input_dim] + list(hidden_dims) + [output_dim]
        
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            if i < len(dims) - 2:  # No activation after last layer
                layers.append(self._get_activation(activation))
        
        self.network = nn.Sequential(*layers)
    
    def _get_activation(self, name: str) -> nn.Module:
        """Get activation function by name."""
        if name.lower() == 'relu':
            return nn.ReLU()
        elif name.lower() == 'swish' or name.lower() == 'silu':
            return nn.SiLU()
        elif name.lower() == 'gelu':
            return nn.GELU()
        elif name.lower() == 'tanh':
            return nn.Tanh()
        else:
            raise ValueError(f"Unknown activation: {name}")
    
    def forward(self, y: torch.Tensor) -> torch.Tensor:
        """
        Encode observation values.
        
        Args:
            y: Observation values, shape (batch, n_points, d_u)
        
        Returns:
            Encoded features, shape (batch, n_points, d_y)
        """
        return self.network(y)


class ParameterEncoder(nn.Module):
    """
    Parameter encoder φ_λ: R^{d_λ} → R^{d_λ_enc}
    
    Encodes PDE parameters λ into feature representations.
    """
    
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_dims: Tuple[int, ...] = (64, 64),
        activation: str = 'swish'
    ):
        """
        Args:
            input_dim: Dimension of PDE parameters (d_λ)
            output_dim: Dimension of encoded features (d_λ_enc)
            hidden_dims: Hidden layer dimensions
            activation: Activation function name
        """
        super().__init__()
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        # Build MLP
        layers = []
        dims = [input_dim] + list(hidden_dims) + [output_dim]
        
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            if i < len(dims) - 2:
                layers.append(self._get_activation(activation))
        
        self.network = nn.Sequential(*layers)
    
    def _get_activation(self, name: str) -> nn.Module:
        """Get activation function by name."""
        if name.lower() == 'relu':
            return nn.ReLU()
        elif name.lower() == 'swish' or name.lower() == 'silu':
            return nn.SiLU()
        elif name.lower() == 'gelu':
            return nn.GELU()
        elif name.lower() == 'tanh':
            return nn.Tanh()
        else:
            raise ValueError(f"Unknown activation: {name}")
    
    def forward(self, lambda_params: torch.Tensor) -> torch.Tensor:
        """
        Encode PDE parameters.
        
        Args:
            lambda_params: PDE parameters, shape (batch, d_λ)
        
        Returns:
            Encoded features, shape (batch, d_λ_enc)
        """
        return self.network(lambda_params)


class LatentEncoder(nn.Module):
    """
    Latent encoder φ_z: (R^{d_x} × R^{d_y}) → R^{d_z}
    
    Encodes (location, observation) pairs into latent features.
    This is used to create the aggregated representation on the latent grid.
    """
    
    def __init__(
        self,
        spatial_dim: int,
        observation_dim: int,
        output_dim: int,
        hidden_dims: Tuple[int, ...] = (64, 64),
        activation: str = 'swish'
    ):
        """
        Args:
            spatial_dim: Dimension of spatial locations (d_x)
            observation_dim: Dimension of encoded observations (d_y)
            output_dim: Dimension of latent features (d_z)
            hidden_dims: Hidden layer dimensions
            activation: Activation function name
        """
        super().__init__()
        
        self.spatial_dim = spatial_dim
        self.observation_dim = observation_dim
        self.output_dim = output_dim
        
        # Build MLP
        input_dim = spatial_dim + observation_dim
        layers = []
        dims = [input_dim] + list(hidden_dims) + [output_dim]
        
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            if i < len(dims) - 2:
                layers.append(self._get_activation(activation))
        
        self.network = nn.Sequential(*layers)
    
    def _get_activation(self, name: str) -> nn.Module:
        """Get activation function by name."""
        if name.lower() == 'relu':
            return nn.ReLU()
        elif name.lower() == 'swish' or name.lower() == 'silu':
            return nn.SiLU()
        elif name.lower() == 'gelu':
            return nn.GELU()
        elif name.lower() == 'tanh':
            return nn.Tanh()
        else:
            raise ValueError(f"Unknown activation: {name}")
    
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Encode (location, observation) pairs.
        
        Args:
            x: Spatial locations, shape (batch, n_points, d_x)
            y: Encoded observations, shape (batch, n_points, d_y)
        
        Returns:
            Latent features, shape (batch, n_points, d_z)
        """
        # Concatenate locations and observations
        xy = torch.cat([x, y], dim=-1)
        return self.network(xy)


class SetEncoder(nn.Module):
    """
    Complete set encoder that combines all encoding steps.
    
    This combines:
    1. Observation encoding: y → φ_y(y)
    2. Location-observation encoding: (x, φ_y(y)) → φ_z(x, φ_y(y))
    
    Used to encode context sets into latent representations.
    """
    
    def __init__(
        self,
        spatial_dim: int,
        observation_dim: int,
        latent_dim: int,
        observation_encoder_dim: int = 64,
        hidden_dims: Tuple[int, ...] = (64, 64),
        activation: str = 'swish'
    ):
        """
        Args:
            spatial_dim: Dimension of spatial locations
            observation_dim: Dimension of raw observations
            latent_dim: Dimension of final latent representation
            observation_encoder_dim: Intermediate dimension for φ_y
            hidden_dims: Hidden dimensions for encoders
            activation: Activation function
        """
        super().__init__()
        
        # Observation encoder
        self.observation_encoder = ObservationEncoder(
            input_dim=observation_dim,
            output_dim=observation_encoder_dim,
            hidden_dims=hidden_dims,
            activation=activation
        )
        
        # Latent encoder
        self.latent_encoder = LatentEncoder(
            spatial_dim=spatial_dim,
            observation_dim=observation_encoder_dim,
            output_dim=latent_dim,
            hidden_dims=hidden_dims,
            activation=activation
        )
    
    def forward(
        self,
        x_context: torch.Tensor,
        y_context: torch.Tensor
    ) -> torch.Tensor:
        """
        Encode context set into latent features.
        
        Args:
            x_context: Context locations, shape (batch, n_context, d_x)
            y_context: Context observations, shape (batch, n_context, d_u)
        
        Returns:
            Latent features, shape (batch, n_context, d_z)
        """
        # Encode observations
        y_encoded = self.observation_encoder(y_context)
        
        # Encode (location, observation) pairs
        latent_features = self.latent_encoder(x_context, y_encoded)
        
        return latent_features


class MultiDimensionalEncoder(nn.Module):
    """
    Encoder that handles multi-dimensional inputs by encoding each dimension separately
    and then combining them. Useful for 2D/3D problems.
    """
    
    def __init__(
        self,
        spatial_dim: int,
        observation_dim: int,
        latent_dim: int,
        observation_encoder_dim: int = 64,
        hidden_dims: Tuple[int, ...] = (64, 64),
        activation: str = 'swish',
        use_positional_encoding: bool = True,
        num_frequencies: int = 10
    ):
        """
        Args:
            spatial_dim: Dimension of spatial domain
            observation_dim: Dimension of observations
            latent_dim: Output latent dimension
            observation_encoder_dim: Intermediate dimension
            hidden_dims: Hidden layer sizes
            activation: Activation function
            use_positional_encoding: Whether to use Fourier features for locations
            num_frequencies: Number of frequencies for positional encoding
        """
        super().__init__()
        
        self.spatial_dim = spatial_dim
        self.use_positional_encoding = use_positional_encoding
        self.num_frequencies = num_frequencies
        
        # Positional encoding dimension
        if use_positional_encoding:
            pos_dim = spatial_dim * (2 * num_frequencies + 1)
        else:
            pos_dim = spatial_dim
        
        # Observation encoder
        self.observation_encoder = ObservationEncoder(
            input_dim=observation_dim,
            output_dim=observation_encoder_dim,
            hidden_dims=hidden_dims,
            activation=activation
        )
        
        # Combined encoder
        self.combined_encoder = LatentEncoder(
            spatial_dim=pos_dim,
            observation_dim=observation_encoder_dim,
            output_dim=latent_dim,
            hidden_dims=hidden_dims,
            activation=activation
        )
        
        # Frequency matrix for positional encoding
        if use_positional_encoding:
            # 使用 math.pi 而不是 torch.pi
            self.register_buffer(
                'frequencies',
                2.0 ** torch.linspace(0, num_frequencies - 1, num_frequencies) * math.pi
            )
    
    def positional_encoding(self, x: torch.Tensor) -> torch.Tensor:
        """
        Apply sinusoidal positional encoding (Fourier features).
        
        Args:
            x: Locations, shape (..., d_x)
        
        Returns:
            Encoded locations, shape (..., d_x * (2 * num_frequencies + 1))
        """
        if not self.use_positional_encoding:
            return x
        
        # Compute sin and cos for each frequency
        x_expanded = x[..., :, None] * self.frequencies  # (..., d_x, num_freq)
        sin_features = torch.sin(x_expanded)
        cos_features = torch.cos(x_expanded)
        
        # Concatenate original + sin + cos
        encoded = torch.cat([
            x,
            sin_features.reshape(*x.shape[:-1], -1),
            cos_features.reshape(*x.shape[:-1], -1)
        ], dim=-1)
        
        return encoded
    
    def forward(
        self,
        x_context: torch.Tensor,
        y_context: torch.Tensor
    ) -> torch.Tensor:
        """
        Encode context set with positional encoding.
        
        Args:
            x_context: Context locations, shape (batch, n_context, d_x)
            y_context: Context observations, shape (batch, n_context, d_u)
        
        Returns:
            Latent features, shape (batch, n_context, d_z)
        """
        # Apply positional encoding to locations
        x_encoded = self.positional_encoding(x_context)
        
        # Encode observations
        y_encoded = self.observation_encoder(y_context)
        
        # Combine
        latent_features = self.combined_encoder.network(
            torch.cat([x_encoded, y_encoded], dim=-1)
        )
        
        return latent_features