import pickle
import numpy as np
import torch
import os
import argparse
import time

from torch.utils.data import Dataset, DataLoader, Subset
# from data.pytorch_datasets import get_dataset
import torch.nn.functional as F
from copy import deepcopy

np.random.seed(0)
pkl_dir = "distribution_vectors_pkls"

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='cifar10', type=str)
parser.add_argument('-m', '--methods', default="all", type=str, nargs='+')
parser.add_argument('--debug', action='store_false')
parser.add_argument('-a','--attack', default='auto-att', type=str)
parser.add_argument('-c','--att_class', default='0', type=str)

args = parser.parse_args()

## all the baselines assume pgd attack of the adversary and choose the hyperparameter accordingly using validation set
with open(f'baselines_choose_hyp_worst_case/best_hyps_{args.data}_pgd.pkl', 'rb') as f:
    best_hyps = pickle.load(f)

############################
new_best_hyps = []
for baseline in best_hyps:
    if "RFGSM" in baseline:
        new_best_hyps.append("RFGSM-AT_0.05")
    elif "PGD_AT" in baseline:
        new_best_hyps.append("PGD_AT")
    elif args.data=='fmnist' and "MART" in baseline:
        new_best_hyps.append("MART_1.0")
    else:
        new_best_hyps.append(baseline)
best_hyps=new_best_hyps

if args.data == 'cifar100':
    best_hyps.insert(-2, 'RFGSM-AT_0.5')
    best_hyps.insert(-3, 'PGD_AT')
elif args.data=='cifar10':
    best_hyps.insert(-2, 'RFGSM-AT_0.05')
    best_hyps.insert(-3, 'PGD_AT')

best_hyps.insert(1, 'FBF')

############################

if args.methods=="all":
    methods = best_hyps 

# Note: There will be NO model loading and evaluation in this file! Just dot products!

def DEBUG(log):
    if args.debug:
        print(log)

# 1. Load the distribution vector
# 2. Load the pre saved pickled robust acc vector and natural acc vectors.
# Idea: Store these in a dict (from the model eval file) and reload them here

# 1. Load the distribution vector
DEBUG("Loading distribution vector and sampled subsets")
start = time.time()
# here we are selecting the subset selection strategies based on clean loss of the vanilla classifier.
# Hence there is no role of attack method (eg pgd) here
load_path = f"attack_label_subsets/{args.data}_label_{args.att_class}.pkl"     
with open(load_path, 'rb') as fp:
    subset = pickle.load(fp)
end = time.time()
DEBUG(f"Subset vector! Time taken: {end-start}")


# Utility functions:

def eval_acc_on_subsets(acc_vec, subsets, acc_type):
    # Gets robust accuracies for each subset
    if acc_type=="rob":
        len_S = subsets[0].sum()
    elif acc_type=="nat":
        len_S = (1-subsets[0]).sum()
    
    accs = []
    for S in subsets:
        # Evaluate robust accuracy on each subset and return value
        if acc_type=="rob":
            multiplier = deepcopy(S)
        elif acc_type=="nat":
            multiplier = deepcopy(1-S)
        
        correct = (acc_vec*multiplier).sum()
        accuracy = correct/len_S
        accs.append(accuracy)
    
    return np.array(accs)

def get_overall_accs(rob_accs, nat_accs, subsets):
    # Get the overall accuracy
    len_S = subsets[0].sum()
    len_test_set = len(subsets[0])

    frac_rob = len_S/len_test_set
    frac_nat = 1 - frac_rob

    overall_accs = frac_rob*rob_accs + frac_nat*nat_accs

    return overall_accs


def evaluate_model_on_subsets(model_dict, subsets):
    # Given a model (acc, loss, and other vectors) evaluate it on subsets
    # Returns:
    # 1. A vector of robust accuracies for each subset
    # 2. A vector of natural accuracies for each subset
    # 3. A vector of overall accuracies on each subset
    # 4. A vector of robust losses for each subset
    # 5. A vector of natural losses for each subset
    # 6. A vector of overall losses on each subset

    # Get vectors corresponding to model
    rob_loss_vec = model_dict['robust loss vector']
    rob_acc_vec = model_dict['robust acc vector']
    nat_acc_vec = model_dict['natural acc vector']
    nat_loss_vec = model_dict['natural loss vector']

    # Get the accuracies and losses. Use helper functions
    rob_accs = eval_acc_on_subsets(rob_acc_vec, subsets, 'rob')
    nat_accs = eval_acc_on_subsets(nat_acc_vec, subsets, 'nat')
    overall_accs = get_overall_accs(rob_accs, nat_accs, subsets)


    return_dict = {
        "rob accs": rob_accs,
        "nat accs": nat_accs,
        "overall accs": overall_accs
    }

    return return_dict

