"""
This script will define the hyperparameter ranges and execute hyperparameter
tuning for a given dataset on all models.
"""

import argparse
import json
import logging
import os
from math import isnan
from syne_tune import Tuner, StoppingCriterion
from syne_tune.backend import PythonBackend
from syne_tune import config_space as config_space_class
from syne_tune.optimizer.baselines import ASHA
from syne_tune.experiments import load_experiment
from numpyencoder import NumpyEncoder


def train_hpo_model(
    model_name: str,
    dataset_name: str,
    dim_reduction: float,
    epochs: int,
    lr: float,
    lambda_c: float,
    package_path: str,
):
    import sys
    import json

    sys.path.append(package_path)
    # Module imports
    import importlib
    from models.train_utils.train import train_model

    # TODO check if importlib is really needed here
    # (why does train_model work without)
    model_base_path = f"models"
    model_class = importlib.import_module(model_base_path)
    model_class = getattr(model_class, model_name)
    # model_base_path = f"models.{model_name}"
    # model_class = importlib.import_module(model_base_path + ".model")
    # model_class = getattr(model_class, model_name)
    # Load config and add dynamic params
    model_path = f"models/architectures/{model_name}/"
    with open(model_path + "conf.json", "r") as fp:
        model_config = json.load(fp)["static"]
    model_config["lr"] = lr
    model_config["epochs"] = epochs
    model_config["dim_reduction"] = dim_reduction
    if lambda_c != "None":
        model_config["lambda_c"] = lambda_c
    # Train model
    train_model(
        model_class=model_class,
        dataset_name=dataset_name,
        config=model_config,
        dim_reduction=dim_reduction,
        tuning=True,
        checkpointing=False,
    )


def get_tuning_range(range_type, lower, upper):
    assert range_type in ["uniform", "loguniform"], "Unsupported range type"
    range_class = getattr(config_space_class, range_type)
    return range_class(lower, upper)


def get_args_parser(add_help=True):

    parser = argparse.ArgumentParser(
        description="Cuso Verse Benchmarking", add_help=add_help
    )
    # Script Arguments
    parser.add_argument("--model-name", nargs="+", help="Model name, e.g. StandardAE")
    parser.add_argument(
        "--dataset-name", nargs="+", help="Dataset name, e.g. TeaRetail"
    )
    parser.add_argument(
        "--latent-ratio",
        nargs="+",
        type=float,
        default=[0.505],
        help="The ratio between input and latent dimension (latent/input)",
    )
    parser.add_argument(
        "--n-workers",
        default=8,
        type=int,
        help="Number of workers",
    )
    parser.add_argument(
        "--max-time",
        default=120,
        type=int,
        help="Maximum Wallclock time for one tuning job",
    )
    return parser


if __name__ == "__main__":

    args = get_args_parser().parse_args()

    root = logging.getLogger()
    root.setLevel(logging.INFO)

    n_workers = args.n_workers
    max_time = args.max_time
    metric = "mean_loss"
    mode = "min"
    package_path = os.getcwd()

    for m_name in args.model_name:
        for d_name in args.dataset_name:
            for ratio in args.latent_ratio:
                # Load model config
                model_path = f"models/architectures/{m_name}/"
                with open(model_path + "conf.json", "r") as fp:
                    model_config = json.load(fp)
                # Create and populate tuning config space
                config_space = {
                    "dim_reduction": ratio,
                    "model_name": m_name,
                    "dataset_name": d_name,
                    "epochs": model_config["static"]["epochs"],
                    "lambda_c": "None",
                    "package_path": package_path,
                }
                # for key, value in model_config["static"].items():
                #     config_space[key] = value
                for key, value in model_config["tunable"].items():
                    config_space[key] = get_tuning_range(*value)
                # Create tuning components
                scheduler = ASHA(
                    config_space,
                    metric=metric,
                    max_resource_attr="epochs",
                    resource_attr="epoch",
                    mode=mode,
                )
                trial_backend = PythonBackend(
                    tune_function=train_hpo_model, config_space=config_space
                )
                stop_criterion = StoppingCriterion(
                    max_wallclock_time=max_time, min_metric_value={metric: 0.0}
                )
                tuner = Tuner(
                    trial_backend=trial_backend,
                    scheduler=scheduler,
                    stop_criterion=stop_criterion,
                    n_workers=n_workers,
                    save_tuner=False,
                    wait_trial_completion_when_stopping=True,
                    start_jobs_without_delay=False,
                    # trial_backend_path="./model_comparison/tuning_results/",
                )
                # Start hyperparameter tuning
                tuner.run()
                # Get results
                tuner_path = tuner.tuner_path
                tuning_experiment = load_experiment(tuner_path)
                # TODO save dataframe with tuning results
                root.info(
                    f"Tuning {m_name} on {d_name} with "
                    f"{ratio} compression finished:"
                )
                output_path = "artifacts/tuning/"
                tuning_experiment.results.to_csv(
                    output_path + f"{m_name}_{d_name}_{ratio}.csv"
                )
                best_config = tuning_experiment.best_config()
                if isnan(best_config["config_lambda_c"]):
                    del best_config["config_lambda_c"]
                root.info(f"-> best result found: {best_config}")
                with open(
                    output_path + f"{m_name}_{d_name}_{ratio}_best.json", "w"
                ) as jf:
                    json.dump(best_config, jf, cls=NumpyEncoder)
