import pickle
import numpy as np
import torch
import os
import argparse
from models.defense.nn_cifar10 import NN_CIFAR10
from models.defense.nn_mnist import NN_MNIST
from copy import deepcopy
import time

from torch.utils.data import Dataset, DataLoader, Subset
from data.pytorch_datasets import get_dataset
import torch.nn.functional as F

from cleverhans.torch.attacks.projected_gradient_descent import projected_gradient_descent


parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='cifar10', type=str)
parser.add_argument('-gpu', default=0, type=int, help='gpu:id to be used')
# parser.add_argument('-eps', '--eps', type=float, default=0.031)
# parser.add_argument('--step_size', type=float, default=0.007)    
# parser.add_argument('-ns', '--num_steps', type=int, default=20)
parser.add_argument('--debug', action='store_false')
parser.add_argument('-a','--attack', default='pgd', type=str)
parser.add_argument('--val', action='store_true', help='to run on validation set')
parser.add_argument('-f', '--frac', type=float, default=0.1)

args = parser.parse_args()

if args.attack in ['auto-att']:
    from autoattack import AutoAttack
elif args.attack in ['square']:
    import torchattacks

#  [ \sum_i [ P_i RobustLoss(i) + (1-P(i)) CleanLoss(i) ] - E[RobustLoss(x,y)] ] / E[CleanLoss(x,y)]
def DEBUG(log):
    if args.debug:
        print(log)

def softmax(vec):
    # Assumes 1 dimensional vector
    return np.exp(vec)/(np.exp(vec).sum())

def get_perturbed(attack, x, y, model, args, use_normalize=False):
    eps = args.eps
    if use_normalize:
        DEBUG("normalizing!")
        x = normalize(x, args.data)
        eps = eps/args.norm_eps
    if attack=='pgd':
        x_adv = projected_gradient_descent(model, x, eps, args.step_size, args.num_steps, np.inf)
    elif attack=="auto-att":
        adversary = AutoAttack(model, norm='Linf', eps=args.eps, version='standard', verbose=True)
        x_adv = adversary.run_standard_evaluation(x, y, bs=256)
    elif attack=="square":
        adversary = torchattacks.Square(model, norm='Linf', n_queries=1000, n_restarts=1, eps=args.eps, p_init=.8, seed=0, verbose=True, loss='margin', resc_schedule=True)
        x_adv = adversary(x, y)
    else:
        raise NotImplementedError
    return x_adv

def get_model(data):
    if data=='cifar10':
        return NN_CIFAR10()
    elif data=='fmnist':
        return NN_MNIST()
    else:
        raise NotImplementedError

if args.data == 'cifar10':
    args.eps = 0.031
    args.num_steps = 20
    args.step_size = 0.007
    clean_model_path = "models/defense/CIFAR10_models_LATEST/model-120-checkpoint"
elif args.data == 'fmnist':
    args.eps = 0.3
    args.num_steps = 40
    args.step_size = 0.01
    clean_model_path = 'base_fmnist_model_wd0/model-100-checkpoint'
else:
    raise NotImplementedError
    

# Load clean model
with open(clean_model_path, 'rb') as fp:
    load_dict = torch.load(fp, map_location=torch.device('cpu'))
    a = get_model(args.data)
    a.model.load_state_dict(load_dict['state_dict'])
    clean_model = deepcopy(a.model)

print("Done till here...")
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    device = torch.device('cuda')
    torch.cuda.set_device(args.gpu)
    print('Using Device: ', torch.cuda.get_device_name())
else:
    device = torch.device('cpu')
print('Using device:',device)

print("Moving to device!")
start = time.time()
clean_model = clean_model.to(device)
end = time.time()
DEBUG(f"Time taken to move to GPU is {end-start}")
# exit(0)

start = time.time()
if args.val:
    with open(f'dataset{"_"+args.data if args.data!="cifar10" else ""}_split.pkl', 'rb') as f:
        dat = pickle.load(f)
        val_ds = dat["val_ds"]
        print("Loaded validation set")
        print('Valset size: ', len(val_ds))
        ds = val_ds
else:
    # Load test set and create dataset
    if not os.path.exists(f"{'cifar' if args.data=='cifar10' else args.data}_fixed_testds.pkl"):
        test_ds = get_dataset(args)[1]
        with open(f"{'cifar' if args.data=='cifar10' else args.data}_fixed_testds.pkl", 'wb') as fp:
            pickle.dump(test_ds, fp)
    else:
        print("Loading fixed test dataset...")
        with open(f"{'cifar' if args.data=='cifar10' else args.data}_fixed_testds.pkl", 'rb') as fp:
            test_ds = pickle.load(fp)
    ds = test_ds

dl = DataLoader(ds, batch_size=512, shuffle=False)
end = time.time()
DEBUG(f"Time taken to load dataset to GPU is {end-start}")
# exit(1)
# For each example in test dataset compute perturbed loss and bitvector for whether pred is correct
rob_loss_vec = [] # To store losses on perturbed points
rob_acc_vec = [] # To store perturbed predictions' correctness
nat_loss_vec = [] # To store losses on clean points
nat_acc_vec = [] # To store clean predictions' correctness
nat_max_prob_vec = [] # To store max prob of softmax for each point (clean)
rob_max_prob_vec = [] # To store max prob of softmax for each point (after perturbing)
rob_preds_vec = [] # To store the predicted labels for each perterbed point
nat_preds_vec = [] # To store the predicted labels for each clean point

