import torch
import torch.nn as nn
from .backbones.resnet import resnet50
from .backbones.resnet_local import resnet50_local
from loss.arcface import ArcFace
from .backbones.resnet_ibn_a import resnet50_ibn_a, resnet101_ibn_a
from .backbones.resnet_ibn_a_local import resnet50_ibn_a_local
from .backbones.se_resnet_ibn_a import se_resnet101_ibn_a
import torch.nn.functional as F
from torch.utils.model_zoo import load_url as load_state_dict_from_url

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth'
}

def weights_init_kaiming(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
        nn.init.constant_(m.bias, 0.0)

    elif classname.find('Conv') != -1:
        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)
    elif classname.find('BatchNorm') != -1:
        if m.affine:
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0.0)

def weights_init_classifier(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.normal_(m.weight, std=0.001)
        if m.bias:
            nn.init.constant_(m.bias, 0.0)

def build_model(cfg, order):
    model_name = cfg.MODEL.NAME[order]
    frozen_stage = cfg.MODEL.FROZEN
    downsample = eval(cfg.MODEL.DOWNSAMPLE[order]) if cfg.MODEL.DOWNSAMPLE[order] != '' else None
    
    if model_name == 'resnet50':
        output_dim = 2048
        model = resnet50(frozen_stages=frozen_stage)
        print('using resnet50 as a backbone')
    elif model_name == 'resnet50_local':
        output_dim = 2048
        model = resnet50_local(downsample=downsample, frozen_stages=frozen_stage)
        print('using resnet50_local as a backbone')
    elif model_name == 'resnet50_ibn_a':
        output_dim = 2048
        model = resnet50_ibn_a(frozen_stages=frozen_stage)
        print('using resnet50_ibn_a as a backbone')
    elif model_name == 'resnet50_ibn_a_local':
        output_dim = 2048
        model = resnet50_ibn_a_local(downsample=downsample, frozen_stages=frozen_stage)
        print('using resnet50_ibn_a_local as a backbone')
    elif model_name == 'resnet101_ibn_a':
        output_dim = 2048
        model = resnet101_ibn_a(frozen_stages=frozen_stage)
        print('using resnet101_ibn_a as a backbone')
    elif model_name == 'se_resnet101_ibn_a':
        output_dim = 2048
        model = se_resnet101_ibn_a(frozen_stages=frozen_stage)
        print('using se_resnet101_ibn_a as a backbone')
    else:
        print('unsupported backbone! but got {}'.format(model_name))

    model_path = cfg.MODEL.PRETRAIN_PATH[order]
    pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE[order]

    if pretrain_choice == 'self':
        model.load_param(torch.load(model_path))
        print('Loading self model......from {}'.format(model_path))
    elif pretrain_choice == 'imagenet':
        state_dict = load_state_dict_from_url(model_urls[model_name.split('_')[0]])
        model.load_param(state_dict)
        print('Loading pretrained ImageNet model......from model zoo')
    return model, output_dim
    
class Backbone(nn.Module):
    def __init__(self, cfg, num_classes):
        super(Backbone, self).__init__()
        self.use_scale1, self.use_scale2 = cfg.MODEL.USE_SCALE1, cfg.MODEL.USE_SCALE2
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT
        self.in_planes = 0

        # Build scale1 Network
        if self.use_scale1:
            self.scale1_net, output_dim = build_model(cfg, 0)
            self.in_planes += output_dim

        # Build scale2 Network
        if self.use_scale2:
            self.scale2_net, output_dim = build_model(cfg, 1)
            self.in_planes += output_dim

        self.fc = None
        if cfg.MODEL.IN_PLANES != None:
            self.fc = nn.Linear(self.in_planes, cfg.MODEL.IN_PLANES)
            self.in_planes = cfg.MODEL.IN_PLANES           
        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)

        self.num_classes = num_classes
        if self.cos_layer:
            print('using cosine layer')
            self.arcface = ArcFace(self.in_planes, self.num_classes, s=30.0, m=0.50)
        else:
            self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
            self.classifier.apply(weights_init_classifier)

    def forward(self, x):  # label is unused if self.cos_layer == 'no'
        if self.use_scale1:
            scale1_feat, low_feat = self.scale1_net(x)
            scale1_feat = nn.functional.avg_pool2d(scale1_feat, scale1_feat.shape[2:4])
            scale1_feat = scale1_feat.view(scale1_feat.shape[0], -1)  # flatten to (bs, self.output_dim)

        if self.use_scale2:
            scale2_feat = self.scale2_net(low_feat)
            scale2_feat = nn.functional.avg_pool2d(scale2_feat, scale2_feat.shape[2:4])
            scale2_feat = scale2_feat.view(scale2_feat.shape[0], -1)  # flatten to (bs, self.output_dim)
            scale1_feat = torch.cat((scale1_feat, scale2_feat), 1)
        
        if self.fc != None:
            scale1_feat = self.fc(scale1_feat)
        feat = self.bottleneck(scale1_feat)

        if self.neck == 'no':
            feat = scale1_feat
        elif self.neck == 'bnneck':
            feat = self.bottleneck(scale1_feat)

        if self.training:
            if self.cos_layer:
                cls_score = self.arcface(feat, label)
            else:
                cls_score = self.classifier(feat)
            return cls_score, scale1_feat  # global feature for triplet loss
        else:
            if self.neck_feat == 'after':
                # print("Test with feature after BN")
                return feat
            else:
                # print("Test with feature before BN")
                return scale1_feat

    def load_param(self, model_path):
        param_dict = torch.load(model_path)
        for i in param_dict:
            if 'classifier' in i or 'arcface' in i:
                continue
            self.state_dict()[i].copy_(param_dict[i])
        print('Loading pretrained model from {}'.format(model_path))

        
def make_model(cfg, num_class):
    model = Backbone(cfg, num_class)
    return model
