# Load test dataset D (WITHOUT SHUFFLING!)
# Require attack
# For each model M:
# 	For i \in D:
# 		xp_i = Perturb(M, x_i, y_i)
# 		nat_loss[i] = loss(M, x_i, y_i)
# 		rob_loss[i] = loss(M, xp_i, y_i)
# 		nat_acc[i] = M(x_i) == y_i			// binary
# 		rob_acc[i] = M(xp_i) == y_i			// binary

# 	save_dict[model] = {
# 						"nat loss" : nat_loss
# 						"rob loss" : rob_loss
#                       "nat acc" : nat_acc
#                       "rob acc" : rob_acc
# 					}
    
#   filename = f"{model_with_param}_{dataset}_{attack}_test_stats.pkl"
#   with open(filename, 'wb') as fp:
#       pickle.dump(save_dict, fp)

# Incrementally....

import pickle
import numpy as np
import torch
import os
import argparse
from models.defense.nn_cifar10 import NN_CIFAR10
from models.defense.nn_cifar100 import NN_CIFAR100
from models.defense.nn_mnist import NN_MNIST
from models.attacks.AdvGAN import AdvGAN
from models.defense.resnet import ResNet18
from models.defense.baselines.TRADES.models.resnet import *
from copy import deepcopy
import time

from torch.utils.data import Dataset, DataLoader, Subset
from data.pytorch_datasets import get_dataset
import torch.nn.functional as F

from cleverhans.torch.attacks.projected_gradient_descent import projected_gradient_descent

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='cifar10', type=str)
parser.add_argument('-gpu', default=0, type=int, help='gpu:id to be used')
parser.add_argument('-m', '--methods', default="all", type=str, nargs='+')
parser.add_argument('-a','--attack', default='pgd', type=str)
# parser.add_argument('-eps', '--eps', type=float, default=0.031)
# parser.add_argument('-epsiter', '--epsiter', type=float, default=0.007)
# parser.add_argument('-ns', '--num_steps', type=int, default=20)
parser.add_argument('--debug', action='store_false')
args = parser.parse_args()

if args.attack in ['square', 'mifgsm']:
    import torchattacks
elif args.attack in ['auto-att']:
    from autoattack import AutoAttack

if args.data.lower() == 'cifar10' or args.data.lower() == 'cifar100':
    args.eps = 0.031
    args.num_steps = 20
    args.step_size = 0.007
    args.norm_eps = 0.4465
    image_nc=3
    gen_input_nc = image_nc
    mi_eps = 0.2
    mi_steps = 3*args.num_steps
    mi_step_size = args.step_size
            
elif args.data.lower() == 'fmnist':
    args.eps = 0.3
    args.num_steps = 40
    args.step_size = 0.01
    args.norm_eps = 0.2860
    image_nc=1
    gen_input_nc = image_nc
    mi_eps = 0.305
    mi_steps = 2*args.num_steps
    mi_step_size = args.step_size
else:
    raise NotImplementedError


if args.methods=="all":
    args.methods = ["GAT", "FBF", "TRADES", "Nu_AT_4.5", "MART", "PGD_AT", "RFGSM-AT_0.05", "Ours_0.5", "Ours_AdvGAN"]

if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    device = torch.device('cuda')
    torch.cuda.set_device(args.gpu)
    print('Using Device: ', torch.cuda.get_device_name())
else:
    device = torch.device('cpu')
print('Using device:',device)

def DEBUG(log):
    if args.debug:
        print(log)

# Load the model
def load_ours(model_path):
    # net_path = os.path.join(results_dir, 'Roget', 'model-final')
    a = torch.load(model_path, map_location='cuda:0')
    sd = a.state_dict()
    return sd

def get_state_dict(model_path):
    # model_path = os.path.join(results_dir, defense, 'model-final')
    # model_path = os.path.join(results_dir, defense, 'model-nn-epoch10_new.pt')
    print('fetching from ', model_path)
    return torch.load(model_path, map_location='cuda:0')

def normalize(X, data):
    cifar10_mean = (0.4914, 0.4822, 0.4465)
    cifar10_std = (0.2471, 0.2435, 0.2616)
    fmnist_mean = 0.2860
    fmnist_std = 0.3530
    cifar100_mean = (0.50707525, 0.48654878, 0.44091785)
    cifar100_std = (0.20089656, 0.19844316, 0.2022971)
    if data=='cifar10':
        mu = torch.tensor(cifar10_mean).view(3,1,1).cuda()
        std = torch.tensor(cifar10_std).view(3,1,1).cuda()
    elif data=='fmnist':
        mu = torch.tensor([fmnist_mean]).cuda()
        std = torch.tensor([fmnist_std]).cuda()
    elif data=='cifar100':
        mu = torch.tensor(cifar100_mean).view(3,1,1).cuda()
        std = torch.tensor(cifar100_std).view(3,1,1).cuda()
    else: 
        raise NotImplementedError
    return (X - mu)/std

