'''
evaluate on the test set using the uncertainty based subset selection strategy.
the hyperparameters are chosen based on mean overall accuracy on 10000 random subsets
'''

import pickle
import numpy as np
import torch
import os
import argparse
import time

from torch.utils.data import Dataset, DataLoader, Subset
# from data.pytorch_datasets import get_dataset
import torch.nn.functional as F
from copy import deepcopy

np.random.seed(0)
pkl_dir = "distribution_vectors_pkls"

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='cifar10', type=str)
# parser.add_argument('-gpu', default=0, type=int, help='gpu:id to be used')
# parser.add_argument('-eps', '--eps', type=float, default=0.031)
# parser.add_argument('-epsiter', '--epsiter', type=float, default=0.007)
# parser.add_argument('-ns', '--ns', type=int, default=20)
parser.add_argument('-m', '--methods', default="all", type=str, nargs='+')
parser.add_argument('--debug', action='store_false')
parser.add_argument('-a','--attack', default='auto-att', type=str)
parser.add_argument('-c', '--constraint', type=float, default=0.0, help='minimum robust accuracy (on the validation set) required for a hyp to be selected')

args = parser.parse_args()

## all the baselines assume pgd attack of the adversary and choose the hyperparameter accordingly using validation set
with open(f'baselines_choose_hyp_random_c{args.constraint}/best_hyps_{args.data}_pgd_random_c{args.constraint}.pkl', 'rb') as f:
    best_hyps = pickle.load(f)

############################
new_best_hyps = []
for baseline in best_hyps:
    if "RFGSM" in baseline:
        # new_best_hyps.append("RFGSM_AT_p_1.0")
        new_best_hyps.append("RFGSM-AT_0.05")
    elif "PGD_AT" in baseline:
        # new_best_hyps.append("RFGSM_AT_p_1.0")
        new_best_hyps.append("PGD_AT")
    elif "TRADES" in baseline:
        new_best_hyps.append("TRADES_0.1")
    elif args.data=='fmnist' and "MART" in baseline:
        new_best_hyps.append("MART_1.0")
    else:
        new_best_hyps.append(baseline)
best_hyps=new_best_hyps

# if args.data == 'cifar100':
    # best_hyps.insert(-3, 'RFGSM-AT_0.5')
    # best_hyps.insert(-4, 'PGD_AT')
# best_hyps = best_hyps[:-2] + best_hyps_ours[-2:]
############################

if args.methods=="all":
    methods = best_hyps #+ ["Ours_0.5", "Ours_AdvGAN_0.5"]

# Note: There will be NO model loading and evaluation in this file! Just dot products!

def DEBUG(log):
    if args.debug:
        print(log)

# 1. Load the distribution vector
# 2. Load the pre saved pickled robust acc vector and natural acc vectors.
# Idea: Store these in a dict (from the model eval file) and reload them here

# 1. Load the distribution vector
DEBUG("Loading distribution vector and sampled subsets")
start = time.time()
# here we are selecting the subset selection strategies based on clean loss of the vanilla classifier.
# Hence there is no role of attack method (eg pgd) here
load_path = f"attacked_subsets/clean_model_{args.data}_pgd_losses_accs_subsets.pkl"     
with open(load_path, 'rb') as fp:
    distr_data = pickle.load(fp)
end = time.time()
DEBUG(f"Distribution vector and subsets loaded! Time taken: {end-start}")

# Keys from the distribution vector:
# ['robust loss vector', 'robust acc vector', 'natural acc vector', 'natural loss vector', 'distribution rob loss',
# 'distribution softmax rob loss', 'distribution softmax nat loss', 
# 'softmax rob loss subsets', 'softmax nat loss subsets', 'rob loss subsets']

# Utility functions:

def eval_acc_on_subsets(acc_vec, subsets, acc_type):
    # Gets robust accuracies for each subset
    if acc_type=="rob":
        len_S = subsets[0].sum()
    elif acc_type=="nat":
        len_S = (1-subsets[0]).sum()
    
    accs = []
    for S in subsets:
        # Evaluate robust accuracy on each subset and return value
        if acc_type=="rob":
            multiplier = deepcopy(S)
        elif acc_type=="nat":
            multiplier = deepcopy(1-S)
        
        correct = (acc_vec*multiplier).sum()
        accuracy = correct/len_S
        accs.append(accuracy)
    
    return np.array(accs)

