import torch
import torch.nn as nn
import torch.nn.functional as F
from hybrid_ada_depth_model import (
    NAFBlock, LayerNorm2d, DPTHead, ReversibleDecoder
)
from utils import resize_image
import math

class DINOv2EncoderFixed(nn.Module):
    def __init__(self, model_name='vitl', patch_size=14):
        super().__init__()
        self.model_name = model_name
        self.patch_size = patch_size
        
        if model_name == 'vits':
            self.embed_dim = 384
            self.num_heads = 6
            self.num_layers = 12
        elif model_name == 'vitb':
            self.embed_dim = 768
            self.num_heads = 12
            self.num_layers = 12
        elif model_name == 'vitl':
            self.embed_dim = 1024
            self.num_heads = 16
            self.num_layers = 24
        elif model_name == 'vitg':
            self.embed_dim = 1536
            self.num_heads = 24
            self.num_layers = 40
            
        self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        
        self.blocks = nn.ModuleList([
            TransformerBlock(self.embed_dim, self.num_heads) 
            for _ in range(self.num_layers)
        ])
        
        self.norm = nn.LayerNorm(self.embed_dim)

    def interpolate_pos_embed(self, x):
        B, C, H, W = x.shape
        patch_h, patch_w = H // self.patch_size, W // self.patch_size
        num_patches = patch_h * patch_w
        
        pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
        nn.init.trunc_normal_(pos_embed, std=0.02)
        return pos_embed.to(x.device)

    def get_intermediate_layers(self, x, indices, return_class_token=False):
        B, C, H, W = x.shape
        patch_h, patch_w = H // self.patch_size, W // self.patch_size
        
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        pos_embed = self.interpolate_pos_embed(torch.zeros(B, C, H, W))
        x = x + pos_embed
        
        features = []
        for i, block in enumerate(self.blocks):
            x = block(x)
            if i in indices:
                patch_tokens = x[:, 1:].reshape(B, patch_h, patch_w, -1).permute(0, 3, 1, 2)
                features.append(patch_tokens)
                    
        return features

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