def get_perturbed(attack, x, y, model, base_model, args, use_normalize=False):
    eps = args.eps
    if use_normalize:
        DEBUG("normalizing!")
        x = normalize(x, args.data)
        eps = eps/args.norm_eps
    if attack=='pgd':
        x_adv = projected_gradient_descent(model, x, eps, args.step_size, args.num_steps, np.inf)
    elif attack=="auto-att":
        adversary = AutoAttack(model, norm='Linf', eps=args.eps, version='standard', verbose=True)
        x_adv = adversary.run_standard_evaluation(x, y, bs=256)
    elif attack=="square":
        adversary = torchattacks.Square(model, norm='Linf', n_queries=1000, n_restarts=1, eps=args.eps, p_init=.8, seed=0, verbose=True, loss='margin', resc_schedule=True)
        x_adv = adversary(x, y)
    elif attack=="mifgsm":
        adversary = torchattacks.MIFGSM(base_model, eps=mi_eps, alpha=mi_step_size, steps=mi_steps, decay=1.0)
        x_adv = adversary(x, y)
    elif args.attack=="advgan":
        if args.data=='cifar10':
            advgan_path = 'saved_models/cifar10/AdvGAN/AdvGANnetG_epoch_60.pth'
            advgan_args = args
            advgan_args.num_labels=10
            advgan_args.num_channels=3
            advgan_args.min_clamp=0
            advgan_args.max_clamp=1
            advgan_args.adv_lr=0.001
            advgan_args.adv_epsilon=0.031
            advgan_args.adv_gan_batch_size=128
            advgan_args.adv_gan_train_epochs=60
            advgan_args.adv_gan_retrain_timestep=1
            advgan_args.adv_gan_retraining_epochs=15
        elif args.data=='fmnist':
            advgan_path = 'saved_models/fmnist/2022-09-20 01:25/AdvGANnetG_epoch_60.pth'
            advgan_args = args
            advgan_args.num_labels=10
            advgan_args.num_channels=1
            advgan_args.min_clamp=0
            advgan_args.max_clamp=1
            advgan_args.adv_lr=0.001
            advgan_args.adv_epsilon=0.3
            advgan_args.adv_gan_batch_size=128
            advgan_args.adv_gan_train_epochs=60
            advgan_args.adv_gan_retrain_timestep=1
            advgan_args.adv_gan_retraining_epochs=15
        else:
            raise NotImplementedError
        adversary = AdvGAN(None, advgan_args)
        sd = torch.load(advgan_path)
        adversary.netG.load_state_dict(sd)
        x_adv = adversary.get_perturbed(x)
    else:
        raise NotImplementedError
    return x_adv

def get_model(data):
    if data=='cifar10':
        return NN_CIFAR10()
    elif data=='fmnist':
        return NN_MNIST()
    elif data=='cifar100':
        return NN_CIFAR100()
    else:
        raise NotImplementedError

def get_base_model(data):
    if data=='cifar10':
        base_model = ResNet18()
        base_model_path = 'models/defense/CIFAR10_models_LATEST/model-120-checkpoint'
        sd = torch.load(base_model_path)
        base_model.load_state_dict(sd['state_dict'])
    elif data=='fmnist':
        base_model = NN_MNIST()
        base_model_path = 'base_fmnist_model_wd0/model-100-checkpoint'
        sd = torch.load(base_model_path)['state_dict']
        new_sd = {}
        for key in sd.keys():
            new_sd['model.'+str(key)] = sd[key]
        base_model.load_state_dict(new_sd)
    elif data=='cifar100':
        base_model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)
        base_model_path = 'CIFAR100_resnet18_models/model-120-checkpoint'
        sd = torch.load(base_model_path)
        base_model.load_state_dict(sd['state_dict'])
    else:
        print('Dataset not available!')
        exit()
    return base_model


results_dir = f'../../results/{args.data}'

# Load test set and create dataset
if not os.path.exists(f"{'cifar' if args.data=='cifar10' else args.data}_fixed_testds.pkl"):
    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    print("CAUTION: CREATING DATASET")
    print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    test_ds = get_dataset(args)[1]
    with open(f"{'cifar' if args.data=='cifar10' else args.data}_fixed_testds.pkl", 'wb') as fp:
        pickle.dump(test_ds, fp)
else:
    print("Loading fixed test dataset...")
    with open(f"{'cifar' if args.data=='cifar10' else args.data}_fixed_testds.pkl", 'rb') as fp:
        test_ds = pickle.load(fp)

