import pickle
import numpy as np
import torch
torch.backends.cuda.matmul.allow_tf32 = False
import os
import argparse
from models.defense.nn_cifar10 import NN_CIFAR10
from models.defense.nn_cifar100 import NN_CIFAR100
from models.defense.nn_mnist import NN_MNIST
from models.defense.resnet import ResNet18
from models.defense.baselines.TRADES.models.resnet import *
from copy import deepcopy
import time

from torch.utils.data import Dataset, DataLoader, Subset
from data.pytorch_datasets import get_dataset
import torch.nn.functional as F

np.random.seed(0)
torch.manual_seed(0)

from cleverhans.torch.attacks.projected_gradient_descent import projected_gradient_descent

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='cifar10', type=str)
parser.add_argument('-gpu', default=0, type=int, help='gpu:id to be used')
parser.add_argument('-m', '--methods', default="all", type=str, nargs='+')
parser.add_argument('-a','--attack', default='pgd', type=str)
# parser.add_argument('-eps', '--eps', type=float, default=0.031)
# parser.add_argument('-epsiter', '--epsiter', type=float, default=0.007)
# parser.add_argument('-ns', '--num_steps', type=int, default=20)
parser.add_argument('--debug', action='store_false')
# parser.add_argument('--attack_flag', action='store_true')
parser.add_argument('--val_size', type=int, default=512)
args = parser.parse_args()

if args.attack in ['square', 'mifgsm', 'pgd']:
    import torchattacks
elif args.attack in ['auto-att']:
    from autoattack import AutoAttack


if args.data.lower() == 'cifar10' or args.data.lower() == 'cifar100':
    args.eps = 0.031
    args.num_steps = 20
    args.step_size = 0.007
    args.norm_eps = 0.4465
    image_nc=3
    gen_input_nc = image_nc
    mi_eps = 0.2
    mi_steps = 3*args.num_steps
    mi_step_size = args.step_size
            
elif args.data.lower() == 'fmnist':
    args.eps = 0.3
    args.num_steps = 40
    args.step_size = 0.01
    args.norm_eps = 0.2860
    image_nc=1
    gen_input_nc = image_nc
    mi_eps = 0.305
    mi_steps = 2*args.num_steps
    mi_step_size = args.step_size
else:
    raise NotImplementedError


if args.methods=="all":
    args.methods = ["GAT", "FBF", "TRADES", "Nu_AT_4.5", "MART", "PGD_AT", "RFGSM-AT_0.05", "Ours_0.5", "Ours_AdvGAN"]

if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    device = torch.device('cuda')
    torch.cuda.set_device(args.gpu)
    print('Using Device: ', torch.cuda.get_device_name())
else:
    device = torch.device('cpu')
print('Using device:',device)

def DEBUG(log):
    if args.debug:
        print(log)

# Load the model
def load_ours(model_path):
    # net_path = os.path.join(results_dir, 'Roget', 'model-final')
    a = torch.load(model_path, map_location='cuda:0')
    sd = a.state_dict()
    return sd

def get_state_dict(model_path):
    # model_path = os.path.join(results_dir, defense, 'model-final')
    # model_path = os.path.join(results_dir, defense, 'model-nn-epoch10_new.pt')
    print('fetching from ', model_path)
    return torch.load(model_path, map_location='cuda:0')

def normalize(X, data):
    cifar10_mean = (0.4914, 0.4822, 0.4465)
    cifar10_std = (0.2471, 0.2435, 0.2616)
    fmnist_mean = 0.2860
    fmnist_std = 0.3530
    # cifar100_mean = 
    # cifar100_std = 
    if data=='cifar10':
        mu = torch.tensor(cifar10_mean).view(3,1,1).cuda()
        std = torch.tensor(cifar10_std).view(3,1,1).cuda()
    elif data=='fmnist':
        mu = torch.tensor([fmnist_mean]).cuda()
        std = torch.tensor([fmnist_std]).cuda()
    else: 
        raise NotImplementedError
    return (X - mu)/std

