'''
choose the default hyperparameter selection 
evaluate on n random subsets
find the min overall accuracy 
and report the corresponding natural and robust accuracy
'''


import pickle
import numpy as np
import torch
import os
import argparse
import time

from torch.utils.data import Dataset, DataLoader, Subset
from data.pytorch_datasets import get_dataset
import torch.nn.functional as F
from copy import deepcopy

np.random.seed(0)
pkl_dir = "distribution_vectors_pkls"

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='cifar10', type=str)
parser.add_argument('-m', '--methods', default="all", type=str, nargs='+')
parser.add_argument('--debug', action='store_false')
parser.add_argument('-a','--attack', default='auto-att', type=str)
parser.add_argument('-c', '--constraint', type=float, default=0.0, help='minimum robust accuracy (on the validation set) required for a hyp to be selected')
parser.add_argument('-n','--num_subsets', type=int, default=10000)

args = parser.parse_args()

## the best hyps attack is fixed to pgd
## all the baselines assume pgd attack of the adversary and choose the hyperparameter accordingly using validation set
with open(f'baselines_choose_suggested/best_hyps_{args.data}.pkl', 'rb') as f:  
    best_hyps = pickle.load(f)

############################
new_best_hyps = []
for baseline in best_hyps:
    if "RFGSM" in baseline and args.data=='cifar10':
        new_best_hyps.append("RFGSM_AT_p_1.0")
        # new_best_hyps.append("RFGSM-AT_0.05")
    # elif "PGD_AT" in baseline:
    #     # new_best_hyps.append("RFGSM_AT_p_1.0")
    #     new_best_hyps.append("PGD_AT")
    
    ###########
    elif "Ours" in baseline and "AdvGAN" not in baseline:
        new_best_hyps.append("Ours_1.0")
    elif "Ours_AdvGAN" in baseline:
        new_best_hyps.append("Ours_AdvGAN_1.0")
    ##########
    
    elif args.data=='fmnist' and "MART" in baseline:
        new_best_hyps.append("MART_1.0")
    else:
        new_best_hyps.append(baseline)
best_hyps=new_best_hyps
best_hyps.insert(1, 'FBF')
############################

if args.methods=="all":
    args.methods = best_hyps #+ ["Ours_0.5", "Ours_AdvGAN_0.5"]

# Note: There will be NO model loading and evaluation in this file! Just dot products!

def DEBUG(log):
    if args.debug:
        print(log)

# 1. Load the distribution vector
# 2. Load the pre saved pickled robust acc vector and natural acc vectors.
# Idea: Store these in a dict (from the model eval file) and reload them here

# 1. Load the distribution vector
DEBUG("Loading distribution vector and sampled subsets")
start = time.time()
# here we are selecting the subset selection strategies based on clean loss of the vanilla classifier.
# Hence there is no role of attack method (eg pgd) here
load_path = f"attacked_subsets/random_subsets_{args.num_subsets}.pkl"     
with open(load_path, 'rb') as fp:
    random_subsets = pickle.load(fp)
end = time.time()
DEBUG(f"Distribution vector and subsets loaded! Time taken: {end-start}")

# Keys from the distribution vector:
# ['robust loss vector', 'robust acc vector', 'natural acc vector', 'natural loss vector', 'distribution rob loss',
# 'distribution softmax rob loss', 'distribution softmax nat loss', 
# 'softmax rob loss subsets', 'softmax nat loss subsets', 'rob loss subsets']

# Utility functions:

def eval_acc_on_subsets(acc_vec, subsets, acc_type):
    # Gets robust accuracies for each subset
    if acc_type=="rob":
        len_S = subsets[0].sum()
    elif acc_type=="nat":
        len_S = (1-subsets[0]).sum()
    
    accs = []
    for S in subsets:
        # Evaluate robust accuracy on each subset and return value
        if acc_type=="rob":
            multiplier = deepcopy(S)
        elif acc_type=="nat":
            multiplier = deepcopy(1-S)
        
        correct = (acc_vec*multiplier).sum()
        accuracy = correct/len_S
        accs.append(accuracy)
    
    return np.array(accs)

def get_overall_accs(rob_accs, nat_accs, subsets):
    # Get the overall accuracy
    len_S = subsets[0].sum()
    len_test_set = len(subsets[0])

    frac_rob = len_S/len_test_set
    frac_nat = 1 - frac_rob

    overall_accs = frac_rob*rob_accs + frac_nat*nat_accs

    return overall_accs


# Helper functions to do things more explicitly!
# def eval_robust_acc_on_subsets(rob_acc_vec, subsets):
#     # Gets robust accuracies for each subset
#     len_S = subsets[0].sum()
#     rob_accs = []
#     for S in subsets:
#         # Evaluate robust accuracy on each subset and return value
#         correct = (rob_acc_vec*S).sum()
#         accuracy = correct/len_S
#         rob_accs.append(accuracy)
    
#     return rob_accs

# def eval_natural_acc_on_subsets(nat_acc_vec, subsets):
#     # Gets natrual accuracies for each subset

#     len_S = (1-subsets[0]).sum()
#     print(len_S)
#     nat_accs = []
#     for S in subsets:
#         # Evaluate robust accuracy on each subset and return value
#         correct = (nat_acc_vec*(1-S)).sum()
#         accuracy = correct/len_S
#         nat_accs.append(accuracy)
    
#     return nat_accs

