import pickle
import numpy as np
import torch
import os
import argparse
import time

from torch.utils.data import Dataset, DataLoader, Subset
from data.pytorch_datasets import get_dataset
import torch.nn.functional as F
from copy import deepcopy

np.random.seed(0)
pkl_dir = "distribution_vectors_pkls"

parser = argparse.ArgumentParser()
parser.add_argument('--data', default='cifar10', type=str)
parser.add_argument('-gpu', default=0, type=int, help='gpu:id to be used')
parser.add_argument('-eps', '--eps', type=float, default=0.031)
parser.add_argument('-epsiter', '--epsiter', type=float, default=0.007)
parser.add_argument('-ns', '--ns', type=int, default=20)
parser.add_argument('-m', '--methods', default="all", type=str, nargs='+')
parser.add_argument('--debug', action='store_false')
parser.add_argument('-a','--attack', default='auto-att', type=str)
args = parser.parse_args()

if args.methods=="all":
    args.methods = ["GAT", "FBF", "TRADES", "Nu_AT_4.5", "MART", "PGD_AT", "RFGSM-AT_0.05", "Ours_0.5", "Ours_AdvGAN"]

# Note: There will be NO model loading and evaluation in this file! Just dot products!

def DEBUG(log):
    if args.debug:
        print(log)

# 1. Load the distribution vector
# 2. Load the pre saved pickled robust acc vector and natural acc vectors.
# Idea: Store these in a dict (from the model eval file) and reload them here

# 1. Load the distribution vector
DEBUG("Loading distribution vector and sampled subsets")
start = time.time()
load_path = f"clean_model_{args.data}_saved_acc_losses_accs.pkl"
with open(load_path, 'rb') as fp:
    distr_data = pickle.load(fp)
end = time.time()
DEBUG(f"Distribution vector and subsets loaded! Time taken: {end-start}")

# Keys from the distribution vector:
# ['robust loss vector', 'robust acc vector', 'natural acc vector', 'natural loss vector', 'distribution rob loss',
# 'distribution softmax rob loss', 'distribution softmax nat loss', 
# 'softmax rob loss subsets', 'softmax nat loss subsets', 'rob loss subsets']

# Utility functions:

def eval_acc_on_subsets(acc_vec, subsets, acc_type):
    # Gets robust accuracies for each subset
    if acc_type=="rob":
        len_S = subsets[0].sum()
    elif acc_type=="nat":
        len_S = (1-subsets[0]).sum()
    
    accs = []
    for S in subsets:
        # Evaluate robust accuracy on each subset and return value

        if acc_type=="rob":
            multiplier = deepcopy(S)
        elif acc_type=="nat":
            multiplier = deepcopy(1-S)
        
        correct = (acc_vec*multiplier).sum()
        accuracy = correct/len_S
        accs.append(accuracy)
    
    return np.array(accs)

def get_overall_accs(rob_accs, nat_accs, subsets):
    # Get the overall accuracy
    len_S = subsets[0].sum()
    len_test_set = len(subsets[0])

    frac_rob = len_S/len_test_set
    frac_nat = 1 - frac_rob

    # print(frac_nat, frac_rob)
    # print(rob_accs[:20])
    # print(nat_accs[:20])

    overall_accs = frac_rob*rob_accs + frac_nat*nat_accs

    return overall_accs


# Helper functions to do things more explicitly!
# def eval_robust_acc_on_subsets(rob_acc_vec, subsets):
#     # Gets robust accuracies for each subset
#     len_S = subsets[0].sum()
#     rob_accs = []
#     for S in subsets:
#         # Evaluate robust accuracy on each subset and return value
#         correct = (rob_acc_vec*S).sum()
#         accuracy = correct/len_S
#         rob_accs.append(accuracy)
    
#     return rob_accs

# def eval_natural_acc_on_subsets(nat_acc_vec, subsets):
#     # Gets natrual accuracies for each subset

#     len_S = (1-subsets[0]).sum()
#     print(len_S)
#     nat_accs = []
#     for S in subsets:
#         # Evaluate robust accuracy on each subset and return value
#         correct = (nat_acc_vec*(1-S)).sum()
#         accuracy = correct/len_S
#         nat_accs.append(accuracy)
    
#     return nat_accs

