"""
Convolutional backbone network for BSNP.
Implements the CNN architecture that processes the aggregated latent representation.
"""

import torch
import torch.nn as nn
from typing import Tuple, Optional


class ResidualBlock(nn.Module):
    """
    Residual block with two convolutional layers.
    
    Architecture:
        x -> Conv -> Activation -> Conv -> + -> Activation
        |                                   |
        +-----------------------------------+
    """
    
    def __init__(
        self,
        channels: int,
        kernel_size: int = 3,
        activation: str = 'swish',
        use_batch_norm: bool = False
    ):
        """
        Args:
            channels: Number of channels (input and output)
            kernel_size: Size of convolutional kernels
            activation: Activation function name
            use_batch_norm: Whether to use batch normalization
        """
        super().__init__()
        
        padding = kernel_size // 2  # Same padding
        
        layers = []
        
        # First conv layer
        layers.append(nn.Conv1d(channels, channels, kernel_size, padding=padding))
        if use_batch_norm:
            layers.append(nn.BatchNorm1d(channels))
        layers.append(self._get_activation(activation))
        
        # Second conv layer
        layers.append(nn.Conv1d(channels, channels, kernel_size, padding=padding))
        if use_batch_norm:
            layers.append(nn.BatchNorm1d(channels))
        
        self.conv_block = nn.Sequential(*layers)
        self.activation = self._get_activation(activation)
    
    def _get_activation(self, name: str) -> nn.Module:
        """Get activation function by name."""
        if name.lower() == 'relu':
            return nn.ReLU()
        elif name.lower() == 'swish' or name.lower() == 'silu':
            return nn.SiLU()
        elif name.lower() == 'gelu':
            return nn.GELU()
        elif name.lower() == 'tanh':
            return nn.Tanh()
        else:
            raise ValueError(f"Unknown activation: {name}")
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with residual connection.
        
        Args:
            x: Input tensor, shape (batch, channels, length)
        
        Returns:
            Output tensor, same shape as input
        """
        return self.activation(x + self.conv_block(x))


class ResidualBlock2D(nn.Module):
    """
    2D Residual block for 2D problems (e.g., Navier-Stokes).
    """
    
    def __init__(
        self,
        channels: int,
        kernel_size: int = 3,
        activation: str = 'swish',
        use_batch_norm: bool = False
    ):
        super().__init__()
        
        padding = kernel_size // 2
        
        layers = []
        
        # First conv layer
        layers.append(nn.Conv2d(channels, channels, kernel_size, padding=padding))
        if use_batch_norm:
            layers.append(nn.BatchNorm2d(channels))
        layers.append(self._get_activation(activation))
        
        # Second conv layer
        layers.append(nn.Conv2d(channels, channels, kernel_size, padding=padding))
        if use_batch_norm:
            layers.append(nn.BatchNorm2d(channels))
        
        self.conv_block = nn.Sequential(*layers)
        self.activation = self._get_activation(activation)
    
    def _get_activation(self, name: str) -> nn.Module:
        if name.lower() == 'relu':
            return nn.ReLU()
        elif name.lower() == 'swish' or name.lower() == 'silu':
            return nn.SiLU()
        elif name.lower() == 'gelu':
            return nn.GELU()
        elif name.lower() == 'tanh':
            return nn.Tanh()
        else:
            raise ValueError(f"Unknown activation: {name}")
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor, shape (batch, channels, height, width)
        Returns:
            Output tensor, same shape
        """
        return self.activation(x + self.conv_block(x))


