"""
Neural network models for regular grids

This module provides UNet and Fourier Neural Operator (FNO) 
implementations for solving PDEs on regular grids.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple, List

from .base import Activation


class UNetBlock(nn.Module):
    """Basic UNet convolutional block"""
    
    def __init__(self, in_channels: int, out_channels: int, 
                 activation: str = "relu", use_batchnorm: bool = True):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        
        self.activation = Activation(activation)
        
        self.bn1 = nn.BatchNorm2d(out_channels) if use_batchnorm else nn.Identity()
        self.bn2 = nn.BatchNorm2d(out_channels) if use_batchnorm else nn.Identity()
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.activation(x)
        
        return x


class UNet(nn.Module):
    """
    UNet architecture for regular grids
    
    This implementation follows the classical UNet architecture with
    an encoder-decoder structure and skip connections.
    """
    
    def __init__(self, in_channels: int, out_channels: int, 
                 features: List[int] = [64, 128, 256, 512],
                 activation: str = "relu", use_batchnorm: bool = True,
                 dropout: float = 0.0):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.features = features
        self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
        
        # Encoder (downsampling path)
        self.encoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # First encoder block
        self.encoder.append(UNetBlock(in_channels, features[0], activation, use_batchnorm))
        
        # Remaining encoder blocks
        for feature in features[1:]:
            self.encoder.append(UNetBlock(features[features.index(feature)-1], 
                                        feature, activation, use_batchnorm))
        
        # Bottleneck
        self.bottleneck = UNetBlock(features[-1], features[-1]*2, activation, use_batchnorm)
        
        # Decoder (upsampling path)
        self.decoder = nn.ModuleList()
        self.upconvs = nn.ModuleList()
        
        # Decoder blocks
        reversed_features = features[::-1]
        for i, feature in enumerate(reversed_features):
            if i == 0:
                # First decoder block (from bottleneck)
                self.upconvs.append(
                    nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)
                )
                self.decoder.append(
                    UNetBlock(feature*2, feature, activation, use_batchnorm)
                )
            else:
                # Subsequent decoder blocks
                prev_feature = reversed_features[i-1]
                self.upconvs.append(
                    nn.ConvTranspose2d(prev_feature, feature, kernel_size=2, stride=2)
                )
                self.decoder.append(
                    UNetBlock(prev_feature + feature, feature, activation, use_batchnorm)
                )
        
        # Final output layer
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
    
    def forward(self, x):
        """
        Parameters:
        -----------
            x: torch.FloatTensor [batch_size, in_channels, height, width]
        Returns:
        --------
            y: torch.FloatTensor [batch_size, out_channels, height, width]
        """
        # Store skip connections
        skip_connections = []
        
        # Encoder path
        for encoder_block in self.encoder:
            x = encoder_block(x)
            skip_connections.append(x)
            x = self.pool(x)
            x = self.dropout(x)
        
        # Bottleneck
        x = self.bottleneck(x)
        
        # Decoder path
        skip_connections = skip_connections[::-1]  # Reverse for decoder
        
        for i, (upconv, decoder_block) in enumerate(zip(self.upconvs, self.decoder)):
            x = upconv(x)
            
            # Handle potential size mismatch
            skip_connection = skip_connections[i]
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:], mode='bilinear', align_corners=False)
            
            # Concatenate skip connection
            x = torch.cat([skip_connection, x], dim=1)
            x = decoder_block(x)
        
        # Final output
        x = self.final_conv(x)
        return x


class SpectralConv2d(nn.Module):
    """2D Spectral convolution for Fourier Neural Operator"""
    
    def __init__(self, in_channels: int, out_channels: int, modes1: int, modes2: int):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  # Number of Fourier modes in first dimension
        self.modes2 = modes2  # Number of Fourier modes in second dimension
        
        self.scale = (1 / (in_channels * out_channels))
        
        # Learnable weights for Fourier modes
        self.weights1 = nn.Parameter(
            self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)
        )
        self.weights2 = nn.Parameter(
            self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)
        )
    
    def compl_mul2d(self, input, weights):
        """Complex multiplication for 2D tensors"""
        # Convert last dimension to complex
        input_real, input_imag = input[..., 0], input[..., 1]
        weight_real, weight_imag = weights[..., 0], weights[..., 1]
        
        # Complex multiplication: (a + bi)(c + di) = (ac - bd) + (ad + bc)i
        real = torch.einsum("bixy,ioxy->boxy", input_real, weight_real) - \
               torch.einsum("bixy,ioxy->boxy", input_imag, weight_imag)
        imag = torch.einsum("bixy,ioxy->boxy", input_real, weight_imag) + \
               torch.einsum("bixy,ioxy->boxy", input_imag, weight_real)
        
        return torch.stack([real, imag], dim=-1)
    
    def forward(self, x):
        """
        Parameters:
        -----------
            x: torch.FloatTensor [batch_size, in_channels, height, width]
        Returns:
        --------
            y: torch.FloatTensor [batch_size, out_channels, height, width]
        """
        batch_size = x.shape[0]
        
        # Compute Fourier coefficients
        x_ft = torch.fft.rfft2(x, dim=[-2, -1])
        x_ft = torch.stack([x_ft.real, x_ft.imag], dim=-1)
        
        # Truncate to relevant modes
        out_ft = torch.zeros(batch_size, self.out_channels, x.size(-2), 
                           x.size(-1)//2 + 1, 2, device=x.device, dtype=x.dtype)
        
        # Apply spectral convolution for the relevant modes
        out_ft[:, :, :self.modes1, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)
        
        # Convert back to complex and perform inverse FFT
        out_ft_complex = torch.complex(out_ft[..., 0], out_ft[..., 1])
        x = torch.fft.irfft2(out_ft_complex, s=(x.size(-2), x.size(-1)), dim=[-2, -1])
        
        return x


class FNOBlock(nn.Module):
    """Fourier Neural Operator block"""
    
    def __init__(self, in_channels: int, out_channels: int, 
                 modes1: int, modes2: int, activation: str = "gelu"):
        super().__init__()
        
        self.spectral_conv = SpectralConv2d(in_channels, out_channels, modes1, modes2)
        self.linear = nn.Conv2d(in_channels, out_channels, 1)
        self.activation = Activation(activation)
    
    def forward(self, x):
        """
        Apply spectral convolution and point-wise convolution, then activation
        """
        x1 = self.spectral_conv(x)
        x2 = self.linear(x)
        x = x1 + x2
        x = self.activation(x)
        return x


class FourierNeuralOperator(nn.Module):
    """
    Fourier Neural Operator for learning operators on regular grids
    
    Based on "Fourier Neural Operator for Parametric Partial Differential Equations"
    by Li et al. (2021)
    """
    
    def __init__(self, in_channels: int, out_channels: int,
                 modes1: int = 12, modes2: int = 12, width: int = 64,
                 num_layers: int = 4, activation: str = "gelu",
                 padding: int = 8):
        super().__init__()
        
        self.modes1 = modes1
        self.modes2 = modes2
        self.width = width
        self.padding = padding
        
        # Input projection
        self.fc0 = nn.Linear(in_channels, self.width)
        
        # FNO layers
        self.fno_layers = nn.ModuleList([
            FNOBlock(self.width, self.width, self.modes1, self.modes2, activation)
            for _ in range(num_layers)
        ])
        
        # Output projection
        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, out_channels)
        self.activation = Activation(activation)
    
    def forward(self, x):
        """
        Parameters:
        -----------
            x: torch.FloatTensor [batch_size, in_channels, height, width]
               or [batch_size, height, width, in_channels]
        Returns:
        --------
            y: torch.FloatTensor [batch_size, out_channels, height, width]
               or [batch_size, height, width, out_channels]
        """
        # Handle different input formats
        if x.dim() == 4 and x.shape[1] <= x.shape[-1]:
            # Assume format is [batch, height, width, channels]
            x = x.permute(0, 3, 1, 2)  # Convert to [batch, channels, height, width]
            output_permute = True
        else:
            output_permute = False
        
        # Store original size for later
        original_size = x.shape[-2:]
        
        # Apply padding
        if self.padding > 0:
            x = F.pad(x, (0, self.padding, 0, self.padding))
        
        # Convert to channel-last format for linear layers
        x = x.permute(0, 2, 3, 1)  # [batch, height, width, channels]
        
        # Input projection
        x = self.fc0(x)
        
        # Convert back to channel-first for convolutions
        x = x.permute(0, 3, 1, 2)  # [batch, channels, height, width]
        
        # Apply FNO layers
        for fno_layer in self.fno_layers:
            x = fno_layer(x)
        
        # Remove padding
        if self.padding > 0:
            x = x[..., :-self.padding, :-self.padding]
        
        # Ensure we're back to original size
        if x.shape[-2:] != original_size:
            x = F.interpolate(x, size=original_size, mode='bilinear', align_corners=False)
        
        # Convert to channel-last for output projection
        x = x.permute(0, 2, 3, 1)  # [batch, height, width, channels]
        
        # Output projection
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        
        # Convert back to original format if needed
        if output_permute:
            return x  # Keep as [batch, height, width, channels]
        else:
            return x.permute(0, 3, 1, 2)  # Convert to [batch, channels, height, width]


def create_unet(in_channels: int, out_channels: int, **kwargs) -> UNet:
    """Factory function to create UNet model"""
    return UNet(in_channels, out_channels, **kwargs)


def create_fno(in_channels: int, out_channels: int, **kwargs) -> FourierNeuralOperator:
    """Factory function to create FNO model"""
    return FourierNeuralOperator(in_channels, out_channels, **kwargs)