import pickle
import pandas as pd
import argparse
import os
from utils import *

# import matplotlib.pyplot as plt


def main(args):
    max_t = args.max_t
    alg = args.tune_algo
    # get detailed n, r of Hyperband/BOHB with a fixed eta=3
    hb, budget, total_n, s_max = compute_hb_details(max_t, 3)
    s = [str(i) for i in range(s_max, -1, -1)]

    exp_name = (
        args.dataset
        + "_"
        + args.optimizer
        + "_"
        + args.tune_algo
        + "_{}".format(args.num_config)
        + "_seed-{}".format(args.seed)
    )
    exp_dir = os.path.join(args.save_root, exp_name)
    dfs = pickle.load(open(f"{exp_dir}/log.pkl", "rb"))
    trial_path = {}
    for k, v in dfs.items():
        trial_path[v.trial_id[0]] = k
    keys = list(dfs.keys())
    values = list(dfs.values())

    if alg == "HB":
        keys = keys[::-1]
        values = values[::-1]

    # sort iterations in a sequential order described in Hyperband/BOHB
    new_df = pd.DataFrame(columns=values[0].columns, index=None)
    acc_n0 = 0
    interval = []
    for k in s:
        n0 = hb[k]["n"][0]
        r = hb[k]["r"]
        sub = values[acc_n0 : acc_n0 + n0]
        for i in range(len(r)):
            if i == 0:
                interval.append(r[0])
            else:
                interval.append(r[i] - r[i - 1])
            for config in sub:
                if len(config) >= r[i]:
                    if i == 0:
                        new_df = new_df.append(config.iloc[0 : r[i]])
                    else:
                        new_df = new_df.append(config.iloc[r[i - 1] : r[i]])
        acc_n0 += n0

    # find the best trial id given the budget b
    new_df = new_df.reset_index(drop=True)

    running_epoch = []
    for i in range(len(budget)):
        b, gap = budget[i], interval[i]
        if i == 0:
            l = list(range(1, b + gap, gap))
            running_epoch.extend(l)
        else:
            l = list(range(budget[i - 1] + gap, b + gap, gap))
            running_epoch.extend(l)

    acc = []
    for t in running_epoch:
        max_acc = new_df[:t].mean_accuracy.max()
        acc.append(max_acc)

    # fig, ax = plt.subplots()
    # ax.plot(running_epoch, acc)
    acc_dict = {
        "running_epoch": running_epoch,
        "acc": acc,
    }

    pickle.dump(acc_dict, open(f"{exp_dir}/acc.pkl", "wb"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--save_root", type=str, default="./checkpoint", help="Root of checkpoint."
    )
    parser.add_argument(
        "--tune_algo",
        type=str,
        required=True,
        choices=["HB", "BOHB"],
        help="Name of tuning algorithm.",
    )
    parser.add_argument(
        "--optimizer",
        type=str,
        required=True,
        choices=["SGD", "Adam"],
        help="Name of optimizer.",
    )
    parser.add_argument("--seed", type=int, help="Random seed")
    parser.add_argument("--max_t", type=int, default=50, help="maximum epoch per trial")
    parser.add_argument(
        "--num_config", type=int, default=62, help="number of configurations"
    )
    parser.add_argument("--dataset", type=str, default="mnist")
    args = parser.parse_args()
    main(args)
