from ctypes import sizeof
import torch
import torch.nn as nn
import torch.nn.functional as F

from models import res_net
from models import repvgg as RepVGG
from models.convnext import LayerNorm,Block
import numpy as np
import copy
class md_resnet18(nn.Module):
    def __init__(self, in_channel=3, strides=[2, 2, 1]):
        super(md_resnet18, self).__init__()
        
        depths=[3, 3, 9, 3]
        dims=[96, 192, 384, 768]
        drop_path_rate=0.1
        layer_scale_init_value=1e-6
        self.downsample_layers = nn.ModuleList() 
        stem = nn.Sequential(
            nn.Conv2d(in_channel, dims[0], kernel_size=4, stride=4),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                    LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                    nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList() 
        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 
                layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]
        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)
        self.head = nn.Sequential(
            nn.Conv2d(dims[-1], 512, kernel_size=1),
            nn.BatchNorm2d(512)
        )
        self.embedding = nn.Sequential(
            nn.Conv2d(192 +768 + 768, 512, kernel_size=1),
            nn.BatchNorm2d(512)
        )

    def forward(self, x):
        l1 = self.downsample_layers[0](x)
        l1 = self.stages[0](l1)
        l2 = self.downsample_layers[1](l1)
        l2 = self.stages[1](l2)
        l3 = self.downsample_layers[2](l2)
        l3 = self.stages[2](l3)
        l4 = self.downsample_layers[3](l3)
        l4 = self.stages[3](l4)
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        lg = F.adaptive_avg_pool2d(l4, (1, 1))
        
        l4 = F.interpolate(l4, l2.size()[-2:], mode='bilinear')
        
        lg = F.adaptive_avg_pool2d(lg, l2.size()[-2:])
        ft = self.embedding(torch.cat([l2, l4, lg], dim=1))


        return ft



class md_resnet34(nn.Module):
    def __init__(self, in_channel=3, strides=[2, 2, 1]):
        super(md_resnet34, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        self.downsample2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=strides[0], padding=1, bias=False), nn.BatchNorm2d(128))
        self.downsample3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=strides[1], padding=1, bias=False), nn.BatchNorm2d(256))
        self.downsample4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=strides[2], padding=1, bias=False), nn.BatchNorm2d(512))
        self.embedding = nn.Sequential(
            nn.Conv2d(128 + 512 + 512, 512, kernel_size=1),
            nn.BatchNorm2d(512)
        )
        self.block1 = nn.Sequential(
            res_net.BasicBlock(64, 64),
            res_net.BasicBlock(64, 64),
            res_net.BasicBlock(64, 64)
        )
        self.block2 = nn.Sequential(
            res_net.BasicBlock(64, 128, stride=strides[0], downsample=self.downsample2),
            res_net.BasicBlock(128, 128),
            res_net.BasicBlock(128, 128),
            res_net.BasicBlock(128, 128)
        )
        self.block3 = nn.Sequential(
            res_net.BasicBlock(128, 256, stride=strides[1], downsample=self.downsample3),
            res_net.BasicBlock(256, 256),
            res_net.BasicBlock(256, 256),
            res_net.BasicBlock(256, 256),
            res_net.BasicBlock(256, 256),
            res_net.BasicBlock(256, 256)
        )
        self.block4 = nn.Sequential(
            res_net.BasicBlock(256, 512, stride=strides[2], downsample=self.downsample4),
            res_net.BasicBlock(512, 512),
            res_net.BasicBlock(512, 512)
        )

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        l1 = self.block1(x)
        l2 = self.block2(l1)
        l3 = self.block3(l2)
        l4 = self.block4(l3)
        lg = F.adaptive_avg_pool2d(l4, (1, 1))

        l4 = F.interpolate(l4, l2.size()[-2:], mode='bilinear')
        lg = F.adaptive_avg_pool2d(lg, l2.size()[-2:])

        ft = self.embedding(torch.cat([l2, l4, lg], dim=1))
        return ft

