import pickle
import numpy as np
import torch
import os
import argparse
import time

from torch.utils.data import Dataset, DataLoader, Subset
from data.pytorch_datasets import get_dataset
import torch.nn.functional as F
from copy import deepcopy

np.random.seed(0)
pkl_dir = "distribution_vectors_pkls"

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='cifar10', type=str)
# parser.add_argument('-gpu', default=0, type=int, help='gpu:id to be used')
# parser.add_argument('-eps', '--eps', type=float, default=0.031)
# parser.add_argument('-epsiter', '--epsiter', type=float, default=0.007)
# parser.add_argument('-ns', '--ns', type=int, default=20)
parser.add_argument('-m', '--methods', default="all", type=str, nargs='+')
parser.add_argument('--debug', action='store_false')
parser.add_argument('-a','--attack', default='pgd', type=str)
# parser.add_argument('-c', '--constraint', type=float, default=0.0, help='minimum robust accuracy required for a hyp to be selected')
args = parser.parse_args()

if args.methods=="all" and args.data!='cifar100':
    args.methods = ["GAT", "TRADES", "Nu_AT", "MART", "PGD_AT_p", "RFGSM_AT_p", "Ours", "Ours_AdvGAN"]
if args.methods=="all" and args.data=='cifar100':
    args.methods = ["GAT", "TRADES", "Nu_AT", "MART", "Ours", "Ours_AdvGAN"]

hyp_dict = {
    "GAT": ["2.5", "5.0", "10.0", "15.0", "20.0", "30.0"],
    "TRADES_cifar10": ["0.1", "0.4", "0.6", "0.8", "2.0", "4.0", "6.0"],
    "TRADES_fmnist": ["0.4", "0.6", "0.8", "1.0", "2.0", "4.0", "6.0"],
    "Nu_AT": ["2.0", "2.5", "3.0", "3.5", "4.0", "4.5", "5.0"],
    "MART": ["0.5", "1.0", "2.5", "5.0", "7.5", "10.0"],
    "PGD_AT_p": ["0.05", "0.1", "0.15", "0.2", "0.25"],
    "RFGSM_AT_p": ["0.05", "0.1", "0.15", "0.2", "0.25"],
    "Ours_cifar10": ["0.01", "0.05", "0.5", "1.0", "2.0", "5.0", "8.0", "10.0"],
    "Ours_fmnist": ["0.01", "0.1", "0.5", "1.0", "1.5", "2.0", "4.0", "8.0"],
    "Ours_AdvGAN_cifar10": ["0.01", "0.25", "0.5", "1.0", "1.5", "2.0"],
    "Ours_AdvGAN_fmnist": ["0.01", "0.1", "0.5", "1.0", "1.5", "2.0"]
}

hyp_dict_cifar100 = {
    "GAT": ["2.0", "5.0", "10.0", "20.0", "30.0"],
    "TRADES": ["1.0", "2.0", "4.0", "6.0", "8.0"],
    "Nu_AT": ["1.0", "2.0", "4.5", "6.0", "8.0"],
    "MART": ["0.5", "1.0", "2.5", "5.0", "10.0"],
    # "PGD_AT_p": ["0.25"],
    # "RFGSM_AT_p": ["1.0"],
    "Ours": ["0.5", "1.0", "2.0", "5.0", "10.0"],
    "Ours_AdvGAN": ["0.005", "0.01", "0.25", "0.5", "1.0"],
}


# Note: There will be NO model loading and evaluation in this file! Just dot products!

def DEBUG(log):
    if args.debug:
        print(log)

# 1. Load the distribution vector
# 2. Load the pre saved pickled robust acc vector and natural acc vectors.
# Idea: Store these in a dict (from the model eval file) and reload them here

# 1. Load the distribution vector
DEBUG("Loading distribution vector and sampled subsets")
start = time.time()
# here we are selecting the subset selection strategies based on clean loss of the vanilla classifier.
# Hence there is no role of attack method (eg pgd) here
load_path = f"attacked_subsets/random_subsets_10000.pkl"     
with open(load_path, 'rb') as fp:
    random_subsets = pickle.load(fp)
end = time.time()
DEBUG(f"Distribution vector and subsets loaded! Time taken: {end-start}")

# Keys from the distribution vector:
# ['robust loss vector', 'robust acc vector', 'natural acc vector', 'natural loss vector', 'distribution rob loss',
# 'distribution softmax rob loss', 'distribution softmax nat loss', 
# 'softmax rob loss subsets', 'softmax nat loss subsets', 'rob loss subsets']

