import argparse
import pandas as pd

import sys
from benchmarks.lcbench import LCBench
from benchmarks.pd1 import PD1
from benchmarks.taskset import TaskSet
from data.meta_test_datasets import META_TEST_DATASET_DICT

# from hpo.optimizers.asha.asha import AHBOptimizer
from sklearn.neighbors import KNeighborsRegressor
import os
import time
import json
import random

import numpy as np
from syne_tune.blackbox_repository import (
    load_blackbox,
    add_surrogate,
    BlackboxRepositoryBackend,
    UserBlackboxBackend,
)
from syne_tune.experiments import load_experiment
from syne_tune.blackbox_repository.blackbox_tabular import BlackboxTabular

# from benchmarking.commons.benchmark_definitions.lcbench import lcbench_benchmark
# from syne_tune.blackbox_repository import BlackboxRepositoryBackend
from syne_tune.backend.simulator_backend.simulator_callback import SimulatorCallback
from syne_tune.optimizer.baselines import (
    BayesianOptimization,
    RandomSearch,
    ASHA,
    BOHB,
    DEHB,
    SyncBOHB,
)
from syne_tune import Tuner, StoppingCriterion
import syne_tune.config_space as sp


if __name__ == "__main__":

    aggregate_data = False

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument(
        "--repeat",
        type=int,
        default=5,
        help="Reproducibility seed.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./bo_results",
        help="Output directory.",
    )
    parser.add_argument(
        "--dataset_names",
        type=str,
        default="all",
    )
    parser.add_argument(
        "--method_name",
        type=str,
        default="dehb",
        help="The algorithm name.",
    )
    parser.add_argument(
        "--benchmark_name",
        type=str,
        default="lcbench"
    )

    args = parser.parse_args()

    def run(args):
        method_algorithms = {
            "asha": ASHA,
            "dehb": DEHB,
            "bohb": BOHB,
            "sync-bohb": SyncBOHB,
            "gp": BayesianOptimization,
        }

        seed_list = [ random.randint(0, 9999) for _ in range(args.repeat) ]
        
        # data
        data_dir = "./data"
        if args.benchmark_name == 'lcbench':
            benchmark = LCBench(os.path.join(data_dir, "data_2k.json"), "segment")
        elif args.benchmark_name == 'taskset':
            benchmark = TaskSet(os.path.join(data_dir, "taskset_chosen.json"), 'rnn_text_classification_family_seed19')
        elif args.benchmark_name == 'pd1':
            benchmark = PD1(os.path.join(data_dir, "pd1_preprocessed.json"), 'imagenet_resnet_batch_size_512')
        else:
            raise NotImplementedError
        if args.dataset_names == "all":
            dataset_names = META_TEST_DATASET_DICT[args.benchmark_name]
        else:
            dataset_names = args.dataset_names.split("_")
        
        hp_names = list(benchmark.param_space.keys())
        full_hp_candidates = benchmark.get_hyperparameter_candidates()
        max_budget = benchmark.max_budget

        # configuration space
        hp_ranges = []
        for hp_index in range(0, full_hp_candidates.shape[1]):
            min_value = np.min(full_hp_candidates[:, hp_index])
            max_value = np.max(full_hp_candidates[:, hp_index])
            if min_value == max_value:
                if min_value == 0:
                    max_value = 1
                else:
                    min_value = 0
            hp_ranges.append((min_value, max_value))            
        config_space = {
            hp_names[hp_index]: sp.uniform(
                hp_ranges[hp_index][0], hp_ranges[hp_index][1]
            )
            for hp_index in range(0, full_hp_candidates.shape[1])
        }
        cs_fidelity = {
            "hp_epochs": sp.randint(0, max_budget),
        }

        for dataset_name in dataset_names:
            for repeat_idx in range(args.repeat):
                benchmark.set_dataset_name(dataset_name)
                seed = seed_list[repeat_idx]
                
                objectives_evaluations = np.ones(
                    (full_hp_candidates.shape[0], 1, max_budget, 2)
                )

                for hp_index in range(0, full_hp_candidates.shape[0]):
                    curve = benchmark.get_curve(hp_index, max_budget)
                    for budget in range(1, max_budget + 1):
                        objectives_evaluations[hp_index, 0, budget-1, 0] = curve[budget - 1]

                full_hp_candidates = pd.DataFrame(full_hp_candidates, columns=hp_names)
                blackbox_tabular = BlackboxTabular(
                    hyperparameters=full_hp_candidates,
                    configuration_space=config_space,
                    fidelity_space=cs_fidelity,
                    objectives_evaluations=objectives_evaluations,
                    objectives_names=["accuracy", "runtime"],
                )

                max_resource_attr = "hp_epochs"
                if (
                    args.method_name == "dehb"
                    or args.method_name == "bohb"
                    or args.method_name == "sync-bohb"
                ):
                    backend_blackbox = add_surrogate(
                        blackbox_tabular, surrogate=KNeighborsRegressor(n_neighbors=1)
                    )
                else:
                    backend_blackbox = blackbox_tabular

                trial_backend = UserBlackboxBackend(
                    blackbox=backend_blackbox,
                    elapsed_time_attr="runtime",
                )
                blackbox = trial_backend.blackbox
                restrict_configurations = blackbox_tabular.all_configurations()

                algorithm = method_algorithms[args.method_name]
                scheduler = algorithm(
                    config_space=blackbox.configuration_space_with_max_resource_attr(
                        max_resource_attr
                    ),
                    resource_attr="hp_epochs",
                    max_resource_attr=max_resource_attr,
                    mode="max",
                    metric="accuracy",
                    random_seed=seed,
                    search_options=dict(restrict_configurations=restrict_configurations),
                )

                stop_criterion = StoppingCriterion(
                    #max_cost=100
                    max_num_evaluations=10000
                    #max_wallclock_time=100*10000+10
                )
                # Printing the status during tuning takes a lot of time, and so does
                # storing results.
                print_update_interval = 700
                results_update_interval = 300
                # It is important to set ``sleep_time`` to 0 here (mandatory for simulator
                # backend)
                tuner = Tuner(
                    trial_backend=trial_backend,
                    scheduler=scheduler,
                    stop_criterion=stop_criterion,
                    n_workers=1,
                    sleep_time=0,
                    results_update_interval=results_update_interval,
                    print_update_interval=print_update_interval,
                    # This callback is required in order to make things work with the
                    # simulator callback. It makes sure that results are stored with
                    # simulated time (rather than real time), and that the time_keeper
                    # is advanced properly whenever the tuner loop sleeps
                    callbacks=[SimulatorCallback()],
                    tuner_name=f"{int(time.time())}",
                    metadata={"description": "Running a baseline for AutoFineTune"},
                )
                try:
                    tuner.run()
                except ValueError as e:
                    print(e)
                    pass
                # print(tuner.get_best_configuration())
                tuning_experiment = load_experiment(tuner.name)
                #print(tuning_experiment)
                #tuning_experiment.plot(figure_path="plot.jpg")
                #tuning_experiment.plot_trials_over_time(figure_path="plot_trials_over_time.jpg")
                # print(tuning_experiment)
                result_df = tuning_experiment.results
                #print("--------------here----------------------------")
                #print(result_df)
                epochs = result_df["hp_epochs"].values
                accuracies = result_df["accuracy"].values
                runtimes = result_df["runtime"].values
                info_dict = dict()
                epochs = epochs.tolist()
                spent_epochs = [i for i in range(0, len(epochs))]
                info_dict["accuracy"] = accuracies.tolist()
                max_value = 0
                incumbent_trajectory = []
                for accuracy in accuracies:
                    if accuracy > max_value:
                        max_value = accuracy
                    incumbent_trajectory.append(max_value)

                info_list = []
                info_list.append(spent_epochs)
                info_list.append(runtimes.tolist())
                info_list.append(incumbent_trajectory)
                info_dict["incumbent_trajectory"] = incumbent_trajectory

                output_dir = os.path.join(
                    args.output_dir,
                    args.method_name,
                    args.benchmark_name
                )

                os.makedirs(output_dir, exist_ok=True)
                with open(os.path.join(output_dir, f"{dataset_name}_{repeat_idx}.json"), "w") as fp:
                    json.dump(info_list, fp)

    run(args)
    print("Done")