import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
from typing import List, Tuple, Optional
from collections import OrderedDict

class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2

class LayerNorm2d(nn.Module):
    def __init__(self, channels, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(channels))
        self.bias = nn.Parameter(torch.zeros(channels))
        self.eps = eps

    def forward(self, x):
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x

class NAFBlock(nn.Module):
    def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
        super().__init__()
        dw_channel = c * DW_Expand
        self.conv1 = nn.Conv2d(c, dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
        self.conv2 = nn.Conv2d(dw_channel, dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, bias=True)
        self.conv3 = nn.Conv2d(dw_channel // 2, c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
        
        self.sca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(dw_channel // 2, dw_channel // 2, kernel_size=1, padding=0, stride=1, groups=1, bias=True),
        )
        
        self.sg = SimpleGate()
        
        ffn_channel = FFN_Expand * c
        self.conv4 = nn.Conv2d(c, ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
        self.conv5 = nn.Conv2d(ffn_channel // 2, c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)

        self.norm1 = LayerNorm2d(c)
        self.norm2 = LayerNorm2d(c)

        self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
        self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()

        self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

    def forward(self, inp):
        x = inp
        x = self.norm1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.sg(x)
        x = x * self.sca(x)
        x = self.conv3(x)
        x = self.dropout1(x)
        y = inp + x * self.beta

        x = self.conv4(self.norm2(y))
        x = self.sg(x)
        x = self.conv5(x)
        x = self.dropout2(x)

        return y + x * self.gamma

class FeatureFusionBlock(nn.Module):
    def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
        super().__init__()
        self.deconv = deconv
        self.align_corners = align_corners
        
        self.groups = 1
        
        self.expand = expand
        out_features = features
        if self.expand:
            out_features = features // 2
            
        self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups)
        
        self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
        self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
        
        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, *xs, size=None):
        output = xs[0]
        
        if len(xs) == 2:
            res = self.resConfUnit1(xs[1])
            output = output + res

        output = self.resConfUnit2(output)
        
        if size is not None:
            output = F.interpolate(output, size=size, mode="bilinear", align_corners=self.align_corners)
            
        output = self.out_conv(output)
        
        return output

class ResidualConvUnit(nn.Module):
    def __init__(self, features, activation, bn):
        super().__init__()
        
        self.bn = bn
        
        self.groups = 1
        
        self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
        self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
        
        if self.bn:
            self.bn1 = nn.BatchNorm2d(features)
            self.bn2 = nn.BatchNorm2d(features)
            
        self.activation = activation
        
        self.skip_add = nn.quantized.FloatFunctional()

    def forward(self, x):
        out = self.activation(x)
        out = self.conv1(out)
        if self.bn:
            out = self.bn1(out)
            
        out = self.activation(out)
        out = self.conv2(out)
        if self.bn:
            out = self.bn2(out)
            
        return self.skip_add.add(out, x)

class Interpolate(nn.Module):
    def __init__(self, scale_factor, mode, align_corners=False):
        super().__init__()
        self.interp = F.interpolate
        self.scale_factor = scale_factor
        self.mode = mode
        self.align_corners = align_corners
        
    def forward(self, x):
        x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
        return x

def _make_scratch(in_shape, out_shape, groups=1, expand=False):
    scratch = nn.Module()
    
    out_shape1 = out_shape
    out_shape2 = out_shape
    out_shape3 = out_shape
    out_shape4 = out_shape
    if expand:
        out_shape1 = out_shape
        out_shape2 = out_shape * 2
        out_shape3 = out_shape * 4
        out_shape4 = out_shape * 8

    scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
    scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
    scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
    scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)

    return scratch

class DINOv2Encoder(nn.Module):
    def __init__(self, model_name='vitl', img_size=518, patch_size=14):
        super().__init__()
        self.model_name = model_name
        self.img_size = img_size
        self.patch_size = patch_size
        
        if model_name == 'vits':
            self.embed_dim = 384
            self.num_heads = 6
            self.num_layers = 12
        elif model_name == 'vitb':
            self.embed_dim = 768
            self.num_heads = 12
            self.num_layers = 12
        elif model_name == 'vitl':
            self.embed_dim = 1024
            self.num_heads = 16
            self.num_layers = 24
        elif model_name == 'vitg':
            self.embed_dim = 1536
            self.num_heads = 24
            self.num_layers = 40
            
        self.num_patches = (img_size // patch_size) ** 2
        
        self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, self.embed_dim))
        
        self.blocks = nn.ModuleList([
            TransformerBlock(self.embed_dim, self.num_heads) 
            for _ in range(self.num_layers)
        ])
        
        self.norm = nn.LayerNorm(self.embed_dim)

    def forward(self, x, return_class_token=False):
        B, C, H, W = x.shape
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        
        features = []
        for i, block in enumerate(self.blocks):
            x = block(x)
            features.append(x)
            
        x = self.norm(x)
        
        if return_class_token:
            return x[:, 0], x[:, 1:]
        return x[:, 1:]
    
    def get_intermediate_layers(self, x, indices, return_class_token=False):
        B, C, H, W = x.shape
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        
        features = []
        for i, block in enumerate(self.blocks):
            x = block(x)
            if i in indices:
                if return_class_token:
                    features.append(x)
                else:
                    features.append(x[:, 1:])
                    
        return features

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