def tex(method):
    if 'AdvGAN' in method:
        tex = '\\our\\,($\\advp = \\text{\\modelx}$)'
    elif 'Ours' in method:
        tex = '\\our\\, ($\\advp = \\text{PGD}$)'
    elif 'TRADES' in method:
        tex = '\\trades~\\ctrades'
    elif 'Nu_AT' in method:
        tex = '\\nuat~\\cnuat'
    elif 'MART' in method:
        tex = '\\mart~\\cmart'
    elif 'PGD_AT' in method:
        tex = '\\pgdat~\\cpgdat'
    elif 'RFGSM' in method:
        tex = '\\rfgsmat~\\crfgsmat'
    elif 'FBF' in method:
        tex = '\\fbf~\\cfbf'
    elif 'GAT' in method:
        tex = '\\gat~\\cgat'
    else:
        tex = method
    return tex

output_folder = f'testset_results_class_attack'
os.makedirs(output_folder, exist_ok=True)
output_path = f'{output_folder}/{args.data}_{args.attack}_label_{args.att_class}{"_"+str(args.methods) if args.methods!="all" else ""}.txt'
output_path_tex = f'{output_folder}/{args.data}_{args.attack}_label_{args.att_class}{"_"+str(args.methods) if args.methods!="all" else ""}_tex.txt'
f = open(output_path, 'w')
ft = open(output_path_tex, 'w')
f.write(f'_________________ Dataset: {args.data}, Attack: {args.attack} _________________ \n\n')

class_attack_stats_dict = {}
# 2. Load the pre saved pickled robust acc vector and natural acc vectors.
for defense in methods:
    start = time.time()
    folder = 'test_stats_pkl'
    if 'None' in defense:
        continue
    filepath = os.path.join(folder, f"{defense}_{args.data}_{args.attack}_test_stats.pkl")

    with open(filepath, 'rb') as fp:
        pkl_dict = pickle.load(fp)
    
    rob_loss_vec = pkl_dict['robust loss vector']
    rob_acc_vec = pkl_dict['robust acc vector']
    nat_acc_vec = pkl_dict['natural acc vector']
    nat_loss_vec = pkl_dict['natural loss vector']

    # Now evaluate the accuracies and losses
    class_attack_subsets = [subset]

    class_attack_stats = evaluate_model_on_subsets(pkl_dict, class_attack_subsets)

    end = time.time()
    DEBUG(f"Defense {defense} took {end-start} seconds!")

    class_attack_stats_dict[defense] = class_attack_stats


print(f"Accuracies for attack {args.attack}")
print(f"Defense: \t & natural accuracy & \t robust accuracy& \t & overall accuracy \t")
for defense in methods:
    if 'None' in defense:
        op_str= f"{defense:<16} \t\t & - &\t\t - &\t\t -"
        tex_str= f"{tex(defense)} & - & - & -"
    else:
        class_attack_stats = class_attack_stats_dict[defense]
        rob_accs = class_attack_stats["rob accs"]
        nat_accs = class_attack_stats["nat accs"]
        overall_accs = class_attack_stats["overall accs"]

        mean_comp = lambda x: x.mean()*100.0
        mean_rob_acc = mean_comp(rob_accs)
        mean_nat_acc = mean_comp(nat_accs)
        mean_overall_acc = mean_comp(overall_accs)

        op_str= f"{defense:<16} \t\t {mean_nat_acc:.3f} \t\t {mean_rob_acc:.3f} \t\t {mean_overall_acc:.3f}"

        tex_str= f"{tex(defense)} & {mean_nat_acc:.2f} & {mean_rob_acc:.2f} & {mean_overall_acc:.2f}" 

    print(op_str)
    f.write(op_str+"\n")
    if 'Ours' in defense and 'AdvGAN' not in defense:
        ft.write("\\hline \n")    
    ft.write(tex_str+"\\\ \n")

f.close()
ft.close()