class ConvBackbone1D(nn.Module):
    """
    1D Convolutional backbone for 1D problems (Poisson, Burgers).
    
    Processes aggregated latent representation on the grid using
    a stack of residual blocks.
    """
    
    def __init__(
        self,
        input_channels: int,
        hidden_channels: int = 64,
        num_blocks: int = 6,
        kernel_size: int = 3,
        activation: str = 'swish',
        use_batch_norm: bool = False
    ):
        """
        Args:
            input_channels: Number of input channels (latent_dim + density)
            hidden_channels: Number of hidden channels
            num_blocks: Number of residual blocks
            kernel_size: Kernel size for convolutions
            activation: Activation function
            use_batch_norm: Whether to use batch normalization
        """
        super().__init__()
        
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.num_blocks = num_blocks
        
        # Initial projection to hidden dimension
        self.input_projection = nn.Conv1d(
            input_channels, hidden_channels, kernel_size=1
        )
        
        # Stack of residual blocks
        self.residual_blocks = nn.ModuleList([
            ResidualBlock(
                channels=hidden_channels,
                kernel_size=kernel_size,
                activation=activation,
                use_batch_norm=use_batch_norm
            )
            for _ in range(num_blocks)
        ])
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Process latent grid representation.
        
        Args:
            x: Input tensor, shape (batch, input_channels, grid_size)
        
        Returns:
            Processed features, shape (batch, hidden_channels, grid_size)
        """
        # Project to hidden dimension
        h = self.input_projection(x)
        
        # Apply residual blocks
        for block in self.residual_blocks:
            h = block(h)
        
        return h


class ConvBackbone2D(nn.Module):
    """
    2D Convolutional backbone for 2D problems (Navier-Stokes).
    """
    
    def __init__(
        self,
        input_channels: int,
        hidden_channels: int = 64,
        num_blocks: int = 6,
        kernel_size: int = 3,
        activation: str = 'swish',
        use_batch_norm: bool = False
    ):
        """
        Args:
            input_channels: Number of input channels
            hidden_channels: Number of hidden channels
            num_blocks: Number of residual blocks
            kernel_size: Kernel size for convolutions
            activation: Activation function
            use_batch_norm: Whether to use batch normalization
        """
        super().__init__()
        
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.num_blocks = num_blocks
        
        # Initial projection
        self.input_projection = nn.Conv2d(
            input_channels, hidden_channels, kernel_size=1
        )
        
        # Residual blocks
        self.residual_blocks = nn.ModuleList([
            ResidualBlock2D(
                channels=hidden_channels,
                kernel_size=kernel_size,
                activation=activation,
                use_batch_norm=use_batch_norm
            )
            for _ in range(num_blocks)
        ])
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Process 2D latent grid.
        
        Args:
            x: Input tensor, shape (batch, input_channels, height, width)
        
        Returns:
            Processed features, shape (batch, hidden_channels, height, width)
        """
        h = self.input_projection(x)
        
        for block in self.residual_blocks:
            h = block(h)
        
        return h

