from __future__ import print_function
import os
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import models
from utils import Bar, Logger, AverageMeter, accuracy
from utils_awp import TradesAWP
import torchvision
import random
import pickle
from torch.utils.data import Dataset
from PIL import Image
import wandb
from torchvision import datasets, transforms
import torchvision
from torch.utils.data.sampler import SubsetRandomSampler

parser = argparse.ArgumentParser(description='PyTorch CIFAR TRADES Adversarial Training')
parser.add_argument('--arch', type=str, default='ResNet18')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',
                    help='input batch size for testing (default: 128)')
parser.add_argument('--epochs', type=int, default=110, metavar='N',
                    help='number of epochs to train')
parser.add_argument('--start_epoch', type=int, default=1, metavar='N',
                    help='retrain from which epoch')
parser.add_argument('--data', type=str, default='CIFAR10', choices=['CIFAR10'])
parser.add_argument('--data-path', type=str, default='../data',
                    help='where is the dataset CIFAR-10')
parser.add_argument('--weight-decay', '--wd', default=5e-4,
                    type=float, metavar='W')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                    help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--attack_train_mode', action='store_true', default=False,
                    help='which mode the model should be while generating the attack')
parser.add_argument('--norm', default='l_inf', type=str, choices=['l_inf', 'l_2'],
                    help='The threat model')
parser.add_argument('--epsilon', default=8/255, type=float,
                    help='perturbation')
parser.add_argument('--num-steps', default=10, type=int,
                    help='perturb number of steps')
parser.add_argument('--step-size', default=2/255, type=float,
                    help='perturb step size')
parser.add_argument('--beta', default=6.0, type=float,
                    help='regularization, i.e., 1/lambda in TRADES')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--model-dir', default='./model-cifar-ResNet',
                    help='directory of model for saving checkpoint')
parser.add_argument('--resume-model', default='', type=str,
                    help='directory of model for retraining')
parser.add_argument('--resume-optim', default='', type=str,
                    help='directory of optimizer for retraining')
parser.add_argument('--save-freq', '-s', default=1, type=int, metavar='N',
                    help='save frequency')
parser.add_argument('--awp-gamma', default=0.005, type=float,
                    help='whether or not to add parametric noise')
parser.add_argument('--awp-warmup', default=10, type=int,
                    help='We could apply AWP after some epochs for accelerating.')
parser.add_argument('--lr_schedule', default='step',
                    help='schedule used for training')
parser.add_argument('--exp_name', default='AWP',
                    help='name of the method used for training')
### args for wandb initialization and logging in wandb ####
parser.add_argument('--wandb-run', default="AWP_sample")
parser.add_argument('--wandb-notes', default="AWP_sample")
parser.add_argument('--wandb-project', default="OAAT")
parser.add_argument('--wandb-dir', default="./wandb_log")


args = parser.parse_args()
print(args)
epsilon = args.epsilon
if args.awp_gamma <= 0.0:
    args.awp_warmup = np.infty
if args.data == 'CIFAR100':
    NUM_CLASSES = 100
elif args.data == 'CIFAR10' or args.data == 'SVHN':
    NUM_CLASSES = 10

# settings
model_dir = args.model_dir
wandb_dir = args.wandb_dir
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
if not os.path.exists(wandb_dir):
    os.makedirs(wandb_dir)
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}

# setup data loader


transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])


transform_test = transforms.Compose([
    transforms.ToTensor(),
])

trainset = getattr(datasets, args.data)(root=args.data_path, train=True, download=True, transform=transform_train)
valset = getattr(datasets, args.data)(root=args.data_path, train=True, download=True, transform=transform_test)
testset = getattr(datasets, args.data)(root=args.data_path, train=False, download=True, transform=transform_test)


train_size = 49000
valid_size = 1000
test_size  = 10000
train_indices = list(range(50000))
val_indices = []
count = np.zeros(10)
for index in range(len(trainset)):
    _, target = trainset[index]
    if(np.all(count==100)):
        break
    if(count[target]<100):
        count[target] += 1
        val_indices.append(index)
        train_indices.remove(index)


        
print("Overlap indices:",list(set(train_indices) & set(val_indices)))
print("Size of train set:",len(train_indices))
print("Size of val set:",len(val_indices))
#get data loader ofr train val and test