def get_perturbed(attack, x, y, model_copy, base_model, args, use_normalize=False):
    eps = args.eps
    if use_normalize:
        DEBUG("normalizing!")
        x = normalize(x, args.data)
        eps = eps/args.norm_eps
    if attack=='pgd':
        x_adv = projected_gradient_descent(model_copy, x, eps, args.step_size, args.num_steps, np.inf)
        # adversary = torchattacks.attacks.pgd.PGD(model_copy, eps, alpha=args.step_size, steps=args.num_steps, random_start=True)
        # x_adv = adversary(x, y)
        # x_adv = x
    elif attack=="auto-att":
        adversary = AutoAttack(model_copy, norm='Linf', eps=args.eps, version='standard', verbose=True)
        x_adv = adversary.run_standard_evaluation(x, y, bs=256)
    elif attack=="square":
        adversary = torchattacks.Square(model_copy, norm='Linf', n_queries=1000, n_restarts=1, eps=args.eps, p_init=.8, seed=0, verbose=False, loss='margin', resc_schedule=True)
        x_adv = adversary(x, y)
    elif attack=="mifgsm":
        adversary = torchattacks.MIFGSM(base_model, eps=mi_eps, alpha=mi_step_size, steps=mi_steps, decay=1.0)
        x_adv = adversary(x, y)
    else:
        raise NotImplementedError
    return x_adv

def get_model(data):
    if data=='cifar10':
        return NN_CIFAR10()
    elif data=='fmnist':
        return NN_MNIST()
    elif data=='cifar100':
        return NN_CIFAR100()
    else:
        raise NotImplementedError

def get_base_model(data):
    if data=='cifar10':
        base_model = ResNet18()
        base_model_path = 'models/defense/CIFAR10_models_LATEST/model-120-checkpoint'
        sd = torch.load(base_model_path)
        base_model.load_state_dict(sd['state_dict'])
    elif data=='fmnist':
        base_model = NN_MNIST()
        base_model_path = 'base_fmnist_model_wd0/model-100-checkpoint'
        sd = torch.load(base_model_path)['state_dict']
        new_sd = {}
        for key in sd.keys():
            new_sd['model.'+str(key)] = sd[key]
        base_model.load_state_dict(new_sd)
    elif data=='cifar100':
        base_model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=100)
        base_model_path = 'CIFAR100_resnet18_models/model-120-checkpoint'
        sd = torch.load(base_model_path)
        base_model.load_state_dict(sd['state_dict'])
    else:
        print('Dataset not available!')
        exit()
    return base_model


results_dir = f'../../results/{args.data}'

with open(f'dataset{"_"+args.data if args.data!="cifar10" else ""}_split.pkl', 'rb') as f:
    dat = pickle.load(f)
    valset = dat["val_ds"]      # use valset.dataset.transform = transforms.Compose([transforms.ToTensor()])  to correct the randomness
    print("Loaded validation set")
    print('Valset size: ', len(valset))
    
val_dl = DataLoader(valset, batch_size=args.val_size, shuffle=False)

base_model = get_base_model(args.data)

