# The code is based on https://github.com/locuslab/robust_overfitting (Rice et al.)

import argparse
import logging
import sys
import time
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms

import os

from preactresnet import PreActResNet18

from utils import *

tiny_imagenet_mean = (0.485, 0.456, 0.406) # equals np.mean(train_set.train_data, axis=(0,1,2))/255
tiny_imagenet_std = (0.229, 0.224, 0.225) # equals np.std(train_set.train_data, axis=(0,1,2))/255

mu = torch.tensor(tiny_imagenet_mean).view(3,1,1).cuda()
std = torch.tensor(tiny_imagenet_std).view(3,1,1).cuda()

def normalize(X):
    return (X - mu)/std

upper_limit, lower_limit = 1, 0


def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)

def attack_pgd(model, X, y, epsilon, alpha, attack_iters, restarts, norm, early_stop=False, fgsm_init=None, attack='pgd', vals=None, probs=None):
    max_loss = torch.zeros(y.shape[0]).cuda()
    max_delta = torch.zeros_like(X).cuda()
    
    for _ in range(restarts):
        delta = torch.zeros_like(X).cuda()
        if attack_iters>1 or fgsm_init=='random': 
            if norm == "l_inf":
                delta.uniform_(-epsilon, epsilon)
            elif norm == "l_2":
                delta.normal_()
                d_flat = delta.view(delta.size(0),-1)
                n = d_flat.norm(p=2,dim=1).view(delta.size(0),1,1,1)
                r = torch.zeros_like(n).uniform_(0, 1)
                delta *= r/n*epsilon
            else:
                raise ValueError
        delta = clamp(delta, lower_limit-X, upper_limit-X)
        delta.requires_grad = True
        for _ in range(attack_iters):
            
            output = model(normalize(X + delta))
            if early_stop:
                index = torch.where(output.max(1)[1] == y)[0]
            else:
                index = slice(None,None,None)
            if not isinstance(index, slice) and len(index) == 0:
                break
            loss = F.cross_entropy(output, y)
            loss.backward()
            grad = delta.grad.detach()
            d = delta[index, :, :, :]
            g = grad[index, :, :, :]
            x = X[index, :, :, :]
            if norm == "l_inf":
                if attack == 'ipgd':
                    interval_val = torch.tensor(probs).cuda()
                    intervals = torch.quantile(torch.abs(g).view(g.size(0), -1), interval_val, dim=1)
                    norm_grad = g.clone().detach()
                    prev = torch.zeros_like(g)
                    for i, val in enumerate(intervals):
                        cur = val.view(g.size(0), 1, 1, 1)
                        norm_grad[(g.abs() >= prev) & (g.abs() < cur)] = vals[i]
                        prev = cur
                    norm_grad[g.abs() >= cur] = vals[-1]
                    norm_grad = norm_grad * g.sign()
                    d = torch.clamp(d + alpha * norm_grad, min=-epsilon, max=epsilon)
                else:
                    d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
            elif norm == "l_2":
                g_norm = torch.norm(g.view(g.shape[0],-1),dim=1).view(-1,1,1,1)
                scaled_g = g/(g_norm + 1e-10)
                d = (d + scaled_g*alpha).view(d.size(0),-1).renorm(p=2,dim=0,maxnorm=epsilon).view_as(d)
            d = clamp(d, lower_limit - x, upper_limit - x)
            delta.data[index, :, :, :] = d
            delta.grad.zero_()
        with torch.no_grad():
            all_loss = F.cross_entropy(model(normalize(X+delta)), y, reduction='none')
            max_delta[all_loss >= max_loss] = torch.clone(delta.detach()[all_loss >= max_loss])
            max_loss = torch.max(max_loss, all_loss)
    return max_delta.detach()

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='PreActResNet18')
    parser.add_argument('--batch-size', default=128, type=int)
    parser.add_argument('--data-dir', default='./data/', type=str)
    parser.add_argument('--epochs', default=30, type=int)
    parser.add_argument('--lr-schedule', default='cyclic', choices=['cyclic', 'onedrop'])
    parser.add_argument('--lr-drop', default=50, type=int)
    parser.add_argument('--lr-max', default=0.2, type=float)
    parser.add_argument('--attack', default='ipgd', type=str, choices=['fgsm', 'pgd', 'ipgd'])
    parser.add_argument('--epsilon', default=8, type=int)
    parser.add_argument('--pgd-alpha', default=2, type=float, help='step size for PGD')
    parser.add_argument('--steps', default=7, type=int, help='steps=1 for ipgd, step>1 for pgd')
    parser.add_argument('--imitation-steps', default=2, type=int, help='the steps for imitated pgd')
    parser.add_argument('--prob', default=2/3., type=float, help='the probability for two consecutive gradient with same sign')
    parser.add_argument('--fgsm-alpha', default=1, type=float, help='step size for I-PGD-AT and Fast-AT')
    parser.add_argument('--fgsm-init', default='random', type=str, choices=['zero', 'random'])
    parser.add_argument('--fname', default='tiny_model', type=str)
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--width-factor', default=10, type=int)
    parser.add_argument('--full-test', action='store_true', help='turn on to get test set evaluation in each epoch instead of validation set')
    parser.add_argument('--test-iters', default=10, type=int, help='number of pgd steps used in evaluation during training')
    parser.add_argument('--test-restarts', default=1, type=int, help='number of pgd restarts used in evaluation during training')
    parser.add_argument('--resume', default=0, type=int)
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--chkpt-iters', default=10, type=int)
    return parser.parse_args()