def evaluate_model_on_subsets(model_dict, subsets):
    # Given a model (acc, loss, and other vectors) evaluate it on subsets
    # Returns:
    # 1. A vector of robust accuracies for each subset
    # 2. A vector of natural accuracies for each subset
    # 3. A vector of overall accuracies on each subset
    # 4. A vector of robust losses for each subset
    # 5. A vector of natural losses for each subset
    # 6. A vector of overall losses on each subset

    # Get vectors corresponding to model
    rob_loss_vec = model_dict['robust loss vector']
    rob_acc_vec = model_dict['robust acc vector']
    nat_acc_vec = model_dict['natural acc vector']
    nat_loss_vec = model_dict['natural loss vector']

    # Get the accuracies and losses. Use helper functions
    rob_accs = eval_acc_on_subsets(rob_acc_vec, subsets, 'rob')
    nat_accs = eval_acc_on_subsets(nat_acc_vec, subsets, 'nat')
    overall_accs = get_overall_accs(rob_accs, nat_accs, subsets)


    return_dict = {
        "rob accs": rob_accs,
        "nat accs": nat_accs,
        "overall accs": overall_accs
    }

    return return_dict

rob_softmax_stats_dict = {}
nat_softmax_stats_dict = {}
# 2. Load the pre saved pickled robust acc vector and natural acc vectors.
for defense in args.methods:
    start = time.time()
    folder = 'test_stats_pkl'
    os.makedirs(folder, exist_ok=True)
    filepath = os.path.join(folder, f"{defense}_{args.data}_{args.attack}_test_stats.pkl")

    with open(filepath, 'rb') as fp:
        defense_dict = pickle.load(fp)
    
    # defense_dict keys: ['robust loss vector', 'robust acc vector', 'natural acc vector', 'natural loss vector']
    rob_loss_vec = defense_dict['robust loss vector']
    rob_acc_vec = defense_dict['robust acc vector']
    nat_acc_vec = defense_dict['natural acc vector']
    nat_loss_vec = defense_dict['natural loss vector']

    # Now evaluate the accuracies and losses
    rob_softmax_subsets = distr_data['softmax rob loss subsets']
    nat_softmax_subsets = distr_data['softmax nat loss subsets']

    rob_softmax_stats = evaluate_model_on_subsets(defense_dict, rob_softmax_subsets)
    nat_softmax_stats = evaluate_model_on_subsets(defense_dict, nat_softmax_subsets)

    end = time.time()
    DEBUG(f"Defense {defense} took {end-start} seconds!")

    rob_softmax_stats_dict[defense] = rob_softmax_stats
    nat_softmax_stats_dict[defense] = nat_softmax_stats


# Print the stats
# Proportional to robust loss
print("\n\nFor probability of selection of x_i in S proportional to e^\{robust_loss(x_i)\}")
print(f"Accuracies for attack {args.attack}")
print(f"Defense: \t & natural accuracy & \t robust accuracy& \t & overall accuracy \t")
for defense in args.methods:
    rob_softmax_stats = rob_softmax_stats_dict[defense]
    rob_accs = rob_softmax_stats["rob accs"]
    nat_accs = rob_softmax_stats["nat accs"]
    overall_accs = rob_softmax_stats["overall accs"]

    mean_comp = lambda x: x.mean()*100.0
    mean_rob_acc = mean_comp(rob_accs)
    mean_nat_acc = mean_comp(nat_accs)
    mean_overall_acc = mean_comp(overall_accs)

    stderr_comp = lambda x: (x.std()*100.0)/np.sqrt(len(x))
    std_err_rob_accs = stderr_comp(rob_accs)
    std_err_nat_accs = stderr_comp(nat_accs)
    std_err_overall_accs = stderr_comp(overall_accs)

    op_str= f"{defense:<12} \t\t & {mean_nat_acc:.3f}\pm{std_err_nat_accs:.4f} &\t\t \
    {mean_rob_acc:.3f}\pm{std_err_rob_accs:.4f} &\t\t {mean_overall_acc:.3f}\pm{std_err_overall_accs:.4f}"

    # op_str= f"Defense: {defense} \t & {mean_nat_acc:.3f} \
    #     {mean_rob_acc:.3f} &\t\t {mean_overall_acc:.3f}"

    print(op_str)