train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,sampler=SubsetRandomSampler(train_indices), **kwargs)
val_loader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size,sampler=SubsetRandomSampler(val_indices), **kwargs)
test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, **kwargs)
print('{} dataloader: Done'.format(args.data)) 

def perturb_input(model,
                  x_natural, args,
                  step_size=0.003,
                  epsilon=0.031,
                  perturb_steps=10,
                  distance='l_inf'):
    if args.attack_train_mode == False:
        model.eval()
    batch_size = len(x_natural)
    if distance == 'l_inf':
        x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).to(device).detach()
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_kl = F.kl_div(F.log_softmax(model(x_adv), dim=1),
                                   F.softmax(model(x_natural), dim=1),
                                   reduction='sum')
            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
            x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
    elif distance == 'l_2':
        delta = 0.001 * torch.randn(x_natural.shape).to(device).detach()
        delta = Variable(delta.data, requires_grad=True)

        # Setup optimizers
        optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2)

        for _ in range(perturb_steps):
            adv = x_natural + delta

            # optimize
            optimizer_delta.zero_grad()
            with torch.enable_grad():
                loss = (-1) * F.kl_div(F.log_softmax(model(adv), dim=1),
                                       F.softmax(model(x_natural), dim=1),
                                       reduction='sum')
            loss.backward()
            # renorming gradient
            grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1)
            delta.grad.div_(grad_norms.view(-1, 1, 1, 1))
            # avoid nan or inf if gradient is 0
            # if (grad_norms == 0).any():
            #     delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0])
            optimizer_delta.step()

            # projection
            delta.data.add_(x_natural)
            delta.data.clamp_(0, 1).sub_(x_natural)
            delta.data.renorm_(p=2, dim=0, maxnorm=epsilon)
        x_adv = Variable(x_natural + delta, requires_grad=False)
    else:
        x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).to(device).detach()
        x_adv = torch.clamp(x_adv, 0.0, 1.0)
    return x_adv


def train(model, train_loader, optimizer, epoch, awp_adversary):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    losses_clean = AverageMeter()
    top1_clean = AverageMeter()
    end = time.time()

    print('epoch: {}'.format(epoch))
    bar = Bar('Processing', max=len(train_loader))

    for batch_idx, (data,target) in enumerate(train_loader):
        x_natural, target = data.to(device), target.to(device)

        # craft adversarial examples
        x_adv = perturb_input(model=model,
                              x_natural=x_natural,args=args,
                              step_size=args.step_size,
                              epsilon=epsilon,
                              perturb_steps=args.num_steps,
                              distance=args.norm)

        model.train()
        # calculate adversarial weight perturbation
        if epoch >= args.awp_warmup:
            awp = awp_adversary.calc_awp(inputs_adv=x_adv,
                                         inputs_clean=x_natural,
                                         targets=target,
                                         beta=args.beta)
            awp_adversary.perturb(awp)

        optimizer.zero_grad()
        logits_adv = model(x_adv)
        loss_robust = F.kl_div(F.log_softmax(logits_adv, dim=1),
                               F.softmax(model(x_natural), dim=1),
                               reduction='batchmean')
        # calculate natural loss and backprop
        logits = model(x_natural)
        loss_natural = F.cross_entropy(logits, target)
        loss = loss_natural + args.beta * loss_robust


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= args.awp_warmup:
            awp_adversary.restore(awp)
            
        prec1, prec5 = accuracy(logits_adv, target, topk=(1, 5))
        prec1_clean, prec5_clean = accuracy(logits, target, topk=(1, 5))
        losses.update(loss.item(), x_natural.size(0))
        losses_clean.update(loss_natural.item(),x_natural.size(0))
        top1.update(prec1.item(), x_natural.size(0))
        top1_clean.update(prec1_clean.item(), x_natural.size(0))

        # update the parameters at last


        batch_time.update(time.time() - end)
        end = time.time()


        bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s| Total:{total:}| ETA:{eta:}| Loss:{loss:.4f}| top1:{top1:.2f}'.format(
            batch=batch_idx + 1,
            size=len(train_loader),
            bt=batch_time.val,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            loss=losses.avg,
            top1=top1.avg)
        bar.next()
    bar.finish()
    return losses.avg, top1.avg, losses_clean.avg, top1_clean.avg


