# -*- coding: utf-8 -*-
# @Author: Chen Renjie
# @Date:   2021-08-16 23:04:20
# @Last Modified by:   Chen Renjie
# @Last Modified time: 2021-10-05 21:40:52


import os
import sys
import random
import argparse

import numpy as np

from tqdm import tqdm 

import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, CIFAR100, SVHN, ImageFolder

from models import ResNet18, PreActResNet18, WideResNet_28_10


def worker_init(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    random.seed(worker_seed)
    np.random.seed(worker_seed)
    
def setup_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        cudnn.benchmark = True
        cudnn.deterministic = True

def logits_accuracy(logits, labels):
    preds = logits.argmax(dim=-1)
    acc = (preds==labels).sum().item()/labels.size(0)
    return acc

def load_model(args):
    if args.model == "ResNet18":
        model = ResNet18(args.num_class)
    elif args.model == "PreActResNet18":
        model = PreActResNet18(args.num_class, None)
    elif args.model == "WideResNet-28-10":
        model = WideResNet_28_10(args.num_class)
    else:
        model = ResNet18(args.num_class)
    return model

def load_dataset(args):
    g = torch.Generator()
    g.manual_seed(args.seed)
    train_transform = transforms.Compose([
        transforms.RandomCrop(args.img_size, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor()
    ])
    if args.dataset == "CIFAR10":
        train_set = CIFAR10(root=args.data_path, train=True, download=True, transform=train_transform)
        test_set = CIFAR10(root=args.data_path, train=False, download=False, transform=test_transform)
    elif args.dataset == "CIFAR100":
        train_set = CIFAR100(root=args.data_path, train=True, download=True, transform=train_transform)
        test_set = CIFAR100(root=args.data_path, train=False, download=False, transform=test_transform)
    elif args.dataset == "SVHN":
        train_transform = transforms.Compose([
            transforms.RandomCrop(args.img_size, padding=4),
            transforms.ToTensor()
        ])
        train_set = SVHN(root=args.data_path, split="train", download=True, transform=train_transform)
        test_set = SVHN(root=args.data_path, split="test", download=False, transform=test_transform)
    elif args.dataset == "Tiny-ImageNet":
        # os.path.join(args.data_path, "tiny-imagenet-200/train")
        # os.path.join(args.data_path, "tiny-imagenet-200/val")
        train_set = ImageFolder(root="/content/tiny-imagenet-200/train", transform=train_transform)
        test_set = ImageFolder(root="/content/tiny-imagenet-200/val", transform=test_transform)
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=2, worker_init_fn=worker_init, generator=g)
    test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=2, worker_init_fn=worker_init, generator=g)
    return train_loader, test_loader


@torch.no_grad()
def test(epoch, model, dataloader, attacker):
    assert callable(attacker)
    Acc_ori, Acc_adv = 0, 0
    device = next(model.parameters()).device
    model.eval()
    name = getattr(attacker, "__name__")
    with tqdm(enumerate(dataloader), total=dataloader.__len__(), desc=f"Test {name}") as t:
        for i, (images, labels) in t:
            images, labels = images.to(device), labels.to(device)
            images_adv = images + attacker(model, images, labels)

            Acc_ori += logits_accuracy(model(images), labels)
            Acc_adv += logits_accuracy(model(images_adv), labels)

            t.set_postfix({"ORI": f"{Acc_ori/(i+1):4.2%}", "ADV": f"{Acc_adv/(i+1):4.2%}"})
    return Acc_ori/(i+1), Acc_adv/(i+1)


