from __future__ import division
# Codes are borrowed from https://github.com/vikasverma1077/manifold_mixup/tree/master/supervised

import os, sys, shutil, time, random
sys.path.append('..')
if sys.version_info[0] < 3:
    import cPickle as pickle
else:
    import _pickle as pickle
from glob import glob
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

import models
from mixup import distance
from collections import OrderedDict, Counter
from load_data import load_data_subset
from logger import plotting, copy_script_to_folder, AverageMeter, RecorderMeter, time_string, convert_secs2time


model_names = sorted(name for name in models.__dict__
  if name.islower() and not name.startswith("__")
  and callable(models.__dict__[name]))

def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
        

parser = argparse.ArgumentParser(description='Train Classifier with mixup', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Data
parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar10', 'cifar100', 'tiny-imagenet-200'], help='Choose between Cifar10/100 and Tiny-ImageNet.')
parser.add_argument('--data_dir', type = str, default = '~/Datasets/cifar100', help='file where results are to be written')
parser.add_argument('--root_dir', type = str, default = 'experiments', help='folder where results are to be stored')
parser.add_argument('--labels_per_class', type=int, default=500, metavar='NL', help='labels_per_class')
parser.add_argument('--valid_labels_per_class', type=int, default=0, metavar='NL', help='validation labels_per_class')

# Model
parser.add_argument('--arch', metavar='ARCH', default='preactresnet18', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: wrn28_10)')
parser.add_argument('--initial_channels', type=int, default=64, choices=(16,64))

# Optimization options
parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train')
parser.add_argument('--train', type=str, default='mixup', choices=['vanilla', 'mixup', 'mixup_hidden'], help='mixup layer')
parser.add_argument('--mixup_alpha', type=float, default=1.0, help='alpha parameter for mixup')
parser.add_argument('--dropout', type=str2bool, default=False, help='whether to use dropout or not in final layer')

# Co-Mixup
parser.add_argument('--match_mix', type=str2bool, default=False, help='true for Co-Mixup')
parser.add_argument('--m_block_num', type=int, default=4, help='resolution of labeling, -1 for random')

parser.add_argument('--m_beta', type=float, default=0.32, help='label smoothness coef, 0.5-2.0')
parser.add_argument('--m_gamma', type=float, default=1.0, help='supermodular diversity coef')
parser.add_argument('--m_thres', type=float, default=0.83, help='threshold for over-penalization, [0,1]')
parser.add_argument('--m_thres_type', type=str, default='hard', choices=['soft', 'hard'], help='thresholding type')
parser.add_argument('--m_eta', type=float, default=0.05, help='prior coef')

parser.add_argument('--lam_dist', type=float, default=0., help='input compatibility coef \omega')
parser.add_argument('--m_dist_type', type=str, default='sc', choices=['l2', 'l22', 'l1', 'linf', 'sc'], help='compatibility metric')
parser.add_argument('--set_resolve', type=str2bool, default=True, help='post-processing for resolving the same outputs')
parser.add_argument('--m_niter', type=int, default=4, help='number of outer iteration')
parser.add_argument('--m_part', type=int, default=20, help='partition size')

parser.add_argument('--clean_lam', type=float, default=1.0, help='clean input regularization')
parser.add_argument('--clean_eph', type=int, default=0)
parser.add_argument('--sc_loss', type=str, default='ce', choices=['ce', 'bce'], help='')


# Baselines
parser.add_argument('--box', type=str2bool, default=False, help='true for CutMix')
parser.add_argument('--graph', type=str2bool, default=False, help='true for PuzzleMix')
parser.add_argument('--neigh_size', type=int, default=4, help='neighbor size for computing distance beteeen image regions')
parser.add_argument('--n_labels', type=int, default=3, help='label space size')

parser.add_argument('--beta', type=float, default=1.2, help='label smoothness')
parser.add_argument('--gamma', type=float, default=0.5, help='data local smoothness')
parser.add_argument('--eta', type=float, default=0.2, help='prior term')
parser.add_argument('--transport', type=str2bool, default=False, help='whether to use transport')
parser.add_argument('--t_eps', type=float, default=0.8, help='transport cost coefficient')
parser.add_argument('--t_size', type=int, default=-1, help='transport resolution. -1 for using the same resolution with graphcut')

# training
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--learning_rate', type=float, default=0.2)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--decay', type=float, default=0.0001, help='weight decay (L2 penalty)')
parser.add_argument('--schedule', type=int, nargs='+', default=[100, 200], help='decrease learning rate at these epochs')
parser.add_argument('--gammas', type=float, nargs='+', default=[0.1, 0.1], help='LR is multiplied by gamma on schedule, number of gammas should be equal to schedule')

