import logging
import warnings
import numpy as np
import matplotlib.pyplot as plt
import mlflow

# this import is necessary to set the tracking uri correct
import utils.result_management.result_management as result_management
from hyperparameter_tuning.utils.gpytorch.models.variational_gpr import NativeVariationalGPR, OPTIMIZE_INDUCING_INPUTS, \
    VariationalGPR, NUM_INDUCING_INPUTS
from utils.result_management.constants import LOSS_TIME, GRAD_TIME, EXACT_LOG_DET, EXACT_QUAD, ALGORITHM, RMSE, DATASET, \
    KERNEL, EXACT_LOSS, BLOCK_SIZE, LOSS_TIME_PS, GRAD_TIME_PS, SETUP_TIME, APPROXIMATE_LOSS, SEED, NLPD
from utils.result_management.result_management import get_steps_and_values_from_run as get_steps_and_values_from_run_
from utils.visualization.visualization_constants import exact_color, acgp_color, cglb_color, svgp_color
from hyperparameter_tuning.analysis.make_results_table import checkRunIsIncluded as checkRunIsIncludedInTable
from hyperparameter_tuning.analysis.make_results_table import algo_labels_dict, algos, metric_names

MAKE_TIKZ = True

metric = NLPD  #EXACT_LOSS  #RMSE  #NLPD
optimizer = "L-BFGS-B"
#optimizer = "BFGS"  # why is BFGS no good? That's surprising

id_list = [str(i) for i in range(5, 9)]  # large scale CPU experiments
#id_list = [str(i) for i in range(11, 15)]  # few-cores experiments
#id_list = [str(i) for i in range(1, 5)]   # GPU experiments
id_list = [str(i) for i in range(1, 9)]
#id_list = ["5"]


def checkRunIsIncluded(algo, run):
    if not checkRunIsIncludedInTable(algo, run):
        return False
    elif algo == NativeVariationalGPR.get_registry_key() and run.data.tags[NUM_INDUCING_INPUTS] != "1024":
        return False
    return True


plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amssymb} \usepackage{amsmath} \usepackage{marvosym} \usepackage{bm}')
plt.rc('font', family='serif')
plt.rcParams.update({'font.size': 18})


def get_algo_labels():
    return algo_labels_dict.copy()
algo_labels = get_algo_labels()
inv_map = {v: k for k, v in algo_labels.items()}
colors = [exact_color()[-1], acgp_color()[-1], cglb_color()[-1], svgp_color()[-1]]
algo_colors = {algos[i].get_registry_key(): colors[i] for i in range(len(algos))}
symbols = ['x-', '.-', '+-', 'o-']
algo_symbols = {algos[i].get_registry_key(): symbols[i] for i in range(len(algos))}


def get_steps_and_values_from_run(run, metric: str):
    steps, values = get_steps_and_values_from_run_(run_id=run.info.run_id, metric=metric)
    steps = np.array(steps)
    values = np.array(values)
    sorted = np.argsort(steps)
    return steps[sorted], values[sorted]


for exp_id in id_list:
    exp = mlflow.tracking.MlflowClient().get_experiment(exp_id)
    dataset_name = exp.tags[DATASET]
    kernel_name = exp.tags["kernel_name"]

    mlfc = mlflow.tracking.MlflowClient()
    runs = mlfc.search_runs([exp.experiment_id], filter_string="tags.optimizer='%s'" % optimizer)

    #plt.figure(1, figsize=(14, 7))
    plt.figure(1, figsize=(9, 5), constrained_layout=True)
    #plt.xlabel("\\xlabeltime{}")
    #plt.ylabel("\\ylabelRMSE{}")

    for run in runs:
        if ALGORITHM not in run.data.tags.keys():
            continue  # crashed run
        algo = run.data.tags[ALGORITHM]
        if not checkRunIsIncluded(algo, run):
            continue

        try:
            accepted_steps, metric_values = get_steps_and_values_from_run(run, metric)
            steps_loss, loss_run_times = get_steps_and_values_from_run(run, LOSS_TIME)
            steps_grad, grad_run_times = get_steps_and_values_from_run(run, GRAD_TIME)
            if len(grad_run_times) == len(loss_run_times) - 1:
                # it can happen that a job is killed just after computing the loss but while computing the gradient
                assert(steps_grad[-1] == steps_loss[-2])
                loss_run_times = loss_run_times[:-1]
            setup_time = run.data.metrics[SETUP_TIME]
            run_times = np.cumsum(loss_run_times + grad_run_times) + setup_time

            #accepted_steps = accepted_steps[:-1]  # for some reason, some plots jump back to the start---the last point isn't interesting anyway
            #metric_values = metric_values[:-1]
            if len(metric_values) < 4:
                warnings.warn(f"Run {run.info.run_id} in experiment {run.info.experiment_id} has less than 4 values!")
                continue
            # we need a -1 because the steps start counting at 1
            #plt.plot(run_times[accepted_steps-1], metric_values, color=algo_colors[algo], label=algo_labels[algo],
            #         lw=1.5, marker='.', ms=5.0, mew=2.0, alpha=.7,)
            plt.plot(run_times[accepted_steps-1], metric_values, color=algo_colors[algo], label=algo_labels[algo],
                     lw=5, ms=1.5, mew=2.0, alpha=.3,)
            plt.plot(run_times[accepted_steps-1], metric_values, '.', color=algo_colors[algo], label=None,
                     ms=2.5)
            algo_labels[algo] = None  # to make sure we add the legend entry only once
        except Exception as e:
            logging.exception(e)

    plt.legend(frameon=False)
    leg = plt.gca().get_legend()
    for handle in leg.legendHandles:
        #handle.set_color(algo_colors[inv_map[handle.get_label()]])  # this is to get rid of the transparency
        handle.set_alpha(1.)
    plt.title(r'Metric Evolution over GP training -- \texttt{%s}' % dataset_name.replace("_", "-").replace("wilson-", ""))
    plt.xlabel(r'time in seconds')
    plt.ylabel(r'%s' % metric_names[metric])
    plt.gca().set_xscale('log')
    if MAKE_TIKZ:
        plt.savefig(fname='./output/figures/hyperparametertuning/experiment_4_hyp_tuning' + str(dataset_name) + metric + '.pdf', format='pdf')
    else:
        plt.show()

    plt.clf()
    algo_labels = get_algo_labels()