def get_overall_accs(rob_accs, nat_accs, subsets):
    # Get the overall accuracy
    len_S = subsets[0].sum()
    len_test_set = len(subsets[0])

    frac_rob = len_S/len_test_set
    frac_nat = 1 - frac_rob

    overall_accs = frac_rob*rob_accs + frac_nat*nat_accs

    return overall_accs


# Helper functions to do things more explicitly!
# def eval_robust_acc_on_subsets(rob_acc_vec, subsets):
#     # Gets robust accuracies for each subset
#     len_S = subsets[0].sum()
#     rob_accs = []
#     for S in subsets:
#         # Evaluate robust accuracy on each subset and return value
#         correct = (rob_acc_vec*S).sum()
#         accuracy = correct/len_S
#         rob_accs.append(accuracy)
    
#     return rob_accs

# def eval_natural_acc_on_subsets(nat_acc_vec, subsets):
#     # Gets natrual accuracies for each subset

#     len_S = (1-subsets[0]).sum()
#     print(len_S)
#     nat_accs = []
#     for S in subsets:
#         # Evaluate robust accuracy on each subset and return value
#         correct = (nat_acc_vec*(1-S)).sum()
#         accuracy = correct/len_S
#         nat_accs.append(accuracy)
    
#     return nat_accs

def evaluate_model_on_subsets(model_dict, subsets):
    # Given a model (acc, loss, and other vectors) evaluate it on subsets
    # Returns:
    # 1. A vector of robust accuracies for each subset
    # 2. A vector of natural accuracies for each subset
    # 3. A vector of overall accuracies on each subset
    # 4. A vector of robust losses for each subset
    # 5. A vector of natural losses for each subset
    # 6. A vector of overall losses on each subset

    # Get vectors corresponding to model
    rob_loss_vec = model_dict['robust loss vector']
    rob_acc_vec = model_dict['robust acc vector']
    nat_acc_vec = model_dict['natural acc vector']
    nat_loss_vec = model_dict['natural loss vector']

    # Get the accuracies and losses. Use helper functions
    rob_accs = eval_acc_on_subsets(rob_acc_vec, subsets, 'rob')
    nat_accs = eval_acc_on_subsets(nat_acc_vec, subsets, 'nat')
    overall_accs = get_overall_accs(rob_accs, nat_accs, subsets)


    return_dict = {
        "rob accs": rob_accs,
        "nat accs": nat_accs,
        "overall accs": overall_accs
    }

    return return_dict

def tex(method):
    if 'AdvGAN' in method:
        tex = '\\our\\,($\\advp = \\text{\\modelx}$)'
    elif 'Ours' in method:
        tex = '\\our\\, ($\\advp = \\text{PGD}$)'
    elif 'TRADES' in method:
        tex = '\\trades~\\ctrades'
    elif 'Nu_AT' in method:
        tex = '\\nuat~\\cnuat'
    elif 'MART' in method:
        tex = '\\mart~\\cmart'
    elif 'PGD_AT' in method:
        tex = '\\pgdat~\\cpgdat'
    elif 'RFGSM' in method:
        tex = '\\rfgsmat~\\crfgsmat'
    elif 'FBF' in method:
        tex = '\\fbf~\\cfbf'
    elif 'GAT' in method:
        tex = '\\gat~\\cgat'
    else:
        tex = method
    return tex

# output_folder = f'testset_results_c{args.constraint}'
output_folder = f'testset_results_random_set_selection_c{args.constraint}'
os.makedirs(output_folder, exist_ok=True)
output_path = f'{output_folder}/{args.data}_{args.attack}{"_"+str(args.methods) if args.methods!="all" else ""}.txt'
output_path_tex = f'{output_folder}/{args.data}_{args.attack}{"_"+str(args.methods) if args.methods!="all" else ""}_tex.txt'
f = open(output_path, 'w')
ft = open(output_path_tex, 'w')
f.write(f'_________________ Dataset: {args.data}, Attack: {args.attack} _________________ \n\n')

