import time
from typing import Dict, List, Optional

import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import scipy
import seaborn as sns
from autogluon.timeseries.dataset import TimeSeriesDataFrame
from autogluon.timeseries.metrics.quantile import SQL
from autogluon.timeseries.utils.datetime import get_seasonality

from .elo_utils import compute_elo_ratings
from .metrics import get_ag_metric, get_metric
from .models import BestValidationModel, GreedyEnsemble, LinearEnsemble


def val_preds_to_folds(y_val_preds: Dict[str, List[TimeSeriesDataFrame]]):
    """
    Transform a dict of lists into a list fo dicts.
    Convenient to transform the original base model predictions on the validation data
    into a list of predictions for each fold, in order to compute the validation loss.
    """
    return [dict(zip(y_val_preds, t)) for t in zip(*y_val_preds.values())]


def compute_loss(metric, y_test, y_pred, data):
    if isinstance(y_pred, list):
        return np.mean([compute_loss(metric, y, p, data) for y, p in zip(y_test, y_pred)])
    if isinstance(y_pred, dict):
        return compute_loss(metric, y_test, val_preds_to_folds(y_pred), data)

    seasonal_period = data["eval_metric_seasonal_period"] or get_seasonality(y_test.freq)
    metric.save_past_metrics(
        data_past=y_test.slice_by_timestep(None, -data["prediction_length"]),
        target=data["target"],
        seasonal_period=seasonal_period,
    )
    return metric.compute_metric(
        y_test.slice_by_timestep(-data["prediction_length"], None),
        y_pred,
        target=data["target"],
    )


def format_results_df(df: pd.DataFrame, baseline: Optional[Dict[str, Dict[str, float]]] = None):
    """
    SOMEWHAT OUTDATED! I mostly use `add_metrics` now.

    Format the results data frame obtained from mlflow to prepare it for plotting and evaluation.

    Args:
        df: data frame obtained from mlflow
        baseline: dictionary containing the baseline values for each dataset, used to compute
            relative performance of each method, of the form {dataset: {metric: baseline_value}}
    """

    df.replace("None", np.nan, inplace=True)

    for k in (
        "params.lr",
        "params.optimizer_kwargs.lr",
        "params.epochs",
        "params.max_iter",
        "params.optimizer_kwargs.tolerance_change",
        "params.optimizer_kwargs.tolerance_grad",
        "params.optimizer_kwargs.xtol",
        "params.optimizer_kwargs.gtol",
        "params.tolerance_change",
        "params.tolerance_grad",
        "params.rank",
        "params.n_windows",
    ):
        if k in df:
            df[k] = df[k].astype("float")

    if "params.optimizer_kwargs.lr" in df:
        df["params.lr"] = df["params.optimizer_kwargs.lr"]

    for k in (
        "params.softmax",
        "params.optimizer",
        "params.weights_per",
    ):
        if k in df:
            df[k] = df[k].astype("category")

    # if "params.rank" in df:
    #     df["params.rank"] = df["params.rank"].astype("float")

    if baseline is not None:
        datasets = df["params.dataset"].unique()
        for dataset in datasets:
            for ql in ("sql", "wql"):
                val_losses = df[df["params.dataset"] == dataset][f"metrics.{ql}_val"]
                reference_loss = baseline[dataset][f"metrics.{ql}_val"]
                df.loc[df["params.dataset"] == dataset, f"metrics.{ql}_val_relative"] = val_losses / reference_loss

                test_losses = df[df["params.dataset"] == dataset][f"metrics.{ql}_test"]
                reference_loss = baseline[dataset][f"metrics.{ql}_test"]
                df.loc[df["params.dataset"] == dataset, f"metrics.{ql}_test_relative"] = test_losses / reference_loss
    return df


