import os
import argparse
import numpy as np
import ray
from ray import tune
from ray.tune import Stopper
from ray.tune.schedulers import ASHAScheduler, HyperBandScheduler
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.suggest.bohb import TuneBOHB
from ray.tune import CLIReporter
from train_image_classification import CIFAR
from train_NLP_tasks import GlueTask
from train_cluster_GCN import ClusterGcnTask
from train_gan import SNGAN
from train_vae import VAECelebA
import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH
import pickle
import random
import torch
import torch.optim as optim
import optimizers as new_optim
from yaml import load
from typing import Dict, Union
from ray.tune.utils import validate_save_restore
import logging

# initialize ray to call validation_save_restore()
ray.init(
    object_store_memory=1000 * 1000 * 1000 * 42,
    redis_max_memory=1000 * 1000 * 1000 * 42,
)


def get_task_class_and_set_path(task_name, config):
    task = None
    if task_name in ["CIFAR10", "CIFAR100"]:
        task = CIFAR
    elif task_name == "VAE":
        task = VAECelebA
    elif task_name in ["Glue-RTE", "Glue-MRPC"]:
        data_name = task_name[5:]
        hp_table = {
            "RTE": (2, 16, 2030, 122),
            "MRPC": (2, 16, 2296, 137),
        }
        config["data_dir"] = os.path.abspath(f"{data_name}-bin/")
        config["restore_file"] = os.path.abspath("./roberta.base/model.pt")
        config["dev_file"] = os.path.abspath(f"./glue_data/{data_name}/dev.tsv")
        config["num_classes"] = hp_table[data_name][0]
        config["max_sentences"] = hp_table[data_name][1]
        config["total_num_update"] = hp_table[data_name][2]
        config["warmup_updates"] = hp_table[data_name][3]

        # config["gpt2_encoder_json"] = os.path.abspath("./encoder.json")
        # config["gpt2_vocab_bpe"] = os.path.abspath("./vocab.bpe")
        task = GlueTask
    elif task_name == "GCN":
        config["data_prefix"] = os.path.abspath("./fastergcn")
        task = ClusterGcnTask
    elif task_name == "RL":
        raise NotImplementedError
    else:
        raise NotImplementedError
    # validate the correctness of _save() and _restore() functions before any
    # execution
    # validate_save_restore(task)
    # validate_save_restore(task, use_object_store=True)
    return task


class Uniform(object):
    def __init__(self, a_min: float, a_max: float, log: bool):
        self.a_min = a_min
        self.a_max = a_max
        self.log = log

    def make_HB(self):
        if self.log:
            return tune.sample_from(
                lambda _: 10 ** (np.random.uniform(self.a_min, self.a_max))
            )
        else:
            return tune.uniform(self.a_min, self.a_max)

    def make_BOHB(self, name):
        if self.log:
            return CSH.UniformFloatHyperparameter(
                name,
                lower=10 ** self.a_min,
                upper=10 ** self.a_max,
                log=True,
                log_base=10.0,
            )
        else:
            return CSH.NormalFloatHyperparameter(
                name, lower=self.a_min, upper=self.a_max
            )


def parse_distributions(hyper: Dict[str, Union[str, float]], use_BOHB: bool):
    """Parse the distribution settings to distribution objects in ray.tune."""
    # hyper space contains all tunable parameters; const space contains fixed parameters
    hyper_space = [] if use_BOHB else {}
    const_space = {}
    for k, v in hyper.items():
        if isinstance(v, float) or isinstance(v, int):
            const_space[k] = v
        elif isinstance(v, str):
            if use_BOHB:
                hyper_space.append(eval(v).make_BOHB(k))
            else:
                if "choice" in v:
                    hyper_space[k] = eval(v)
                else:
                    hyper_space[k] = eval(v).make_HB()
        else:
            raise ValueError(f"Value type not understood: {v} with type: {type(v)}")
    return hyper_space, const_space


def make_optimizer(config, optimizer_name: str, use_BOHB: bool):
    """Given optimizer name such as SGD.1param, make the optimizer object with parameter space.

    Args:
        config: dict, configuration dictionary.
        optimizer_name: str, the name of optimizer with number of parameters.
        use_BOHB: bool, whether or not to use BOHB.

    Returns:
        An tuple of optimizer instance with hyperparameter space and/or const space.
    """
    opt, nparams = optimizer_name.split(".")
    hyper = config["Optimizers"][opt][nparams]
    if opt == "SGD":
        optimizer = optim.SGD
    elif opt == "Adam":
        optimizer = new_optim.NewAdam
    elif opt == "RAdam":
        optimizer = new_optim.NewRAdam
    elif opt == "Yogi":
        optimizer = new_optim.NewYogi
    elif opt == "Lookahead":
        optimizer = new_optim.NewLookahead
    elif opt == "LAMB":
        optimizer = new_optim.NewLAMB
    elif opt == "LARS":
        optimizer = new_optim.NewLARS
    else:
        raise NotImplementedError
    hyper_space, const_space = parse_distributions(hyper, use_BOHB)
    return optimizer, hyper_space, const_space


