from ofa.utils.layers import ResidualBlock
from ofa.imagenet_classification.networks import ProxylessNASNets
from .modules import my_set_layer_from_config
import torch
import numpy as np
import torch.nn as nn
import torchvision
from torchvision import models
import torch.nn.utils.weight_norm as weightNorm
from ofa.model_zoo import proxylessnas_net

__all__ = ['build_residual_block_from_config', 'build_network_from_config', 'calc_coeff',
           'init_weights', 'VGGBase', 'ProxylessBase', 'ResBase', 'gen_shot_model', 'gen_optim']


def build_residual_block_from_config(config):
    conv_config = config['conv'] if 'conv' in config else config['mobile_inverted_conv']
    conv = my_set_layer_from_config(conv_config)
    shortcut = my_set_layer_from_config(config['shortcut'])
    return ResidualBlock(conv, shortcut)


def build_network_from_config(config):
    first_conv = my_set_layer_from_config(config['first_conv'])
    feature_mix_layer = my_set_layer_from_config(config['feature_mix_layer'])
    classifier = my_set_layer_from_config(config['classifier'])

    blocks = []
    for block_config in config['blocks']:
        blocks.append(build_residual_block_from_config(block_config))

    net = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier)
    if 'bn' in config:
        net.set_bn_param(**config['bn'])
    else:
        net.set_bn_param(momentum=0.1, eps=1e-3)

    return net

def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
    return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha*iter_num / max_iter)) - (high - low) + low)

def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
        nn.init.kaiming_uniform_(m.weight)
        nn.init.zeros_(m.bias)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.zeros_(m.bias)
    elif classname.find('Linear') != -1:
        nn.init.xavier_normal_(m.weight)
        nn.init.zeros_(m.bias)

vgg_dict = {"vgg11":models.vgg11, "vgg13":models.vgg13, "vgg16":models.vgg16, "vgg19":models.vgg19,
"vgg11bn":models.vgg11_bn, "vgg13bn":models.vgg13_bn, "vgg16bn":models.vgg16_bn, "vgg19bn":models.vgg19_bn}
class VGGBase(nn.Module):
  def __init__(self, vgg_name):
    super(VGGBase, self).__init__()
    model_vgg = vgg_dict[vgg_name](pretrained=True)
    self.features = model_vgg.features
    self.classifier = nn.Sequential()
    for i in range(6):
        self.classifier.add_module("classifier"+str(i), model_vgg.classifier[i])
    self.in_features = model_vgg.classifier[6].in_features

  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), -1)
    x = self.classifier(x)
    return x


class ProxylessBase(ProxylessNASNets):
    def __init__(self, model_proxyless):
        first_conv = model_proxyless.first_conv
        blocks = model_proxyless.blocks
        feature_mix_layer = model_proxyless.feature_mix_layer
        classifier = model_proxyless.classifier
        super(ProxylessBase, self).__init__(first_conv, blocks, feature_mix_layer, classifier)
        self.global_avg_pool = model_proxyless.global_avg_pool
        self.in_features = model_proxyless.classifier.in_features

    def forward(self, x):
        x = self.first_conv(x)
        for block in self.blocks:
            x = block(x)
        if self.feature_mix_layer is not None:
            x = self.feature_mix_layer(x)
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)
        return x

res_dict = {"resnet18":models.resnet18, "resnet34":models.resnet34, "resnet50":models.resnet50,
"resnet101":models.resnet101, "resnet152":models.resnet152, "resnext50":models.resnext50_32x4d, "resnext101":models.resnext101_32x8d}

class ResBase(nn.Module):
    def __init__(self, res_name):
        super(ResBase, self).__init__()
        model_resnet = res_dict[res_name](pretrained=True)
        self.conv1 = model_resnet.conv1
        self.bn1 = model_resnet.bn1
        self.relu = model_resnet.relu
        self.maxpool = model_resnet.maxpool
        self.layer1 = model_resnet.layer1
        self.layer2 = model_resnet.layer2
        self.layer3 = model_resnet.layer3
        self.layer4 = model_resnet.layer4
        self.avgpool = model_resnet.avgpool
        self.in_features = model_resnet.fc.in_features

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return x

class feat_bootleneck(nn.Module):
    def __init__(self, feature_dim, bottleneck_dim=256, type="ori"):
        super(feat_bootleneck, self).__init__()
        self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=0.5)
        self.bottleneck = nn.Linear(feature_dim, bottleneck_dim)
        self.bottleneck.apply(init_weights)
        self.type = type

    def forward(self, x):
        x = self.bottleneck(x)
        if self.type == "bn":
            x = self.bn(x)
        return x

class feat_classifier(nn.Module):
    def __init__(self, class_num, bottleneck_dim=256, type="linear"):
        super(feat_classifier, self).__init__()
        self.type = type
        if type == 'wn':
            self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight")
            self.fc.apply(init_weights)
        else:
            self.fc = nn.Linear(bottleneck_dim, class_num)
            self.fc.apply(init_weights)

    def forward(self, x):
        x = self.fc(x)
        return x

class feat_classifier_two(nn.Module):
    def __init__(self, class_num, input_dim, bottleneck_dim=256):
        super(feat_classifier_two, self).__init__()
        self.type = type
        self.fc0 = nn.Linear(input_dim, bottleneck_dim)
        self.fc0.apply(init_weights)
        self.fc1 = nn.Linear(bottleneck_dim, class_num)
        self.fc1.apply(init_weights)

    def forward(self, x):
        x = self.fc0(x)
        x = self.fc1(x)
        return x