for defense in args.methods:
    save_dict = dict()

    model = get_model(args.data).to(device)
    model_copy = get_model(args.data).to(device)
    # if "RFGSM" in defense and args.data=='fmnist':
    #     defense="RFGSM-AT_0.5"s
    model_path = os.path.join(results_dir, defense,'model-final')
    print('Evaluating ', defense)
    if "Ours" in defense:    
        sd = load_ours(model_path)
        model.model.load_state_dict(sd)
        model_copy.model.load_state_dict(sd)
    elif "Nu_AT" in defense:
        sd = get_state_dict(model_path)
        try:
            model.model.load_state_dict(sd['state_dict'])
            model_copy.model.load_state_dict(sd['state_dict'])
        except:
            model.load_state_dict(sd['state_dict'])
            model_copy.load_state_dict(sd['state_dict'])
    elif "MART" in defense and args.data!='cifar100':
        sd = get_state_dict(model_path)
        model.load_state_dict(sd)
        model_copy.load_state_dict(sd)
    # elif "PGD_AT" in defense and args.data=='cifar100':
    #     sd = get_state_dict(model_path)
    #     model.model.load_state_dict(sd, strict=False)
    else:
        sd = get_state_dict(model_path)
        if args.data=='cifar10' or args.data=='cifar100':
            model.model.load_state_dict(sd)
            model_copy.model.load_state_dict(sd)
        else:
            model.load_state_dict(sd)
            model_copy.load_state_dict(sd)

    # For each example in val dataset compute perturbed loss and bitvector for whether pred is correct
    rob_loss_vec = [] # To store losses on perturbed points
    rob_acc_vec = [] # To store perturbed predictions' correctness
    nat_loss_vec = [] # To store losses on clean points
    nat_acc_vec = [] # To store clean predictions' correctness

    for i, (x, y) in enumerate(val_dl):
        # First, get robust losses and robust accuracy vector
        # Step 1: perturb
        start = time.time()
        model.eval()
        model_copy.eval()
        x = x.to(device)
        y = y.to(device)
    
        # Now, get clean losses and clean accuracy vector

        clean_outputs = model(x)
        _, predicted = torch.max(clean_outputs, 1)
        clean_accuracy = (predicted==y)
        clean_accuracy = clean_accuracy.detach().cpu()
        nat_acc_vec = np.concatenate((nat_acc_vec, clean_accuracy), axis=0)

        clean_loss = F.cross_entropy(clean_outputs, y, reduction="none").detach().cpu()
        # print(f"Loss shape is: {rob_loss.shape}")
        nat_loss_vec = np.concatenate((nat_loss_vec, clean_loss), axis=0)
        model.zero_grad()
       
        print('_batch_nat_acc_vec.sum() = ', nat_acc_vec.sum())
    
        # if args.attack_flag:
        use_normalize = ("FBF" in defense)
        # print('BEfore : nat_acc_vec.sum() = ', nat_acc_vec.sum())
        # x_adv = get_perturbed(args.attack, x.detach().clone(), y.detach().clone(), model_copy, base_model, args, use_normalize=use_normalize)
        x_adv = get_perturbed('pgd', x.detach().clone(), y.detach().clone(), model_copy, base_model, args, use_normalize=use_normalize)
        # x_adv = x
        # print('After: nat_acc_vec.sum() = ', nat_acc_vec.sum())
        
        # if model.training==True:
        #     DEBUG('Training!!!')
        
        # Step 2: forward pass and get predictions. Get correctness vector while you're at it
        rob_outputs = model_copy(x_adv)
        _, rob_predicted = torch.max(rob_outputs, 1)
        rob_accuracy = (rob_predicted==y)
        rob_accuracy = rob_accuracy.detach().cpu()
        # print(f"Accuracy shape is: {accuracy.shape}")
        rob_acc_vec = np.concatenate((rob_acc_vec, rob_accuracy), axis=0)

        # Step 3: Get the losses on the perturbed points
        rob_loss = F.cross_entropy(rob_outputs, y, reduction="none").detach().cpu()
        # print(f"Loss shape is: {rob_loss.shape}")
        rob_loss_vec = np.concatenate((rob_loss_vec, rob_loss), axis=0)
        model_copy.zero_grad()

        end = time.time()
        if i%5==0:
            DEBUG(f"Time taken for batch {i+1} is {end-start}")

    # print('nat_acc_vec.sum() = ', nat_acc_vec.sum())
    
    DEBUG(f"Serializing!")
    # Intermediate step: Save the vector of losses and the vector of correctness
    save_dict["robust loss vector"] = rob_loss_vec
    save_dict["robust acc vector"] = rob_acc_vec
    save_dict["natural acc vector"] = nat_acc_vec
    save_dict["natural loss vector"] = nat_loss_vec

    folder = 'val_stats_pkl'
    os.makedirs(folder, exist_ok=True)
    filepath = os.path.join(folder, f"{defense}_{args.data}_{args.attack}_val_stats.pkl")
    with open(filepath, 'wb') as fp:
        pickle.dump(save_dict, fp)
    print("saved pickle successfully!")
