from __future__ import print_function
from __future__ import absolute_import
from __future__ import division

import os, sys
import argparse
import inspect
import shutil
import time
import pickle
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms
from loss import KLLoss

import models
import dataset
import learning

# Parser settings
parser = argparse.ArgumentParser(description=' ')
parser.add_argument('--arch', type=str, default='WideResNet', metavar='string')
parser.add_argument('--dml_arch', type=str, default='ResNet', metavar='string')
parser.add_argument('--num_branches', type=int, default=4, metavar='int')
parser.add_argument('--depth', type=int, default=20, metavar='int')
parser.add_argument('--wf', type=int, default=1, metavar='int')
parser.add_argument('--bottleneck', action='store_true', default=False)
parser.add_argument('--se', action='store_true', default=False)
parser.add_argument('--dataset', type=str, default='cifar10', metavar='string')
parser.add_argument('--batch-size', type=int, default=128, metavar='int')
parser.add_argument('--test-batch-size', type=int, default=256, metavar='int')
parser.add_argument('--data', type=str, default='', metavar='string')
parser.add_argument('--epochs', type=int, default=300, metavar='int')
parser.add_argument('--start-epoch', type=int, default=0, metavar='int')
parser.add_argument('--lr', type=float, default=0.1, metavar='float')
parser.add_argument('--wd', type=float, default=5e-4, metavar='float')
parser.add_argument('--milestones', type=int, nargs='+', default=[60, 100])
parser.add_argument('--consistency_rampup', type=float, default=80, metavar='float')
parser.add_argument('--seed', type=int, default=3, metavar='int')
parser.add_argument('--margin', type=float, default=0.0, metavar='float')
parser.add_argument('--gamma', type=float, default=0.5, metavar='float')
parser.add_argument('--T', type=float, default=4, metavar='float')
parser.add_argument('--no-cuda', action='store_true', default=False)
parser.add_argument('--ngpu', type=str, default='cuda:0', metavar='string')
parser.add_argument('--resume', type=str, default='', metavar='string')
parser.add_argument('--evaluate', action='store_true', default=False)
parser.add_argument('--pretrained', action='store_true', default=False)
parser.add_argument('--save', type=str, default='', metavar='string')


args = parser.parse_args()
print("=> Framework Arguments : {}".format(args))

try: os.mkdir(args.save)
except FileExistsError: pass

args.cuda = not args.no_cuda and torch.cuda.is_available()
device = args.ngpu if args.cuda else "cpu"

torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Define Data Loader

train_loader, test_loader = dataset.batch_loader.loader(args.dataset,
                                                        args.data,
                                                        args.batch_size,
                                                        args.test_batch_size,
                                                        device)

def set_argument(sig, args):
    model_keys = list(sig.parameters.keys())
    kwargs = {}
    if len(model_keys) == 1:
        return kwargs

    m_keys = model_keys[1:] if 'self' in model_keys else model_keys
    for idx in range(len(m_keys)):
        if m_keys[idx] in list(vars(args).keys()):
            kwargs[m_keys[idx]] = vars(args)[m_keys[idx]]
    return kwargs

def class_specialization(init_weights, class_info):
    assert isinstance(class_info, list), 'Check out info structure'
    for i in range(len(class_info)):
        init_weights[class_info[i]] = 1.0
    return init_weights