# Checkpoints
parser.add_argument('--print_freq', default=100, type=int, metavar='N', help='print frequency (default: 200)')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')

# Acceleration
parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU')
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')

# random seed
parser.add_argument('--seed', default=0, type=int, help='manual seed')
parser.add_argument('--add_name', type=str, default='')
parser.add_argument('--log_off', type=str2bool, default=False)
parser.add_argument('--job_id', type=str, default='')

args = parser.parse_args()
args.use_cuda = args.ngpu>0 and torch.cuda.is_available()

# random seed 
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

cudnn.benchmark = True


def experiment_name_non_mnist(args=args):
    '''
    function for experiment result folder name.
    '''
    exp_name = args.dataset
    exp_name += '_per_' + str(args.labels_per_class)
    exp_name += '_arch_'+str(args.arch)
    exp_name += '_train_'+str(args.train)
    exp_name += '_eph_'+str(args.epochs)
    exp_name += '_lr_'+str(args.learning_rate)
    if args.mixup_alpha:
        exp_name += '_m_alpha_'+str(args.mixup_alpha)
    if args.box:
        exp_name += '_box'
    if args.graph:
        exp_name += '_graph' + '_n_labels_' + str(args.n_labels) + '_beta_' + str(args.beta) + '_gamma_' + str(args.gamma) + '_neigh_' + str(args.neigh_size) + '_eta_' +str(args.eta)
    if args.transport:
        exp_name += '_transport' + '_eps_' + str(args.t_eps) + '_size_' + str(args.t_size)
    if args.match_mix:
        exp_name += '_mixmatch'
    if args.match_mix:
        exp_name += '_mblock_' + str(args.m_block_num) + '_mbeta_' + str(args.m_beta) + '_mgamma_' + str(args.m_gamma) + '_mthres_' + str(args.m_thres_type) + str(args.m_thres) + '_meta_' + str(args.m_eta)
        exp_name += '_mpart_' + str(args.m_part) + '_niter_' + str(args.m_niter) + '_lamdist_' + str(args.lam_dist) + str(args.m_dist_type)
        if args.set_resolve:
            exp_name += '_set'
    if args.clean_eph > 0:
        exp_name += '_clean_eph_' + str(args.clean_eph)
    if args.job_id!=None:
        exp_name += '_job_id_'+str(args.job_id)
    if args.clean_lam>0:
        exp_name += '_clean_' + str(args.clean_lam)
    if args.sc_loss == 'bce':
        exp_name += 'bce'
    exp_name += '_seed_'+str(args.seed)
    if args.add_name!='':
        exp_name += '_add_name_'+str(args.add_name)

    print('\nexperiement name: ' + exp_name)
    return exp_name


def print_log(print_string, log, end='\n'):
    '''print log'''
    print("{}".format(print_string), end=end)
    if log is not None:
        if end == '\n':
            log.write('{}\n'.format(print_string))
        else:
            log.write('{} '.format(print_string))
        log.flush()

def save_checkpoint(state, is_best, save_path, filename):
    '''save checkpoint'''
    filename = os.path.join(save_path, filename)
    torch.save(state, filename)
    if is_best:
        bestname = os.path.join(save_path, 'model_best.pth.tar')
        shutil.copyfile(filename, bestname)

def adjust_learning_rate(optimizer, epoch, gammas, schedule):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.learning_rate
    assert len(gammas) == len(schedule), "length of gammas and schedule should be equal"
    for (gamma, step) in zip(gammas, schedule):
        if (epoch >= step):
            lr = lr * gamma
        else:
            break
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


bce_loss = nn.BCELoss().cuda()
bce_loss_sum = nn.BCELoss(reduction='sum').cuda()
softmax = nn.Softmax(dim=1).cuda()
criterion = nn.CrossEntropyLoss().cuda()
criterion_batch = nn.CrossEntropyLoss(reduction='none').cuda()


