import logging
import math
import warnings

import mlflow
import numpy as np
import torch

import run_script_util  # this import is necessary to run the script from different locations
from utils.data.load_dataset import get_train_test_dataset
from acgp.hooks.fixed_stopping_point_hook import FixedStoppingPointHook
from hyperparameter_tuning.utils.abstract_hyper_parameter_tuning_algorithm import AbstractHyperParameterTuningAlgorithm
from hyperparameter_tuning.utils.gpytorch.kernel_factory import KernelFactory
from hyperparameter_tuning.utils.gpytorch.stopped_cholesky import _StoppedCholesky, StoppedCholesky
from hyperparameter_tuning.utils.gpytorch.variational_gpr import SELECTION_SCHEME, RANDOM
from make_results_table_old import checkRunIsIncluded
from hyperparameter_tuning.run_hyper_parameter_tuning import HYPER_PARAMETER_TUNING_ALGOS, LOWER_NOISE_CONSTRAINT, _make_noise_func, \
    _make_mean_func
from utils.execution.run_cluster import execute_job_array_on_slurm_cluster
from utils.result_management.constants import ALGORITHM, PARSER_ARGS, DATASET, PARAMETERS, NLPD, SEED, NLPDfixed
from utils.result_management.result_management import load_artifact_dict, get_steps_and_values_from_run, get_results_path


assert(mlflow.get_tracking_uri() == get_results_path())


def update_nlpd_value(experiment_id, run_id):
    print(experiment_id)
    print(run_id)
    mlfc = mlflow.tracking.MlflowClient()
    exp = mlfc.get_experiment(experiment_id)
    mlflow.set_experiment(exp.name)
    dataset = exp.tags[DATASET]
    run = mlflow.get_run(run_id=run_id)
    seed = run.data.tags[SEED]
    device = "cpu"

    original_parser_args = load_artifact_dict(experiment_id=experiment_id, run_id=run_id,
                                              artifact_name=PARSER_ARGS)

    X, y, X_test, y_test = get_train_test_dataset(dataset, seed=int(seed))
    X = torch.tensor(X)
    y = torch.tensor(y)
    X_test = torch.tensor(X_test)
    y_test = torch.tensor(y_test)

    k = KernelFactory().create(args=original_parser_args, X=X)

    c = torch.log(2 * torch.tensor(math.pi))

    steps, losses = get_steps_and_values_from_run(run_id, NLPD)  # get steps for accepted steps

    # TODO: remove
    steps = [np.max(steps)]

    # remove already computed steps
    try:
        already_computed_steps, _ = get_steps_and_values_from_run(run_id, NLPDfixed)
    except:
        already_computed_steps = []

    steps = np.setdiff1d(steps, already_computed_steps)


    algo_key = original_parser_args[ALGORITHM]
    # ACGP is a special case
    if algo_key == StoppedCholesky.get_registry_key():
        steps_, processed_points = get_steps_and_values_from_run(run_id, "FULLY_PROCESSED_DATAPOINTS")
        assert(len(np.setdiff1d(steps, steps_)) == 0)  # all steps should be contained in steps_
        steps_ = np.array(steps_)
    elif SELECTION_SCHEME in original_parser_args.keys():
        # the inducing inputs are going to be overwritten anyway
        # no need to start the pivoted Cholesky
        original_parser_args[SELECTION_SCHEME] = RANDOM

    for step in steps:
        # load hyper-parameters
        params = load_artifact_dict(experiment_id=experiment_id, run_id=run_id, artifact_name=PARAMETERS + str(step))
        raw_sn2 = torch.tensor(params.pop("raw_sn2"))
        lower_noise_constraint = torch.tensor(float(original_parser_args[LOWER_NOISE_CONSTRAINT]), dtype=torch.float64)
        sn2 = _make_noise_func(raw_sn2, lower_noise_constraint=lower_noise_constraint)
        raw_mu = torch.tensor(params.pop("raw_mu"))
        mu = _make_mean_func(raw_mu)
        if algo_key == StoppedCholesky.get_registry_key():
            # ACGP is a special case
            class PostExperimentStoppedCholesky(_StoppedCholesky):
                def _get_hook(self):
                    idx = np.where(steps_ == step)[0]
                    assert(len(idx) == 1)
                    return FixedStoppingPointHook(int(processed_points[idx[0]]))
            algo = PostExperimentStoppedCholesky
        else:
            algo = HYPER_PARAMETER_TUNING_ALGOS[algo_key]

        algorithm: AbstractHyperParameterTuningAlgorithm = algo(X, y, k, sn2, mu, original_parser_args, device)
        closure = algorithm.create_loss_closure()

        for n, _ in k.named_parameters():
            p = torch.nn.Parameter(torch.tensor(params[n], dtype=torch.float64, requires_grad=False))
            # setattr doesn't work due to dots in the name
            exec(f"k.{n} = p")
        for n, p in algorithm.get_named_tunable_parameters():
            p.data = torch.tensor(params[n], dtype=torch.float64)
        closure()
        #raise NotImplementedError("Stopping the script here just to investigate the out-of-memory exception.")
        mu, fvar = algorithm.get_posterior(X_test)
        mu, var = algorithm.get_y_posterior(X_test, mu, fvar)
        nlpd = (torch.mean(torch.square(mu - y_test) / var + torch.log(var)) + c) / 2
        mlfc.log_metric(run_id, NLPDfixed, nlpd.item(), step=step)


def generate_batch_jobs():
    template = "python %s " % __file__
    for e in mlflow.list_experiments():
        command_ls = []
        # start job array for each experiment, that is each run is one cluster job
        runs = mlflow.tracking.MlflowClient().search_runs(experiment_ids=[e.experiment_id])
        for r in runs:
            try:
                algo = r.data.tags[ALGORITHM]
                if NLPDfixed in r.data.metrics.keys():
                    continue
                if not checkRunIsIncluded(algo, r):
                    continue
                # add to batch job
                command_ls.append(template + e.experiment_id + " " + r.info.run_id)
            except Exception as ex:
                logging.exception(ex)
            # TODO: remove!
            #break
        # no need to run exclusive experiments
        if len(command_ls) > 0:
            execute_job_array_on_slurm_cluster(command_ls, cpus=10, exclusive=False, max_jobs_parallel=4, set_core_affinity=False)
        else:
            warnings.warn(f"Nothing to run for experiment {e.experiment_id}?!")
        #break  # TODO: remove!


if __name__ == "__main__":
    # TODO: remove
    #update_nlpd_value("4", "406e4700f14d4c4a8b97f3fe34d75c0a")
    #exit()

    from sys import argv
    if len(argv) >= 2:
        update_nlpd_value(argv[1], argv[2])
    else:
        generate_batch_jobs()