for i, (x, y) in enumerate(dl):
    # First, get robust losses and robust accuracy vector
    # Step 1: perturb
    start = time.time()
    clean_model.eval()
    x = x.to(device)
    y = y.to(device)

    # Use AA instead?
    # egs = projected_gradient_descent(clean_model, x, args.eps, args.epsiter, args.ns, np.inf)
    x_adv = get_perturbed(args.attack, x, y, clean_model, args, use_normalize=False)    
    
    # Step 2: forward pass and get predictions. Get correctness vector while you're at it
    rob_outputs = clean_model(x_adv)
    max_prob_robust, rob_predicted = torch.max(rob_outputs, 1)
    rob_accuracy = (rob_predicted==y)
    rob_accuracy = rob_accuracy.detach().cpu()
    # print(f"Accuracy shape is: {accuracy.shape}")
    rob_acc_vec = np.concatenate((rob_acc_vec, rob_accuracy), axis=0)
    rob_max_prob_vec = np.concatenate((rob_max_prob_vec, max_prob_robust.detach().cpu()), axis=0)
    rob_preds_vec = np.concatenate((rob_preds_vec, rob_predicted.detach().cpu()), axis=0)

    # Step 3: Get the losses on the perturbed points
    rob_loss = F.cross_entropy(rob_outputs, y, reduction="none").detach().cpu()
    # print(f"Loss shape is: {rob_loss.shape}")
    rob_loss_vec = np.concatenate((rob_loss_vec, rob_loss), axis=0)
    clean_model.zero_grad()

    # Now, get clean losses and clean accuracy vector

    clean_outputs = clean_model(x)
    max_prob_clean, nat_predicted = torch.max(clean_outputs, 1)
    clean_accuracy = (nat_predicted==y)
    clean_accuracy = clean_accuracy.detach().cpu()
    nat_acc_vec = np.concatenate((nat_acc_vec, clean_accuracy), axis=0)
    nat_max_prob_vec = np.concatenate((nat_max_prob_vec, max_prob_clean.detach().cpu()), axis=0)
    nat_preds_vec = np.concatenate((nat_preds_vec, nat_predicted.detach().cpu()), axis=0)

    clean_loss = F.cross_entropy(clean_outputs, y, reduction="none").detach().cpu()
    # print(f"Loss shape is: {rob_loss.shape}")
    nat_loss_vec = np.concatenate((nat_loss_vec, clean_loss), axis=0)
    clean_model.zero_grad()

    end = time.time()
    if i%5==0:
        DEBUG(f"Time taken for batch {i+1} is {end-start}")


DEBUG(f"Serializing!")
# Intermediate step: Save the test_ds, the vector of losses and the vector of correctness
my_dict = {}
# my_dict["test_ds"] = test_ds
my_dict["robust loss vector"] = rob_loss_vec
my_dict["robust acc vector"] = rob_acc_vec
my_dict["natural acc vector"] = nat_acc_vec
my_dict["natural loss vector"] = nat_loss_vec
my_dict["robust max prob vector"] = rob_max_prob_vec
my_dict["natural max prob vector"] = nat_max_prob_vec
my_dict["robust predictions vector"] = rob_preds_vec
my_dict["natural predictions vector"] = nat_preds_vec

# Create distribution
my_dict["distribution rob loss"] = rob_loss_vec/np.sum(rob_loss_vec)
# Distribution where probability of x_i is proportional to e^{rob_loss(x_i)}
# Intuition: favour points that are difficult to classify when they are perturbed
# Targeting these point reduces robust accuracy since they are difficult for the classifier
my_dict["distribution softmax rob loss"] = softmax(rob_loss_vec)
# Distribution where probability of x_i is proportional to e^{-nat_loss(x_i)}
# Intuition: favour points that the classifier classifies best when they are clean. 
# Targeting these points directly reduces natural accuracy and could also potentially decrease robust accuracy 
my_dict["distribution softmax nat loss"] = softmax(-1*nat_loss_vec)

# Distribution where probability of x_i is proportional to (1 - nat_max_prob(x_i))
# Intuition: favour points that the classifier is least confident about when clean (correctly or incorrectly).
my_dict["distribution natural max prob"] = (1-nat_max_prob_vec)/np.sum(1-nat_max_prob_vec)

# Distribution where probability of x_i is proportional to (1 - rob_max_prob(x_i))
# Intuition: favour points that the classifier is least confident about when perturbed(correctly or incorrectly).
my_dict["distribution robust max prob"] = (1-rob_max_prob_vec)/np.sum(1-rob_max_prob_vec)

DEBUG("Done with loss vectors computation. Now sampling subsets! :)")

