import numpy as np
from src.evaluation.evaluation_pipeline.evaluate_realizations import *
from src.evaluation.aux.load_results import *

import matplotlib.pyplot as plt
import os
import argparse
import matplotlib as mpl
import seaborn as sns
import pandas as pd
import config as config

my_pal= config.COLOR
my_marker=config.MARKER

time="_Date-2022-05-18_Time-16-45"


'''this file plots the figure 2 in main paper'''

parser = argparse.ArgumentParser(description='task9 for plotting figure 2 in the main paper')
parser.add_argument('-t', default="time", type=str, help='please input time of your .txt file generated by .sh bash file. For example, _Date-2022-05-18_Time-16-45')

args = parser.parse_args()
time = args.t

if time == "time":
    print("please input time of your .txt file generated by .sh bash file. For example, _Date-2022-05-18_Time-16-45")
    exit()


#rename macros
n_RS=config.n_RS
n_Oracle=config.n_CAMS_best_policy
n_QBC = config.n_qbc
n_IWAL = config.n_iwal
n_MP = config.n_mp
n_CQBC = config.n_contextual_qbc
n_CIWAL = config.n_contextual_iwal
n_CAMS = config.n_CAMS_identity
n_test = config.n_CAMS_test

def rename_method_list(methods):
    arr=[]
    for item in methods:
        if item == "rs":
            arr.append(n_RS)
        elif item == "qbc":
            arr.append(n_QBC)
        elif item == "iwal":
            arr.append(n_IWAL)
        elif item == "mp":
            arr.append(n_MP)
        elif item == "contextual_qbc":
            arr.append(n_CQBC)
        elif item == "contextual_iwal":
            arr.append(n_CIWAL)
        elif item == "CAMS_best_policy":
            arr.append(n_Oracle)
        elif item == "CAMS_identity":
            arr.append(n_CAMS)
        elif item == "CAMS_test":
            arr.append(n_test)
        else:
            print("error")
            print(item)
            exit()
    
    return arr


def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]



