import mlflow
import numpy as np
from mlflow.entities import Experiment
from warnings import warn
import torch
import logging

from utils.data.load_dataset import get_train_test_dataset
from hyperparameter_tuning.utils.gpytorch.models.exact_gpr import ExactGPR
from utils.result_management.constants import EXPERIMENT_TYPE, HYPER_PARAMETER_TUNING_EXPERIMENT, SEED, PARSER_ARGS, \
    ALGORITHM, DATASET, KL_CURR_FINAL, KL_FINAL_CURR, EXACT_LOG_DET, EXACT_QUAD, PARAMETERS, APPROXIMATE_LOSS, RMSE, \
    EXACT_LOSS
from utils.result_management.result_management import get_results_path, get_steps_and_values_from_run, load_artifact_dict

assert(mlflow.get_tracking_uri() == get_results_path())


def run_local_auxiliary_computations(experiment_id: str) -> ():
    """
    Fill EXACT_LOSS field from the computed values.
    :param experiment_id:
    :return:
    """
    mlfc = mlflow.tracking.MlflowClient()
    exp: Experiment = mlfc.get_experiment(experiment_id)
    if not exp.tags[EXPERIMENT_TYPE] == HYPER_PARAMETER_TUNING_EXPERIMENT:
        warn("Experiment is not a hyper-parameter tuning experiment. Abort.")
        return
    dataset = exp.tags[DATASET]
    X, _, _, _ = get_train_test_dataset(dataset)
    c = X.shape[0] / 2 * np.log(2 * np.pi)
    del X
    all_runs = mlflow.search_runs(experiment_ids=[experiment_id], output_format="pandas")

    with torch.no_grad():
        for i, r in all_runs.iterrows():
            run_id = r["run_id"]
            if r["tags." + ALGORITHM] == ExactGPR.get_registry_key():
                try:
                    steps, _ = get_steps_and_values_from_run(run_id, RMSE)
                    _, losses = get_steps_and_values_from_run(run_id, APPROXIMATE_LOSS)
                    for i in range(len(steps)):
                        mlfc.log_metric(run_id, EXACT_LOSS, losses[steps[i]-1], step=steps[i])
                except Exception as e:
                    logging.error(e)
            else:
                try:
                    steps, quads = get_steps_and_values_from_run(run_id, EXACT_QUAD)
                    steps, dets = get_steps_and_values_from_run(run_id, EXACT_LOG_DET)
                    for i in range(len(steps)):
                        v = quads[i] / 2 + dets[i] / 2 + c
                        mlfc.log_metric(run_id, EXACT_LOSS, v, step=steps[i])
                except Exception as e:
                    logging.error(e)


if __name__ == "__main__":
    experiment_ids = [str(i) for i in range(1, 9)]
    for e in experiment_ids:
        run_local_auxiliary_computations(e)
    #exit()
