import argparse
import os
import time
import shutil
import math

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np


import torchvision
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

#from models import *
import models
import paddle
import paddlehub as hub

from goodfellow_backprop import goodfellow_backprop

DATADIR = '../data'
SAVE = False
ADJUST = True


parser = argparse.ArgumentParser(description='PyTorch Cifar10 Training')
parser.add_argument('--epochs', default=160, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=128, type=int, metavar='N', help='mini-batch size (default: 128),only used for train')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=100, type=int, metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
parser.add_argument('-ct', '--cifar-type', default='10', type=int, metavar='CT', help='10 for cifar10,100 for cifar100 (default: 10)')
parser.add_argument('--arch', default='resnet20_cifar', type=str, help='architecture to use')
parser.add_argument('--reg_lambda', default=0.1, type=float, metavar='L', help='lambda for regularization term')
parser.add_argument('--reg', default='', type=str, metavar='REG', help='use reg if this is reg')
parser.add_argument('--T', default=1, type=float, metavar='T', help='temperature')
parser.add_argument('--gamma', default=1, type=float, metavar='GAMMA', help='gamma for hinge loss')
parser.add_argument('--remark', default='test', type=str, metavar='R', help='appending to the save dir')
parser.add_argument('--cyclic', default=1, type=int, metavar='C', help='how many cycles?')
parser.add_argument('--alpha', default=0.01, type=float, metavar='A', help='weight of agn')
#parser.add_argument('--gpu', default='0', type=str, metavar='GPU', help='which gpu?')

parser.add_argument("--saved_params_dir",   type=str,               default="",                        help="Directory for saving model")
parser.add_argument("--model_path",         type=str,               default="",                        help="load model path")



best_prec = 0


