from __future__ import print_function

import torch
import numpy as np
import argparse
import os

from models.resnet import InsResNet50, InsResNet34
from models.resnet_cifar import InsResNet50_cifar, InsResNet34_cifar
from models.resnet_MOE_full import InsResNet34 as InsResNet34_MOE
from models.resnet_cifar_MOE_full import InsResNet34_cifar as InsResNet34_cifar_MOE


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def create_model(model_name, contrastive_model, n_label=10):
    if model_name == 'resnet34':
        model = InsResNet34()
        model_ema = InsResNet34()
    elif model_name == 'resnet34_cifar':
        model = InsResNet34_cifar()
        model_ema = InsResNet34_cifar()
    elif model_name == 'resnet34_moe':
        print("Get MOE full")
        model = InsResNet34_MOE(n_label=n_label)
        model_ema = InsResNet34_MOE(n_label=n_label)
    elif model_name == 'resnet34_cifar_moe':
        model = InsResNet34_cifar_MOE(n_label=n_label)
        model_ema = InsResNet34_cifar_MOE(n_label=n_label)
    else:
        raise NotImplementedError('model not supported {}'.format(model_name))

    print("Number of parameters:", count_parameters(model))

    return model, model_ema



def adjust_learning_rate(epoch, opt, optimizer):
    """Sets the learning rate to the initial LR decayed by 0.2 every steep step"""
    steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs))
    if epoch <= opt.lr_warmup:  # warm up epoch = 10
        print("Linear warm up")
        decay_weight = linear_rampup(epoch, opt.lr_warmup)

        for param_group in optimizer.param_groups:
            param_group['lr'] =  opt.learning_rate * decay_weight
        return

    if steps > 0:
        decay_weight = (opt.lr_decay_rate ** steps)
        for param_group in optimizer.param_groups:
            param_group['lr'] =  opt.learning_rate * decay_weight

def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))


def linear_rampup(current, rampup_length):
    """Linear rampup"""
    assert current >= 0 and rampup_length >= 0
    if current >= rampup_length:
        return 1.0
    else:
        return current / rampup_length

def linear_rampdown(current, rampup_length):
    """Linear rampdown"""
    assert current >= 0 and rampup_length >= 0
    if current >= rampup_length:
        return 0.
    else:
        return 1 - current / rampup_length



def cosine_rampdown(current, rampdown_length):
    """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
    # print('rd', current, rampdown_length)
    assert 0 <= current <= rampdown_length
    return max(0., float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)))


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res



def moment_update(model, model_ema, m):
    """ model_ema = m * model_ema + (1 - m) model """
    for p1, p2 in zip(model.parameters(), model_ema.parameters()):
        p2.data.mul_(m).add_(1-m, p1.detach().data)


def get_shuffle_ids(bsz):
    """generate shuffle ids for ShuffleBN"""
    forward_inds = torch.randperm(bsz).long().cuda()
    backward_inds = torch.zeros(bsz).long().cuda()
    value = torch.arange(bsz).long().cuda()
    backward_inds.index_copy_(0, forward_inds, value)
    return forward_inds, backward_inds


def set_bn_train(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.train()



if __name__ == '__main__':
    print()