uncert_stats_dict = {}
# 2. Load the pre saved pickled robust acc vector and natural acc vectors.
for defense in methods:
    # defense_dict = {}
    # for hyp in hyp_dict[defense]:
    start = time.time()
    folder = 'test_stats_pkl'
    # os.makedirs(folder, exist_ok=True)
    if 'None' in defense:
        continue
    filepath = os.path.join(folder, f"{defense}_{args.data}_{args.attack}_test_stats.pkl")

    with open(filepath, 'rb') as fp:
        pkl_dict = pickle.load(fp)
    
    # defense_dict keys: ['robust loss vector', 'robust acc vector', 'natural acc vector', 'natural loss vector']
    rob_loss_vec = pkl_dict['robust loss vector']
    rob_acc_vec = pkl_dict['robust acc vector']
    nat_acc_vec = pkl_dict['natural acc vector']
    nat_loss_vec = pkl_dict['natural loss vector']

    # Now evaluate the accuracies and losses
    uncert_subsets = distr_data["nat max prob subsets"]

    uncert_stats = evaluate_model_on_subsets(pkl_dict, uncert_subsets)

    end = time.time()
    DEBUG(f"Defense {defense} took {end-start} seconds!")

    uncert_stats_dict[defense] = uncert_stats


# Proportional to the uncertainty
print("\n\nFor probability of selection of x_i in S proportional to 1-MaxPi(x_i)")
print(f"Accuracies for attack {args.attack}")
print(f"Defense: \t & natural accuracy & \t robust accuracy& \t & overall accuracy \t")
for defense in methods:
    if 'None' in defense:
        op_str= f"{defense:<16} \t\t & - &\t\t - &\t\t -"
        tex_str= f"{tex(defense)} & - & - & -"
    else:
        uncert_stats = uncert_stats_dict[defense]
        rob_accs = uncert_stats["rob accs"]
        nat_accs = uncert_stats["nat accs"]
        overall_accs = uncert_stats["overall accs"]

        mean_comp = lambda x: x.mean()*100.0
        mean_rob_acc = mean_comp(rob_accs)
        mean_nat_acc = mean_comp(nat_accs)
        mean_overall_acc = mean_comp(overall_accs)

        stderr_comp = lambda x: (x.std()*100.0)/np.sqrt(len(x))
        std_err_rob_accs = stderr_comp(rob_accs)
        std_err_nat_accs = stderr_comp(nat_accs)
        std_err_overall_accs = stderr_comp(overall_accs)

        op_str= f"{defense:<16} \t\t & {mean_nat_acc:.3f}\pm{std_err_nat_accs:.4f} &\t\t \
        {mean_rob_acc:.3f}\pm{std_err_rob_accs:.4f} &\t\t {mean_overall_acc:.3f}\pm{std_err_overall_accs:.4f}"

        # tex_str= f"{tex(defense)} & {mean_nat_acc:.2f}$\\pm${std_err_nat_accs:.3f} & {mean_rob_acc:.2f}$\\pm${std_err_rob_accs:.3f} & {mean_overall_acc:.2f}$\\pm${std_err_overall_accs:.3f}" 
        tex_str= f"{tex(defense)} & {mean_nat_acc:.2f} & {mean_rob_acc:.2f} & {mean_overall_acc:.2f}" 

    print(op_str)
    f.write(op_str+"\n")
    if 'Ours' in defense and 'AdvGAN' not in defense:
        ft.write("\\hline \n")    
    ft.write(tex_str+"\\\ \n")

    # f.write(f'\n {defense} | overall acc = {highest_overall_acc:.3f} \pm {best_err_ove:.4f}, natural acc = {best_nat_acc:.3f} \pm {best_err_nat:.4f}, robust acc = {best_rob_acc:.3f} \pm {best_err_rob:.4f} \n\n\n')
f.close()
ft.close()