def train(train_loader, model, dense, optimizer, epoch, args, log):
    '''train given model and dataloader'''
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    mixing_avg = []

    # switch to train mode
    model.train()
    dense.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        data_time.update(time.time() - end)
        optimizer.zero_grad()
        
        input = input.cuda()
        target = target.long().cuda()

        # train with clean images
        if args.train == 'vanilla' or epoch < args.clean_eph:
            input_var, target_var = Variable(input), Variable(target)
            z, reweighted_target = model(input_var, target_var)
            output = dense(z)
            loss = bce_loss(softmax(output), reweighted_target)

        # train with mixup images
        elif args.train == 'mixup':
            sc = None
            A_dist = None
            # process for Puzzle Mix
            if args.graph or args.match_mix:
                input_var = Variable(input, requires_grad=True)
                target_var = Variable(target)

                # calculate saliency (unary)
                if args.sc_loss == 'ce':
                    if args.clean_lam == 0:
                        model.eval()
                        dense.eval()
                        z = model(input_var)
                        output = dense(z)
                        loss_batch =  criterion_batch(output, target_var)
                    else:
                        model.train()
                        dense.train()
                        z = model(input_var)
                        output = dense(z)
                        loss_batch =  2 * args.clean_lam * criterion_batch(output, target_var) / args.num_classes
                    loss_batch_mean = torch.mean(loss_batch, dim=0)
                    loss_batch_mean.backward(retain_graph=True)
                else:
                    if args.clean_lam == 0:
                        model.eval()
                        dense.eval()
                        z, reweighted_target = model(input_var, target_var)
                        output = dense(z)
                        loss_sc = bce_loss(softmax(output), reweighted_target)
                    else:
                        model.train()
                        dense.train()
                        z, reweighted_target = model(input_var, target_var)
                        output = dense(z)
                        loss_sc = args.clean_lam * bce_loss(softmax(output), reweighted_target)
                    loss_sc.backward(retain_graph=True)

                sc = torch.sqrt(torch.mean(input_var.grad **2, dim=1))

                with torch.no_grad():
                    if args.m_dist_type == 'sc':
                        z = F.avg_pool2d(sc, kernel_size=8, stride=1)
                        z_reshape = z.reshape(args.batch_size, -1)
                        z_idx_1d = torch.argmax(z_reshape, dim=1)
                        z_idx_2d = torch.zeros(args.batch_size, 2)
                        z_idx_2d[:,0] = z_idx_1d // z.shape[-1]
                        z_idx_2d[:,1] = z_idx_1d % z.shape[-1]
                        A_dist = distance(z_idx_2d, dist_type='l1')
                    else:
                        A_dist = distance(z, dist_type=args.m_dist_type)

                if args.clean_lam == 0:
                    model.train()
                    optimizer.zero_grad()
            
            input_var, target_var = Variable(input), Variable(target)
            # perform mixup and calculate loss
            z, reweighted_target = model(input_var, target_var, mixup=True, args=args, sc=sc, A_dist=A_dist)
            output = dense(z)
            loss = bce_loss(softmax(output), reweighted_target)

        # for manifold mixup
        elif args.train== 'mixup_hidden':
            input_var, target_var = Variable(input), Variable(target)
            z, reweighted_target = model(input_var,target_var, mixup_hidden=True, args=args)
            output = dense(z)
            loss = bce_loss(softmax(output), reweighted_target)
        else:
            raise AssertionError('wrong train type!!')
        
        # measure accuracy and record loss
        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
    print_log('  **Train** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, top5=top5, error1=100-top1.avg), log)
    return top1.avg, top5.avg, losses.avg


def validate(val_loader, model, dense, log):
    '''evaluate trained model'''
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    dense.eval()

    for i, (input, target) in enumerate(val_loader):
        if args.use_cuda:
            input = input.cuda()
            target = target.cuda()
        
        with torch.no_grad():
            input_var = Variable(input)
            target_var = Variable(target)

            # compute output
            output = dense(model(input_var))
            loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))
  
    print_log('**Test** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f} Loss: {losses.avg:.3f} '.format(top1=top1, top5=top5, error1=100-top1.avg, losses=losses), log)
    return top1.avg, losses.avg