def main():
    global args, best_prec
    args = parser.parse_args()
    print(args)
    use_gpu = torch.cuda.is_available()
    #os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu

    # Model building
    print('=> Building model...')
    if use_gpu:
        # model can be set to anyone that I have defined in models folder
        # note the model should match to the cifar type !
        if args.cifar_type in [10, 100]:
            model = models.__dict__[args.arch](num_classes= args.cifar_type)
        elif args.cifar_type in [0, 1]: # 0 is fashionMNIST, 1 is MNIST
            model = models.__dict__[args.arch]()
        elif args.cifar_type==5:
            model = models.__dict__[args.arch](num_classes=10)
        pytorch_total_params = sum(p.numel() for p in model.parameters()) # if p.requires_grad)
        print('total num of trainable params', pytorch_total_params)

        # mkdir a new folder to store the checkpoint and best model
        if not os.path.exists('result'):
            os.makedirs('result')


        # adjust the lr according to the model type
        if isinstance(model, (models.ResNet_Cifar, models.PreAct_ResNet_Cifar)):
            model_type = 1
        elif isinstance(model, models.Wide_ResNet_Cifar):
            model_type = 2
        elif isinstance(model, (models.ResNeXt_Cifar, models.DenseNet_Cifar)):
            model_type = 3
        elif args.arch == 'conv_mnist':
            model_type = 4
        else:
            print('model type unrecognized...')
            model_type = -1

        #model = nn.DataParallel(model).cuda()
        model = model.cuda()
        if args.arch != 'linear':
            criterion = nn.CrossEntropyLoss().cuda()
        else:
            criterion = nn.MSELoss().cuda()
        optimizer = optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        #scheduler = CosineAnnealingLR(optimizer, args.epochs/args.cyclic)
        cudnn.benchmark = True
    else:
        print('Cuda is not available!')
        return

    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint "{}"'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec = checkpoint['best_prec']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    if args.model_path !='':
        ckpt = os.path.join(args.model_path, 'checkpoint.pth')
        if os.path.isfile(ckpt):
            print('=> loading checkpoint "{}"'.format(ckpt))
            checkpoint = torch.load(ckpt)
            #args.start_epoch = checkpoint['epoch']
            best_prec = checkpoint['best_prec']
            model.load_state_dict(checkpoint['state_dict'])
            #optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(ckpt, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.model_path))

    # Data loading and preprocessing
    # CIFAR10
    if args.cifar_type == 10:
        print('=> loading cifar10 data...')
        normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])

        train_dataset = torchvision.datasets.CIFAR10(
            root=DATADIR,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        test_dataset = torchvision.datasets.CIFAR10(
            root=DATADIR,
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ]))
        testloader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)
    # CIFAR100
    elif args.cifar_type == 100:
        print('=> loading cifar100 data...')
        normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])

        train_dataset = torchvision.datasets.CIFAR100(
            root=DATADIR,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)

        test_dataset = torchvision.datasets.CIFAR100(
            root=DATADIR,
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ]))
        testloader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)
    elif args.cifar_type == 0:
        print('=> loading fashion mnist...')
        # Define a transform to normalize the data
        transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize([0.5],[0.5])
                                     ])
        # Download and load the training data
        trainset = torchvision.datasets.FashionMNIST(DATADIR, download=True, train=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True)

        # Download and load the test data
        testset = torchvision.datasets.FashionMNIST(DATADIR, download=True, train=False, transform=transform)
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True)
    elif args.cifar_type == 1:
        print('=> loading mnist...')
        kwargs = {'num_workers': 1, 'pin_memory': True}
        trainloader = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST(DATADIR, train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        testloader = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST(DATADIR, train=False, transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=args.batch_size, shuffle=True, **kwargs)
    elif args.cifar_type == 5:
        print('=> loading SVHN...')

        def target_transform(target):
            return int(target[0]) - 1

        kwargs = {'num_workers': 1, 'pin_memory': True}
        trainloader = torch.utils.data.DataLoader(
            torchvision.datasets.SVHN(
                root=DATADIR, split='train', download=False,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                target_transform=None, #target_transform,
            ),
            batch_size=args.batch_size, shuffle=True, **kwargs)

        testloader = torch.utils.data.DataLoader(
            torchvision.datasets.SVHN(
                root=DATADIR, split='test', download=False,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]),
                target_transform=None, #target_transform
            ),
            batch_size=args.batch_size, shuffle=False, **kwargs)



    if args.evaluate:
        validate(testloader, model, criterion)
        return

    #norm_sum = 0
    for epoch in range(args.start_epoch, args.epochs):
        if model_type != -1 and ADJUST:
            adjust_learning_rate(optimizer, epoch, model_type)
        print('>>>> current epoch is', epoch, ' lr is', get_lr(optimizer))

        # train for one epoch
        train(trainloader, model, criterion, optimizer, epoch)

        #norm, train_loss = get_grads_norm(trainloader, model, criterion)
        #norm_sum += norm
        #print('norm_sum is', norm_sum)

        # evaluate on test set
        prec = validate(testloader, model, criterion)

        """
        if args.cyclic != 1 and (epoch+1) % (args.epochs / args.cyclic) == 0:
            scheduler = CosineAnnealingLR(optimizer, args.epochs/args.cyclic)
        else:
            scheduler.step()
        """

        """
        # remember best precision and save checkpoint
        #is_best = valid > best_prec
        is_best = (epoch == args.epochs-1)
        #best_prec = max(valid,best_prec)
        #fdir = args.saved_params_dir
        fdir = 'result/' + args.arch + args.remark + '/' + args.saved_params_dir
        if not os.path.exists(fdir):
            os.makedirs(fdir)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec': best_prec,
            'optimizer': optimizer.state_dict(),
        }, is_best, fdir, epoch+1)
        """

    #score = -args.alpha*norm_sum - train_loss
    #score = -train_loss
    #print('score is', score)
    #hub.report_final_result(score)


def is_path_valid(path):
    if path == "":
        return False
    path = os.path.abspath(path)
    dirname = os.path.dirname(path)
    if not os.path.exists(dirname):
        os.mkdir(dirname)
    return True

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def train(trainloader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    hinge_loss = AverageMeter()

    model.train()

    end = time.time()
    for i, (input, target) in enumerate(trainloader):
        #if args.cifar_type in [0]:
        if args.arch in ['fashion', 'mnist']:
            input.resize_(input.shape[0], 784)

        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.cuda(), target.cuda()

        # compute output
        output, _, _ = model.forward(input) #model(input)  # shape is (batch_size, num_of_classes)

        #each_loss = F.cross_entropy(output, target, reduction='none')
        loss = criterion(output, target)
        hinge = 0
        if args.reg == 'reg':
            reg = np.mean([(l.cpu().item())**2 for l in each_loss])
            loss +=  math.sqrt(args.reg_lambda * reg)
        elif args.reg == 'hinge':
            # compute hinge loss
            #output = F.softmax(output/args.T)
            for out, tar in zip(output, target):
                correct = out[int(tar)]
                out_copy = torch.zeros_like(out)
                out_copy += out
                out_copy[int(tar)] = -float('inf')
                m_h = (correct - torch.max(out_copy)) / args.gamma
                if m_h < 0:
                    m_h = 1
                elif m_h > 1:
                    m_h = 0
                else:
                    m_h = 1 - m_h

                h_loss = m_h
                hinge += h_loss
            hinge /= input.size(0)
            loss += args.reg_lambda * (hinge ** 2)

        # measure accuracy and record loss
        prec = accuracy(output, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec.item(), input.size(0))
        #hinge_loss.update(args.reg_lambda * (hinge**2), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0 or i == len(trainloader) - 1:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {top1.val:.3f} ({top1.avg:.3f})'.format(
                   epoch, i+1, len(trainloader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1))

