'''
for hyp selection of the baselines
for each hyp, take the min overall acc across n random subsets
choose the closeest(avg(top3 hyps)) which have the highest min acc
'''


import pickle
import numpy as np
import torch
import os
import argparse
import time

from torch.utils.data import Dataset, DataLoader, Subset
from data.pytorch_datasets import get_dataset
import torch.nn.functional as F
from copy import deepcopy

np.random.seed(0)
pkl_dir = "distribution_vectors_pkls"

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='cifar10', type=str)
parser.add_argument('-m', '--methods', default="all", type=str, nargs='+')
parser.add_argument('--debug', action='store_false')
parser.add_argument('-a','--attack', default='pgd', type=str)
args = parser.parse_args()

if args.methods=="all" and args.data!='cifar100':
    args.methods = ["GAT", "TRADES", "Nu_AT", "MART", "PGD_AT_p", "RFGSM_AT_p", "Ours", "Ours_AdvGAN"]
if args.methods=="all" and args.data=='cifar100':
    args.methods = ["GAT", "TRADES", "Nu_AT", "MART", "Ours", "Ours_AdvGAN"]

hyp_dict = {
    "GAT": ["2.5", "5.0", "10.0", "15.0", "20.0", "30.0"],
    "TRADES_cifar10": ["0.1", "0.4", "0.6", "0.8", "2.0", "4.0", "6.0"],
    "TRADES_fmnist": ["0.4", "0.6", "0.8", "1.0", "2.0", "4.0", "6.0"],
    "Nu_AT_cifar10": ["2.0", "2.5", "3.0", "3.5", "4.0", "4.5", "5.0"],
    "Nu_AT_fmnist": ["1.0", "2.0", "2.5", "3.0", "3.5", "4.0", "4.5", "5.0"],
    "MART": ["0.5", "1.0", "2.5", "5.0", "7.5", "10.0"],
    "PGD_AT_p": ["0.05", "0.1", "0.15", "0.2", "0.25"],
    "RFGSM_AT_p": ["0.05", "0.1", "0.15", "0.2", "0.25"],
    "Ours_cifar10": ["0.01", "0.05", "0.5", "1.0", "2.0", "5.0", "8.0", "10.0"],
    "Ours_fmnist": ["0.01", "0.1", "0.5", "1.0", "1.5", "2.0", "4.0", "8.0"],
    "Ours_AdvGAN_cifar10": ["0.01", "0.25", "0.5", "1.0", "1.5", "2.0"],
    "Ours_AdvGAN_fmnist": ["0.01", "0.1", "0.5", "1.0", "1.5", "2.0"]
}

hyp_dict_cifar100 = {
    "GAT": ["2.0", "5.0", "10.0", "20.0", "30.0"],
    "TRADES": ["1.0", "2.0", "4.0", "6.0", "8.0"],
    "Nu_AT": ["1.0", "2.0", "4.5", "6.0", "8.0"],
    "MART": ["0.5", "1.0", "2.5", "5.0", "10.0"],
    "Ours": ["0.5", "1.0", "2.0", "5.0", "10.0"],
    "Ours_AdvGAN": ["0.005", "0.01", "0.25", "0.5", "1.0"],
}


# Note: There will be NO model loading and evaluation in this file! Just dot products!

def DEBUG(log):
    if args.debug:
        print(log)

# 1. Load the distribution vector
# 2. Load the pre saved pickled robust acc vector and natural acc vectors.
# Idea: Store these in a dict (from the model eval file) and reload them here

# 1. Load the distribution vector
DEBUG("Loading distribution vector and sampled subsets")
start = time.time()
# here we are selecting the subset selection strategies based on clean loss of the vanilla classifier.
# Hence there is no role of attack method (eg pgd) here
load_path = f"attacked_subsets/random_subsets_10000.pkl"     
with open(load_path, 'rb') as fp:
    random_subsets = pickle.load(fp)
end = time.time()
DEBUG(f"Distribution vector and subsets loaded! Time taken: {end-start}")


# Utility functions:

def eval_acc_on_subsets(acc_vec, subsets, acc_type):
    # Gets robust accuracies for each subset
    if acc_type=="rob":
        len_S = subsets[0].sum()
    elif acc_type=="nat":
        len_S = (1-subsets[0]).sum()
    
    accs = []
    for S in subsets:
        # Evaluate robust accuracy on each subset and return value
        if acc_type=="rob":
            multiplier = deepcopy(S)
        elif acc_type=="nat":
            multiplier = deepcopy(1-S)
        
        correct = (acc_vec*multiplier).sum()
        accuracy = correct/len_S
        accs.append(accuracy)
    
    return np.array(accs)

def get_overall_accs(rob_accs, nat_accs, subsets):
    # Get the overall accuracy
    len_S = subsets[0].sum()
    len_test_set = len(subsets[0])

    frac_rob = len_S/len_test_set
    frac_nat = 1 - frac_rob

    overall_accs = frac_rob*rob_accs + frac_nat*nat_accs

    return overall_accs