best_acc = 0
def main():
    # set up the experiment directories
    if not args.log_off:
        exp_name = experiment_name_non_mnist()
        exp_dir = os.path.join(args.root_dir, exp_name)

        if not os.path.exists(exp_dir):
            os.makedirs(exp_dir)
    
        copy_script_to_folder(os.path.abspath(__file__), exp_dir)

        result_png_path = os.path.join(exp_dir, 'results.png')
        log = open(os.path.join(exp_dir, 'log.txt'.format(args.seed)), 'w')
        print_log('save path : {}'.format(exp_dir), log)
    else: 
        log = None

    global best_acc

    state = {k: v for k, v in args._get_kwargs()}
    print("")
    print_log(state, log)
    print("")
    print_log("Random Seed: {}".format(args.seed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)
    
    # dataloader
    train_loader, valid_loader, _ , test_loader, num_classes = load_data_subset(args.batch_size, 2 ,args.dataset, args.data_dir, 
                labels_per_class=args.labels_per_class, valid_labels_per_class=args.valid_labels_per_class)
    
    if args.dataset == 'tiny-imagenet-200':
        stride = 2 
        args.mean = torch.tensor([0.5] * 3, dtype=torch.float32).view(1,3,1,1).cuda()
        args.std = torch.tensor([0.5] * 3, dtype=torch.float32).view(1,3,1,1).cuda()
        args.labels_per_class = 500
    elif args.dataset == 'cifar10':
        stride = 1
        args.mean = torch.tensor([x / 255 for x in [125.3, 123.0, 113.9]], dtype=torch.float32).view(1,3,1,1).cuda()
        args.std = torch.tensor([x / 255 for x in [63.0, 62.1, 66.7]], dtype=torch.float32).view(1,3,1,1).cuda()
        args.labels_per_class = 5000
    elif args.dataset == 'cifar100':
        stride = 1
        args.mean = torch.tensor([x / 255 for x in [129.3, 124.1, 112.4]], dtype=torch.float32).view(1,3,1,1).cuda()
        args.std = torch.tensor([x / 255 for x in [68.2, 65.4, 70.4]], dtype=torch.float32).view(1,3,1,1).cuda()
        args.labels_per_class = 500
    else:
        raise AssertionError('Given Dataset is not supported!')
        
    # create model 
    print_log("=> creating model '{}'".format(args.arch), log)
    net = models.__dict__[args.arch](num_classes, args.dropout, stride).cuda()
    args.num_classes = num_classes
    dense = models.__dict__['dense'](net.n_feature, args.num_classes).cuda()

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    dense = torch.nn.DataParallel(dense, device_ids=list(range(args.ngpu)))
    optimizer = torch.optim.SGD(list(net.parameters()) + list(dense.parameters()), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=True)

    recorder = RecorderMeter(args.epochs)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("\n=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            dense.load_state_dict(checkpoint['state_dict_dense'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        validate(test_loader, net, dense, log)
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    train_loss = []
    train_acc=[]
    test_loss=[]
    test_acc=[]
    
    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)
        if epoch == args.schedule[0]:
            args.clean_lam == 0

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        tr_acc, tr_acc5, tr_los  = train(train_loader, net, dense, optimizer, epoch, args, log)

        # evaluate on validation set
        val_acc, val_los = validate(test_loader, net, dense, log)
                
        train_loss.append(tr_los)
        train_acc.append(tr_acc)
        test_loss.append(val_los)
        test_acc.append(val_acc)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc
        
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

        if args.log_off:
            continue

        # save log
        dummy = recorder.update(epoch, tr_los, tr_acc, val_los, val_acc)
        if (epoch+1) % 100 ==0:
            recorder.plot_curve(result_png_path)
    
        train_log = OrderedDict()
        train_log['train_loss'] = train_loss
        train_log['train_acc']=train_acc
        train_log['test_loss']=test_loss
        train_log['test_acc']=test_acc
        
        pickle.dump(train_log, open(os.path.join(exp_dir,'log.pkl'), 'wb'))
        plotting(exp_dir)
    
        save_checkpoint({
                  'epoch': epoch + 1,
                  'arch': args.arch,
                  'state_dict': net.state_dict(),
                  'state_dict_dense': dense.state_dict(),
                  'recorder': recorder,
                  'optimizer' : optimizer.state_dict(),
                }, is_best, exp_dir, 'checkpoint.pth.tar')
 
    acc_var = np.maximum(np.max(test_acc[-10:])- np.median(test_acc[-10:]), np.median(test_acc[-10:]) - np.min(test_acc[-10:]))
    print_log("\nfinal 10 epoch acc (median) : {:.2f} (+- {:.2f})".format(np.median(test_acc[-10:]), acc_var), log)

    if not args.log_off:
        log.close()
    

if __name__ == '__main__':
    main()