class DPTHead(nn.Module):
    def __init__(self, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False):
        super().__init__()
        
        self.use_clstoken = use_clstoken
        
        self.projects = nn.ModuleList([
            nn.Conv2d(in_channels, out_channel, kernel_size=1, stride=1, padding=0)
            for out_channel in out_channels
        ])
        
        self.resize_layers = nn.ModuleList([
            nn.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0),
            nn.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0),
            nn.Identity(),
            nn.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1)
        ])
        
        self.scratch = _make_scratch(out_channels, features, groups=1, expand=False)
        
        self.scratch.refinenet1 = FeatureFusionBlock(features, nn.ReLU(False), deconv=False, bn=use_bn, expand=False, align_corners=True)
        self.scratch.refinenet2 = FeatureFusionBlock(features, nn.ReLU(False), deconv=False, bn=use_bn, expand=False, align_corners=True)
        self.scratch.refinenet3 = FeatureFusionBlock(features, nn.ReLU(False), deconv=False, bn=use_bn, expand=False, align_corners=True)
        self.scratch.refinenet4 = FeatureFusionBlock(features, nn.ReLU(False), deconv=False, bn=use_bn, expand=False, align_corners=True)
        
        self.scratch.output_conv = nn.Conv2d(features, 1, kernel_size=3, stride=1, padding=1)
        
    def forward(self, features, patch_h, patch_w):
        out = []
        for i, (feat, proj, resize) in enumerate(zip(features, self.projects, self.resize_layers)):
            if self.use_clstoken:
                x = feat
            else:
                x = feat
                
            if len(x.shape) == 3:
                x = x.permute(0, 2, 1).reshape(x.shape[0], -1, patch_h, patch_w)
                
            x = proj(x)
            x = resize(x)
            out.append(x)
            
        layer_1, layer_2, layer_3, layer_4 = out
        
        layer_1_rn = self.scratch.layer1_rn(layer_1)
        layer_2_rn = self.scratch.layer2_rn(layer_2)
        layer_3_rn = self.scratch.layer3_rn(layer_3)
        layer_4_rn = self.scratch.layer4_rn(layer_4)
        
        path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
        path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
        path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
        
        out = self.scratch.output_conv(path_1)
        
        return out

class ReversibleDecoder(nn.Module):
    def __init__(self, width=64, dec_blk_nums=[1, 1, 1, 1]):
        super().__init__()
        self.width = width
        channels = [width, width * 2, width * 4, width * 8]
        
        self.ups = nn.ModuleList()
        self.decoders = nn.ModuleList()
        
        chan = channels[-1]
        for i in range(len(dec_blk_nums)):
            self.ups.append(
                nn.Sequential(
                    nn.Conv2d(chan, chan * 2, 1, bias=False),
                    nn.PixelShuffle(2)
                )
            )
            chan = chan // 2
            self.decoders.append(
                nn.Sequential(
                    *[NAFBlock(chan) for _ in range(dec_blk_nums[i])]
                )
            )
            
        self.ending = nn.Conv2d(width, 3, 3, 1, 1)

    def forward(self, x, skip_connections):
        skips = skip_connections[::-1]
        
        for i, (up, decoder) in enumerate(zip(self.ups, self.decoders)):
            x = up(x)
            if i < len(skips):
                x = x + skips[i]
            x = decoder(x)
            
        return self.ending(x)

