

import os
import numpy as np
from src.common.constant import Config
from src.tools.io_tools import read_json

base_dir_folder = os.environ.get("base_dir")
if base_dir_folder is None:base_dir_folder = os.getcwd()
base_dir = os.path.join(base_dir_folder, "img_data")

print("gt_api running at {}".format(base_dir))
train_base201_c10 = os.path.join(base_dir, "train_based_201_c10.json")
train_base201_c100 = os.path.join(base_dir, "train_based_201_c100.json")
train_base201_img = os.path.join(base_dir, "train_based_201_img.json")

train_base101_c10 = os.path.join(base_dir, "train_based_101_c10_100run_24k_models.json")


def post_processing_train_base_result(search_space, dataset, x_max_value: int = None):

    if search_space == Config.NB201 and dataset == Config.c10:
        data = read_json(train_base201_c10)

    elif search_space == Config.NB201 and dataset == Config.c100:
        data = read_json(train_base201_c100)
    elif search_space == Config.NB201 and dataset == Config.imgNet:
        data = read_json(train_base201_img)

    elif search_space == Config.NB101 and dataset == Config.c10:
        data = read_json(train_base101_c10)
    else:
        print(f"Cannot read dataset {dataset} of file")
        raise

    # data is in form of
    """
    data[run_id] = {}
    data[run_id]["arch_id_list"]
    data[run_id]["current_best_acc"] 
    data[run_id]["x_axis_time"] 
    """

    acc_got_row = []
    time_used_row = []
    min_arch_across_all_run = 15625
    for run_id in data:
        acc_got_row.append(data[run_id]["current_best_acc"])
        time_used_row.append(data[run_id]["x_axis_time"])
        if len(data[run_id]["current_best_acc"]) < min_arch_across_all_run:
            min_arch_across_all_run = len(data[run_id]["current_best_acc"])

    # for each run, only use min_arch_across_all_run
    for i in range(len(acc_got_row)):
        acc_got_row[i] = acc_got_row[i][:min_arch_across_all_run]
        time_used_row[i] = time_used_row[i][:min_arch_across_all_run]

    acc_got = np.array(acc_got_row)
    time_used = np.array(time_used_row)

    if data['0']["current_best_acc"][-1] < 1:
        acc_got = acc_got * 100

    acc_l = np.quantile(acc_got, 0.25, axis=0)
    acc_m = np.quantile(acc_got, 0.5, axis=0)
    acc_h = np.quantile(acc_got, 0.75, axis=0)

    time_l = np.quantile(time_used, 0.25, axis=0)
    time_m = np.quantile(time_used, 0.5, axis=0).tolist()
    time_h = np.quantile(time_used, 0.75, axis=0)

    x_list = [ele/60 for ele in time_m]
    y_list_low = acc_l[:len(x_list)]
    y_list_m = acc_m[:len(x_list)]
    y_list_high = acc_h[:len(x_list)]

    # if the x array max value is provided.
    if x_max_value is not None:
        final_x_list = []
        final_x_list_low = []
        final_x_list_m = []
        final_x_list_high = []
        for i in range(len(x_list)):
            if x_list[i] <= x_max_value:
                final_x_list.append(x_list[i])
                final_x_list_low.append(y_list_low[i])
                final_x_list_m.append(y_list_m[i])
                final_x_list_high.append(y_list_high[i])
            else:
                break
        return final_x_list, final_x_list_low, final_x_list_m, final_x_list_high
    else:
        return x_list, y_list_low.tolist(), y_list_m.tolist(), y_list_high.tolist()


if __name__ == "__main__":
    search_space = Config.NB201
    dataset = Config.c100
    x_list, y_list_low, y_list_m, y_list_high = post_processing_train_base_result(search_space, dataset)

    from matplotlib import pyplot as plt

    plt.fill_between(x_list, y_list_low, y_list_high, alpha=0.1)
    plt.plot(x_list, y_list_m, "-*", label="Training-based")

    plt.xscale("symlog")
    plt.grid()
    plt.xlabel("Time Budget given by user (mins)")
    plt.ylabel("Test Accuracy")
    plt.legend()
    plt.show()


