import json
import sys
import torch
from ray import tune
from ray.tune.search import ConcurrencyLimiter
from ray.tune.search.optuna import OptunaSearch
from exp_ood import OODDetectionExp
from factory import (
    make_ood_dataset,
    make_lossfn,
    make_model,
    make_optimizer,
    make_reporter,
)

def parse_param(param):
    if not isinstance(param, dict):
        return tune.choice([param])
    if "min" in param.keys():
        if param["distribution"] == "randint":
            return tune.randint(param["min"], param["max"])
        elif param["distribution"] == "uniform":
            return tune.uniform(param["min"], param["max"])
        else:
            pass
    if "values" in param.keys():
        return tune.choice(param["values"])


def objective(cfg):
    reporter = make_reporter(cfg["reporter"], cfg)
    dataset_ind, dataset_ood_tr, dataset_ood_te = make_ood_dataset(cfg["dataset"])
    loss_fn, eval_func = make_lossfn(cfg["lossfn"])
    model = make_model(cfg["model"], dataset_ind)
    warmup_optimizer = None
    if cfg["model"]["name"].lower() == "sgcn":
        teacher, model = model
        teacher_optimizer = make_optimizer(cfg["optimizer"], teacher)
    elif cfg["model"]["name"].lower() == "gpn":
        optimizer, warmup_optimizer = make_optimizer(cfg["optimizer"], model)
    else:
        optimizer = make_optimizer(cfg["optimizer"], model)

    exp = OODDetectionExp(cfg=cfg["exp"], 
                          cfg_model=cfg["model"], 
                          model=model, 
                          criterion=loss_fn, 
                          eval_func=eval_func, 
                          optimizer=optimizer,
                          warmup_optimizer=warmup_optimizer,
                          reporter=reporter, 
                          dataset_ind=dataset_ind, 
                          dataset_ood_tr=dataset_ood_tr, 
                          dataset_ood_te=dataset_ood_te)

    metrics = exp.run()
    asdf = {
        "ood_aupr": metrics["END_AUROC"],
        "ood_aupr_in": metrics["END_AUPR_in"],
        "ood_aupr_out": metrics["END_AUPR_out"],
        "ood_fpr95": metrics["END_FPR95"],
        "ood_detection_acc": metrics["END_DETECTION_acc"],
    }
    return asdf


def parse_tuning_subconfigs(group_name, cfg):
    params, core = {}, {}
    if f"{group_name}_tuning" in cfg:
        for param in cfg[f"{group_name}_tuning"]:
            params[param["name"]] = parse_param(param)
    if group_name in cfg:
        core = cfg[group_name]
    return {**params, **core}


def main():
    config_name = sys.argv[1]
    with open("{}".format(config_name), mode="r") as f:
        cfg = json.load(f)

    dataset_cfg = cfg["dataset"]
    if isinstance(dataset_cfg, dict):
        dataset_cfg = [dataset_cfg]

    exp_config = parse_tuning_subconfigs("exp", cfg)
    model_config = parse_tuning_subconfigs("model", cfg)
    loss_fn_config = parse_tuning_subconfigs("lossfn", cfg)
    optimizer_config = parse_tuning_subconfigs("optimizer", cfg)

    reporter_cfg = {}
    if "reporter" in cfg:
        reporter_cfg = cfg["reporter"]

    for dataset in dataset_cfg:
        _cfg = {
            "dataset": dataset,
            "model": model_config,
            "lossfn": loss_fn_config,
            "optimizer": optimizer_config,
            "reporter": reporter_cfg,
            "exp": exp_config,
        }
        print(_cfg)

        trainable = tune.with_parameters(objective)
        if torch.cuda.is_available():
            trainable_with_resources = tune.with_resources(trainable, {"gpu": 1})
        else:
            trainable_with_resources = tune.with_resources(trainable, {"cpu": 2})

        num_samples = 64

        algo = OptunaSearch(seed=52519)
        tuner = tune.Tuner(
            trainable_with_resources,
            tune_config=tune.TuneConfig(
                metric="ood_aupr",
                mode="max",
                search_alg=ConcurrencyLimiter(algo, max_concurrent=64),
                num_samples=num_samples,
            ),
            param_space=_cfg,
        )
        results = tuner.fit()
        print("Best config is:", results.get_best_result().config)
        print("Best results are:", results.get_best_result())


if __name__ == "__main__":
    main()