def add_metrics(df: pd.DataFrame, baseline: str, compute_rank_only_on_complete=True):
    _rank_type = "average"
    # _rank_type = "min"
    if compute_rank_only_on_complete:
        total_methods = df["method"].nunique()
        complete_datasets = df.groupby("dataset").filter(lambda x: x["method"].nunique() == total_methods)
        complete_datasets["rank"] = complete_datasets.groupby("dataset")["test_loss"].rank(_rank_type)
        df["rank"] = np.nan  # Initialize all ranks as NaN
        df.loc[complete_datasets.index, "rank"] = complete_datasets["rank"]
    else:
        df["rank"] = df.groupby("dataset")["test_loss"].rank(_rank_type)
    # df["rank"] = df.groupby("dataset")["test_loss"].rank("min")

    # relative test_loss
    baseline_loss = df[df["method"] == baseline].set_index("dataset")["test_loss"]
    df = df.set_index("dataset")
    df["baseline_loss"] = df.index.map(baseline_loss)
    df["relative_loss"] = df["test_loss"] / df["baseline_loss"]
    df["relative_loss"] = df["relative_loss"].clip(1e-3, 5)
    # clip to 1e-3, 5
    df = df.reset_index()

    # relative val_loss
    baseline_val_loss = df[df["method"] == baseline].set_index("dataset")["val_loss"]
    df = df.set_index("dataset")
    df["baseline_val_loss"] = df.index.map(baseline_val_loss)
    df["relative_val_loss"] = df["val_loss"] / df["baseline_val_loss"]
    df["relative_val_loss"] = df["relative_val_loss"].clip(1e-4, 10)
    # clip to 1e-3, 5
    df = df.reset_index()

    # relative duration
    if "duration" in df.columns:
        baseline_duration = df[df["method"] == baseline].set_index("dataset")["duration"]
        df = df.set_index("dataset")
        df["baseline_duration"] = df.index.map(baseline_duration)
        df["relative_duration"] = df["duration"] / df["baseline_duration"]
        df = df.reset_index()
    else:
        for d in ("training_time", "inference_time"):
            if d in df.columns:
                baseline_duration = df[df["method"] == baseline].set_index("dataset")[d]
                df = df.set_index("dataset")
                df[f"baseline_{d}"] = df.index.map(baseline_duration)
                df[f"relative_{d}"] = df[d] / df[f"baseline_{d}"]
                df = df.reset_index()

    def classify_performance(row):
        relative = True
        delta = 0.001
        diff = row["test_loss"] - row["baseline_loss"]
        if abs(diff) < (delta if not relative else delta * row["baseline_loss"]):
            return "tie"
        elif diff < 0:
            return "win"
        else:
            return "loss"

    df["performance"] = df.apply(classify_performance, axis=1)

    df["champion"] = df["rank"] == 1
    df["top-1"] = df["rank"] == 1
    df["top-3"] = df["rank"] <= 3
    df["win"] = (df["performance"] == "win") + 0.5 * (df["performance"] == "tie")

    return df


def _plot_win_rate(agg_df, ax, highlight_best=3, alphas=None):
    if alphas is None:
        alphas = np.ones(len(agg_df["method"]))
    values = agg_df["win_rate"] - 0.5
    top3_idx = values.nlargest(highlight_best, keep="all").index
    colors = [sns.color_palette()[0]] * len(values)
    highlight_color = sns.color_palette()[2]
    for i, v in enumerate(values):
        if v < 0:
            colors[i] = sns.color_palette("deep")[3]
    for idx in top3_idx:
        colors[idx] = highlight_color
    ax.barh(
        agg_df["method"],
        values[::-1],
        left=0.5,
        color=[c + (a,) for c, a in zip(colors[::-1], alphas[::-1])],
    )
    ax.set_xlabel("Win rate vs. baseline")
    ax.set_ylim(-0.5, len(agg_df["method"]) - 0.5)
    ax.axvline(x=[0.5], color="black", linewidth=1)
    ax.set_xlim(0, 1)
    ax.set_xticks([0, 1])
    ax.set_xticklabels([f"{int(tick*100)}%" for tick in ax.get_xticks()])
    # have x lines at each 10 steps
    ax.set_axisbelow(True)
    ax.minorticks_on()
    ax.tick_params(axis="y", which="minor", left=False)  # Disable minor ticks on y-axis
    ax.set_xticks(np.arange(0, 1.1, 0.1), minor=True)
    ax.grid(True, which="both", axis="x", linestyle="-", linewidth=0.5)