def evaluate_model_on_subsets(model_dict, subsets):
    # Given a model (acc, loss, and other vectors) evaluate it on subsets
    # Returns:
    # 1. A vector of robust accuracies for each subset
    # 2. A vector of natural accuracies for each subset
    # 3. A vector of overall accuracies on each subset
    # 4. A vector of robust losses for each subset
    # 5. A vector of natural losses for each subset
    # 6. A vector of overall losses on each subset

    # Get vectors corresponding to model
    rob_loss_vec = model_dict['robust loss vector']
    rob_acc_vec = model_dict['robust acc vector']
    nat_acc_vec = model_dict['natural acc vector']
    nat_loss_vec = model_dict['natural loss vector']

    # Get the accuracies and losses. Use helper functions
    rob_accs = eval_acc_on_subsets(rob_acc_vec, subsets, 'rob')
    nat_accs = eval_acc_on_subsets(nat_acc_vec, subsets, 'nat')
    overall_accs = get_overall_accs(rob_accs, nat_accs, subsets)


    return_dict = {
        "rob accs": rob_accs,
        "nat accs": nat_accs,
        "overall accs": overall_accs
    }

    return return_dict

output_folder = f'baselines_choose_hyp_worst_case'
os.makedirs(output_folder, exist_ok=True)
output_path = f'{output_folder}/{args.data}_{args.attack}_top3.txt'
f = open(output_path, 'w')
f.write(f'_________________ Dataset: {args.data}, Attack: {args.attack} _________________ \n\n')

random_stats_dict = {}
# 2. Load the pre saved pickled robust acc vector and natural acc vectors.
for defense in args.methods:
    defense_dict = {}
    if args.data == 'cifar100':
        hyp_list = hyp_dict_cifar100[defense]
    else:
        if 'Ours' in defense or 'Ours_AdvGAN' in defense or 'TRADES' in defense or 'Nu_AT' in defense:
            hyp_list = hyp_dict[f'{defense}_{args.data}']
        else:
            hyp_list = hyp_dict[defense]
    for hyp in hyp_list:
        start = time.time()
        folder = 'val_stats_pkl'
        filepath = os.path.join(folder, f"{defense}_{hyp}_{args.data}_{args.attack}_val_stats.pkl")

        with open(filepath, 'rb') as fp:
            pkl_dict = pickle.load(fp)
        
        # defense_dict keys: ['robust loss vector', 'robust acc vector', 'natural acc vector', 'natural loss vector']
        rob_loss_vec = pkl_dict['robust loss vector']
        rob_acc_vec = pkl_dict['robust acc vector']
        nat_acc_vec = pkl_dict['natural acc vector']
        nat_loss_vec = pkl_dict['natural loss vector']

        random_stats = evaluate_model_on_subsets(pkl_dict, random_subsets)

        end = time.time()
        DEBUG(f"Defense {defense}_{hyp} took {end-start} seconds!")

        defense_dict[hyp] = random_stats

    random_stats_dict[defense] = defense_dict

best_hyps = []
print(f"Accuracies for attack {args.attack}")
print(f"Defense: \t & natural accuracy & \t robust accuracy& \t & overall accuracy \t")
for defense in args.methods:
    top3_hyps = [-1.0, -1.0, -1.0]
    top3_nat_accs = np.zeros(3)
    top3_rob_accs = np.zeros(3)
    top3_ove_accs = np.zeros(3)
    if args.data == 'cifar100':
        hyp_list = hyp_dict_cifar100[defense]
    else:
        if 'Ours' in defense or 'Ours_AdvGAN' in defense or 'TRADES' in defense or 'Nu_AT' in defense:
            hyp_list = hyp_dict[f'{defense}_{args.data}']
        else:
            hyp_list = hyp_dict[defense]
    for hyp in hyp_list:
        random_stats = random_stats_dict[defense][hyp]
        rob_accs = random_stats["rob accs"]
        nat_accs = random_stats["nat accs"]
        overall_accs = random_stats["overall accs"]

        min_overall_idx = np.argmin(overall_accs)
        min_overall_acc = overall_accs[min_overall_idx]*100
        corres_natural_acc = nat_accs[min_overall_idx]*100
        corres_robust_acc = rob_accs[min_overall_idx]*100


        if min_overall_acc > top3_ove_accs[-1]: #and mean_rob_acc >= args.constraint 
            top3_ove_accs = np.append(top3_ove_accs, min_overall_acc)
            top3_nat_accs = np.append(top3_nat_accs, corres_natural_acc)
            top3_rob_accs = np.append(top3_rob_accs, corres_robust_acc)
            asc_sorted_ids = np.argsort(top3_ove_accs)
            desc_sorted_ids = np.flip(asc_sorted_ids)
            top3_ove_accs = top3_ove_accs[desc_sorted_ids]
            top3_ove_accs = np.delete(top3_ove_accs, -1)
            top3_nat_accs = top3_nat_accs[desc_sorted_ids]
            top3_nat_accs = np.delete(top3_nat_accs, -1)
            top3_rob_accs = top3_rob_accs[desc_sorted_ids]
            top3_rob_accs = np.delete(top3_rob_accs, -1)

            new_index = np.where(desc_sorted_ids==3)[0][0]
            top3_hyps.insert(new_index, float(hyp))
            top3_hyps.pop(-1)            

        op_str= f"{defense+'_'+hyp:<16} \t\t {corres_natural_acc:.3f} \t\t {corres_robust_acc:.3f} \t\t {min_overall_acc:.3f}"

        print(op_str)
        f.write(op_str+"\n")
    
    avg_hyp = np.array(top3_hyps).sum()/3
    hyp_list_float = np.array([float(h) for h in hyp_list])
    closest_hyp_ind = np.argmin(np.abs(hyp_list_float - avg_hyp))
    best_hyp = hyp_list[closest_hyp_ind]

    f.write(f'\n {defense}: best hyp = {best_hyp} \n\n\n')
    best_hyps.append(defense+'_'+best_hyp)

with open(f'{output_folder}/best_hyps_{args.data}_{args.attack}.pkl', 'wb') as fp:
    pickle.dump(best_hyps, fp)

f.close()