class Res50(nn.Module):
    def __init__(self):
        super(Res50, self).__init__()
        model_resnet = models.resnet50(pretrained=True)
        self.conv1 = model_resnet.conv1
        self.bn1 = model_resnet.bn1
        self.relu = model_resnet.relu
        self.maxpool = model_resnet.maxpool
        self.layer1 = model_resnet.layer1
        self.layer2 = model_resnet.layer2
        self.layer3 = model_resnet.layer3
        self.layer4 = model_resnet.layer4
        self.avgpool = model_resnet.avgpool
        self.in_features = model_resnet.fc.in_features
        self.fc = model_resnet.fc

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        y = self.fc(x)
        return x, y

def gen_shot_model(res_name, class_num, B_type = 'ori', bottleneck_dim = 256, C_type = 'linear'):
    '''
    gen_shot_model
    FUNCTION: generate a SHOT_like model.

    ARGUMENTS:
    "res_name": the resnet name (lower case, like resnet18)  or proxyless net name of the feature extractor
    "class_num": the classes in total of the dataset.
    "B_type": ('bn' or otherwise) whether to apply Batch Normalization on bottleneck.
    "bottleneck_dim":  the dimension of bottleneck
    "C_type": ('wn' or otherwise) whether to apply Weight Normalization on classifier.
    
    RETURN:
    netF, netB, netC.
    '''
    if res_name[0:3] == 'res':
        netF = ResBase(res_name=res_name).cuda()
    elif res_name[0:3] == 'pro':
        net_full = proxylessnas_net(res_name, pretrained=True)
        netF = ProxylessBase(net_full).cuda()
    else:
        raise NotImplementedError
        
    netB = feat_bootleneck(type=B_type, feature_dim=netF.in_features,
                                   bottleneck_dim=bottleneck_dim).cuda()
    netC = feat_classifier(type=C_type, class_num=class_num, bottleneck_dim=bottleneck_dim).cuda()

    netF.eval()
    netB.eval()
    netC.eval()
    return netF, netB, netC


def op_copy(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr0'] = param_group['lr']
    return optimizer

def gen_optim(args, net):
    '''
    INPUT:
    args: should contain hyperparameter "lr", "lr_decay1", "lr_decay2"
        if not, then automatically filled as 1e-2, 0.1, 1.0
    net: the loaded model pretrained on SOURCE DOMAIN.
    OUTPUT: an optimizer with only netF & B parameters
    '''
    net.train()
    # For safety
    if not hasattr(args, 'lr'):
        print('args.lr filled as 0.01')
        args.lr = 1e-2
    if not hasattr(args, 'lr_decay1'):
        print('args.lr_decay1 filled as 0.1')
        args.lr_decay1 = 0.1
    if not hasattr(args, 'lr_decay2'):
        print('args.lr_decay1 filled as 1.0')
        args.lr_decay2 = 1.0
    
    # only net[0](F) and net[1](B) are trained
    param_group = []
    for k, v in net[0].named_parameters():
        if not k.__contains__('classifier'):
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
        else:
            v.requires_grad = False
    
    for k, v in net[1].named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False
    net[2].eval() # classifier set to eval mode
    optimizer = torch.optim.SGD(param_group)
    optimizer = op_copy(optimizer)
    return optimizer


def get_parameters(args, net):
    '''
    FUNCTION: only when net.modelname starts with string "SHOT", 
                this function will return the parameters of the first two parts.
    INPUT:
    args: should contain hyperparameter "lr", "lr_decay1", "lr_decay2"
        if not, then automatically filled as 1e-2, 0.1, 1.0
    net: the loaded model pretrained on SOURCE DOMAIN.
    OUTPUT: an optimizer with only netF & B parameters
    '''
    if not hasattr(net, 'modelname'): # not SHOTLIKE Model
        return net.parameters() # return all parameters
    if net.modelname[:4] != 'SHOT': # not SHOTLIKE Model
        return net.parameters() # return all parameters
    # set to training mode
    print("Warning: Only netF and netB will be activated during training.")
    net.train()
    # For safety
    if not hasattr(args, 'lr'):
        print('args.lr filled as 0.01')
        args.lr = 1e-2
    if not hasattr(args, 'lr_decay1'):
        print('args.lr_decay1 filled as 0.1')
        args.lr_decay1 = 0.1
    if not hasattr(args, 'lr_decay2'):
        print('args.lr_decay1 filled as 1.0')
        args.lr_decay2 = 1.0

    # only net[0](F) and net[1](B) are trained
    param_group = []
    for k, v in net[0].named_parameters():
        if not k.__contains__('classifier'):
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
        else:
            v.requires_grad = False
    
    for k, v in net[1].named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False
    net[2].eval() # classifier set to eval mode
    # optimizer = torch.optim.SGD(param_group)
    # optimizer = op_copy(optimizer)
    return param_group

def compact_train(model):
    '''
    For setting the compact model to training mode
    '''
    try:
        if model.modelname[:4] == 'SHOT':
            
            model[0].train(), model[1].train(), model[2].eval()
            return
    except:
        pass
    print("WARNING: Not set to SHOT-type model training mode.")
    model.train()
    return