"""
Script to check whether all results exist.
"""
import warnings
import mlflow
import shutil
from os.path import sep

from mlflow.exceptions import MlflowException

from hyperparameter_tuning.utils.gpytorch.cglb import CGLB, OPTIMIZE_INDUCING_INPUTS
from hyperparameter_tuning.utils.gpytorch.exact_gpr import ExactGPR
from hyperparameter_tuning.utils.gpytorch.kernel_factory import KernelFactory
from hyperparameter_tuning.utils.gpytorch.stopped_cholesky import StoppedCholesky
from utils.registry import KERNEL_DICT
from utils.result_management import ENV_CPUS, find_experiments_with_tags, delete_runs_if_crashed, filter_runs_with_tags, \
    get_run_list_from_dataframe
from hyperparameter_tuning.run_hyper_parameter_tuning import OPTIMIZER
from utils.result_management.constants import SEED, ALGORITHM, RMSE
from utils.result_management.result_management import delete_run

datasets = ['metro', 'tamilnadu_electricity', 'pm25', 'protein', 'bank']

seeds = [str(i) for i in range(0, 5)]

cpus = 40
environment = {ENV_CPUS: cpus}
block_size = 256 * cpus


def check_consistency(all_runs, algorithm, seed):
    runs = all_runs.loc[(all_runs[f"tags.{SEED}"] == seed) & (all_runs[f"tags.{ALGORITHM}"] == algorithm)]
    if len(runs) >= 1:
        ids = []
        for _, r in runs.iterrows():
            try:
                results = mlflow.tracking.MlflowClient().get_metric_history(r['run_id'], RMSE)
                if len(results) < 20:
                    delete_run(mlflow.get_run(r['run_id']))
                else:
                    ids.append(r)
                    if len(runs) > 1:
                        print(f"{len(ids)}: took {len(results)} steps ({r['run_id']}, {r['experiment_id']})")
            except MlflowException as e:
                assert(e.message.startswith(f"Metric '{RMSE}' not found"))
                delete_run(mlflow.get_run(r['run_id']))
        if len(ids) > 1:
            warnings.warn(f"More than one result for for {algorithm} and seed {seed}.")
            action = input("Type number for which run to keep or [c]ancel:")
            if action != "c" and action != "":
                action = int(action)
                for i in range(len(ids)):
                    if i == action - 1:
                        continue
                    delete_run(mlflow.get_run(ids[i]['run_id']))
        elif len(ids) == 0:
            # TODO: run job
            warnings.warn(f"Missing run for {algorithm} and seed {seed}")

    elif len(runs) == 0:
        # TODO: run job
        warnings.warn(f"Missing run for {algorithm} and seed {seed}")
    return runs


experiment_ids = [str(i) for i in range(5, 9)]
algorithms = [a.get_registry_key() for a in [ExactGPR, StoppedCholesky, CGLB]]
for experiment_id in experiment_ids:
    runs = mlflow.search_runs(experiment_ids=[experiment_id])
    runs = runs.loc[(runs["tags." + OPTIMIZER] == 'L-BFGS-B') & (runs["tags.r"].isna()) & (runs["tags." + OPTIMIZE_INDUCING_INPUTS] != "False")]
    for a in algorithms:
        for s in seeds:
            check_consistency(all_runs=runs, algorithm=a, seed=s)
