import argparse
import os
import random
import json
from tqdm import trange
import numpy as np

from data.meta_test_datasets import META_TEST_DATASET_DICT
from benchmarks.lcbench import LCBench
from benchmarks.taskset import TaskSet
from benchmarks.pd1 import PD1

from hpo_methods.dyhpo_hpo import DyHPOAlgorithm

parser = argparse.ArgumentParser(
    description='DyHPO experiments.',
)
parser.add_argument(
    '--data_dir',
    type=str,
    default='./data',
)
parser.add_argument(
    '--benchmark_name',
    type=str,
    default='lcbench',
)
parser.add_argument(
    '--dataset_names',
    type=str,
    default='all',
)
parser.add_argument(
    '--repeat',
    type=int,
    default=5,
)
parser.add_argument(
    '--budget_limit',
    type=int,
    default=300,
)
parser.add_argument(
    '--model_ckpt_dir',
    type=str,
    default=None,
)
parser.add_argument(
    '--output_dir',
    type=str,
    default='./bo_results/random',
)

args = parser.parse_args()
benchmark_name = args.benchmark_name
budget_limit = args.budget_limit
if args.dataset_names == "all":
    dataset_names = META_TEST_DATASET_DICT[args.benchmark_name]
else:
    dataset_names = args.dataset_names.split("_")

# data
benchmark_name = args.benchmark_name
if benchmark_name == 'lcbench':
    benchmark_data_path = os.path.join(args.data_dir, "data_2k.json")
    dummy_dataset_name = 'segment'
elif benchmark_name == 'taskset':
    benchmark_data_path = os.path.join(args.data_dir, 'taskset_chosen.json')
    dummy_dataset_name = 'char_rnn_language_model_family_seed94'
elif benchmark_name == 'pd1':
    benchmark_data_path = os.path.join(args.data_dir, "pd1_preprocessed.json")
    dummy_dataset_name = 'imagenet_resnet_batch_size_512'
else:
    raise NotImplementedError

minimization_problem_type = {
    'lcbench': False,
    'taskset': False,
    'pd1': False,
}

minimization = minimization_problem_type[benchmark_name]
benchmark_dict = {
    'lcbench': LCBench,
    'taskset': TaskSet,
    'pd1': PD1
}

benchmark = benchmark_dict[benchmark_name](benchmark_data_path, dummy_dataset_name)

seed_list = [ random.randint(0, 9999) for _ in range(args.repeat) ]

for repeat_idx in range(args.repeat):

    for dataset_name in dataset_names:
        benchmark.set_dataset_name(dataset_name)

        output_dir = os.path.join(
            args.output_dir,
            f'{benchmark_name}',
            dataset_name
        )
        os.makedirs(output_dir, exist_ok=True)

        print(dataset_name)

        n_trials = args.budget_limit // benchmark.max_budget + 1
                
        config_indices = np.random.choice(
            benchmark.nr_hyperparameters, 
            size=n_trials, 
            replace=False
        )
        incumbent = 0
        trajectory = []
        for config_index in config_indices:
            for budget in range(1, benchmark.max_budget+1):
                score = benchmark.get_performance(config_index, budget)
                if score > incumbent:
                    incumbent = score
                trajectory.append(incumbent)

        with open(os.path.join(output_dir, f"trajectory_{repeat_idx}.json"), "w") as fp:
            json.dump(trajectory, fp)