from ctypes import sizeof
import torch
import torch.nn as nn
import torch.nn.functional as F

from models import res_net
from models import repvgg as RepVGG
import numpy as np
import copy
class md_resnet18(nn.Module):
    def __init__(self, in_channel=3, strides=[2, 2, 1]):
        super(md_resnet18, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        
        
        
        
        
        self.embedding = nn.Sequential(
            nn.Conv2d(128 +1280 + 1280, 512, kernel_size=1),
            nn.BatchNorm2d(512)
        )
        
        
        self.block1 = nn.Sequential(
            RepVGG.RepVGGBlock(in_channels=64,out_channels=64,kernel_size=3,stride=2, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=64, out_channels=64, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False)
            
            
            
        )
        
        self.block2 = nn.Sequential(
            RepVGG.RepVGGBlock(in_channels=64, out_channels=128, kernel_size=3,stride=2, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=128, out_channels=128, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=128, out_channels=128, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=128, out_channels=128, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False)
            
            
        )
        self.block3 = nn.Sequential(
            RepVGG.RepVGGBlock(in_channels=128, out_channels=256, kernel_size=3,stride=2, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=256, out_channels=256, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=256, out_channels=256, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=256, out_channels=256, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=256, out_channels=256, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=256, out_channels=256, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=256, out_channels=256, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=256, out_channels=256, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=256, out_channels=256, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=256, out_channels=256, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=256, out_channels=256, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=256, out_channels=256, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=256, out_channels=256, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False),
            RepVGG.RepVGGBlock(in_channels=256, out_channels=256, kernel_size=3,stride=1, padding=1, groups=1, deploy=False, use_se=False)
            
        )
        self.block4 = nn.Sequential(
            RepVGG.RepVGGBlock(in_channels=256, out_channels=1280, kernel_size=3,stride=2, padding=1, groups=1, deploy=False, use_se=False)
            
            
        )

    def forward(self, x):
        
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        l1 = self.block1(x)
        l2 = self.block2(l1)
        l3 = self.block3(l2)
        l4 = self.block4(l3)
        print(l4.size())
        lg = F.adaptive_avg_pool2d(l4, (1, 1))
        print(lg.size())
        l4 = F.interpolate(l4, l2.size()[-2:], mode='bilinear')
        
        lg = F.adaptive_avg_pool2d(lg, l2.size()[-2:])

        ft = self.embedding(torch.cat([l2, l4, lg], dim=1))


        return ft



class md_resnet34(nn.Module):
    def __init__(self, in_channel=3, strides=[2, 2, 1]):
        super(md_resnet34, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        self.downsample2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=strides[0], padding=1, bias=False), nn.BatchNorm2d(128))
        self.downsample3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=strides[1], padding=1, bias=False), nn.BatchNorm2d(256))
        self.downsample4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=strides[2], padding=1, bias=False), nn.BatchNorm2d(512))
        self.embedding = nn.Sequential(
            nn.Conv2d(128 + 512 + 512, 512, kernel_size=1),
            nn.BatchNorm2d(512)
        )
        self.block1 = nn.Sequential(
            res_net.BasicBlock(64, 64),
            res_net.BasicBlock(64, 64),
            res_net.BasicBlock(64, 64)
        )
        self.block2 = nn.Sequential(
            res_net.BasicBlock(64, 128, stride=strides[0], downsample=self.downsample2),
            res_net.BasicBlock(128, 128),
            res_net.BasicBlock(128, 128),
            res_net.BasicBlock(128, 128)
        )
        self.block3 = nn.Sequential(
            res_net.BasicBlock(128, 256, stride=strides[1], downsample=self.downsample3),
            res_net.BasicBlock(256, 256),
            res_net.BasicBlock(256, 256),
            res_net.BasicBlock(256, 256),
            res_net.BasicBlock(256, 256),
            res_net.BasicBlock(256, 256)
        )
        self.block4 = nn.Sequential(
            res_net.BasicBlock(256, 512, stride=strides[2], downsample=self.downsample4),
            res_net.BasicBlock(512, 512),
            res_net.BasicBlock(512, 512)
        )

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        l1 = self.block1(x)
        l2 = self.block2(l1)
        l3 = self.block3(l2)
        l4 = self.block4(l3)
        lg = F.adaptive_avg_pool2d(l4, (1, 1))

        l4 = F.interpolate(l4, l2.size()[-2:], mode='bilinear')
        lg = F.adaptive_avg_pool2d(lg, l2.size()[-2:])

        ft = self.embedding(torch.cat([l2, l4, lg], dim=1))
        return ft