def _plot_rank(agg_df, ax, highlight_best=3, alphas=None):
    # drop index from agg_df
    agg_df = agg_df.reset_index(drop=True)
    if alphas is None:
        alphas = np.ones(len(agg_df["method"]))
    values = agg_df["average_rank"]
    top3_idx = values.nsmallest(highlight_best, keep="all").index
    colors = [sns.color_palette()[0]] * len(values)
    highlight_color = sns.color_palette()[2]
    for idx in top3_idx:
        colors[idx] = highlight_color
    ax.barh(
        agg_df["method"][::-1],
        values[::-1],
        # left=0,
        color=[c + (a,) for c, a in zip(colors[::-1], alphas[::-1])],
    )
    ax.bar_label(
        ax.containers[0],
        fontsize=10,
        padding=-24,
        color="white",
        fontweight="bold",
        fmt="%.1f",
    )
    ax.set_ylim(-0.5, len(agg_df["method"]) - 0.5)


def _plot_elo(agg_df, ax, highlight_best=3, alphas=None):
    agg_df = agg_df.reset_index(drop=True)
    if alphas is None:
        alphas = np.ones(len(agg_df["method"]))
    values = agg_df["elo"]
    top3_idx = values.nlargest(highlight_best, keep="all").index
    colors = [sns.color_palette()[0]] * len(values)
    highlight_color = sns.color_palette()[2]
    for idx in top3_idx:
        colors[idx] = highlight_color
    ax.barh(
        agg_df["method"][::-1],
        values[::-1],
        left=0,
        color=[c + (a,) for c, a in zip(colors[::-1], alphas[::-1])],
    )
    ax.bar_label(
        ax.containers[0],
        fontsize=10,
        padding=-30,
        color="white",
        fontweight="bold",
        fmt="%.0f",
    )
    ax.set_ylim(-0.5, len(agg_df["method"]) - 0.5)