def window_generalization(n_classes, num_branches):
    print('Window width: {}'.format(n_classes // num_branches))
    width = n_classes // num_branches
    window_pool = {}
    window_weight = {}

    ## Define initial points
    start = []
    for i in range(num_branches):
        start.append(0+i*width)

    for i in range(num_branches):
        class_weights = [args.margin for i in range(n_classes)]
        empty =list()
        assert args.gamma >= ((n_classes // args.num_branches) / 100)
        for j in range(int(n_classes*args.gamma)): # end-point: (n_classes//2 + start[i] -1)
            if j+start[i] >= n_classes:
                empty.append(j+start[i]-n_classes)
            else:
                empty.append(j+start[i])
        window_pool[i] = empty
        window_weight[i] = class_specialization(class_weights, window_pool[i])
    
    print(window_pool)
    print(window_weight)
    return window_weight, window_pool  #both window_weights and window_pool are dictionary

def get_loss(num_branches, class_weights, device):
    assert isinstance(class_weights, dict), 'Check out "class_weights" is dictionary'
    loss_fn = {}
    for i in range(num_branches):
        loss_fn[i] = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights[i]), reduction='mean').cuda(device)
    return loss_fn

def get_n_classes(dataset):
    data_to_n_classes = {'cifar10': 10, 'cifar100': 100}
    assert dataset in list(data_to_n_classes.keys()), '{} is not included in data pool'.format(dataset)
    return data_to_n_classes[dataset]

def count_model_parameters(model, detail=False):
    if detail:
        shared = 0
        task = 0
        for name, param in model.named_parameters():
            if param.requires_grad:
                name = name.split('.')
                if len(name[0].split('_'))==2: ## Task specific
                    task += param.numel()
                else:  ## Shared
                    shared += param.numel()
        full = sum(p.numel() for p in model.parameters() if p.requires_grad)
        assert task % (args.num_branches+1) == 0
        task = task // (args.num_branches+1)
        return full, shared, task
    else:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

## Define Model
sig = inspect.signature(models.__dict__[args.arch].__init__) if args.dataset.startswith('cifar') else \
        inspect.signature(models.__dict__[args.arch])
kwargs = set_argument(sig, args)
model = models.__dict__[args.arch](**kwargs).cuda(device)

#print(model)
params = count_model_parameters(model, detail=True)
print("=> Model Parameters: {}, Shared: {}, Task: {}".format(params[0], params[1], params[2]))

c_weights, c_pool = window_generalization(get_n_classes(args.dataset), args.num_branches)

## Define Loss
loss_fn = get_loss(args.num_branches, c_weights, device)
loss_stu = nn.CrossEntropyLoss().cuda(device)
loss_kl = KLLoss(temperature=args.T, device=device)

optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=args.wd,
                          nesterov=True
                          )
scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                               milestones=args.milestones,
                                               gamma=0.1
                                               )

## For model monitoring
if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}'\n=> (epoch {}) Prec1: {:f}"
              .format(args.resume, checkpoint['epoch'], best_prec1))
    else:
        raise ValueError("=> no checkpoint found at '{}'".format(args.resume))

## Storing model checkpoint
def save_checkpoint(state, is_best, filepath):
    if is_best:
        torch.save(state, os.path.join(filepath, 'model_best.pth.tar'))
    else:
        pass

comp = torch.log(torch.tensor(c_weights[0], dtype=torch.float, requires_grad=False)).unsqueeze(-1)
for i in range(1, args.num_branches):
    comp_ = torch.log(torch.tensor(c_weights[i], dtype=torch.float, requires_grad=False)).unsqueeze(-1)
    comp = torch.cat([comp, comp_], dim=-1)
comp = comp.cuda(device)
best_prec1 = 0.

## Just for testing
if args.evaluate:
    test_metrics = learning.OkdTest(test_loader,
                                        model,
                                        loss_fn,
                                        loss_stu,
                                        0,
                                        comp,
                                        device,
                                        args.num_branches,
                                        args.margin
                                        )
    exit()

for epoch in range(args.start_epoch, args.epochs):
    print('Current epoch: {}, Learning rate: {}'.format(epoch, optimizer.param_groups[0]['lr']))
    train_metrics = learning.OkdTrain(train_loader,
                                              model,
                                              loss_fn,
                                              loss_stu,
                                              loss_kl,
                                              optimizer,
                                              epoch,
                                              comp,
                                              device,
                                              args.num_branches,
                                              args.consistency_rampup,
                                              args.margin
                                              )
    test_metrics = learning.OkdTest(test_loader,
                                        model,
                                        loss_fn,
                                        loss_stu,
                                        epoch,
                                        comp,
                                        device,
                                        args.num_branches,
                                        args.margin
                                        )

    is_best = test_metrics['Top1_stu'] > best_prec1
    best_prec1 = max(test_metrics['Top1_stu'], best_prec1)
    if args.save:
        save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict()
        }, is_best, filepath=args.save)
    print("Best accuracy: "+str(best_prec1))
    print("Exposure: {}".format(args.margin))
print("Finished saving training history")
