import torchvision
from torch import nn
from torch.nn import init
from models.utils import pooling
import torch        
from models.classifier import NormalizedClassifier


class ResNet50(nn.Module):
    def __init__(self, config, num_clothes,**kwargs):
        super().__init__()

        resnet50 = torchvision.models.resnet50(pretrained=True)
        if config.MODEL.RES4_STRIDE == 1:
            resnet50.layer4[0].conv2.stride=(1, 1)
            resnet50.layer4[0].downsample[0].stride=(1, 1)
        self.conv1 = resnet50.conv1
        self.bn1 = resnet50.bn1
        self.relu = resnet50.relu
        self.maxpool = resnet50.maxpool
        self.layer1 = resnet50.layer1
        self.layer2 = resnet50.layer2
        self.layer3 = resnet50.layer3
        self.layer4 = resnet50.layer4
        self.layer4c = resnet50.layer4
        if config.MODEL.POOLING.NAME == 'avg':
            self.globalpooling = nn.AdaptiveAvgPool2d(1)
        elif config.MODEL.POOLING.NAME == 'max':
            self.globalpooling = nn.AdaptiveMaxPool2d(1)
        elif config.MODEL.POOLING.NAME == 'gem':
            self.globalpooling = pooling.GeMPooling(p=config.MODEL.POOLING.P)
        elif config.MODEL.POOLING.NAME == 'maxavg':
            self.globalpooling = pooling.MaxAvgPooling()
        else:
            raise KeyError("Invalid pooling: '{}'".format(config.MODEL.POOLING.NAME))
        self.bn = nn.BatchNorm1d(config.MODEL.FEATURE_DIM)
        init.normal_(self.bn.weight.data, 1.0, 0.02)
        init.constant_(self.bn.bias.data, 0.0)  
        self.bn2 = nn.BatchNorm1d(config.MODEL.FEATURE_DIM)
        init.normal_(self.bn2.weight.data, 1.0, 0.02)
        init.constant_(self.bn2.bias.data, 0.0)  
        self.cam_classifier = nn.Linear(config.MODEL.FEATURE_DIM, 12)
        init.normal_(self.cam_classifier.weight.data, std=0.001)
        init.constant_(self.cam_classifier.bias.data, 0.0) 
        self.clothes_classifier = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_clothes) 
    def forward(self, x, mode='train'):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)  
        base_f = self.layer4(x)
        if mode!='test':
            base_f1,_ = x.split(int(x.size(0)/2), dim=0)
            base_f2 = self.layer4c(base_f1)
            base_f=self.globalpooling(base_f)
            base_f = base_f.view(base_f.size(0), -1)
            f = self.bn(base_f)
            f1=self.globalpooling(base_f2)
            f1 = f1.view(f1.size(0), -1)
            f1 = self.bn2(f1)
            score2 = self.cam_classifier(f1)   
            score3 = self.clothes_classifier(f1) 
            return base_f, f, f1, score2, score3
        f=self.globalpooling(base_f)
        f = f.view(f.size(0), -1)    
        fx = self.bn(f)
        base_f1 = self.layer4c(x)
        f1=self.globalpooling(base_f1)
        f1 = f1.view(f1.size(0), -1)
        f1 = self.bn2(f1)
        score = self.cam_classifier(f1)
        _, preds = torch.max(score.data, 1)
        
        return preds, fx
    