def organize_plot(dataset_name, budget, folder_name ,my_pal=my_pal):

    # Preprocess and load data from experiments
    path_ = os.getcwd() + "/resources/contextual_data/" + dataset_name

    predictions_arr = np.loadtxt(str(path_) + "/predictions.out")
    oracle_arr = np.loadtxt(str(path_) + "/oracle.out")
    oracle_arr = np.asarray(oracle_arr)

    path = os.getcwd() + "/resources/results/" + folder_name + "/"

    file_list = os.listdir(path)
    print(dataset_name,":", file_list)

    # data output
    data = np.load(path + "data.npz")
    num_reals = data["num_reals"]
    print(dataset_name,":", num_reals)
    num_instances = data["num_instances"]
    num_models = data["num_models"]
    methods = data["methods"]
    budget_raw = data["budgets"]
    experiment_result = np.load(path + "experiment_results_budget" + str(budget) + ".npz")

    idx_log = experiment_result['idx_log']  # labelled_instances: if algo decide to query
    idx_budget_log = experiment_result['idx_budget_log']  # U_t_budget: query under budget
    ct_log = experiment_result['ct_log']  # ct_log: how many instance: all 1
    streaming_instances_log = experiment_result['streaming_instances_log']
    hidden_loss_log = experiment_result['hidden_loss_log']  # loss each query
    posterior_log = experiment_result['posterior_log']
    posterior_log_ap = experiment_result["posterior_log_ap"]
    posterior_log_ap_identity = experiment_result["posterior_log_ap_identity"]
    posterior_log_ap_test = experiment_result["posterior_log_ap_test"]
    posterior_log_contextual_qbc = experiment_result["posterior_log_contextual_qbc"]
    posterior_log_contextual_iwal = experiment_result["posterior_log_contextual_iwal"]
    eval = np.load(path + "eval_results.npz")
    box_budget = eval["box_budget"]
    box_budget_actual = eval["box_budget_actual"]

    eval_cumulative_loss = eval["cumulative_loss"]
    query_regardles_budget_detail = eval["query_regardles_budget_detail"]

    max_method=eval['max_method']
    max_budget_actual=eval['max_budget_actual']
    max_cumulative_loss=eval['max_cumulative_loss']
    max_method = rename_method_list(max_method)
    print(max_budget_actual)

    max_bar_query=[]
    for item in max_budget_actual:
        min_bar=0
        for j in budget_raw:
            if  item >= j:
                min_bar=j
        max_bar_query.append(min_bar)

    print(dataset_name,":", budget_raw)

    
    box_cumulative_loss = eval["box_cumulative_loss"]
    box_method = eval["box_method"]
    box_method = rename_method_list(box_method)
    print(dataset_name,":", box_method)

    box_df_shading = {"budget": box_budget_actual,"budget_fixed": box_budget, "c_regret": box_cumulative_loss, "method": box_method}
    box_df_shading = pd.DataFrame(box_df_shading)

    reshape_budget=[]
    reshape_budget_fixed=[]

    for index, row in box_df_shading.iterrows():
        print(row['budget'],row['budget_fixed'], row['c_regret'], row['method'])
        reshape_budget.append(row['budget'])
        budget_w_max=np.concatenate((budget_raw,[max_budget_actual[max_method.index(row['method'])]]))
        round_value=find_nearest(budget_w_max ,row['budget'])

        if round_value == max_bar_query[max_method.index(row['method'])]:
            reshape_budget_fixed.append(max_budget_actual[max_method.index(row['method'])])
        else:
            reshape_budget_fixed.append(round_value)


    box_df_shading = {"budget": reshape_budget, "budget_fixed": reshape_budget_fixed, "c_regret": box_cumulative_loss, "method": box_method}
    box_df_shading = pd.DataFrame(box_df_shading)


    for item in methods:
        for budget_ in budget_raw:
            print(item)
            x = np.where((box_df_shading["method"]==item) & (box_df_shading["budget_fixed"]== budget_))
            y = box_df_shading.loc[x]["budget"].mean()
            box_df_shading.iloc[[x], [box_df_shading.columns.get_loc("budget_fixed")]]=y

    shade_df_2=box_df_shading.filter(["budget_fixed","method","c_regret"],axis=1).drop_duplicates().reset_index(drop=True)
    cost_effective_table=shade_df_2.groupby(['budget_fixed','method'])["c_regret"].mean().reset_index().round(0)
    cost_effective_table.to_csv("./task9/task9_" +dataset_name+"_cost_effective.csv")
    plt.figure(figsize=(10, 10), dpi=300)
    #    sns.set(font_scale = 5)
    line_ = sns.lineplot(x="budget_fixed", y="c_regret", label = n_RS, data=shade_df_2[shade_df_2["method"]==n_RS],color=my_pal[n_RS],  ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_Oracle, data=shade_df_2[shade_df_2["method"]==n_Oracle],color=my_pal[n_Oracle], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_QBC, data=shade_df_2[shade_df_2["method"]==n_QBC],color=my_pal[n_QBC], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_IWAL, data=shade_df_2[shade_df_2["method"]==n_IWAL],color=my_pal[n_IWAL], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_MP, data=shade_df_2[shade_df_2["method"]==n_MP],color=my_pal[n_MP],  ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_CQBC, data=shade_df_2[shade_df_2["method"]==n_CQBC],color=my_pal[n_CQBC], ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_CIWAL, data=shade_df_2[shade_df_2["method"]==n_CIWAL],color=my_pal[n_CIWAL],  ci=63, linewidth=1)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_CAMS, data=shade_df_2[shade_df_2["method"]==n_CAMS],color=my_pal[n_CAMS],  ci=63, linewidth=4)
    sns.lineplot(x="budget_fixed", y="c_regret", label = n_test, data=shade_df_2[shade_df_2["method"]==n_test],color=my_pal[n_test],  ci=63, linewidth=4)

    #generate plot for figure 2 (buttom)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.xlabel("Query cost", fontsize=30)
    plt.ylabel("", fontsize=30)
    plt.legend(loc=2)
    plt.legend(fontsize=18,title=None)
    plt.legend('')
    plt.savefig("./task9/"+dataset_name + "_shade_line_plot.png", bbox_inches='tight', pad_inches=0.01)
    plt.savefig("./task9/"+dataset_name + "_shade_line_plot.pdf", bbox_inches='tight', pad_inches=0.01)


    #save legend
    fig = plt.figure(figsize=(10, 10), dpi=300)
    handles,labels= line_.get_legend_handles_labels()

    fig.legend(handles,labels,ncol=8, loc='center')
    fig.savefig("./task9/" +'legend.png', bbox_inches='tight', pad_inches=0)
    fig.savefig("./task9/" +'legend.pdf', bbox_inches='tight', pad_inches=0)


# ############
# #generate cost effective plot
# please comment out the rest when run specific dataset
# dataset should align with config.py and task9_apply_scaling_parameter.sh files.

#_Date-2022-08-08_Time-00-58
dataset_name="DRIFT"
budget=400
folder_name="drift_contextual_streamsize3000_numreals50"+time+"_which_methods11011011011_policy[1]"
organize_plot( dataset_name, budget,folder_name)

# # #_Date-2022-08-07_Time-22-32 
# dataset_name="HIV"
# budget=600
# folder_name="HIV_contextual_streamsize4000_numreals50"+time+"_which_methods11011011011_policy[0]"
# organize_plot( dataset_name, budget,folder_name)


# #_Date-2022-08-07_Time-20-13
# # _Date-2022-08-08_Time-04-38
# dataset_name="VERTEBRAL"
# budget=60
# folder_name="VERTEBRAL_contextual_streamsize80_numreals50"+time+"_which_methods11011011011_policy[0]"
# organize_plot( dataset_name, budget,folder_name)


# ##_Date-2022-08-08_Time-03-40   
# dataset_name="CIFAR10"
# budget=200
# folder_name="cifar_contextual_streamsize5000_numreals10"+time+"_which_methods11011011011_policy[11]"
# organize_plot( dataset_name, budget,folder_name)