# Proportional to negative of nat loss
print("\n\nFor probability of selection of x_i in S proportional to e^\{-natural_loss(x_i)\}")
print(f"Accuracies for attack {args.attack}")
print(f"Defense: \t & natural accuracy & \t robust accuracy& \t & overall accuracy \t")
for defense in args.methods:
    nat_softmax_stats = nat_softmax_stats_dict[defense]
    rob_accs = nat_softmax_stats["rob accs"]
    nat_accs = nat_softmax_stats["nat accs"]
    overall_accs = nat_softmax_stats["overall accs"]

    mean_comp = lambda x: x.mean()*100.0
    mean_rob_acc = mean_comp(rob_accs)
    mean_nat_acc = mean_comp(nat_accs)
    mean_overall_acc = mean_comp(overall_accs)

    stderr_comp = lambda x: (x.std()*100.0)/np.sqrt(len(x))
    std_err_rob_accs = stderr_comp(rob_accs)
    std_err_nat_accs = stderr_comp(nat_accs)
    std_err_overall_accs = stderr_comp(overall_accs)

    op_str= f"{defense:<12} \t\t & {mean_nat_acc:.3f}\pm{std_err_nat_accs:.4f} &\t\t \
    {mean_rob_acc:.3f}\pm{std_err_rob_accs:.4f} &\t\t {mean_overall_acc:.3f}\pm{std_err_overall_accs:.4f}"

    # op_str= f"Defense: {defense} \t & {mean_nat_acc:.3f} \
    #     {mean_rob_acc:.3f} &\t\t {mean_overall_acc:.3f}"

    print(op_str)

# OLD CODE. FOR REFERENCE ONLY.
# robust_acc_file_name = os.path.join(pkl_dir, f"clean_model_{args.data}_saved_acc_losses_accs.pkl")
# # Choose distribution picking mode. Options are:
# # "distribution rob loss" :- Pr(x_i) propotional to rob_loss(x_i)
# # "distribution softmax rob loss" :- Pr(x_i) propotional to exp{rob_loss(x_i)}
# # "distribution softmax nat loss" :- Pr(x_i) propotional to exp{ -nat_loss(x_i)}
# picking_mode = "distribution softmax rob loss"
# with open(robust_acc_file_name, 'rb') as fp:
#     DEBUG(f"Loading file...")
#     start = time.time()
#     clean_model_pkl_data = pickle.load(fp)
#     end = time.time()
#     DEBUG(f"File loaded in {end-start} seconds")
#     # robust_acc_vec = a["robust acc vector"]
#     # nat_acc_vec = a["natural acc vector"]
#     # distr_vec = a[picking_mode]
#     # len_S = a["len S"]
#     # len_test_ds = a["len test ds"] # can also get from len(nat_acc_vec), need not store explicitly

# exit(0)
# # 3. Sample 10000 different subsets and maintain count of the number of times each element appears in the subset. This will be the estimate of the probabilities.
# n_samples = 10000
# len_test_ds = len(robust_acc_vec)
# counts = np.array([0]*len_test_ds)

# sample_rob_acc_vec = []
# sample_nat_acc_vec = []
# sample_combined_acc_vec = []

# for i in range(n_samples):
#     # Sample |S|=len_S points, without replacement, from the test set as per distribution distr_vec
#     S_inds = np.random.choice(np.arange(len_test_ds), len_S, p=distr_vec, replace=False)
#     # Update count
#     for ind in S_inds:
#         counts[ind] += 1
    
#     # While we're at it, compute robust and natural accuracy on this subset
#     rob_bitmask = np.zeros(np.arange(len_test_ds))
#     rob_bitmask[S_inds] = 1
#     rob_correct = np.dot(rob_bitmask, robust_acc_vec)
#     rob_acc = (float)(rob_correct)/(rob_bitmask.sum())

#     nat_bitmask = np.ones(np.arange(len_test_ds))
#     nat_bitmask[S_inds] = 0
#     nat_correct = np.dot(nat_bitmask, nat_acc_vec)
#     nat_acc = (float)(nat_correct)/(nat_bitmask.sum())

#     total_correct = nat_correct + rob_correct
#     combined_acc = (float)(total_correct)/len_test_ds

#     sample_rob_acc_vec.append(rob_acc)
#     sample_nat_acc_vec.append(nat_acc)
#     sample_combined_acc_vec.append(combined_acc)


# # Now you have the vecs. Can compute mean and variance!
# sample_rob_acc_vec = np.array(sample_rob_acc_vec)
# sample_nat_acc_vec = np.array(sample_nat_acc_vec)
# sample_combined_acc_vec = np.array(sample_combined_acc_vec)

# mean_rob_acc = sample_rob_acc_vec.mean()
# mean_nat_acc = sample_nat_acc_vec.mean()
# mean_comb_acc = sample_combined_acc_vec.mean()

# # Normalize counts to get probs
# probs = counts/n_samples

# # 4. Compute the expected accuracy over distribution of subsets.
# expected_acc = np.dot(probs, robust_acc_vec) + np.dot(1-probs, nat_acc_vec)