def get_grads_norm(trainloader, model, criterion):
    norm_time = time.time()
    model.eval()
    losses = AverageMeter()
    avg_norm = AverageMeter()
    for i, (input, target) in enumerate(trainloader):
        input, target = input.cuda(), target.cuda()

        output, activations, linearCombs = model.forward(input)
        loss = criterion(output, target)
        losses.update(loss.item(), input.size(0))

        linearGrads = torch.autograd.grad(loss, linearCombs)
        gradients = goodfellow_backprop(activations, linearGrads)
        for sample_grad in gradients[-2]:
            norm = torch.norm(sample_grad, 2)
            avg_norm.update(norm.item(), 1)
    norm_time = time.time() - norm_time
    print('AGN:', avg_norm.avg)
    print('AGN_time:', norm_time)
    return avg_norm.avg, losses.avg


def validate(val_loader, model, criterion, test=True):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            #if args.cifar_type in [0]:
            if args.arch in ['fashion', 'mnist']:
                input.resize_(input.shape[0], 784)
            input, target = input.cuda(), target.cuda()

            # compute output
            output,_,_ = model.forward(input) #model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec = accuracy(output, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if test:
                test_str = 'Test'
            else:
                test_str = 'Valid'

            if i % args.print_freq == 0 or i == len(val_loader) - 1:
                print('{0}: [{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {top1.val:.3f} ({top1.avg:.3f})'.format(test_str,
                   i+1, len(val_loader), batch_time=batch_time, loss=losses,
                   top1=top1))

    print(' * Prec {top1.avg:.3f}% '.format(top1=top1))

    return top1.avg


def save_checkpoint(state, is_best, fdir, epoch):
    filepath = os.path.join(fdir, 'checkpoint.pth')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(fdir, 'epoch' + str(epoch) + '.pth.tar'))
        if is_path_valid(args.saved_params_dir) and os.path.exists(filepath):
            shutil.copytree(fdir, args.saved_params_dir)
            shutil.rmtree(fdir)
    elif SAVE:
        shutil.copyfile(filepath, os.path.join(fdir, 'epoch' + str(epoch) + '.pth.tar'))
        if is_path_valid(args.saved_params_dir) and os.path.exists(filepath):
            shutil.copytree(fdir, args.saved_params_dir)
            shutil.rmtree(fdir)


def adjust_learning_rate(optimizer, epoch, model_type):
    """For resnet, the lr starts from 0.1, and is divided by 10 at 80 and 120 epochs"""
    if model_type == 1:
        if epoch < 80:
            lr = args.lr
        elif epoch < 120:
            lr = args.lr * 0.1
        else:
            lr = args.lr * 0.01
    elif model_type == 2:
        if epoch < 60:
            lr = args.lr
        elif epoch < 120:
            lr = args.lr * 0.2
        elif epoch < 160:
            lr = args.lr * 0.04
        else:
            lr = args.lr * 0.008
    elif model_type == 3:
        if epoch < 150:
            lr = args.lr
        elif epoch < 225:
            lr = args.lr * 0.1
        else:
            lr = args.lr * 0.01
    elif model_type == 4:
        if epoch < 4:
            lr = args.lr
        elif epoch < 10:
            lr = args.lr * 0.4
        elif epoch < 16:
            lr = args.lr * 0.2
        elif epoch < 24:
            lr = args.lr * 0.1
        else:
            lr = args.lr * 0.02

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
                return param_group['lr']


if __name__=='__main__':
    main()