class UNetBackbone2D(nn.Module):
    """
    U-Net style backbone for 2D problems.
    Better for capturing multi-scale features.
    """
    
    def __init__(
        self,
        input_channels: int,
        hidden_channels: int = 64,
        num_levels: int = 3,
        num_blocks_per_level: int = 2,
        kernel_size: int = 3,
        activation: str = 'swish'
    ):
        """
        Args:
            input_channels: Number of input channels
            hidden_channels: Base number of hidden channels
            num_levels: Number of downsampling levels
            num_blocks_per_level: Residual blocks per level
            kernel_size: Kernel size
            activation: Activation function
        """
        super().__init__()
        
        self.num_levels = num_levels
        
        # Initial projection
        self.input_projection = nn.Conv2d(
            input_channels, hidden_channels, kernel_size=1
        )
        
        # Track channel sizes at each level for proper skip connections
        self.encoder_channels = []
        
        # Encoder (downsampling path)
        self.encoder_blocks = nn.ModuleList()
        self.downsample = nn.ModuleList()
        
        channels = hidden_channels
        for level in range(num_levels):
            self.encoder_channels.append(channels)
            
            # Residual blocks at this level
            blocks = nn.ModuleList([
                ResidualBlock2D(channels, kernel_size, activation)
                for _ in range(num_blocks_per_level)
            ])
            self.encoder_blocks.append(blocks)
            
            # Downsampling (except last level)
            if level < num_levels - 1:
                self.downsample.append(
                    nn.Conv2d(channels, channels * 2, kernel_size=2, stride=2)
                )
                channels *= 2
        
        # Decoder (upsampling path)
        self.decoder_blocks = nn.ModuleList()
        self.upsample = nn.ModuleList()
        self.skip_projection = nn.ModuleList()
        
        for level in range(num_levels - 1, 0, -1):
            # Current channels (after upsampling will be channels // 2)
            up_channels = channels // 2
            skip_channels = self.encoder_channels[level - 1]
            
            # Upsampling
            self.upsample.append(
                nn.ConvTranspose2d(channels, up_channels, kernel_size=2, stride=2)
            )
            
            # Project skip connection to match upsampled channels
            # Then concatenate, so total will be up_channels * 2
            self.skip_projection.append(
                nn.Conv2d(skip_channels, up_channels, kernel_size=1)
            )
            
            # After concatenation: up_channels + up_channels = up_channels * 2
            # Project back to up_channels
            combined_channels = up_channels * 2
            
            # Residual blocks (input: combined_channels, output: up_channels)
            blocks = nn.ModuleList([
                nn.Sequential(
                    nn.Conv2d(combined_channels, up_channels, kernel_size=1),
                    ResidualBlock2D(up_channels, kernel_size, activation)
                ) if i == 0 else ResidualBlock2D(up_channels, kernel_size, activation)
                for i in range(num_blocks_per_level)
            ])
            self.decoder_blocks.append(blocks)
            
            channels = up_channels
        
        # Final projection back to hidden_channels
        self.output_projection = nn.Conv2d(
            channels, hidden_channels, kernel_size=1
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        U-Net forward pass.
        
        Args:
            x: Input, shape (batch, input_channels, height, width)
        
        Returns:
            Output, shape (batch, hidden_channels, height, width)
        """
        h = self.input_projection(x)
        
        # Encoder path with skip connections
        skip_connections = []
        
        for level in range(self.num_levels):
            # Apply residual blocks
            for block in self.encoder_blocks[level]:
                h = block(h)
            
            # Save for skip connection (before downsampling)
            if level < self.num_levels - 1:
                skip_connections.append(h)
                h = self.downsample[level](h)
        
        # Decoder path
        skip_connections = skip_connections[::-1]  # Reverse order
        
        for level in range(self.num_levels - 1):
            # Upsample
            h = self.upsample[level](h)
            
            # Process skip connection
            skip = skip_connections[level]
            skip = self.skip_projection[level](skip)
            
            # Concatenate
            h = torch.cat([h, skip], dim=1)
            
            # Apply residual blocks
            for block in self.decoder_blocks[level]:
                h = block(h)
        
        # Final projection
        h = self.output_projection(h)
        
        return h


def build_conv_backbone(
    spatial_dim: int,
    input_channels: int,
    hidden_channels: int = 64,
    num_blocks: int = 6,
    kernel_size: int = 3,
    activation: str = 'swish',
    use_unet: bool = False
) -> nn.Module:
    """
    Factory function to build appropriate backbone for problem dimension.
    
    Args:
        spatial_dim: Spatial dimension (1 or 2)
        input_channels: Number of input channels
        hidden_channels: Number of hidden channels
        num_blocks: Number of blocks
        kernel_size: Kernel size
        activation: Activation function
        use_unet: Whether to use U-Net architecture (2D only)
    
    Returns:
        Convolutional backbone module
    """
    if spatial_dim == 1:
        return ConvBackbone1D(
            input_channels=input_channels,
            hidden_channels=hidden_channels,
            num_blocks=num_blocks,
            kernel_size=kernel_size,
            activation=activation
        )
    elif spatial_dim == 2:
        if use_unet:
            return UNetBackbone2D(
                input_channels=input_channels,
                hidden_channels=hidden_channels,
                num_levels=3,
                num_blocks_per_level=2,
                kernel_size=kernel_size,
                activation=activation
            )
        else:
            return ConvBackbone2D(
                input_channels=input_channels,
                hidden_channels=hidden_channels,
                num_blocks=num_blocks,
                kernel_size=kernel_size,
                activation=activation
            )
    else:
        raise ValueError(f"Unsupported spatial dimension: {spatial_dim}")