test_dl = DataLoader(test_ds, batch_size=512, shuffle=False)

base_model = get_base_model(args.data)

for defense in args.methods:
    save_dict = dict()

    model = get_model(args.data).to(device)
    # if "RFGSM" in defense and args.data=='fmnist':
    #     defense="RFGSM-AT_0.5"s
    model_path = os.path.join(results_dir, defense,'model-final')
    print('Evaluating ', defense)
    if "Ours" in defense:    
        sd = load_ours(model_path)
        model.model.load_state_dict(sd)
    elif "Nu_AT" in defense:
        sd = get_state_dict(model_path)
        try:
            model.model.load_state_dict(sd['state_dict'])
        except:
            model.load_state_dict(sd['state_dict'])
    elif "MART" in defense and args.data!='cifar100':
        sd = get_state_dict(model_path)
        model.load_state_dict(sd)
    elif "clean_model" in defense:
        sd = get_state_dict(model_path)['state_dict']
        new_sd = {}
        for key in sd.keys():
            new_sd['model.'+str(key)] = sd[key]
        model.load_state_dict(new_sd)
    elif "PGD_AT" in defense and args.data=='cifar100':
        model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100).to(device)
        model.eval()
        model_path = 'saved_models/baselines/cifar100/PGD_AT/PGD-AT-model-epoch75.pt'
        sd = get_state_dict(model_path)
        model.load_state_dict(sd, strict=False)
    else:
        sd = get_state_dict(model_path)
        if args.data=='cifar10' or args.data=='cifar100':
            model.model.load_state_dict(sd)
        else:
            model.load_state_dict(sd)

    # For each example in test dataset compute perturbed loss and bitvector for whether pred is correct
    rob_loss_vec = [] # To store losses on perturbed points
    rob_acc_vec = [] # To store perturbed predictions' correctness
    nat_loss_vec = [] # To store losses on clean points
    nat_acc_vec = [] # To store clean predictions' correctness

    for i, (x, y) in enumerate(test_dl):
        # First, get robust losses and robust accuracy vector
        # Step 1: perturb
        start = time.time()
        model.eval()
        x = x.to(device)
        y = y.to(device)

        use_normalize = ("FBF" in defense)
        x_adv = get_perturbed(args.attack, x, y, model, base_model, args, use_normalize=use_normalize)
        
        # Step 2: forward pass and get predictions. Get correctness vector while you're at it
        rob_outputs = model(x_adv)
        _, predicted = torch.max(rob_outputs, 1)
        rob_accuracy = (predicted==y)
        rob_accuracy = rob_accuracy.detach().cpu()
        # print(f"Accuracy shape is: {accuracy.shape}")
        rob_acc_vec = np.concatenate((rob_acc_vec, rob_accuracy), axis=0)

        # Step 3: Get the losses on the perturbed points
        rob_loss = F.cross_entropy(rob_outputs, y, reduction="none").detach().cpu()
        # print(f"Loss shape is: {rob_loss.shape}")
        rob_loss_vec = np.concatenate((rob_loss_vec, rob_loss), axis=0)
        model.zero_grad()

        # Now, get clean losses and clean accuracy vector

        if use_normalize:
            DEBUG("normalizing! for clean instances")
            x = normalize(x, args.data)

        clean_outputs = model(x)
        _, predicted = torch.max(clean_outputs, 1)
        clean_accuracy = (predicted==y)
        clean_accuracy = clean_accuracy.detach().cpu()
        nat_acc_vec = np.concatenate((nat_acc_vec, clean_accuracy), axis=0)

        clean_loss = F.cross_entropy(clean_outputs, y, reduction="none").detach().cpu()
        # print(f"Loss shape is: {rob_loss.shape}")
        nat_loss_vec = np.concatenate((nat_loss_vec, clean_loss), axis=0)
        model.zero_grad()

        end = time.time()
        if i%5==0:
            DEBUG(f"Time taken for batch {i+1} is {end-start}")


    DEBUG(f"Serializing!")
    # Intermediate step: Save the test_ds, the vector of losses and the vector of correctness
    save_dict["robust loss vector"] = rob_loss_vec
    save_dict["robust acc vector"] = rob_acc_vec
    save_dict["natural acc vector"] = nat_acc_vec
    save_dict["natural loss vector"] = nat_loss_vec

    folder = 'test_stats_pkl'
    os.makedirs(folder, exist_ok=True)
    filepath = os.path.join(folder, f"{defense}_{args.data}_{args.attack}_test_stats.pkl")
    with open(filepath, 'wb') as fp:
        pickle.dump(save_dict, fp)
    print("saved pickle successfully!")
