import argparse
import os
import random
import json
from tqdm import trange
import numpy as np

from data.meta_test_datasets import META_TEST_DATASET_DICT
from benchmarks.lcbench import LCBench
from benchmarks.taskset import TaskSet
from benchmarks.pd1 import PD1

from hpo_methods.dyhpo_hpo import DyHPOAlgorithm

parser = argparse.ArgumentParser(
    description='DyHPO experiments.',
)
parser.add_argument(
    '--data_dir',
    type=str,
    default='./data',
)
parser.add_argument(
    '--benchmark_name',
    type=str,
    default='lcbench',
)
parser.add_argument(
    '--dataset_names',
    type=str,
    default='all',
)
parser.add_argument(
    '--repeat',
    type=int,
    default=5,
)
parser.add_argument(
    '--budget_limit',
    type=int,
    default=300,
)
parser.add_argument(
    '--model_ckpt_dir',
    type=str,
    default=None,
)
parser.add_argument(
    '--output_dir',
    type=str,
    default='./bo_results/quicktune',
)

args = parser.parse_args()
benchmark_name = args.benchmark_name
budget_limit = args.budget_limit
if args.dataset_names == "all":
    dataset_names = META_TEST_DATASET_DICT[args.benchmark_name]
else:
    dataset_names = args.dataset_names.split("_")

# data
benchmark_name = args.benchmark_name
if benchmark_name == 'lcbench':
    benchmark_data_path = os.path.join(args.data_dir, "data_2k.json")
    dummy_dataset_name = 'segment'
elif benchmark_name == 'taskset':
    benchmark_data_path = os.path.join(args.data_dir, "taskset_chosen.json")
    dummy_dataset_name = 'rnn_text_classification_family_seed19'
elif benchmark_name == 'pd1':
    benchmark_data_path = os.path.join(args.data_dir, "pd1_preprocessed.json")
    dummy_dataset_name = 'imagenet_resnet_batch_size_512'
else:
    raise NotImplementedError

minimization_problem_type = {
    'lcbench': False,
    'taskset': False,
    'pd1': False,
}

minimization = minimization_problem_type[benchmark_name]
benchmark_dict = {
    'lcbench': LCBench,
    'taskset': TaskSet,
    'pd1': PD1
}

benchmark = benchmark_dict[benchmark_name](benchmark_data_path, dummy_dataset_name)

seed_list = [ random.randint(0, 9999) for _ in range(args.repeat) ]

if args.model_ckpt_dir is None:
    init_index = None
else:
    counter = np.array([ 0 ] * benchmark.nr_hyperparameters)
    dataset_names = benchmark.load_dataset_names()
    for dataset_name in dataset_names:        
        if not dataset_name in META_TEST_DATASET_DICT[benchmark_name]:
            benchmark.set_dataset_name(dataset_name)
            data = np.array(benchmark.data)
            best_index = np.argmax(data[:, 0]).item()
            counter[best_index] = counter[best_index] + 1
    init_index = np.argmax(counter).item()

for repeat_idx in range(args.repeat):

    for dataset_name in dataset_names:
        benchmark.set_dataset_name(dataset_name)

        output_dir = os.path.join(
            args.output_dir,
            f'{benchmark_name}',
            dataset_name
        )
        os.makedirs(output_dir, exist_ok=True)

        print(dataset_name)
        dyhpo_surrogate = DyHPOAlgorithm(
            args.model_ckpt_dir,
            init_index,
            benchmark.get_hyperparameter_candidates(),
            benchmark.log_indicator,
            seed=seed_list[repeat_idx],
            max_benchmark_epochs=benchmark.max_budget,
            fantasize_step=1,
            minimization=minimization,
            total_budget=budget_limit,
            dataset_name=dataset_name,
            output_path=output_dir,
        )

        evaluated_configs = dict()
        incumbent = 0
        trajectory = []

        for _ in trange(args.budget_limit):

            hp_index, budget = dyhpo_surrogate.suggest()
            performance_curve = benchmark.get_curve(hp_index, budget)
            score = performance_curve[-1]
            dyhpo_surrogate.observe(hp_index, budget, performance_curve)

            if score > incumbent:
                incumbent = score
            trajectory.append(incumbent)

        with open(os.path.join(output_dir, f"trajectory_{repeat_idx}.json"), "w") as fp:
            json.dump(trajectory, fp)