def plot_results(
    agg_df,
    df,
    # METHODS,
    methods_to_ignore=None,
    show_progress=False,
    n_datasets=None,
    figsize=None,
    highlight_best=3,
    plot_win_rate=True,
    plot_training_time=True,
    plot_inference_time=True,
):
    agg_df = agg_df.reset_index(drop=True)
    if figsize is None:
        figsize = (13, 12 / 100 * len(agg_df) + 2)
    _agg_df = agg_df.copy()
    _df = df.copy()
    if methods_to_ignore is not None:
        for method in methods_to_ignore:
            _agg_df.loc[_agg_df["method"] == method, _agg_df.columns != "method"] = np.nan
            _df.loc[_df["method"] == method, "relative_duration"] = np.nan
    agg_df = _agg_df
    df = _df

    plot_elo = "elo" in agg_df.columns

    n_plots = 2 + plot_elo + plot_win_rate + (plot_training_time or plot_inference_time) + show_progress
    fig, axes = plt.subplots(1, n_plots, figsize=figsize)

    if show_progress:
        total = n_datasets or df["dataset"].nunique()
        diff = 5
        low = total - diff
        alphas = np.clip((agg_df["total"].to_numpy() - low) / (total - low), a_min=0.2, a_max=1.0)
        # print(np.unique(alphas))
    else:
        alphas = np.ones_like(agg_df["total"].to_numpy())
    _alphas = np.ones_like(alphas) * np.nanmin(alphas)

    ax = axes[0]
    _plot_rank(agg_df, ax, highlight_best=highlight_best, alphas=_alphas)
    ax.set_xlabel("Average rank")

    if plot_elo:
        ax = axes[1]
        _plot_elo(agg_df, ax, highlight_best=highlight_best, alphas=_alphas)
        ax.set_xlabel("ELO")
        ax.set_yticklabels([])
        ax.set_ylabel(None)

    ax = axes[1 + plot_elo]
    values = agg_df["average_relative_error"] - 1
    top3_idx = values.nsmallest(highlight_best).index
    colors = [sns.color_palette()[0]] * len(values)
    for i, v in enumerate(values):
        if v > 0:
            colors[i] = sns.color_palette("deep")[3]
    highlight_color = sns.color_palette()[2]
    for idx in top3_idx:
        colors[idx] = highlight_color
    ax.barh(
        agg_df["method"],
        values[::-1],
        left=1,
        color=[c + (a,) for c, a in zip(colors[::-1], alphas[::-1])],
    )
    ax.set_xlabel("Average relative error")
    ax.set_ylim(-0.5, len(agg_df["method"]) - 0.5)
    ax.axvline(x=[1], color="black", linewidth=1)
    ax.set_yticklabels([])
    ax.set_ylabel(None)
    ax.set_axisbelow(True)
    ax.grid(True, which="both", axis="x", linestyle="-", linewidth=0.5)
    # ax.set_xscale("log")

    if plot_win_rate:
        ax = axes[2 + plot_elo]
        _plot_win_rate(agg_df, ax, highlight_best=highlight_best, alphas=alphas)
        ax.set_yticklabels([])
        ax.set_ylabel(None)

    ax = axes[2 + plot_elo + plot_win_rate]
    if plot_training_time:
        sns.boxplot(
            data=df,
            x="relative_training_time",
            y="method",
            ax=ax,
            legend=False,
            order=agg_df["method"],
            label="Training",
            color=sns.color_palette()[0],
        )
        ax.set_yticklabels([])
        ax.set_ylabel(None)
        ax.set_xscale("log")
        ax.axvline(x=[1], color="black", zorder=0, linewidth=1)
        # # have x lines at each 10^x in log scale, without changing the xlims
        # xlims = ax.get_xlim()
        # ax.set_axisbelow(True)
        # ax.minorticks_on()
        # ax.tick_params(axis='y', which='minor', left=False)  # Disable minor ticks on y-axis
        # ax.set_xticks(10.0 ** np.arange(-5, 5))
        # ax.grid(True, which='major', axis='x', linestyle='-', linewidth=0.5)
        # ax.set_xlim(xlims)
    if plot_inference_time:
        sns.boxplot(
            data=df,
            x="relative_inference_time",
            y="method",
            ax=ax,
            legend=False,
            order=agg_df["method"],
            label="Inference",
            color=sns.color_palette()[1],
        )
        ax.set_yticklabels([])
        ax.set_ylabel(None)
        ax.set_xscale("log")
        ax.axvline(x=[1], color="black", zorder=0, linewidth=1)
        # # have x lines at each 10^x in log scale, without changing the xlims
        # xlims = ax.get_xlim()
        # ax.set_axisbelow(True)
        # ax.minorticks_on()
        # ax.tick_params(axis='y', which='minor', left=False)  # Disable minor ticks on y-axis
        # ax.set_xticks(10.0 ** np.arange(-5, 5))
        # ax.grid(True, which='major', axis='x', linestyle='-', linewidth=0.5)
        # ax.set_xlim(xlims)
    if plot_training_time and plot_inference_time:
        # add legend
        ax.legend()
        ax.set_xlabel("Relative training/inference time")

    if show_progress:
        ax = axes[-1]
        done = agg_df["total"] / total * 100
        ax.barh(
            agg_df["method"],
            done[::-1],
            color=[sns.color_palette()[0] + (a,) for c, a in zip(colors[::-1], alphas[::-1])],
        )
        ax.set_xlabel("experiments done")
        ax.set_yticklabels([])
        ax.set_ylim(-0.5, len(agg_df["method"]) - 0.5)
        ax.set_xlim(0, 100)
        ax.set_xticks([0, 100])
        ax.set_xticklabels([f"{int(tick)}%" for tick in ax.get_xticks()])
        ax.set_ylabel(None)
        # ax.bar_label(ax.containers[0], fontsize=10, padding=-33, color="white", fontweight="bold", fmt="%.0f%%")
        f = lambda x: int(x) if not np.isnan(x) else x
        ax.bar_label(
            ax.containers[0],
            labels=[f"{f(x)}" for x in agg_df["total"].to_numpy()][::-1],
            fontsize=10,
            padding=-18,
            color="white",
            fontweight="bold",
            fmt="%.0f",
        )

    fig.tight_layout()
    return fig


