import torch
import torch.nn as nn
import copy
import numpy as np
from ResNetCifar import ResNetCifar
from networks import ConvNet, ResNet101, ResNet18BN, ResNet50BN
from networks2 import WideResNet28
from collections import OrderedDict

from dassl.utils.torchtools import load_checkpoint


common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur',
                    'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
                    'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']


def load_pretrained_weight(net, resume, models):
    if not 'pt' in resume:
        weight_path = resume + '/ckpt.pth'
    else:
        weight_path = resume
    ckpt = torch.load(weight_path)
    
    if models == 'ResNet101': # for VISDA-C
        weight = ckpt['state_dict']
        net_dict = {}
        for k, v in weight.items():
            k = k.replace('module.', '')
            k = k.replace('downsample', 'shortcut')
                
            if 'encoder.0.' in k:
                if 'running_var' not in k and 'running_mean' not in k and 'num_batches_tracked' not in k:
                    k = k.replace('encoder.0.', "") 
                    k = k.replace('fc', 'classifier')
                    net_dict[k] = v
            elif 'encoder.1.' in k:
                k = k.replace('encoder.1.', "classifier_bn.") 
                net_dict[k] = v
            else:          
                k = k.replace('fc', 'classifier2')
                net_dict[k] = v
    else:
        if 'best_val' in ckpt.keys():
            weight = ckpt['model']
            net_dict = {}

            for k, v in weight.items():
                # if k == 'conv1.bias':
                #     continue
                k = k.replace('downsample', 'shortcut')
                k = k.replace('fc', 'classifier')
                net_dict[k] = v

        elif 'model' in ckpt.keys():
            weight = ckpt['model']

            net_dict = {}
            drop_list = []

            for k, v in weight.items():
                if 'encoder.' in k:
                    k = k.replace('encoder.', "") 
                elif 'fc' in k:
                    k = k.replace('fc', 'classifier')
                else:
                    drop_list.append(k)
                    continue
                net_dict[k] = v 

            print ("========================")
            print ("Dropped layers : ", drop_list)
            print ("========================")
        elif 'net' in ckpt.keys():
            net_dict = ckpt['net']
        elif 'state_dict' in ckpt.keys():
            net_dict = ckpt['state_dict']
        else:
            ckpt = rm_substr_from_state_dict(ckpt, 'module.')
            net_dict = ckpt

    net.load_state_dict(net_dict, strict=True)

    return net_dict

def rm_substr_from_state_dict(state_dict, substr):
    new_state_dict = OrderedDict()
    for key in state_dict.keys():
        if substr in key:  # to delete prefix 'module.' if it exists
            new_key = key[len(substr):]
            new_state_dict[new_key] = state_dict[key]
        else:
            new_state_dict[key] = state_dict[key]
    return new_state_dict


# my load model
def my_get_pretrained_network(net, args):
    checkpoint = load_checkpoint(args.resume_path)
    state_dict = checkpoint["state_dict"]
    net.load_state_dict(state_dict)
    return net


def my_get_pretrained_network_ossfda(net, args):

    backbone_checkpoint = load_checkpoint(args.resume_path)
    backbone_state_dict = backbone_checkpoint["state_dict"]
    classifier_checkpoint = load_checkpoint(args.resume_classifier_path)
    classifier_state_dict = classifier_checkpoint['state_dict']
    net.backbone.load_state_dict(backbone_state_dict)
    net.classifier.load_state_dict(classifier_state_dict)
    return net
    


def get_pretrained_network(args, num_classes, im_size=(32,32), state_dict=None):
    if args.model == 'ConvNet':
        from utils import get_default_convnet_setting
        net_width, net_depth, net_act, net_norm, net_pooling = get_default_convnet_setting()

        net = ConvNet(channel=3, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif args.model == 'WideResNet28':
        net = WideResNet28(num_classes=num_classes)
    elif args.model == 'ResNet18BN':
        net = ResNet18BN(3, num_classes)
    elif args.model == 'ResNet26':
        net = ResNetCifar(26, width=1, classes=num_classes, norm_layer=nn.BatchNorm2d, inject_layer=args.inject_layer)
    elif args.model == 'ResNet101':
        net = ResNet101(channel=3, num_classes=num_classes)
    elif args.model == 'ResNet50':
        net = ResNet50BN(3, num_classes, conv1_size=7)
    else:
        raise NotImplemented

    if state_dict is not None:
        net.load_state_dict(state_dict, strict=True)
    elif args.resume_path is not None:
        state_dict = load_pretrained_weight(net, args.resume_path, args.model)
        state_dict = copy.deepcopy(state_dict)
    elif args.pretrained:
        if args.model == 'ResNet50':
            from torchvision.models import resnet50
            state_dict = resnet50(pretrained=True).state_dict()

            net_dict = {}
            for k, v in state_dict.items():
                # k = k.replace('module.', '')
                k = k.replace('downsample', 'shortcut')
                k = k.replace('fc', 'classifier')
                net_dict[k] = v

            net.load_state_dict(net_dict, strict=True)
            state_dict = copy.deepcopy(net.state_dict())
        else:
            raise NotImplemented

    # gpu_num = torch.cuda.device_count()
    # if gpu_num>0:
    #     device = 'cuda'
    #     if gpu_num>1:
    #         net = nn.DataParallel(net)
    # else:
    #     device = 'cpu'
    # net = net.to(device)
    net = net.cuda()

    return net, state_dict



def get_BN_params(net):
    bn_params = []

    for m in net.modules():
        if isinstance(m, nn.BatchNorm2d):
            bn_params.append(m.parameters())

    for i in range(len(bn_params)):
        for j in bn_params[i]:
            yield j


def get_after_inject_params(net, models, inject_layer, only_bn=False):
    raise NotImplemented


def set_track_running_stats(net, track):
    for m in net.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.track_running_stats = track


def get_sgd_optimizer(optim_parameters, net, model_name, lr, momentum=0.9, weight_decay=0.0005, inject_layer=None):
    if optim_parameters == 'BN':
        optimizer = torch.optim.SGD(get_BN_params(net), lr=lr, momentum=momentum, weight_decay=weight_decay)
    elif optim_parameters == 'AFTER_STYLE':
        optimizer = torch.optim.SGD(get_after_inject_params(net, model_name, inject_layer), lr=lr, momentum=momentum, weight_decay=weight_decay)
    elif optim_parameters == 'AFTER_STYLE_BN':
        optimizer = torch.optim.SGD(get_after_inject_params(net, model_name, inject_layer, only_bn=True), lr=lr, momentum=momentum, weight_decay=weight_decay)
    else:
        optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

    return optimizer


def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0 
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))

def linear_rampup(current, rampup_length):
    """Linear rampup"""
    assert current >= 0 and rampup_length >= 0
    if current >= rampup_length:
        return 1.0 
    else:
        return current / rampup_length

def step_rampup(current, rampup_length):
    assert current >= 0 and rampup_length >= 0
    if current >= rampup_length:
        return 1.0 
    else:
        return 0.0 

def get_rampup_weight(
    iteration, rampup_length, rampup_type="step"
):
    if rampup_type == "step":
        rampup_func = step_rampup
    elif rampup_type == "linear":
        rampup_func = linear_rampup
    elif rampup_type == "sigmoid":
        rampup_func = sigmoid_rampup
    else:
        raise ValueError("Rampup schedule not implemented")

    return rampup_func(iteration, rampup_length)