def generate_sample_from_distr(num_points, total_points, distr):
    """
    Sample a subset S of size num_points from np.arange(total_points) according to distribution distr without replacement!

    Args:
        num_points (int): The size of the subset S to be sampled
        total_points (int): The length of the whole dataset from which points are to be sampled
        distr (Numpy array): The distribution over points. Sampling is without replacement!
    
    Each sample is denoted as a bitvector. For instance if len_test_ds is 5 and len_S is 3, 
    then one subset will look like [0, 1, 1, 0, 1] denoting elements 1, 2 and 4 are selected (0-indexed).

    Returns: A list of indices that are sampled.
    """
    # Sample |S|=len_S points, without replacement, from the test set as per distribution distr_vec
    S_inds = np.random.choice(np.arange(total_points), num_points, p=distr, replace=False)
    a = np.zeros(total_points)
    a[S_inds] = 1
    return a

def generate_multiple_S_for_distr(num_subsets, len_test_ds, len_S, distr_vec):
    """
    Generate multiple subsets S from a given distribution vector

    Args:
        num_subsets (int): The number of subsets S to generate
        len_test_ds (int): The length of the whole dataset from which points are to be sampled
        len_S (int): The size of the subset S to be sampled
        distr_vec (Numpy array): The distribution over points. Sampling is without replacement!
    
    Each sample is denoted as a bitvector. For instance if len_test_ds is 5 and len_S is 3, 
    then one subset will look like [0, 1, 1, 0, 1] denoting elements 1, 2 and 4 are selected (0-indexed).

    Returns: An list of num_subsets many bitvectors denoting the different subsets S that have been sampled.
    """
    subsets_S = []
    for i in range(num_subsets):
        # Sample |S|=len_S points, without replacement, from the test set as per distribution distr_vec
        subset = generate_sample_from_distr(len_S, len_test_ds, distr_vec)
        # Create a bitvector of these pointes
        subsets_S.append(subset)
    
    return subsets_S

# Sample and store subsets so that they can be used for evaluation purposes!

n_samples = 10000
len_test_ds = len(rob_acc_vec)
len_S = int(args.frac*len_test_ds)    ## attacking on frac*% of samples

# Target: for each of the distributions, sample subsets from distr vecs and save them!
# Start with softmax rob loss: my_dict["distribution softmax rob loss"] = softmax(rob_loss_vec)
start = time.time()

softmax_rob_loss_distr = my_dict["distribution softmax rob loss"]
softmax_rob_loss_subsets = generate_multiple_S_for_distr(n_samples, len_test_ds, len_S, softmax_rob_loss_distr)
my_dict["softmax rob loss subsets"] = softmax_rob_loss_subsets
end = time.time()
DEBUG(f"Time taken for sampling {n_samples} subsets of size {len_S} each is {end-start}")


# Next, we do softmax nat loss: my_dict["distribution softmax nat loss"] = softmax(-1*nat_loss_vec)
softmax_nat_loss_distr = my_dict["distribution softmax nat loss"]
softmax_nat_loss_subsets = generate_multiple_S_for_distr(n_samples, len_test_ds, len_S, softmax_nat_loss_distr)
my_dict["softmax nat loss subsets"] = softmax_nat_loss_subsets


# Finally, we do normal rob loss: my_dict["distribution rob loss"] = rob_loss_vec/np.sum(rob_loss_vec)
rob_loss_distr = my_dict["distribution rob loss"]
rob_loss_subsets = generate_multiple_S_for_distr(n_samples, len_test_ds, len_S, rob_loss_distr)
my_dict["rob loss subsets"] = rob_loss_subsets

# (1 - Max prob) on perturbed points
rob_max_prob_distr = my_dict["distribution robust max prob"]
rob_max_prob_subsets = generate_multiple_S_for_distr(n_samples, len_test_ds, len_S, rob_max_prob_distr)
my_dict["rob max prob subsets"] = rob_max_prob_subsets

# (1 - Max prob) on clean points
nat_max_prob_distr = my_dict["distribution natural max prob"]
nat_max_prob_subsets = generate_multiple_S_for_distr(n_samples, len_test_ds, len_S, nat_max_prob_distr)
my_dict["nat max prob subsets"] = nat_max_prob_subsets

subsets_dir = 'attacked_subsets'
os.makedirs(subsets_dir, exist_ok=True)
with open(f'{subsets_dir}/clean_model_{args.data}{"_val" if args.val else ""}_frac_{args.frac}_{args.attack}_losses_accs_subsets.pkl', 'wb') as fp:
    pickle.dump(my_dict, fp)
# Suggestion: Should we do the below in a separate file?
# In that case, we'll have to load what we saved above.

# Now, we should have the vector of perturbed losses. 
# Using that as a distribution over points, sample S
# b = 0.1*len(test_ds)
# probs = losses/losses.sum()
# S_inds = np.random.choice(np.arange(len(test_ds)), b, p=probs, replace=False)
# unattacked_inds = np.setdiff1d(np.arange(len(test_ds)), S_inds)
# attacked_subset = Subset(test_ds, S_inds)
# unattacked_subset = Subset(test_ds, unattacked_inds)