import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
import numpy as np
from typing import Tuple, Optional
import pywt # Added for WaveletExpert

class FourierExpert(nn.Module):
    """
    Fourier Expert for global frequency domain processing.
    
    Applies FFT-based global context modeling to capture
    long-range spatial dependencies and frequency patterns.
    """
    
    def __init__(self, channels: int, reduction_ratio: int = 4):
        super().__init__()
        self.channels = channels
        self.reduction_ratio = reduction_ratio
        
        # Frequency domain processing
        self.freq_proj = nn.Sequential(
            nn.Conv2d(channels * 2, channels // reduction_ratio, 1, 1, 0),  # Complex -> Real projection
            nn.GELU(),
            nn.Conv2d(channels // reduction_ratio, channels, 1, 1, 0)
        )
        
        # Learnable frequency filters
        self.low_freq_filter = nn.Parameter(torch.randn(1, channels, 1, 1) * 0.1)
        self.high_freq_filter = nn.Parameter(torch.randn(1, channels, 1, 1) * 0.1)
        
        # Spatial attention for frequency features
        self.spatial_attn = nn.Sequential(
            nn.Conv2d(channels, 1, 7, 1, 3),
            nn.Sigmoid()
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply Fourier-based global processing."""
        B, C, H, W = x.shape
        
        # Apply FFT
        x_fft = torch.fft.fft2(x, dim=(-2, -1))
        x_fft_shifted = torch.fft.fftshift(x_fft, dim=(-2, -1))
        
        # Separate real and imaginary parts
        x_real = x_fft_shifted.real
        x_imag = x_fft_shifted.imag
        
        # Combine real and imaginary for processing
        x_complex = torch.cat([x_real, x_imag], dim=1)  # [B, 2C, H, W]
        
        # Process in frequency domain
        freq_features = self.freq_proj(x_complex)
        
        # Apply learnable frequency filters
        center_h, center_w = H // 2, W // 2
        
        # Create frequency masks
        y, x_coord = torch.meshgrid(
            torch.arange(H, device=x.device), 
            torch.arange(W, device=x.device),
            indexing='ij'
        )
        
        # Distance from center (frequency magnitude)
        freq_dist = torch.sqrt((y - center_h) ** 2 + (x_coord - center_w) ** 2)
        freq_dist = freq_dist.float() / max(H, W)  # Normalize
        
        # Low and high frequency masks
        low_freq_mask = torch.exp(-freq_dist * 3)  # Gaussian low-pass
        high_freq_mask = 1 - low_freq_mask
        
        # Apply frequency filtering
        low_freq_features = freq_features * low_freq_mask.unsqueeze(0).unsqueeze(0)
        high_freq_features = freq_features * high_freq_mask.unsqueeze(0).unsqueeze(0)
        
        # Weighted combination
        filtered_features = (
            low_freq_features * self.low_freq_filter + 
            high_freq_features * self.high_freq_filter
        )
        
        # Convert back to spatial domain
        # For simplicity, we'll treat filtered_features as if it's the real part
        # and set imaginary part to zero for the inverse transform
        filtered_fft = torch.complex(filtered_features, torch.zeros_like(filtered_features))
        filtered_fft_shifted = torch.fft.ifftshift(filtered_fft, dim=(-2, -1))
        
        spatial_output = torch.fft.ifft2(filtered_fft_shifted, dim=(-2, -1)).real
        
        # Apply spatial attention
        spatial_weights = self.spatial_attn(spatial_output)
        spatial_output = spatial_output * spatial_weights
        
        return spatial_output

class WaveletExpert(nn.Module):
    """
    Wavelet Expert for multi-scale decomposition and reconstruction.
    
    Uses Discrete Wavelet Transform (DWT) to capture multi-scale
    patterns and edge information at different resolutions.
    """
    
    def __init__(
        self, 
        channels: int, 
        wavelet: str = 'db4',
        levels: int = 2
    ):
        super().__init__()
        self.channels = channels
        self.wavelet = wavelet
        self.levels = levels
        
        # Learnable wavelet coefficients processing
        self.ll_processor = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(channels, channels, 3, 1, 1),
                nn.GELU(),
                nn.Conv2d(channels, channels, 3, 1, 1)
            ) for _ in range(levels)
        ])
        
        self.lh_processor = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(channels, channels, 3, 1, 1),
                nn.GELU(),
                nn.Conv2d(channels, channels, 3, 1, 1)
            ) for _ in range(levels)
        ])
        
        self.hl_processor = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(channels, channels, 3, 1, 1),
                nn.GELU(),
                nn.Conv2d(channels, channels, 3, 1, 1)
            ) for _ in range(levels)
        ])
        
        self.hh_processor = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(channels, channels, 3, 1, 1),
                nn.GELU(),
                nn.Conv2d(channels, channels, 3, 1, 1)
            ) for _ in range(levels)
        ])
        
        # Scale attention weights
        self.scale_attention = nn.Parameter(torch.ones(levels + 1) / (levels + 1))
        
        # Cross-scale fusion
        self.fusion_conv = nn.Conv2d(channels * (levels + 1), channels, 1, 1, 0)
    
    def dwt_2d(self, x: torch.Tensor) -> Tuple[torch.Tensor, list]:
        """Apply 2D DWT using PyWavelets."""
        B, C, H, W = x.shape
        
        # Convert to numpy for wavelet transform
        x_np = x.detach().cpu().numpy()
        
        coeffs_list = []
        ll = x_np
        
        for level in range(self.levels):
            # Apply DWT to each channel
            coeffs_level = []
            ll_new = []
            lh_new = []
            hl_new = []
            hh_new = []
            
            for b in range(B):
                for c in range(C):
                    coeffs = pywt.dwt2(ll[b, c], self.wavelet, mode='periodization')
                    ll_ch, (lh_ch, hl_ch, hh_ch) = coeffs
                    
                    if b == 0 and c == 0:
                        ll_new = np.zeros((B, C) + ll_ch.shape)
                        lh_new = np.zeros((B, C) + lh_ch.shape)
                        hl_new = np.zeros((B, C) + hl_ch.shape)
                        hh_new = np.zeros((B, C) + hh_ch.shape)
                    
                    ll_new[b, c] = ll_ch
                    lh_new[b, c] = lh_ch
                    hl_new[b, c] = hl_ch
                    hh_new[b, c] = hh_ch
            
            # Convert back to tensors
            ll_tensor = torch.from_numpy(ll_new).float().to(x.device)
            lh_tensor = torch.from_numpy(lh_new).float().to(x.device)
            hl_tensor = torch.from_numpy(hl_new).float().to(x.device)
            hh_tensor = torch.from_numpy(hh_new).float().to(x.device)
            
            coeffs_list.append((ll_tensor, lh_tensor, hl_tensor, hh_tensor))
            ll = ll_new
        
        return ll_tensor, coeffs_list
    
    def idwt_2d(self, ll: torch.Tensor, coeffs_list: list) -> torch.Tensor:
        """Apply inverse 2D DWT."""
        B, C = ll.shape[:2]
        
        # Start reconstruction from the deepest level
        ll_np = ll.detach().cpu().numpy()
        
        for level in range(len(coeffs_list) - 1, -1, -1):
            ll_tensor, lh_tensor, hl_tensor, hh_tensor = coeffs_list[level]
            
            lh_np = lh_tensor.detach().cpu().numpy()
            hl_np = hl_tensor.detach().cpu().numpy()
            hh_np = hh_tensor.detach().cpu().numpy()
            
            # Reconstruct
            reconstructed = []
            for b in range(B):
                for c in range(C):
                    coeffs = (ll_np[b, c], (lh_np[b, c], hl_np[b, c], hh_np[b, c]))
                    recon_ch = pywt.idwt2(coeffs, self.wavelet, mode='periodization')
                    
                    if b == 0 and c == 0:
                        reconstructed = np.zeros((B, C) + recon_ch.shape)
                    
                    reconstructed[b, c] = recon_ch
            
            ll_np = reconstructed
        
        return torch.from_numpy(ll_np).float().to(ll.device)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply wavelet-based multi-scale processing."""
        # Decompose using DWT
        ll, coeffs_list = self.dwt_2d(x)
        
        # Process coefficients at each level
        processed_coeffs = []
        multi_scale_features = []
        
        for level, (ll_level, lh_level, hl_level, hh_level) in enumerate(coeffs_list):
            # Process each subband
            ll_processed = self.ll_processor[level](ll_level)
            lh_processed = self.lh_processor[level](lh_level)
            hl_processed = self.hl_processor[level](hl_level)
            hh_processed = self.hh_processor[level](hh_level)
            
            processed_coeffs.append((ll_processed, lh_processed, hl_processed, hh_processed))
            
            # Collect features at this scale (upsample to original size)
            scale_feature = ll_processed + lh_processed + hl_processed + hh_processed
            scale_feature = F.interpolate(scale_feature, size=x.shape[-2:], mode='bilinear', align_corners=False)
            multi_scale_features.append(scale_feature)
        
        # Add the finest level (original resolution approximation)
        ll_processed = self.ll_processor[-1](ll) if ll.shape[-2:] != x.shape[-2:] else ll
        ll_upsampled = F.interpolate(ll_processed, size=x.shape[-2:], mode='bilinear', align_corners=False)
        multi_scale_features.append(ll_upsampled)
        
        # Apply scale attention
        weighted_features = []
        for i, feat in enumerate(multi_scale_features):
            weight = torch.softmax(self.scale_attention, dim=0)[i]
            weighted_features.append(feat * weight)
        
        # Fuse multi-scale features
        fused_features = torch.cat(weighted_features, dim=1)
        output = self.fusion_conv(fused_features)
        
        return output

class FrequencyAnalyzer(nn.Module):
    """
    Advanced frequency domain analyzer for blur pattern recognition.
    
    Combines Fourier and Wavelet analysis for comprehensive
    frequency domain understanding of blur characteristics.
    """
    
    def __init__(self, channels: int):
        super().__init__()
        self.channels = channels
        
        # Fourier-based blur detector
        self.fourier_analyzer = nn.Sequential(
            nn.Conv2d(channels, channels // 4, 1, 1, 0),
            nn.GELU(),
            nn.Conv2d(channels // 4, 1, 1, 1, 0)
        )
        
        # Wavelet-based edge detector
        self.wavelet_analyzer = nn.Sequential(
            nn.Conv2d(channels, channels // 4, 1, 1, 0),
            nn.GELU(),
            nn.Conv2d(channels // 4, 1, 1, 1, 0)
        )
        
        # Fusion layer
        self.fusion = nn.Sequential(
            nn.Conv2d(2, 16, 3, 1, 1),
            nn.GELU(),
            nn.Conv2d(16, 1, 1, 1, 0),
            nn.Sigmoid()
        )
    
    def forward(self, fourier_features: torch.Tensor, wavelet_features: torch.Tensor) -> torch.Tensor:
        """Analyze frequency characteristics for blur estimation."""
        # Analyze Fourier features for blur patterns
        fourier_blur = self.fourier_analyzer(fourier_features)
        
        # Analyze Wavelet features for edge information
        wavelet_edges = self.wavelet_analyzer(wavelet_features)
        
        # Combine analyses
        combined = torch.cat([fourier_blur, wavelet_edges], dim=1)
        blur_confidence = self.fusion(combined)
        
        return blur_confidence 