def conv_bn_relu(in_channel, out_channel, kernel_sz=3, stride=1, pad=1):
    return nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel_sz, stride, pad),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(inplace=True)
    )

class plain_cnn(nn.Module):
    def __init__(self, in_channel=3, strides=[2, 2, 1]):
        super(plain_cnn, self).__init__()

        self.block1 = nn.Sequential(
            conv_bn_relu(in_channel, 64, 3, 1, 1),
            conv_bn_relu(64, 64, 3, 1, 1),
            conv_bn_relu(64, 64, 3, 1, 1),
            conv_bn_relu(64, 64, 3, 1, 1),
            conv_bn_relu(64, 64, 3, 1, 1),
        )
        self.block2 = nn.Sequential(
            conv_bn_relu(64, 128, 3, 2, 1),
            conv_bn_relu(128, 128, 3, 1, 1),
            conv_bn_relu(128, 128, 3, 1, 1),
            conv_bn_relu(128, 128, 3, 1, 1),
            conv_bn_relu(128, 128, 3, 1, 1),
        )
        self.block3 = nn.Sequential(
            conv_bn_relu(128, 256, 3, 2, 1),
            conv_bn_relu(256, 256, 3, 1, 1),
            conv_bn_relu(256, 256, 3, 1, 1),
            conv_bn_relu(256, 256, 3, 1, 1),
            conv_bn_relu(256, 256, 3, 1, 1),
        )

        self.block4 = nn.Sequential(
            conv_bn_relu(256, 512, 3, 2, 1),
            conv_bn_relu(512, 512, 3, 1, 1),
        )

        self.embedding = nn.Sequential(
            nn.Conv2d(256 + 512, 512, kernel_size=1),
            nn.BatchNorm2d(512)
        )

    def forward(self, x):
        l1 = self.block1(x)
        l2 = self.block2(l1)
        l3 = self.block3(l2)
        l4 = self.block4(l3)
        lg = F.adaptive_avg_pool2d(l4, (1, 1))
        lg = F.adaptive_avg_pool2d(lg, l3.size()[-2:])

        ft = self.embedding(torch.cat([l3, lg], dim=1))
        return ft



class md2_resnet18(nn.Module):
    def __init__(self, in_channel=3, strides=[2, 2, 2]):
        super(md2_resnet18, self).__init__()
        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        self.downsample2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=strides[0], padding=1, bias=False), nn.BatchNorm2d(128))
        self.downsample3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=strides[1], padding=1, bias=False), nn.BatchNorm2d(256))
        self.downsample4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=strides[2], padding=1, bias=False), nn.BatchNorm2d(512))
        self.embedding = nn.Sequential(
            nn.Conv2d(128 + 256 + 512, 512, kernel_size=1),
            nn.BatchNorm2d(512)
        )
        self.block1 = nn.Sequential(
            res_net.BasicBlock(64, 64),
            res_net.BasicBlock(64, 64)
        )
        self.block2 = nn.Sequential(
            res_net.BasicBlock(64, 128, stride=strides[0], downsample=self.downsample2),
            res_net.BasicBlock(128, 128)
        )
        self.block3 = nn.Sequential(
            res_net.BasicBlock(128, 256, stride=strides[1], downsample=self.downsample3),
            res_net.BasicBlock(256, 256)
        )
        self.block4 = nn.Sequential(
            res_net.BasicBlock(256, 512, stride=strides[2], downsample=self.downsample4),
            res_net.BasicBlock(512, 512)
        )

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        l1 = self.block1(x)
        l2 = self.block2(l1)
        l3 = self.block3(l2)
        l4 = self.block4(l3)
        lg = F.adaptive_avg_pool2d(l4, (1, 1))

        l3 = F.interpolate(l3, l2.size()[-2:], mode='bilinear')
        lg = F.adaptive_avg_pool2d(lg, l2.size()[-2:])

        global_ft = self.embedding(torch.cat([l2, l3, lg], dim=1))
        return l3, global_ft