class SplitTransformMerge(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.split = nn.Conv2d(channels, channels * 2, 1, 1, 0)
        self.transform = nn.Sequential(
            nn.GELU(),
            nn.Conv2d(channels * 2, channels * 2, 3, 1, 1, groups=channels * 2),
            nn.GELU()
        )
        self.merge = nn.Conv2d(channels * 2, channels, 1, 1, 0)
    def forward(self, x):
        x = self.split(x)
        x = self.transform(x)
        x = self.merge(x)
        return x

class FourierExpert(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(channels, channels, dtype=torch.cfloat) * 0.02)
    def forward(self, x):
        B, C, H, W = x.shape
        y = torch.fft.rfft2(x, norm='ortho')
        y = torch.einsum('bchw,cd->bdhw', y, self.weight)
        y = torch.fft.irfft2(y, s=(H, W), norm='ortho')
        return y

class WaveletExpert(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.dwt = nn.Conv2d(channels, channels * 4, 2, 2, bias=False)
        self.conv = nn.Conv2d(channels * 4, channels * 4, 3, 1, 1, groups=channels * 4)
        self.idwt = nn.ConvTranspose2d(channels * 4, channels, 2, 2, bias=False)
    def forward(self, x):
        y = self.dwt(x)
        y = self.conv(y)
        y = self.idwt(y)
        return y

class BlurGatedMixer(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels * 3, channels, 1, 1, 0)
    def forward(self, x_core, x_fft, x_wav, blur_map):
        gate = torch.sigmoid(F.adaptive_avg_pool2d(blur_map, 1))
        fused = torch.cat([x_core * (1 - gate), x_fft * gate, x_wav], dim=1)
        return self.conv(fused)

class BlurGuidedCrossAttention(nn.Module):
    def __init__(self, channels, heads=4):
        super().__init__()
        self.heads = heads
        self.scale = (channels // heads) ** -0.5
        self.qkv = nn.Conv2d(channels + 1, channels * 3, 1, 1, 0)
        self.proj = nn.Conv2d(channels, channels, 1, 1, 0)
    def forward(self, x, blur_map):
        B, C, H, W = x.shape
        blur = F.interpolate(blur_map, size=(H, W), mode='bilinear', align_corners=False)
        inp = torch.cat([x, blur], dim=1)
        qkv = self.qkv(inp).reshape(B, 3, self.heads, C // self.heads, H * W).permute(1, 0, 2, 4, 3)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(-2, -1).reshape(B, C, H, W)
        return self.proj(out)

class BlurBasedEarlyExit(nn.Module):
    def __init__(self, channels, threshold=0.3):
        super().__init__()
        self.threshold = threshold
        self.exit_classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 4, 1),
            nn.ReLU(),
            nn.Conv2d(channels // 4, 2, 1),
            nn.Flatten()
        )
        
    def forward(self, x, blur_map):
        # Calculate average blur severity
        avg_blur = F.adaptive_avg_pool2d(blur_map, 1).squeeze()
        
        # Get exit confidence
        exit_logits = self.exit_classifier(x)
        exit_probs = F.softmax(exit_logits, dim=1)
        
        # Decision: exit if blur is low and confidence is high
        should_exit = (avg_blur < self.threshold) & (exit_probs[:, 1] > 0.7)
        
        return should_exit, exit_probs

class AdaptiveDRAMBlock(nn.Module):
    def __init__(self, channels, enable_early_exit=True, exit_threshold=0.3):
        super().__init__()
        self.enable_early_exit = enable_early_exit
        self.stm = SplitTransformMerge(channels)
        self.fft_expert = FourierExpert(channels)
        self.wav_expert = WaveletExpert(channels)
        self.mixer = BlurGatedMixer(channels)
        self.attn = BlurGuidedCrossAttention(channels)
        self.norm = LayerNorm2d(channels)
        
        if enable_early_exit:
            self.early_exit = BlurBasedEarlyExit(channels, exit_threshold)
        
    def forward(self, x, blur_map, return_exit_info=False):
        # Standard DRAMBlock processing
        core = self.stm(x)
        fft_out = self.fft_expert(core)
        wav_out = self.wav_expert(core)
        mixed = self.mixer(core, fft_out, wav_out, blur_map)
        attn_out = self.attn(mixed, blur_map)
        output = self.norm(x + attn_out)
        
        exit_info = None
        if self.enable_early_exit:
            should_exit, exit_probs = self.early_exit(output, blur_map)
            exit_info = {
                'should_exit': should_exit,
                'exit_probs': exit_probs,
                'avg_blur': F.adaptive_avg_pool2d(blur_map, 1).squeeze()
            }
        
        if return_exit_info:
            return output, exit_info
        return output

class BMEBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, 1, 1, groups=in_channels),
            nn.GELU(),
            nn.Conv2d(in_channels, 1, 1)
        )
    def forward(self, x):
        return self.block(x)

class BlurMapEstimation(nn.Module):
    def __init__(self, encoder_channels=(64,128,256,512)):
        super().__init__()
        self.blocks = nn.ModuleList([BMEBlock(c) for c in encoder_channels])
        self.fuse = nn.Conv2d(len(encoder_channels), 1, 1)
    def forward(self, features):
        maps = [F.interpolate(self.blocks[i](f), size=features[0].shape[-2:], mode='bilinear', align_corners=False) for i, f in enumerate(features)]
        blur_map = self.fuse(torch.cat(maps, dim=1))
        return blur_map

class SimpleDecoder(nn.Module):
    def __init__(self, width=64):
        super().__init__()
        channels = [width * 8, width * 4, width * 2, width]
        
        self.ups = nn.ModuleList()
        self.decoders = nn.ModuleList()
        
        for i in range(len(channels) - 1):
            self.ups.append(nn.ConvTranspose2d(channels[i], channels[i+1], 2, 2))
            self.decoders.append(NAFBlock(channels[i+1]))
            
        self.final_conv = nn.Conv2d(width, 3, 3, 1, 1)

    def forward(self, x, skip_connections):
        for i, (up, decoder) in enumerate(zip(self.ups, self.decoders)):
            x = up(x)
            if i < len(skip_connections):
                x = x + skip_connections[-(i+1)]
            x = decoder(x)
            
        return self.final_conv(x)

class DRAMNet(nn.Module):
    def __init__(self, width=64, depth_encoder='vitl', num_dram_blocks=4, enable_early_exit=True, exit_threshold=0.3):
        super().__init__()
        self.enable_early_exit = enable_early_exit
        self.num_dram_blocks = num_dram_blocks
        
        self.encoder = DINOv2EncoderFixed(model_name=depth_encoder)
        self.depth_head = DPTHead(self.encoder.embed_dim)
        self.initial_conv = nn.Conv2d(3, width, 3, 1, 1)
        encoder_channels = [width, width*2, width*4, width*8]
        self.enc_blocks = nn.ModuleList([nn.Sequential(*[NAFBlock(c) for _ in range(1)]) for c in encoder_channels])
        self.downs = nn.ModuleList([nn.Conv2d(encoder_channels[i], encoder_channels[i]*2, 2,2) for i in range(3)])
        
        # Adaptive DRAM blocks with early exit
        self.dram_blocks = nn.ModuleList([
            AdaptiveDRAMBlock(encoder_channels[-1], enable_early_exit, exit_threshold)
            for _ in range(num_dram_blocks)
        ])
        
        # Early exit decoders for different stages
        if enable_early_exit:
            self.early_decoders = nn.ModuleList([
                SimpleDecoder(width) for _ in range(num_dram_blocks)
            ])
        
        self.decoder = SimpleDecoder(width)
        self.bme = BlurMapEstimation(encoder_channels)
        
    def forward(self, x, return_exit_info=False):
        # Depth processing
        depth_feat = self.encoder.get_intermediate_layers(x, [4,11,17,23], return_class_token=True)
        patch_h, patch_w = x.shape[-2]//14, x.shape[-1]//14
        depth_map = self.depth_head(depth_feat, patch_h, patch_w)
        depth_map = F.interpolate(depth_map, size=x.shape[-2:], mode='bilinear', align_corners=False)
        
        # Encoder processing
        h = self.initial_conv(x)
        feats = []
        for enc, down in zip(self.enc_blocks[:-1], self.downs):
            h = enc(h)
            feats.append(h)
            h = down(h)
        h = self.enc_blocks[-1](h)
        feats.append(h)
        
        # Blur map estimation
        blur_map = self.bme(feats)
        
        # Adaptive DRAM processing with early exit
        exit_decisions = []
        early_outputs = []
        
        for i, dram_block in enumerate(self.dram_blocks):
            if self.enable_early_exit and return_exit_info:
                h, exit_info = dram_block(h, blur_map, return_exit_info=True)
                exit_decisions.append(exit_info)
                
                # Generate early output if we might exit
                if exit_info['should_exit'].any():
                    early_out = self.early_decoders[i](h, feats[:-1])
                    early_outputs.append(early_out)
                else:
                    early_outputs.append(None)
                    
                # Actually exit during inference (not training)
                if not self.training and exit_info['should_exit'].all():
                    final_output = early_out
                    result = {
                        'output': final_output,
                        'depth': depth_map,
                        'blur_map': blur_map,
                        'exit_stage': i,
                        'exit_decisions': exit_decisions,
                        'early_outputs': early_outputs
                    }
                    return result
            else:
                h = dram_block(h, blur_map)
        
        # Final decoder
        final_output = self.decoder(h, feats[:-1])
        
        result = {
            'output': final_output,
            'depth': depth_map,
            'blur_map': blur_map
        }
        
        if return_exit_info:
            result.update({
                'exit_stage': len(self.dram_blocks),  # No early exit
                'exit_decisions': exit_decisions,
                'early_outputs': early_outputs
            })
            
        return result 

# Legacy alias for backward compatibility
DRAMBlock = AdaptiveDRAMBlock 