def evaluate_model_on_subsets(model_dict, subsets):
    # Given a model (acc, loss, and other vectors) evaluate it on subsets
    # Returns:
    # 1. A vector of robust accuracies for each subset
    # 2. A vector of natural accuracies for each subset
    # 3. A vector of overall accuracies on each subset
    # 4. A vector of robust losses for each subset
    # 5. A vector of natural losses for each subset
    # 6. A vector of overall losses on each subset

    # Get vectors corresponding to model
    rob_loss_vec = model_dict['robust loss vector']
    rob_acc_vec = model_dict['robust acc vector']
    nat_acc_vec = model_dict['natural acc vector']
    nat_loss_vec = model_dict['natural loss vector']

    # Get the accuracies and losses. Use helper functions
    rob_accs = eval_acc_on_subsets(rob_acc_vec, subsets, 'rob')
    nat_accs = eval_acc_on_subsets(nat_acc_vec, subsets, 'nat')
    overall_accs = get_overall_accs(rob_accs, nat_accs, subsets)


    return_dict = {
        "rob accs": rob_accs,
        "nat accs": nat_accs,
        "overall accs": overall_accs
    }

    return return_dict

def tex(method):
    if 'AdvGAN' in method:
        tex = '\\our\\,($\\advp = \\text{\\modelx}$)'
    elif 'Ours' in method:
        tex = '\\our\\, ($\\advp = \\text{PGD}$)'
    elif 'TRADES' in method:
        tex = '\\trades~\\ctrades'
    elif 'Nu_AT' in method:
        tex = '\\nuat~\\cnuat'
    elif 'MART' in method:
        tex = '\\mart~\\cmart'
    elif 'PGD_AT' in method:
        tex = '\\pgdat~\\cpgdat'
    elif 'RFGSM' in method:
        tex = '\\rfgsmat~\\crfgsmat'
    elif 'FBF' in method:
        tex = '\\fbf~\\cfbf'
    elif 'GAT' in method:
        tex = '\\gat~\\cgat'
    return tex

output_folder = f'testset_results_random_min_overall_2'
os.makedirs(output_folder, exist_ok=True)
output_path = f'{output_folder}/{args.data}_{args.attack}.txt'
output_path_tex = f'{output_folder}/{args.data}_{args.attack}_tex.txt'
f = open(output_path, 'w')
ft = open(output_path_tex, 'w')
f.write(f'_________________ Dataset: {args.data}, Attack: {args.attack} _________________ \n\n')

random_stats_dict = {}
# 2. Load the pre saved pickled robust acc vector and natural acc vectors.
for defense in args.methods:
    # defense_dict = {}
    # for hyp in hyp_dict[defense]:
    start = time.time()
    folder = 'test_stats_pkl'
    # os.makedirs(folder, exist_ok=True)
    if 'None' in defense:
        continue
    filepath = os.path.join(folder, f"{defense}_{args.data}_{args.attack}_test_stats.pkl")

    with open(filepath, 'rb') as fp:
        pkl_dict = pickle.load(fp)
    
    # defense_dict keys: ['robust loss vector', 'robust acc vector', 'natural acc vector', 'natural loss vector']
    rob_loss_vec = pkl_dict['robust loss vector']
    rob_acc_vec = pkl_dict['robust acc vector']
    nat_acc_vec = pkl_dict['natural acc vector']
    nat_loss_vec = pkl_dict['natural loss vector']

    # Now evaluate the accuracies and losses
    # uncert_subsets = distr_data["nat max prob subsets"]
    
    random_stats = evaluate_model_on_subsets(pkl_dict, random_subsets)

    end = time.time()
    DEBUG(f"Defense {defense} took {end-start} seconds!")

    random_stats_dict[defense] = random_stats


print(f"Accuracies for attack {args.attack}")
print(f"Defense: \t natural acc \t robust acc \t min overall acc \t")
for defense in args.methods:
    if 'None' in defense:
        op_str= f"{defense:<16} \t\t & - &\t\t - &\t\t -"
        tex_str= f"{tex(defense)} & - & - & -"
    else:
        random_stats = random_stats_dict[defense]
        rob_accs = random_stats["rob accs"]
        nat_accs = random_stats["nat accs"]
        overall_accs = random_stats["overall accs"]

        min_overall_idx = np.argmin(overall_accs)
        min_overall_acc = overall_accs[min_overall_idx]*100
        corres_natural_acc = nat_accs[min_overall_idx]*100
        corres_robust_acc = rob_accs[min_overall_idx]*100

        op_str= f"{defense:<16} \t\t {corres_natural_acc:.2f} \t\t {corres_robust_acc:.2f}\t\t {min_overall_acc:.2f}"

        tex_str= f"{tex(defense)} & {corres_natural_acc:.2f} & {corres_robust_acc:.2f} & {min_overall_acc:.2f}"

    print(op_str)
    f.write(op_str+"\n")
    if 'Ours' in defense and 'AdvGAN' not in defense:
        ft.write("\\hline \n")    
    ft.write(tex_str+"\\\ \n")
    if 'AdvGAN' in defense:
        ft.write("\\hline \n")    
    

    # f.write(f'\n {defense} | overall acc = {highest_overall_acc:.3f} \pm {best_err_ove:.4f}, natural acc = {best_nat_acc:.3f} \pm {best_err_nat:.4f}, robust acc = {best_rob_acc:.3f} \pm {best_err_rob:.4f} \n\n\n')
f.close()
ft.close()