def parse_args(prefix, addition={}):
    parser = argparse.ArgumentParser(description="Adversarial Traing Parser")
    parser.add_argument("--suffix", required=True, type=str, help="File suffix")
    # Basic
    parser.add_argument("--cuda", action="store_true", default=False)
    parser.add_argument("--root", default='./', type=str)
    parser.add_argument("--folder", default='result', type=str)
    parser.add_argument("--log", choices=["debug", "info"], default="debug", help='Log print level.')
    parser.add_argument("--seed", default=0, type=int, help="random seed")
    # Model
    parser.add_argument("-m", "--model", default="ResNet18", choices=["PreActResNet18", "ResNet18", "WideResNet-28-10"], type=str, help="Backbone Network {Res-18, Res-34, Res-50}.")
    parser.add_argument("--num_class", default=10, type=int, help="Number of image classes")
    # Dataset
    parser.add_argument("-d", "--dataset", default="CIFAR10", choices=["CIFAR10", "CIFAR100", "SVHN", "Tiny-ImageNet"], type=str, help="Dataset.")   # "CIFAR10"
    parser.add_argument("--normalize", action="store_true", default=False)
    # Optimizer & scheduler
    parser.add_argument("--optimizer", default="SGD", type=str)
    parser.add_argument("--lr", default=0.1, type=float, help="Learning rate")
    parser.add_argument("--momentum", default=0.9, type=float)
    parser.add_argument("--weight_decay", default=5e-4, type=float)
    parser.add_argument("--gamma", default=0.1, type=float)
    parser.add_argument("--milestones", default=[40, 45], type=int, nargs='+')
    # Hyper Parameter
    parser.add_argument("--epoch_num", default=50, type=int, help="Number of Epochs")
    parser.add_argument("--start_epoch", default=1, type=int, help="Train start from # epochs")
    parser.add_argument("--batch_size", default=128, type=int, help="Batch size")
    # Adversarial Attack
    parser.add_argument("--epsilon", default=8, type=float, help="Perturbation radius")
    parser.add_argument("--steps", default=10, type=int, help="Perturb steps")
    parser.add_argument("--step_size", default=2, type=float, help="Perturb step size")  

    for k, v in addition.items():
        parser.add_argument(k, **v)

    args = parser.parse_args(sys.argv[1:])

    args.epsilon = args.epsilon/255
    args.step_size = args.step_size/255

    args.result_path = os.path.join(args.root, args.folder, f"{args.dataset}_{args.model}", f"{prefix}_{args.suffix}")
    if not os.path.exists(args.result_path):
        os.makedirs(args.result_path)

    if "Tiny-ImageNet" in args.dataset:
        args.img_size = 64
    elif "CIFAR" in args.dataset:
        args.img_size = 32
    else:
        args.img_size = 32
    args.data_path = os.path.join(args.root, "data")

    return args

@torch.enable_grad()
def pgd(model, images, labels, steps, step_size, epsilon, restarts):
    model.eval()
    max_delta = torch.zeros_like(images)
    robust = torch.ones_like(labels, dtype=torch.bool)
    imgs = images[robust]
    for _ in range(restarts):
        if _ > 0:
            delta = torch.zeros_like(imgs).uniform_(-epsilon, epsilon)
            delta = torch.clamp(imgs + delta, 0.0, 1.0) - imgs
        else:
            delta = torch.zeros_like(imgs)

        for _ in range(steps):
            delta.requires_grad_(True)
            loss = F.cross_entropy(model(imgs + delta), labels[robust])
            grad = torch.autograd.grad(loss, [delta])[0].detach()

            delta = delta.detach() + step_size * grad.sign()
            delta = torch.clamp(delta, - epsilon, + epsilon)
            delta = torch.clamp(imgs + delta, 0.0, 1.0) - imgs

        with torch.no_grad():
            non_robust = (model(imgs + delta).argmax(dim=-1) != labels[robust])
            non_robust_index = robust.nonzero(as_tuple=False)[non_robust].squeeze()
            max_delta[non_robust_index] = delta[non_robust]
            robust[non_robust_index] = False
        if robust.sum().item() == 0: 
            break
        imgs = images[robust]

    return max_delta

def FGSM(model, images, labels):
    return pgd(model, images, labels, steps=1, step_size=8/255, epsilon=8/255, restarts=1)
def PGD10(model, images, labels):
    return pgd(model, images, labels, steps=10, step_size=2/255, epsilon=8/255, restarts=1)
def PGD_20_10(model, images, labels):
    return pgd(model, images, labels, steps=20, step_size=1/255, epsilon=8/255, restarts=10)
def PGD_50_10(model, images, labels):
    return pgd(model, images, labels, steps=50, step_size=5/2550, epsilon=8/255, restarts=10)

