import sys
import torchvision
import argparse
import os
import shutil
import time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
from losses import *
import resnet
import vgg



parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=32, type=int,
                    metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weightdecay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('--print-freq', '-p', default=50, type=int,
                    metavar='N', help='print frequency (default: 20)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--save-dir', dest='save_dir',
                    help='The directory used to save the trained models',
                    default='save_temp', type=str)
parser.add_argument('--save-every', dest='save_every',
                    help='Saves checkpoints at every specified number of epochs',
                    type=int, default=10)
parser.add_argument('--gpu', default='1,2', type=str)
parser.add_argument('--coeff', default=-1, type=float, help='Coefficient to KL term in BM loss. Set -1 to use CrossEntropy Loss')
parser.add_argument('--rate', default=0.0, type=float, help='Dropout rate')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--exp_num', default=1, type=int)
parser.add_argument('--num-eval', dest='num_eval', default=100, type=int, help='Evaluation count for MC dropout')

parser.add_argument('--reg_coeff', default=0., type=float, help='l2_reg_loss')
parser.add_argument('--use_l1', action='store_true')
parser.add_argument('--use_l2', action='store_true')
parser.add_argument('--use_wass', action='store_true')
parser.add_argument('--use_per', action='store_true')

best_prec1 = 0
test_error_best = -1





import math
def main():
    global args, best_prec1
    args = parser.parse_args()
   
    torch.manual_seed(args.exp_num)
    torch.cuda.manual_seed_all(args.exp_num)
    torch.cuda.manual_seed(args.exp_num)
    np.random.seed(args.exp_num)
   
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
   
    if args.dataset == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    else:
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
   
    transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    ]) # meanstd transformation

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
   
    if(args.dataset == 'cifar10'):
        print("| Preparing CIFAR-10 dataset...")
        sys.stdout.write("| ")
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
        num_classes = 10
    elif(args.dataset == 'cifar100'):
        print("| Preparing CIFAR-100 dataset...")
        sys.stdout.write("| ")
        trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
        testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=False, transform=transform_test)
        num_classes = 100
       
    # Creating data indices for training and validation splits:
    validation_split = 0.2
    dataset_size = len(trainset)
    num_val = int(np.floor(validation_split * dataset_size))
   
    trainset, valset = torch.utils.data.random_split(trainset, [dataset_size-num_val, num_val])
   
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=True, num_workers=args.workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=args.workers, pin_memory=True)
   
    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
       
       
    if args.rate > 0:
        if args.arch =='DropoutResNet50':
            model = resnet.DropoutResNet50(num_classes, args.rate)
        elif args.arch == 'DropoutVGG':
            model = vgg.DropoutVGG('VGG16', args.rate, num_classes)
    else:
        if args.arch == 'resnet50':
            model = resnet.ResNet50(num_classes)
        elif args.arch =='resnet18':
            model = resnet.ResNet18(num_classes)
        elif args.arch =='resnet34':
            model = resnet.ResNet34(num_classes)
        elif args.arch =='resnet101':
            model = resnet.ResNet101(num_classes)
        elif args.arch == 'vgg16':
            model = vgg.VGG('VGG16', num_classes)
       
    model = torch.nn.DataParallel(model)
    model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion =  nn.CrossEntropyLoss().cuda()
    eval_criterion = nn.CrossEntropyLoss().cuda()
   
    name = '{}-dropout{}-exp_num{}-wd{}'.format(args.arch, args.rate, args.exp_num, args.weightdecay)
    name += '-{}'.format(args.dataset)
   
    if args.use_per:
        reg_loss = ProjectedErrorFunction()
        name += '-per_{}'.format(args.reg_coeff)
       
    elif args.use_wass:
        reg_loss = WassDist()
        name += '-wass_{}'.format(args.reg_coeff)
       
    elif args.use_l1:
        reg_loss = LpNorm(1)
        name += '-l1_{}'.format(args.reg_coeff)
     
    elif args.use_l2:
        reg_loss = LpNorm(2)
        name += '-l2_{}'.format(args.reg_coeff)
    else:
        reg_loss = None 

       
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weightdecay)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[100, 150], last_epoch=args.start_epoch - 1)

    # Warmup
    for param_group in optimizer.param_groups:
        param_group['lr'] = args.lr*0.1


    if args.evaluate:
        validate(val_loader, model, criterion)
        return
   
   
    for epoch in range(args.start_epoch, args.epochs):
        if epoch < 6 and epoch > 0:
            param_group['lr'] = args.lr*0.1*(epoch*2)
        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))

        train(train_loader, model, criterion, optimizer, epoch, reg_loss)
        lr_scheduler.step()
       
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best, filename=os.path.join(args.save_dir, '{}_checkpoint.th'.format(name)))

        if is_best:
            test_prec1 = validate(test_loader, model, criterion)
            test_error_best = test_prec1
            np.save(os.path.join(args.save_dir, '{}_testacc'.format(name)), test_prec1)
            save_checkpoint({
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best, filename=os.path.join(args.save_dir, '{}_model.th'.format(name)))
           

    print('='*100)
    print('='*100)
    print('\t\t', name)
    print('\t\tBEST ERROR:', 100 - test_error_best)
    print('='*100)
    print('='*100)


def train(train_loader, model, criterion, optimizer, epoch, reg_loss):
    """
        Run one train epoch
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    mi_meter = AverageMeter()


    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        target = target.cuda()
        input_var = input.cuda()
        target_var = target

        # compute output
        output = model(input_var)
       
        loss = criterion(output, target_var)
        
        if reg_loss != None:
            reg = reg_loss(output)
            loss = loss + args.reg_coeff * reg
       
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        for p in model.parameters():
            nn.utils.clip_grad_norm_(p, 1.)
        optimizer.step()

        output = output.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
       
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      epoch, i, len(train_loader), batch_time=batch_time,
                      data_time=data_time, loss=losses, top1=top1))


def validate(val_loader, model, criterion):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    mi_meter = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()


            # compute output
            output = model(input_var)    
            loss = criterion(output, target_var)
           
            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
           
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i, len(val_loader), batch_time=batch_time, loss=losses,
                          top1=top1))

    print(' * Prec@1 {top1.avg:.3f}'
          .format(top1=top1))
   
    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """
    Save the training model
    """
    torch.save(state, filename)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res



if __name__ == '__main__':
    main()