class AdaptivePatchExiting(nn.Module):
    def __init__(self, in_channels, num_classes=2):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(in_channels, in_channels // 4),
            nn.ReLU(),
            nn.Linear(in_channels // 4, num_classes)
        )
        
    def forward(self, x):
        return self.classifier(x)

class HybridAdaDepthModel(nn.Module):
    def __init__(
        self, 
        width=64, 
        img_channel=3, 
        out_channels=3,
        enc_blk_nums=[1, 1, 1, 28],
        dec_blk_nums=[1, 1, 1, 1],
        depth_encoder='vitl',
        depth_features=256,
        depth_out_channels=[256, 512, 1024, 1024],
        exit_threshold=0.5,
        enable_depth=True,
        enable_deblur=True
    ):
        super().__init__()
        
        self.enable_depth = enable_depth
        self.enable_deblur = enable_deblur
        self.exit_threshold = exit_threshold
        
        self.intro = nn.Conv2d(img_channel, width, kernel_size=3, padding=1, stride=1, bias=True)
        
        self.intermediate_layer_idx = {
            'vits': [2, 5, 8, 11],
            'vitb': [2, 5, 8, 11], 
            'vitl': [4, 11, 17, 23], 
            'vitg': [9, 19, 29, 39]
        }
        
        if self.enable_depth:
            self.depth_encoder = DINOv2Encoder(model_name=depth_encoder)
            self.depth_head = DPTHead(
                self.depth_encoder.embed_dim, 
                depth_features, 
                use_bn=False, 
                out_channels=depth_out_channels, 
                use_clstoken=False
            )
        
        if self.enable_deblur:
            self.encoders = nn.ModuleList()
            self.downs = nn.ModuleList()
            
            chan = width
            for i in range(len(enc_blk_nums)):
                self.encoders.append(
                    nn.Sequential(*[NAFBlock(chan) for _ in range(enc_blk_nums[i])])
                )
                self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2))
                chan = chan * 2
                
            self.middle_blks = nn.Sequential(*[NAFBlock(chan) for _ in range(1)])
            
            self.decoder = ReversibleDecoder(width, dec_blk_nums)
            
            self.ape = nn.ModuleList([
                AdaptivePatchExiting(chan // (2**i)) 
                for i in range(len(enc_blk_nums))
            ])
            
        self.fusion_conv = nn.Conv2d(width + (1 if self.enable_depth else 0), width, 3, 1, 1)
        self.final_conv = nn.Conv2d(width, out_channels, 3, 1, 1)
        
    def forward(self, x):
        B, C, H, W = x.shape
        
        depth_output = None
        deblur_output = None
        exit_decisions = []
        
        if self.enable_depth:
            patch_h, patch_w = H // 14, W // 14
            features = self.depth_encoder.get_intermediate_layers(
                x, self.intermediate_layer_idx[self.depth_encoder.model_name], return_class_token=True
            )
            depth_output = self.depth_head(features, patch_h, patch_w)
            depth_output = F.relu(depth_output)
            depth_output = F.interpolate(depth_output, (H, W), mode="bilinear", align_corners=True)
        
        if self.enable_deblur:
            x_deblur = self.intro(x)
            
            encs = []
            for i, (encoder, down) in enumerate(zip(self.encoders, self.downs)):
                x_deblur = encoder(x_deblur)
                encs.append(x_deblur)
                
                exit_score = self.ape[i](x_deblur)
                exit_prob = torch.softmax(exit_score, dim=1)[:, 1]
                exit_decisions.append(exit_prob)
                
                should_exit = exit_prob.mean() > self.exit_threshold
                if should_exit and self.training:
                    break
                    
                x_deblur = down(x_deblur)
                
            x_deblur = self.middle_blks(x_deblur)
            deblur_output = self.decoder(x_deblur, encs)
        
        if self.enable_depth and self.enable_deblur:
            combined_features = torch.cat([deblur_output, depth_output], dim=1)
            fused = self.fusion_conv(combined_features)
            output = self.final_conv(fused)
        elif self.enable_depth:
            output = depth_output
        elif self.enable_deblur:
            output = deblur_output
        else:
            output = x
            
        results = {
            'output': output,
            'depth': depth_output,
            'deblurred': deblur_output,
            'exit_decisions': exit_decisions
        }
        
        return results 