import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F

#backbone + confidence_map
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        resnet = models.resnet152(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-4])
        self.conv = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=1, stride=1)
        self.batchnorm = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
    def forward(self, x):
        x = self.resnet(x)#shape:(batchsize, 512, 64,42)
        x = self.conv(x)#shape:(batchsize, 64, 64,42)
        x = self.batchnorm(x)
        x = self.relu(x)
        x = self.maxpool(x)#shape:(batchsize, 64, 32,21)
        
        return x
    
class ConfideceMap(nn.Module):
    def __init__(self):
        super(ConfideceMap, self).__init__()
        backbone = models.resnet50(pretrained=True)
        self.backbone = nn.Sequential(*list(backbone.children())[:-3])
        self.conv = nn.Conv2d(in_channels=1024, out_channels=64, kernel_size=1, stride=1)
        self.batchnorm = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(in_features=64*16*10, out_features=64*26)
        
    def forward(self, x):
        x = self.backbone(x)
        x = self.conv(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = x.view(x.size(0),-1, 26)
        
        return x

class ResNet152_CMP(nn.Module):
    def __init__(self) -> None:
        super(ResNet152_CMP, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.confidence_map = ConfideceMap()
        self.fc0 = nn.Linear(in_features=21*32*26, out_features=40*26)
        self.fc1 = nn.Linear(in_features=40*26*4, out_features=40*26*2)
        self.fc2 = nn.Linear(in_features=40*26*2, out_features=40*26)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
    def forward(self, x_all):
        feature_map_list = []
        output_dict = {}
        for i in range(len(x_all)):
            x = x_all[f'img_{i}']
            x_heatmap = self.feature_extractor(x)#shape:(batchsize, 64, 32,21)
            x_confidence = self.confidence_map(x)#shape:(batchsize, 64, 26)
            feature_map = torch.einsum('bcwh,bcn->bwhn', x_heatmap, x_confidence)#shape:(batchsize, 32, 21, 26)
            feature_map = feature_map.view(feature_map.size(0), -1)#shape:(batchsize, 32*21*26)
            feature_map = self.fc0(feature_map)#shape:(batchsize, 40*26)
            feature_map_list.append(feature_map)
        feature = torch.cat(feature_map_list, dim=1)
        feature = self.fc1(feature)
        feature = self.relu(feature)
        feature = self.dropout(feature)
        feature = self.fc2(feature)
        feature = feature.view(feature.size(0), 40, 26)
        output_dict['laban_0'] = feature
        
        return output_dict    

            
            
            
        