# Utility functions:

def eval_acc_on_subsets(acc_vec, subsets, acc_type):
    # Gets robust accuracies for each subset
    if acc_type=="rob":
        len_S = subsets[0].sum()
    elif acc_type=="nat":
        len_S = (1-subsets[0]).sum()
    
    accs = []
    for S in subsets:
        # Evaluate robust accuracy on each subset and return value
        if acc_type=="rob":
            multiplier = deepcopy(S)
        elif acc_type=="nat":
            multiplier = deepcopy(1-S)
        
        correct = (acc_vec*multiplier).sum()
        accuracy = correct/len_S
        accs.append(accuracy)
    
    return np.array(accs)

def get_overall_accs(rob_accs, nat_accs, subsets):
    # Get the overall accuracy
    len_S = subsets[0].sum()
    len_test_set = len(subsets[0])

    frac_rob = len_S/len_test_set
    frac_nat = 1 - frac_rob

    overall_accs = frac_rob*rob_accs + frac_nat*nat_accs

    return overall_accs


# Helper functions to do things more explicitly!
# def eval_robust_acc_on_subsets(rob_acc_vec, subsets):
#     # Gets robust accuracies for each subset
#     len_S = subsets[0].sum()
#     rob_accs = []
#     for S in subsets:
#         # Evaluate robust accuracy on each subset and return value
#         correct = (rob_acc_vec*S).sum()
#         accuracy = correct/len_S
#         rob_accs.append(accuracy)
    
#     return rob_accs

# def eval_natural_acc_on_subsets(nat_acc_vec, subsets):
#     # Gets natrual accuracies for each subset

#     len_S = (1-subsets[0]).sum()
#     print(len_S)
#     nat_accs = []
#     for S in subsets:
#         # Evaluate robust accuracy on each subset and return value
#         correct = (nat_acc_vec*(1-S)).sum()
#         accuracy = correct/len_S
#         nat_accs.append(accuracy)
    
#     return nat_accs

def evaluate_model_on_subsets(model_dict, subsets):
    # Given a model (acc, loss, and other vectors) evaluate it on subsets
    # Returns:
    # 1. A vector of robust accuracies for each subset
    # 2. A vector of natural accuracies for each subset
    # 3. A vector of overall accuracies on each subset
    # 4. A vector of robust losses for each subset
    # 5. A vector of natural losses for each subset
    # 6. A vector of overall losses on each subset

    # Get vectors corresponding to model
    rob_loss_vec = model_dict['robust loss vector']
    rob_acc_vec = model_dict['robust acc vector']
    nat_acc_vec = model_dict['natural acc vector']
    nat_loss_vec = model_dict['natural loss vector']

    # Get the accuracies and losses. Use helper functions
    rob_accs = eval_acc_on_subsets(rob_acc_vec, subsets, 'rob')
    nat_accs = eval_acc_on_subsets(nat_acc_vec, subsets, 'nat')
    overall_accs = get_overall_accs(rob_accs, nat_accs, subsets)


    return_dict = {
        "rob accs": rob_accs,
        "nat accs": nat_accs,
        "overall accs": overall_accs
    }

    return return_dict

output_folder = f'baselines_choose_hyp_random_top3'
os.makedirs(output_folder, exist_ok=True)
output_path = f'{output_folder}/{args.data}_{args.attack}_top3.txt'
f = open(output_path, 'w')
f.write(f'_________________ Dataset: {args.data}, Attack: {args.attack} _________________ \n\n')

random_stats_dict = {}
# 2. Load the pre saved pickled robust acc vector and natural acc vectors.
for defense in args.methods:
    defense_dict = {}
    if args.data == 'cifar100':
        hyp_list = hyp_dict_cifar100[defense]
    else:
        if 'Ours' in defense or 'Ours_AdvGAN' in defense or 'TRADES' in defense:
            hyp_list = hyp_dict[f'{defense}_{args.data}']
        else:
            hyp_list = hyp_dict[defense]
    for hyp in hyp_list:
        start = time.time()
        folder = 'val_stats_pkl'
        filepath = os.path.join(folder, f"{defense}_{hyp}_{args.data}_{args.attack}_val_stats.pkl")

        with open(filepath, 'rb') as fp:
            pkl_dict = pickle.load(fp)
        
        # defense_dict keys: ['robust loss vector', 'robust acc vector', 'natural acc vector', 'natural loss vector']
        rob_loss_vec = pkl_dict['robust loss vector']
        rob_acc_vec = pkl_dict['robust acc vector']
        nat_acc_vec = pkl_dict['natural acc vector']
        nat_loss_vec = pkl_dict['natural loss vector']

        random_stats = evaluate_model_on_subsets(pkl_dict, random_subsets)

        end = time.time()
        DEBUG(f"Defense {defense}_{hyp} took {end-start} seconds!")

        defense_dict[hyp] = random_stats

    random_stats_dict[defense] = defense_dict

