import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
import pywt
from typing import Dict, Any, Optional

from .attention import BlurGuidedCrossAttention
from .frequency_experts import FourierExpert, WaveletExpert
from .fusion import BlurGatedMixer
from ..utils.layer_utils import LayerNorm2d

class SplitTransformMerge(nn.Module):
    """Split-Transform-Merge module for local feature mixing."""
    
    def __init__(self, channels: int, split_ratio: float = 0.5):
        super().__init__()
        self.channels = channels
        self.split_channels = int(channels * split_ratio)
        self.remaining_channels = channels - self.split_channels
        
        # Transform branch
        self.transform_conv1 = nn.Conv2d(self.split_channels, self.split_channels, 3, 1, 1)
        self.transform_conv2 = nn.Conv2d(self.split_channels, self.split_channels, 3, 1, 1)
        self.transform_norm = LayerNorm2d(self.split_channels)
        
        # Merge projection
        self.merge_conv = nn.Conv2d(channels, channels, 1, 1, 0)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply split-transform-merge operation."""
        # Split
        x_transform = x[:, :self.split_channels]
        x_identity = x[:, self.split_channels:]
        
        # Transform
        x_transform = self.transform_conv1(x_transform)
        x_transform = F.gelu(x_transform)
        x_transform = self.transform_norm(x_transform)
        x_transform = self.transform_conv2(x_transform)
        
        # Merge
        x_merged = torch.cat([x_transform, x_identity], dim=1)
        x_merged = self.merge_conv(x_merged)
        
        return x + x_merged  # Residual connection

class DRAMBlock(nn.Module):
    """
    DRAM Block: Dual-branch Restoration with Adaptive Multi-scale processing.
    
    Combines:
    - Split-Transform-Merge for local processing
    - Fourier Expert for global frequency analysis
    - Wavelet Expert for multi-scale decomposition
    - Blur-guided attention for adaptive feature selection
    """
    
    def __init__(
        self,
        channels: int,
        num_heads: int = 8,
        window_size: int = 8,
        enable_frequency_experts: bool = True,
        enable_wavelet: bool = True
    ):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads
        self.enable_frequency_experts = enable_frequency_experts
        self.enable_wavelet = enable_wavelet
        
        # Split-Transform-Merge for local processing
        self.stm = SplitTransformMerge(channels)
        
        # Frequency domain experts
        if enable_frequency_experts:
            self.fourier_expert = FourierExpert(channels)
            
        if enable_wavelet:
            self.wavelet_expert = WaveletExpert(channels)
        
        # Blur-guided cross attention
        self.blur_attention = BlurGuidedCrossAttention(
            channels=channels,
            num_heads=num_heads,
            window_size=window_size
        )
        
        # Feature fusion
        num_experts = 1  # STM always present
        if enable_frequency_experts:
            num_experts += 1
        if enable_wavelet:
            num_experts += 1
            
        self.feature_mixer = BlurGatedMixer(
            channels=channels,
            num_branches=num_experts
        )
        
        # Normalization layers
        self.norm1 = LayerNorm2d(channels)
        self.norm2 = LayerNorm2d(channels)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Conv2d(channels, channels * 2, 1, 1, 0),
            nn.GELU(),
            nn.Conv2d(channels * 2, channels, 1, 1, 0)
        )
    
    def forward(
        self,
        x: torch.Tensor,
        blur_map: Optional[torch.Tensor] = None,
        depth_features: Optional[list] = None
    ) -> torch.Tensor:
        """Forward pass through DRAM block."""
        # Normalize input
        x_norm = self.norm1(x)
        
        # Collect expert outputs
        expert_outputs = []
        
        # 1. Split-Transform-Merge (always active)
        stm_output = self.stm(x_norm)
        expert_outputs.append(stm_output)
        
        # 2. Fourier Expert (frequency domain)
        if self.enable_frequency_experts:
            fourier_output = self.fourier_expert(x_norm)
            expert_outputs.append(fourier_output)
        
        # 3. Wavelet Expert (multi-scale)
        if self.enable_wavelet:
            wavelet_output = self.wavelet_expert(x_norm)
            expert_outputs.append(wavelet_output)
        
        # Blur-guided feature mixing
        if blur_map is not None:
            mixed_features = self.feature_mixer(expert_outputs, blur_map)
        else:
            # Simple averaging if no blur map
            mixed_features = torch.stack(expert_outputs).mean(dim=0)
        
        # Apply blur-guided cross attention
        if blur_map is not None and depth_features is not None:
            attended_features = self.blur_attention(mixed_features, blur_map, depth_features)
        else:
            attended_features = mixed_features
        
        # Residual connection
        x = x + attended_features
        
        # Feed-forward network
        x_norm2 = self.norm2(x)
        ffn_output = self.ffn(x_norm2)
        x = x + ffn_output
        
        return x

class AdaptiveDRAMBlock(nn.Module):
    """
    Adaptive DRAM Block with early exit capability.
    
    Extends DRAMBlock with:
    - Confidence estimation for early exit decisions
    - Adaptive processing based on content complexity
    - Per-stage output generation capability
    """
    
    def __init__(
        self,
        channels: int,
        num_heads: int = 8,
        window_size: int = 8,
        enable_early_exit: bool = True,
        confidence_threshold: float = 0.8
    ):
        super().__init__()
        self.channels = channels
        self.enable_early_exit = enable_early_exit
        self.confidence_threshold = confidence_threshold
        
        # Core DRAM block
        self.dram_block = DRAMBlock(
            channels=channels,
            num_heads=num_heads,
            window_size=window_size
        )
        
        # Confidence estimation head
        if enable_early_exit:
            self.confidence_head = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(channels, channels // 4, 1, 1, 0),
                nn.GELU(),
                nn.Conv2d(channels // 4, 1, 1, 1, 0),
                nn.Sigmoid()
            )
        
        # Content complexity analyzer
        self.complexity_analyzer = nn.Sequential(
            nn.Conv2d(channels, channels // 8, 3, 1, 1),
            nn.GELU(),
            nn.Conv2d(channels // 8, 1, 1, 1, 0),
            nn.Sigmoid()
        )
    
    def forward(
        self,
        x: torch.Tensor,
        blur_map: Optional[torch.Tensor] = None,
        depth_features: Optional[list] = None,
        return_exit_info: bool = False
    ) -> tuple:
        """
        Forward pass with optional early exit information.
        
        Returns:
            - processed_features: Output features
            - exit_info: Dictionary with exit-related information (if requested)
        """
        # Process through DRAM block
        processed_features = self.dram_block(x, blur_map, depth_features)
        
        exit_info = {}
        
        if self.enable_early_exit and return_exit_info:
            # Estimate confidence
            confidence = self.confidence_head(processed_features).mean()
            
            # Analyze content complexity
            complexity = self.complexity_analyzer(processed_features).mean()
            
            # Combine blur severity (if available)
            if blur_map is not None:
                avg_blur = torch.mean(torch.abs(blur_map))
                complexity_score = (complexity + avg_blur) / 2
            else:
                complexity_score = complexity
            
            exit_info = {
                'confidence': confidence.item(),
                'complexity': complexity_score.item(),
                'should_exit': confidence > self.confidence_threshold and complexity_score < 0.3
            }
        
        if return_exit_info:
            return processed_features, exit_info
        else:
            return processed_features
    
    def get_exit_decision(
        self,
        features: torch.Tensor,
        blur_map: Optional[torch.Tensor] = None,
        stage: int = 0
    ) -> Dict[str, Any]:
        """Get detailed exit decision information."""
        if not self.enable_early_exit:
            return {'should_exit': False, 'reason': 'early_exit_disabled'}
        
        confidence = self.confidence_head(features).mean().item()
        complexity = self.complexity_analyzer(features).mean().item()
        
        decision_factors = {
            'confidence': confidence,
            'complexity': complexity,
            'stage': stage,
            'confidence_threshold': self.confidence_threshold
        }
        
        # Decision logic
        if confidence > self.confidence_threshold:
            if complexity < 0.3:  # Low complexity
                decision = True
                reason = 'high_confidence_low_complexity'
            elif stage >= 2:  # Allow exit after minimum stages
                decision = True
                reason = 'high_confidence_sufficient_stages'
            else:
                decision = False
                reason = 'insufficient_processing_stages'
        else:
            decision = False
            reason = 'low_confidence'
        
        # Include blur information if available
        if blur_map is not None:
            avg_blur = torch.mean(torch.abs(blur_map)).item()
            decision_factors['blur_severity'] = avg_blur
            
            # Adjust decision based on blur
            if avg_blur > 0.5 and decision:  # High blur, reconsider exit
                decision = False
                reason = 'high_blur_severity'
        
        return {
            'should_exit': decision,
            'reason': reason,
            'factors': decision_factors
        } 