import numpy as np
import torch
import utils
import argparse
import torch.nn as nn
import torch.utils
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import genotypes
from torch.autograd import Variable
from model import NetworkCIFAR
from loss import trades_loss, madry_loss
from utils import madry_generate
import torch.nn.functional as F
from utils import load_tinyimagenet
import logging
import datetime



parser = argparse.ArgumentParser("cifar")
parser.add_argument('--gpu', type=int, default=1, help='gpu device id')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--epochs', type=int, default=120, help='num of training epochs')
parser.add_argument('--adv_loss', type=str, default='pgd', help='experiment name')
parser.add_argument('--data', type=str, default="", help='location of the data corpus')
parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--epsilon', type=float, default=8/255, help='perturbation')
parser.add_argument('--num_steps', type=int, default=7, help='perturb number of steps')
parser.add_argument('--step_size', type=float, default=0.01, help='perturb step size')
parser.add_argument('--beta', type=float, default=6.0, help='regularization in TRADES')
parser.add_argument('--init_channels', type=int, default=32, help='num of init channels')
parser.add_argument('--layers', type=int, default=10, help='total number of layers')
parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower')
parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path probability')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--data_dir', type=str, default="", help='imagenet-200 datadir')
parser.add_argument('--workers', type=int, default=4, help='data loader workers')
parser.add_argument('--resume', type=str, default="", help='path to the checkpoint to resume from')

args = parser.parse_args()


CIFAR_CLASSES = 100



def main():
    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    print('gpu device: ', args.gpu)
    print("args: ", args)
    MAX_WEIGHT_LOSS =1.2
    genotype = genotypes.search_cifar10_338_0
    model = NetworkCIFAR(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
    print(model)
    print("param size: ", utils.count_parameters_in_MB(model), 'MB')
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch'] + 1
        best_acc = checkpoint['valid_acc']
        print(f"Resumed from checkpoint: {args.resume}, starting from epoch {start_epoch}")
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.1
        print(f"Learning rate manually set to {param_group['lr']}")
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.cuda()


    model = model.cuda()

    train_transform, valid_transform = utils._data_transforms_cifar100else(args)
    train_data = dset.CIFAR100(root=args.data, train=True, download=True, transform=train_transform)
    valid_data = dset.CIFAR100(root=args.data, train=False, download=True, transform=valid_transform)
    #
    train_queue = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=4)
    valid_queue = torch.utils.data.DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, num_workers=4)

    base_lambda = 2

    best_acc = 0

    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        print('epoch: %d' % epoch)
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        lambda_epoch = min(base_lambda * ((epoch + 1) / 60), MAX_WEIGHT_LOSS)

        if epoch < 20:
            lambda_alp = 0.5 * (epoch+1 / 20.0)
        else:
            lambda_alp = 0.5

        train_acc, train_obj = train(train_queue, model, criterion, optimizer, args, lambda_epoch, lambda_alp)


        print('train_acc: ', train_acc)

        valid_acc, valid_obj = infer(valid_queue, model, criterion)

        if valid_acc > best_acc:
            best_acc = valid_acc
            utils.save(model, epoch, optimizer, valid_acc)
            print('Save model!')
        if epoch>100:
            utils.save(model, epoch, optimizer, valid_acc)
            print('Save model!')
        print('valid acc: ', valid_acc, 'best_acc: ', best_acc)



def train(train_queue, model, criterion, optimizer, args, lambda_epoch, lambda_alp, misclass_weight = 1.2 ):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    robust_top1 = utils.AvgrageMeter()  # For robust accuracy
    robust_top5 = utils.AvgrageMeter()  # For robust accuracy
    standard_losses = utils.AvgrageMeter()  # For tracking standard loss
    robust_losses = utils.AvgrageMeter()  # For tracking robust loss
    for step, (input, target) in enumerate(train_queue):
        model.train()
        n = input.size(0)

        input = Variable(input, requires_grad=False).cuda()
        target = Variable(target, requires_grad=False).cuda(non_blocking=True)

        # Generate adversarial example (PGD)
        input_pgd = madry_generate(model, input, target, optimizer,
                               step_size=args.step_size, epsilon=args.epsilon,
                               perturb_steps=args.num_steps)

        logits_standard, logits_aux = model(input)
        # logits_standard = model(input)
        loss_standard = criterion(logits_standard, target)

        # Adversarial loss: loss on the adversarial data
        logits_pgd, _ = model(input_pgd)
        # logits_pgd = model(input_pgd)

        loss_pgd = F.cross_entropy(logits_pgd, target, reduction='none')

        loss_alp = F.mse_loss(logits_standard, logits_pgd.detach())

        preds_pgd = logits_pgd.argmax(dim=1)

        sample_weights = torch.ones_like(target, dtype=torch.float, device=target.device)

        sample_weights[preds_pgd != target] = misclass_weight

        loss_adv_weighted = (sample_weights * loss_pgd).mean()
        print(lambda_epoch)
        loss = loss_standard + lambda_epoch *loss_adv_weighted+ lambda_alp * loss_alp
        if args.auxiliary:
            loss_aux = criterion(logits_aux, target)
            loss += args.auxiliary_weight * loss_aux


        optimizer.zero_grad()
        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        # Compute standard accuracy (on original data)
        prec1_standard, prec5_standard = utils.accuracy(logits_standard, target, topk=(1, 5))

        # Compute robust accuracy (on adversarial data)
        prec1_robust, prec5_robust = utils.accuracy(logits_pgd, target, topk=(1, 5))

        # Update standard metrics
        top1.update(prec1_standard.data.item(), n)
        top5.update(prec5_standard.data.item(), n)

        # Update robust metrics
        robust_top1.update(prec1_robust.data.item(), n)
        robust_top5.update(prec5_robust.data.item(), n)

        # Update loss meters
        standard_losses.update(loss_standard.data.item(), n)

        # Update loss
        objs.update(loss.data.item(), n)

        print(f"Step [{step + 1}/{len(train_queue)}] | "
              f"Loss: {loss.item():.4f} | "
              f"last Top-1robust Accuracy: {prec1_robust.item():.2f}% | "
              f"last Top-1std Accuracy: {prec1_standard.item():.2f}% |"
              f"loss_standard: {loss_standard} |"
              f"loss_adv_weighted: {lambda_epoch *loss_adv_weighted} |"
              f"loss_loss_aux: {loss_aux} |"
              f"loss_alp: {loss_alp} ")


    return top1.avg, objs.avg


def infer(valid_queue, model, criterion):
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    model.eval()

    with torch.no_grad():
        for step, (input, target) in enumerate(valid_queue):
            input = Variable(input, requires_grad=False).cuda(non_blocking=True)
            target = Variable(target, requires_grad=False).cuda(non_blocking=True)

            logits, _ = model(input)
            # logits = model(input)

            loss = criterion(logits, target)

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            n = input.size(0)
            objs.update(loss.data.item(), n)
            top1.update(prec1.data.item(), n)
            top5.update(prec5.data.item(), n)

    return top1.avg, objs.avg



def adjust_learning_rate(optimizer, epoch):
    lr = args.learning_rate
    if epoch >= 99:
        lr = args.learning_rate * 0.1
    if epoch >= 104:
        lr = args.learning_rate * 0.01
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


if __name__ == '__main__':
    main()