def test(model, test_loader, criterion):
    global best_acc
    model.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()

    bar = Bar('Processing', max=len(test_loader))
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s| Total: {total:}| ETA: {eta:}| Loss:{loss:.4f}| top1: {top1:.2f}'.format(
                batch=batch_idx + 1,
                size=len(test_loader),
                bt=batch_time.avg,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss=losses.avg,
                top1=top1.avg)
            bar.next()
    bar.finish()
    return losses.avg, top1.avg

################################################# changed the learning rate schedule to cosine ###################################
def adjust_learning_rate_cosine(optimizer, epoch, args):
    """decrease the learning rate"""
    lr = args.lr * 0.5 * (1 + np.cos((epoch - 1) / args.epochs * np.pi))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return lr


def adjust_learning_rate_step(optimizer, epoch):
    """decrease the learning rate"""
    lr = args.lr
    if epoch >= 75:
        lr = args.lr * 0.1
    if epoch >= 90:
        lr = args.lr * 0.01
    if epoch >= 100:
        lr = args.lr * 0.001
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return lr

def main():
    model = nn.DataParallel(getattr(models, args.arch)(num_classes=NUM_CLASSES)).to(device)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # We use a proxy model to calculate AWP, which does not affect the statistics of BN.
    proxy = nn.DataParallel(getattr(models, args.arch)(num_classes=NUM_CLASSES)).to(device)
    proxy_optim = optim.SGD(proxy.parameters(), lr=args.lr)
    awp_adversary = TradesAWP(model=model, proxy=proxy, proxy_optim=proxy_optim, gamma=args.awp_gamma)

    criterion = nn.CrossEntropyLoss()
    
    wandb.init(name=args.wandb_run,notes = args.wandb_notes,project = args.wandb_project,dir = args.wandb_dir,config=args)

    logger = Logger(os.path.join(model_dir, 'log.txt'), title=args.arch)
    logger.set_names(['Learning Rate',
                      'Adv Train Loss', 'Nat Train Loss', 'Nat Val Loss',
                      'Adv Train Acc.', 'Nat Train Acc.', 'Nat Val Acc.'])

    if args.resume_model:
        model.load_state_dict(torch.load(args.resume_model, map_location=device))
    if args.resume_optim:
        optimizer.load_state_dict(torch.load(args.resume_optim, map_location=device))


    for epoch in range(args.start_epoch, args.epochs + 1):
        # adjust learning rate for SGD
        if args.lr_schedule == 'cosine':
            lr = adjust_learning_rate_cosine(optimizer, epoch, args)
        elif args.lr_schedule == 'step':
            lr = adjust_learning_rate_step(optimizer, epoch)

        # adversarial training
        adv_loss, adv_acc, clean_loss, clean_acc = train(model, train_loader, optimizer, epoch, awp_adversary)

        #Evaluation and logging
        wandb.log({'Adv Loss (Train set) (Beta*KL(Adv||Clean)': adv_loss},step=epoch)
        wandb.log({'Adv Acc @ vareps (Train set)': adv_acc},step=epoch)
        wandb.log({'Clean Loss (Train set)': clean_loss},step=epoch)
        wandb.log({'Clean Acc (Train set)': clean_acc},step=epoch)
        # evaluation on natural examples
        print('================================================================')
        #train_loss, train_acc = test(model, train_loader, criterion)

        val_loss, val_acc = test(model, val_loader, criterion)
        wandb.log({'CE loss on clean samples (Val set)': val_loss},step=epoch)
        wandb.log({'Clean Acc (Val set)': val_acc},step=epoch)
        print('================================================================')


        logger.append([lr, adv_loss, clean_loss, val_loss, adv_acc, clean_acc, val_acc])


        # save checkpoint
        if epoch % args.save_freq == 0:
            torch.save(model.state_dict(),
                       os.path.join(model_dir, '{}_{}_{}_{}_{}_{}.pkl'.format(args.exp_name, args.data, args.start_epoch, args.beta, args.weight_decay, epoch)))
            torch.save(optimizer.state_dict(),
                       os.path.join(model_dir, '{}_{}_{}_{}_{}_{}.tar'.format(args.exp_name, args.data, args.start_epoch, args.beta, args.weight_decay, epoch)))


if __name__ == '__main__':
    main()