def l2_norm_batch(v):
    norms = (v ** 2).sum([1, 2, 3]) ** 0.5
    return norms

def main():
    args = get_args()

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(os.path.join(args.fname, 'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    train_transform = transforms.Compose([
        transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_data = TinyImageNet(root=args.data_dir, train=True, transform=train_transform).data
    test_data = TinyImageNet(root=args.data_dir, train=False, transform=test_transform).data

    train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, drop_last=False)
    test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=args.batch_size, shuffle=False)
   
    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18(num_classes=200, stride=2)
    else:
        raise ValueError("Unknown model")

    model = nn.DataParallel(model).cuda()
    model.train()

    params = model.parameters()

    opt = torch.optim.SGD(params, lr=args.lr_max, momentum=0.9, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()
    epochs = args.epochs

    if args.lr_schedule == 'cyclic':
        lr_schedule = lambda t: np.interp([t], [0, args.epochs // 2, args.epochs], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'onedrop':
        def lr_schedule(t):
            if t < args.lr_drop:
                return args.lr_max
            else:
                return args.lr_max / 10.
        
    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        best_test_robust_acc = torch.load(os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info("No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    # imitation values and probabilities
    vals, probs = imitation(args.imitation_steps, args.prob)
    

    total_train_time = 0
    logger.info('Epoch \t Train Time \t Test Time \t LR \t \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc \t Test FGSM Acc \t Defence Mean \t Attack Mean')
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        defence_mean = 0
        for i, (X, y) in enumerate(train_loader):
            if args.eval:
                break

            X, y = X.cuda(), y.cuda()

            lr = lr_schedule(epoch + (i + 1) / len(train_loader))
            opt.param_groups[0].update(lr=lr)

            if args.attack == 'fgsm':
                delta = attack_pgd(model, X, y, epsilon, args.fgsm_alpha * epsilon, 1, 1, 'l_inf', fgsm_init=args.fgsm_init)
            elif args.attack == 'pgd':
                pgd_alpha = max(pgd_alpha, epsilon/args.steps)
                delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.steps, 1, 'l_inf', attack=args.attack)
            elif args.attack == 'ipgd':
                delta = attack_pgd(model, X, y, epsilon, args.fgsm_alpha * epsilon, args.steps, 1, 'l_inf', fgsm_init=args.fgsm_init, attack=args.attack, vals=vals, probs=probs)
            delta = delta.detach()
            robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_n += y.size(0)
            defence_mean += torch.mean(torch.abs(delta)) * y.size(0)


        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_fgsm_acc = 0
        test_n = 0
        attack_mean = 0
        for i, (X, y) in enumerate(test_loader):
            if not epoch+1==epochs and not args.full_test and i > len(test_loader) / 10:
                break
            X, y = X.cuda(), y.cuda()

            delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.test_iters, args.test_restarts, 'l_inf', early_stop=args.eval)
            delta = delta.detach()

            robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)
            attack_mean += torch.mean(torch.abs(delta)) * y.size(0)

            delta = attack_pgd(model, X, y, epsilon, epsilon, 1, 1, 'l_inf')
            delta = delta.detach()
            fgsm_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
            test_fgsm_acc += (fgsm_output.max(1)[1] == y).sum().item()

        test_time = time.time()

        
        if not args.eval:
            logger.info('%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_robust_loss/train_n, train_robust_acc/train_n,
                test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n, test_fgsm_acc/test_n,
                defence_mean*255/train_n, attack_mean*255/test_n)

            # save checkpoint
            if (epoch+1) % args.chkpt_iters == 0 or epoch+1 == epochs:
                torch.save(model.state_dict(), os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(), os.path.join(args.fname, f'opt_{epoch}.pth'))

            # save best
            if test_robust_acc/test_n > best_test_robust_acc:
                torch.save({
                        'state_dict':model.state_dict(),
                        'test_robust_acc':test_robust_acc/test_n,
                        'test_robust_loss':test_robust_loss/test_n,
                        'test_loss':test_loss/test_n,
                        'test_acc':test_acc/test_n,
                    }, os.path.join(args.fname, f'model_best.pth'))
                best_test_robust_acc = test_robust_acc/test_n
        else:
            logger.info('%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1,
                -1, -1,
                test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n,
                -1, attack_mean*255/test_n)
        total_train_time += (train_time - start_time)

        if args.eval or epoch+1 == epochs:
            start_test_time = time.time()
            test_loss = 0
            test_acc = 0
            test_robust_loss = 0
            test_robust_acc = 0
            test_n = 0
            for i, (X, y) in enumerate(test_loader):
                X, y = X.cuda(), y.cuda()

                delta = attack_pgd(model, X, y, epsilon, pgd_alpha, 50, 10, 'l_inf', early_stop=True)
                delta = delta.detach()

                robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
                robust_loss = criterion(robust_output, y)

                output = model(normalize(X))
                loss = criterion(output, y)

                test_robust_loss += robust_loss.item() * y.size(0)
                test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
                test_loss += loss.item() * y.size(0)
                test_acc += (output.max(1)[1] == y).sum().item()
                test_n += y.size(0)
                
            
            logger.info('Total train Time: %.1f PGD50 \t time: %.1f,\t clean loss: %.4f,\t clean acc: %.4f,\t robust loss: %.4f,\t robust acc: %.4f',
                total_train_time, time.time() - start_test_time,
                test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n)
            return
        


if __name__ == "__main__":
    main()