def aggregate(
    df,
    methods: List[str] | None = None,
    add_elo=True,
    elo_calibration_framework=None,
    elo_calibration_elo=None,
):
    if "duration" in df.columns:
        df["training_time"] = df["duration"]
        df["relative_training_time"] = df["relative_duration"]
    for k in [
        "relative_training_time",
        "relative_inference_time",
        "inference_time_with_base",
    ]:
        if k not in df.columns:
            df[k] = np.NaN
    agg_df = df.groupby("method", as_index=False).agg(
        total=("performance", "count"),
        # wins=('performance', lambda x: (x=='win').sum()),
        # losses=('performance', lambda x: (x=='loss').sum()),
        # ties=('performance', lambda x: (x=='tie').sum()),
        champion=("rank", lambda x: (x == 1).sum()),
        average_rank=("rank", "mean"),
        win_rate=("performance", lambda x: ((x == "win") + 0.5 * (x == "tie")).mean()),
        average_relative_error=("relative_loss", scipy.stats.gmean),
        # Training and inference time related things:
        average_training_time=("training_time", "mean"),
        median_training_time=("training_time", "median"),
        median_inference_time=("inference_time", "median"),
        median_inference_time_with_base=("inference_time_with_base", "median"),
        # Relative training and inference times
        average_relative_training_time=(
            "relative_training_time",
            lambda x: scipy.stats.gmean(x, nan_policy="omit"),
        ),
        average_relative_inference_time=(
            "relative_inference_time",
            lambda x: scipy.stats.gmean(x, nan_policy="omit"),
        ),
        # training and inference time in actual numbers, not relative
    )
    if methods is not None:
        agg_df["method"] = pd.Categorical(agg_df["method"], categories=methods, ordered=True)
        agg_df = agg_df.sort_values("method").reset_index(drop=True)

    if add_elo:
        print("Compute ELO scores...")
        _df = df.copy()
        _df["framework"] = _df.method
        _df["metric_error"] = _df["test_loss"]
        _df = _df[["dataset", "framework", "metric_error"]]
        elos = compute_elo_ratings(
            _df,
            calibration_framework=elo_calibration_framework,
            calibration_elo=elo_calibration_elo,
        )
        agg_df.insert(loc=4, column="elo", value=elos.median()[agg_df.method].to_numpy())
        agg_df.insert(
            loc=5,
            column="elo_lower",
            value=elos.quantile(0.025)[agg_df.method].to_numpy(),
        )
        agg_df.insert(
            loc=6,
            column="elo_upper",
            value=elos.quantile(0.975)[agg_df.method].to_numpy(),
        )

    return agg_df


def latexify_results(agg_df):
    COLUMNS = [
        ("Method", "method", lambda x: x.capitalize()),
        (r"$(\uparrow)$ \makecell[r]{Elo}", "elo", lambda x: f"{x:.0f}"),
        (r"$(\uparrow)$ \makecell[r]{Champion}", "champion", lambda x: f"{x:.0f}"),
        (r"$(\downarrow)$ \makecell[r]{Average \\ rank}", "average_rank", lambda x: f"{x:.2f}"),
        (
            r"$(\downarrow)$ \makecell[r]{Average \\ relative loss}",
            "average_relative_error",
            lambda x: f"{x:.3f}",
        ),
        # ("\makecell[r]{Loss}", "average_relative_error", lambda x: f"{x:.3f}"),
        # ("\makecell[r]{Win rate \\\\ vs.\\ baseline}", "win_rate", lambda x: f"{x*100:.0f}\%"),
        # ("\makecell[r]{Win rate}", "win_rate", lambda x: f"{x*100:.0f}\%"),
        (
            r"$(\downarrow)$ \makecell[r]{Median marginal \\ training time}",
            "median_training_time",
            lambda x: f"{x:.0f}s",
        ),
        # ("\makecell[r]{Average \\ training time}", "average_training_time", lambda x: f"{x:.0f}s"),
        # (
        #     "\makecell[r]{Median marginal \\\\ inference time}",
        #     "median_inference_time",
        #     lambda x: f"{x:.2f}s",
        # ),
        # ("\makecell[r]{Median total \\\\ inference time}", "median_inference_time_with_base", lambda x: f"{x:.0f}s"),
        # ("Additional Relative Training Time", "average_additional_relative_training_time", lambda x: f"{x*100:.2f}\%"),
        # ("Relative Inference Time", "average_relative_inference_time", lambda x: f"{x:.2f}"),
    ]
    output_df = agg_df.rename(columns={cnew: cprev for cprev, cnew, _ in COLUMNS})[[cnew for cnew, _, _ in COLUMNS]]
    print(output_df.to_latex(index=False, formatters={cnew: formatter for cnew, _, formatter in COLUMNS}))
    return output_df