def make_lr_scheduler(config, scheduler_name: str, use_BOHB: bool):
    """Given lr scheduler name, process the lr_scheduler parameter distribution."""
    hyper = config["LR_scheduler"][scheduler_name]
    hyper_space, const_space = parse_distributions(hyper, use_BOHB)
    return hyper_space, const_space


def task_info(config, task: str):
    task_config = config["Tasks"]["default"]
    task_config.update(config["Tasks"][task])
    print("Task_config: ", task_config)
    return task_config


def make_search_algo(config, task_config, tuner: str, hyper_space=None):
    """Build the search algorithm given argument."""
    tuner_config = config["Tuners"][tuner]
    experiment_metrics = dict(metric=task_config["metric"], mode=task_config["mode"])
    if tuner == "HB":
        sched = HyperBandScheduler(
            max_t=task_config["max_t"], **experiment_metrics, **tuner_config,
        )
        search_algo = None
    elif tuner == "BOHB":
        sched = HyperBandForBOHB(
            max_t=task_config["max_t"], **experiment_metrics, **tuner_config,
        )
        search_algo = TuneBOHB(hyper_space, max_concurrent=2, **experiment_metrics)
    else:
        raise NotImplementedError(f"Unknown tuner: {tuner}")
    return sched, search_algo


def main(args):
    # Set random seeds
    # warnings.filterwarnings("ignore", category=DeprecationWarning)
    logging.getLogger("tensorflow").disabled = True
    if args.seed is None:
        args.seed = random.randint(0, 1e3)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # Parse config file
    with open(args.config_file, "r") as reader:
        config = load(reader)

    # Parse distribution settings in optimizer and lr scheduler
    optimizer, opt_hyper_space, opt_const_space = make_optimizer(
        config, args.optimizer, use_BOHB=(args.tuner == "BOHB")
    )
    lr_hyper_space, lr_const_space = make_lr_scheduler(
        config, args.scheduler, use_BOHB=(args.tuner == "BOHB")
    )

    task_config = task_info(config, args.task)

    if args.tuner == "BOHB":
        bohb_hyper_space = CS.ConfigurationSpace(seed=args.seed)
        bohb_hyper_space.add_hyperparameters(opt_hyper_space)
        bohb_hyper_space.add_hyperparameters(lr_hyper_space)
        opt_hyper_space = {}
        lr_hyper_space = {}
    else:
        bohb_hyper_space = None
    sched, search_algo = make_search_algo(
        config, task_config, args.tuner, bohb_hyper_space
    )

    expr_config = {
        "args": args,
        "task_config": task_config,
        "optimizer": optimizer,
        **opt_const_space,
        **lr_const_space,
        **opt_hyper_space,
        **lr_hyper_space,
    }

    exp_name = (
        args.task
        + "_"
        + args.optimizer
        + "_"
        + args.scheduler
        + "_"
        + args.tuner
        + "_{}".format(task_config["num_config"])
        + "_seed-{}".format(args.seed)
    )

    # Get task class and setup the data path before spawning the subprocess
    task_cls = get_task_class_and_set_path(args.task, expr_config)
    analysis = tune.run(
        task_cls,
        name=exp_name,
        local_dir=task_config["save_root"],
        scheduler=sched,
        search_alg=search_algo,
        stop={
            "training_iteration": 3 if args.smoke_test else task_config["max_t"],
            "early_stop": True,
        },
        resources_per_trial={"cpu": task_config["ncpu"], "gpu": task_config["ngpu"]},
        num_samples=1 if args.smoke_test else task_config["num_config"],
        checkpoint_at_end=True,
        checkpoint_freq=task_config["checkpoint_freq"],
        config=expr_config,
    )

    print(
        "Best config: ",
        analysis.get_best_config(
            metric=task_config["metric"], mode=task_config["mode"]
        ),
    )

    dfs = analysis.trial_dataframes
    experiment_dir = os.path.join(task_config["save_root"], exp_name)
    pickle.dump(dfs, open(f"{experiment_dir}/log.pkl", "wb"))

    all_configs = analysis.get_all_configs()
    pickle.dump(all_configs, open(f"{experiment_dir}/all_configs.pkl", "wb"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_file",
        type=str,
        default="./config.yaml",
        help="Path to config.yaml file",
    )
    parser.add_argument(
        "--task",
        type=str,
        required=True,
        choices=[
            "CIFAR10",
            "CIFAR100",
            "VAE",
            "GAN",
            "Glue-RTE",
            "Glue-MRPC",
            "GCN",
            "RL",
        ],
        help="Type of tasks.",
    )
    parser.add_argument(
        "--tuner",
        type=str,
        required=True,
        choices=["HB", "BOHB"],
        help="Name of tuning algorithm.",
    )
    parser.add_argument(
        "--optimizer",
        type=str,
        required=True,
        help="Name of optimizer (in format `optimizer.Nparam`, e.g. SGD.1param).",
    )
    parser.add_argument(
        "--scheduler", type=str, required=True, help="Name of learning rate scheduler",
    )
    parser.add_argument("--seed", type=int, help="Random seed")
    parser.add_argument(
        "--smoke-test", action="store_true", help="Finish quickly for testing"
    )
    args = parser.parse_args()
    main(args)
