import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers) -> None:
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.MSA = nn.MultiheadAttention(embed_dim=self.embed_dim, num_heads=self.num_heads)
    def forward(self, x):
        for i in range(self.num_layers):
            x, _ = self.MSA(x, x, x)
        return x
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape) -> None:
        super(LayerNorm, self).__init__()
        self.layernorm = nn.LayerNorm(normalized_shape=normalized_shape)
    def forward(self, x):
        x = self.layernorm(x)
        return x
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim) -> None:
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, output_dim)
        
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
class CrossViewAttention(nn.Module):
    def __init__(self,embed_dim, num_heads) -> None:
        super(CrossViewAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.CSA = nn.MultiheadAttention(embed_dim=self.embed_dim, num_heads=self.num_heads)
    def forward(self, x_0, x_1):
        x, _ = self.CSA(x_0, x_1, x_1)
        return x
class Embedding(nn.Module):

    def __init__(self, input_channels, embed_dim, patch_size, patch_stride, num_patches) -> None:
        super(Embedding, self).__init__()
        self.embedding = nn. Conv2d(in_channels=input_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_stride)
        self.num_patches = num_patches
        self.position_embedding = nn.Embedding(self.num_patches, embedding_dim=embed_dim)
    def forward(self, x):
        x = self.embedding(x) #shape x:(batchsize, 64, 32,20)
        x = x.view(x.size(0), 64, 32*20) #shape x:(batchsize, 64, 640)
        x = x.permute(0, 2, 1)#shape x:(batchsize, 640, 64)
        x = x + self.position_embedding.weight
        return x #shape x:(batchsize, 640, 64)
class FeatureExtractor(nn.Module):
    def __init__(self,input_channel, output_channel) -> None:
        super(FeatureExtractor,self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-3])
        self.conv = nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=1)
        self.batchnorm = nn.BatchNorm2d(output_channel)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):#shape of origin x:(batchsize, channel, w, h)——(2,3,512,334)
        x = self.resnet(x)#shape of x:(batchsize, channel, w, h)——(2,1024,32,21)
        x = self.conv(x)#shape of x:(batchsize, channel, w, h)——(2,64,32,21)
        x = self.batchnorm(x)#shape of x:(batchsize, channel, w, h)——(2,64,32,21)
        x = self.relu(x)#shape of x:(batchsize, channel, w, h)——(2,64,32,21)
        x = self.maxpool(x)#shape of x:(batchsize, channel, w, h)——(2,64,16,10)
        return x
class GlobalEmbedding(nn.Module):
    def __init__(self, embed_dim, num_heads) -> None:
        super(GlobalEmbedding, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.GEmbedding = nn.MultiheadAttention(embed_dim=self.embed_dim, num_heads=self.num_heads)
    def forward(self, x):
        x, _ = self.GEmbedding(x, x, x)
        return x
class MLPHead(nn.Module):
    def __init__(self, input_dim, output_dim) -> None:
        super(MLPHead, self).__init__()
        self.fc1 = nn.Linear(input_dim, 2*40*26)
        self.fc2 = nn.Linear(2*40*26, 2*40*26)
        self.fc3 = nn.Linear(2*40*26, output_dim)
        
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
class ConfideceMap(nn.Module):
    def __init__(self, input_dim, output_dim, out_feature):
        super(ConfideceMap, self).__init__()
        backbone = models.resnet50(pretrained=True)
        self.backbone = nn.Sequential(*list(backbone.children())[:-3])
        self.conv = nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=1, stride=1)
        self.batchnorm = nn.BatchNorm2d(output_dim)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(in_features=output_dim*16*10, out_features=out_feature)
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.conv(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x
class PreMultiViewAttentionCMP(nn.Module):
    def __init__(self) -> None:
        super(PreMultiViewAttentionCMP, self).__init__()
        self.layernorm = LayerNorm(normalized_shape=(64,160))
        self.MSA = MultiHeadSelfAttention(embed_dim=64, num_heads=2, num_layers=1)
        self.MLP = MLP(input_dim=160*64, output_dim=40*26)
        self.CVA = CrossViewAttention(embed_dim=64, num_heads=2)
        self.Embedding = Embedding(input_channels=3, embed_dim=64, patch_size=16, patch_stride=16, num_patches=640)
        self.GEmbedding = GlobalEmbedding(embed_dim=64, num_heads=2)
        self.linear = nn.Linear(in_features=64, out_features=64)
        self.MLPHead = MLPHead(input_dim=40*26*4, output_dim=40*26)
        self.feature_extractor = FeatureExtractor(input_channel=1024, output_channel=64)
        self.CMP = ConfideceMap(1024, 64, 40*26)
    def forward(self, x_all):
        cross_feature = []
        feature_maps = []
 
        output_dict = {}
        for i in range(len(x_all)):
            x = x_all[f'img_{i}'] #shape x: (batchsize, 3, 512, 334)
            cmp = self.CMP(x)
            
            # x = self.Embedding(x)#shape x:(batchsize, 640, 64)——seq_len:640, batchsize:batchsize, embed_dim:64 (seq_len, batchsize, embed_dim)
            x = self.feature_extractor(x)#shape x:(batchsize, 64, 16,10)——seq_len:160, batchsize:batchsize, embed_dim:64 (seq_len, batchsize, embed_dim)
            x = x.view(x.size(0), 64, -1)#shape x:(batchsize, 64, 160)
            # x = x.permute(2, 0, 1)#shape x:(160, batchsize, 64)
            x_0 = self.layernorm(x)#shape x:(batchsize, 64, 160)
            x_0 = x_0.permute(2, 0, 1)#shape x:(160, batchsize, 64)
            x = self.MSA(x_0)##shape x:(160, batchsize, 64)
            cross_feature.append(x)#shape x:(160, batchsize, 64)
            x_1 = x + x_0#shape x:(160, batchsize, 64)
            if i == 0:
                x = self.CVA(x_1,x_1)
            else:
                x = self.CVA(x_1,cross_feature[i-1])
            x_2 = x + x_1 #shape x:(160, batchsize, 64)
            x_2 = x_2.permute(1, 2, 0) #shape x:(batchsize, 64, 160)
            x = self.layernorm(x_2)#shape x:(batchsize, 64, 160)
            x = x.view(x.size(0), -1) #shape x:(batchsize, 160* 64)
            x = self.MLP(x)#shape x:(batchsize, 40*26)
            # x = x.view(x.size(0), 640, 64)
            # x = x + x_2
            x = x * cmp
            feature_maps.append(x)
        feature = torch.cat(feature_maps, dim=1) #shape x:(batchsize, 4*40*26)
        
        # feature = self.GEmbedding(feature)#shape x:(batchsize, 4*640,64)
        # feature = feature.view(feature.size(0), -1) #shape x:(batchsize, 4*640*64)
        feature = self.MLPHead(feature)#shape x:(batchsize, 40*26)
        feature = feature.view(feature.size(0), 40, 26)
        output_dict['laban_0'] = feature
        
        return output_dict
        
            
            
        
        
        
        