best_hyps = []
print(f"Accuracies for attack {args.attack}")
print(f"Defense: \t & natural accuracy & \t robust accuracy& \t & overall accuracy \t")
for defense in args.methods:
    top3_hyps = [-1.0, -1.0, -1.0]
    top3_nat_accs = np.zeros(3)
    top3_rob_accs = np.zeros(3)
    top3_ove_accs = np.zeros(3)
    if args.data == 'cifar100':
        hyp_list = hyp_dict_cifar100[defense]
    else:
        if 'Ours' in defense or 'Ours_AdvGAN' in defense or 'TRADES' in defense:
            hyp_list = hyp_dict[f'{defense}_{args.data}']
        else:
            hyp_list = hyp_dict[defense]
    for hyp in hyp_list:
        random_stats = random_stats_dict[defense][hyp]
        rob_accs = random_stats["rob accs"]
        nat_accs = random_stats["nat accs"]
        overall_accs = random_stats["overall accs"]

        mean_comp = lambda x: x.mean()*100.0
        mean_rob_acc = mean_comp(rob_accs)
        mean_nat_acc = mean_comp(nat_accs)
        mean_overall_acc = mean_comp(overall_accs)

        stderr_comp = lambda x: (x.std()*100.0)/np.sqrt(len(x))
        std_err_rob_accs = stderr_comp(rob_accs)
        std_err_nat_accs = stderr_comp(nat_accs)
        std_err_overall_accs = stderr_comp(overall_accs)

        if mean_overall_acc > top3_ove_accs[-1]: #and mean_rob_acc >= args.constraint 
            top3_ove_accs = np.append(top3_ove_accs, mean_overall_acc)
            top3_nat_accs = np.append(top3_nat_accs, mean_nat_acc)
            top3_rob_accs = np.append(top3_rob_accs, mean_rob_acc)
            asc_sorted_ids = np.argsort(top3_ove_accs)
            desc_sorted_ids = np.flip(asc_sorted_ids)
            top3_ove_accs = top3_ove_accs[desc_sorted_ids]
            top3_ove_accs = np.delete(top3_ove_accs, -1)
            top3_nat_accs = top3_nat_accs[desc_sorted_ids]
            top3_nat_accs = np.delete(top3_nat_accs, -1)
            top3_rob_accs = top3_rob_accs[desc_sorted_ids]
            top3_rob_accs = np.delete(top3_rob_accs, -1)

            new_index = np.where(desc_sorted_ids==3)[0][0]
            top3_hyps.insert(new_index, float(hyp))
            top3_hyps.pop(-1)            

        op_str= f"{defense+hyp:<16} \t\t & {mean_nat_acc:.3f}\pm{std_err_nat_accs:.4f} &\t\t \
        {mean_rob_acc:.3f}\pm{std_err_rob_accs:.4f} &\t\t {mean_overall_acc:.3f}\pm{std_err_overall_accs:.4f}"

        print(op_str)
        f.write(op_str+"\n")
    
    avg_hyp = np.array(top3_hyps).sum()/3
    hyp_list_float = np.array([float(h) for h in hyp_list])
    closest_hyp_ind = np.argmin(np.abs(hyp_list_float - avg_hyp))
    best_hyp = hyp_list[closest_hyp_ind]

    # f.write(f'\n {defense}: best hyp = {best_hyp} | overall acc = {highest_overall_acc:.3f} \pm {best_err_ove:.4f}, natural acc = {best_nat_acc:.3f} \pm {best_err_nat:.4f}, robust acc = {best_rob_acc:.3f} \pm {best_err_rob:.4f} \n\n\n')
    f.write(f'\n {defense}: best hyp = {best_hyp} \n\n\n')
    best_hyps.append(defense+'_'+best_hyp)

with open(f'{output_folder}/best_hyps_{args.data}_{args.attack}_random_top3.pkl', 'wb') as fp:
    pickle.dump(best_hyps, fp)

f.close()