def conv_bn_relu(in_channel, out_channel, kernel_sz=3, stride=1, pad=1):
    return nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel_sz, stride, pad),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(inplace=True)
    )

class plain_cnn(nn.Module):
    def __init__(self, in_channel=3, strides=[2, 2, 1]):
        super(plain_cnn, self).__init__()

        self.block1 = nn.Sequential(
            conv_bn_relu(in_channel, 64, 3, 1, 1),
            conv_bn_relu(64, 64, 3, 1, 1),
            conv_bn_relu(64, 64, 3, 1, 1),
            conv_bn_relu(64, 64, 3, 1, 1),
            conv_bn_relu(64, 64, 3, 1, 1),
        )
        self.block2 = nn.Sequential(
            conv_bn_relu(64, 128, 3, 2, 1),
            conv_bn_relu(128, 128, 3, 1, 1),
            conv_bn_relu(128, 128, 3, 1, 1),
            conv_bn_relu(128, 128, 3, 1, 1),
            conv_bn_relu(128, 128, 3, 1, 1),
        )
        self.block3 = nn.Sequential(
            conv_bn_relu(128, 256, 3, 2, 1),
            conv_bn_relu(256, 256, 3, 1, 1),
            conv_bn_relu(256, 256, 3, 1, 1),
            conv_bn_relu(256, 256, 3, 1, 1),
            conv_bn_relu(256, 256, 3, 1, 1),
        )

        self.block4 = nn.Sequential(
            conv_bn_relu(256, 512, 3, 2, 1),
            conv_bn_relu(512, 512, 3, 1, 1),
        )

        self.embedding = nn.Sequential(
            nn.Conv2d(256 + 512, 512, kernel_size=1),
            nn.BatchNorm2d(512)
        )

    def forward(self, x):
        l1 = self.block1(x)
        l2 = self.block2(l1)
        l3 = self.block3(l2)
        l4 = self.block4(l3)
        lg = F.adaptive_avg_pool2d(l4, (1, 1))
        lg = F.adaptive_avg_pool2d(lg, l3.size()[-2:])

        ft = self.embedding(torch.cat([l3, lg], dim=1))
        return ft



class md2_resnet18(nn.Module):
    def __init__(self, in_channel=3, strides=[2, 2, 2]):
        super(md2_resnet18, self).__init__()
        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        self.downsample2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=strides[0], padding=1, bias=False), nn.BatchNorm2d(128))
        self.downsample3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=strides[1], padding=1, bias=False), nn.BatchNorm2d(256))
        self.downsample4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=strides[2], padding=1, bias=False), nn.BatchNorm2d(512))
        self.embedding = nn.Sequential(
            nn.Conv2d(128 + 256 + 512, 512, kernel_size=1),
            nn.BatchNorm2d(512)
        )
        self.block1 = nn.Sequential(
            res_net.BasicBlock(64, 64),
            res_net.BasicBlock(64, 64)
        )
        self.block2 = nn.Sequential(
            res_net.BasicBlock(64, 128, stride=strides[0], downsample=self.downsample2),
            res_net.BasicBlock(128, 128)
        )
        self.block3 = nn.Sequential(
            res_net.BasicBlock(128, 256, stride=strides[1], downsample=self.downsample3),
            res_net.BasicBlock(256, 256)
        )
        self.block4 = nn.Sequential(
            res_net.BasicBlock(256, 512, stride=strides[2], downsample=self.downsample4),
            res_net.BasicBlock(512, 512)
        )

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        l1 = self.block1(x)
        l2 = self.block2(l1)
        l3 = self.block3(l2)
        l4 = self.block4(l3)
        lg = F.adaptive_avg_pool2d(l4, (1, 1))

        l3 = F.interpolate(l3, l2.size()[-2:], mode='bilinear')
        lg = F.adaptive_avg_pool2d(lg, l2.size()[-2:])

        global_ft = self.embedding(torch.cat([l2, l3, lg], dim=1))
        return l3, global_ft
