import os
from typing import Any, Callable, Union

import numpy as np
import pandas as pd
import plotly.graph_objects as go

from lib_dl.analysis.aggregate import aggregate_mean_std

from .plotting import add_std_dev_trace


def plot_training_loss(
    results: dict[str, list[pd.DataFrame]],
    legend_title: str,
) -> go.Figure:
    progress_unit = "epoch"
    for res_name in results.keys():
        res_list = results[res_name]
        converted_results = []
        for res in res_list:
            converted_res = res.groupby(progress_unit).mean().reset_index()
            converted_results.append(converted_res)
        results[res_name] = converted_results
    mean_results, std_results = compute_training_loss_mean_std(
        results, [progress_unit]
    )

    name_prefix = os.path.commonprefix(list(mean_results.keys()))
    mean_results = {
        res_name[len(name_prefix) :]: res
        for res_name, res in mean_results.items()
    }
    std_results = {
        res_name[len(name_prefix) :]: res
        for res_name, res in std_results.items()
    }

    fig = go.Figure()
    for i, (res_name, mean_res) in enumerate(mean_results.items()):
        std_res = std_results[res_name]
        add_std_dev_trace(
            fig,
            i,
            mean_res["loss"],
            std_res["loss"],
            name=res_name,
            x_values=mean_res[progress_unit],
        )
    fig.update_layout(
        title="Training loss",
        xaxis_title=progress_unit.capitalize(),
        yaxis_title="Loss",
        # legend_title_text="Learning rate",
        legend_title_text=legend_title,
        width=800,
        height=600,
    )
    return fig


def compute_training_loss_mean_std(
    results: dict[str, list[pd.DataFrame]],
    # results: dict[str, list[Result]],
    additional_columns: list[str],
) -> tuple[dict[str, pd.DataFrame], dict[str, pd.DataFrame]]:
    mean_results = {}
    std_results = {}
    for res_name, res_list in results.items():
        training_histories = [
            (
                res[["loss", *additional_columns]]
                .rename_axis("iter", axis=0)
                .dropna(axis=0)
            )
            for res in res_list
            # if hasattr(res.value, "training_history")
        ]
        # if len(training_histories) == 0:
        #     continue
        mean_results[res_name], std_results[res_name] = aggregate_mean_std(
            training_histories,
            levels_to_preserve=["iter"],
        )
    return mean_results, std_results


def plot_training_accuracy(
    results: dict[str, list[pd.DataFrame]],
    legend_title: str,
) -> go.Figure:
    epoch_accuracy_means, epoch_accuracy_stds = {}, {}
    for res_name, res_list in results.items():
        accuracies = [
            res["correct"].groupby("epoch").mean() for res in res_list
        ]
        acc_mean, acc_std = aggregate_mean_std(
            accuracies,
            ["epoch"],
        )
        epoch_accuracy_means[res_name] = acc_mean
        epoch_accuracy_stds[res_name] = acc_std

    fig = go.Figure()
    for i, (res_name, acc_mean) in enumerate(epoch_accuracy_means.items()):
        acc_std = epoch_accuracy_stds[res_name]
        add_std_dev_trace(
            fig,
            i,
            acc_mean,
            acc_std,
            name=res_name,
            x_values=acc_mean.index,
        )
    fig.update_layout(
        title="Training accuracy",
        xaxis_title="Epoch",
        yaxis_title="Accuracy",
        legend_title=legend_title,
        width=800,
        height=600,
    )
    return fig


def plot_sequence_token_distribution(
    token_distribution: pd.DataFrame,
    sequence: str,
    log_scale: bool,
) -> go.Figure:
    """Creates a heatmap over the token probabilities that is animated
    over the differnt epoch timesteps.
    """
    distribution_columns = list(token_distribution.columns)
    distribution_columns.remove("correct")
    sequence_distribution = token_distribution.loc[
        token_distribution.index.get_level_values("string") == sequence,
        distribution_columns,
    ]
    if not log_scale:
        sequence_distribution = sequence_distribution.apply(np.exp)
    return plot_sequence_distribution(sequence_distribution, "distribution")


def plot_sequence_error_distribution(
    token_distribution: pd.DataFrame,
    sequence: str,
    # export_to: str | None = None,
) -> go.Figure:
    """Plots a heatmap showing which tokens were predicted correctly at
    which position. The heatmap is animated over the training episodes,
    i.e. you can play it or scroll to a specific episode.

    Args:
        token_distribution: The token distribution dataframe
            It should have a multiindex with the levels "epoch", "sequence",
            "character" (i.e. the character at the i-th position in the string),
            and one column for each character in the sequence's
            alphabet, with the log-probability of that character at that
            position in the string.
        sequence: The string memorized by the model (tokenized chracter-wise)
        export_to: If given, a converted dataframe showing whether the
            i-th token in the sequence at each epoch has been predicted
            correctly is exported to this path.
    Returns:
        The plotly figure object of the heatmap.
    """
    prediction_correctness = token_distribution.loc[
        token_distribution.index.get_level_values("string") == sequence,
        ["correct"],
    ]
    # max_prob_cols = sequence_distribution.idxmax(axis=1)
    # target_characters = pd.DataFrame(
    #     {"target": sequence_distribution.index.get_level_values("character")},
    #     index=sequence_distribution.index,
    # )
    # correct_tokens = target_characters["target"] == max_prob_cols
    # correct_df = correct_tokens.to_frame(name="correct").astype(int)

    # if export_to is not None:
    #     correct_df.to_csv(f"{export_to}.csv")
    #     correct_df.to_pickle(f"{export_to}.pkl")

    return plot_sequence_distribution(prediction_correctness, "correctness")


def plot_sequence_distribution(
    distribution: pd.DataFrame,
    description: str,
) -> go.Figure:
    """Creates a heatmap over a token distribution that is animated
    over the differnt epoch timesteps.
    """

    max_val = distribution.max().max()
    min_val = distribution.min().min()
    frame_0 = None
    frames = []
    for epoch, epoch_dist in distribution.groupby("epoch"):
        if int(epoch) % 2 != 0:
            # It's enough to just show every second epoch and save some space
            continue
        epoch_dist = epoch_dist.droplevel("epoch")
        heatmap = go.Heatmap(
            z=epoch_dist.T,
            y=epoch_dist.columns,
            zmin=min_val,
            zmax=max_val,
        )
        if epoch == 0:
            frame_0 = heatmap
        frames.append(
            go.Frame(
                data=[heatmap],
                name=str(epoch),
                layout=go.Layout(
                    title_text=f"Epoch {epoch} {description}",
                    xaxis_title="Token position",
                ),
            ),
        )

    fig = go.Figure(
        data=[frame_0],
        frames=frames,
        layout=go.Layout(
            title=f"Epoch 0 {description}",
            xaxis_title="Token position",
            width=1000,
            height=500,
            updatemenus=[
                {
                    "buttons": [
                        {
                            "args": [
                                None,
                                {
                                    "frame": {
                                        "duration": 500,
                                        "redraw": True,
                                    },
                                    "fromcurrent": True,
                                },
                            ],
                            "label": "Play",
                            "method": "animate",
                        },
                        {
                            "args": [
                                [None],
                                {
                                    "frame": {
                                        "duration": 0,
                                        "redraw": True,
                                    },
                                    "mode": "immediate",
                                    "transition": {"duration": 0},
                                },
                            ],
                            "label": "Pause",
                            "method": "animate",
                        },
                    ],
                    "direction": "left",
                    "pad": {"r": 10, "t": 87},
                    "type": "buttons",
                    "x": 0.1,
                    "xanchor": "right",
                    "y": 0,
                    "yanchor": "top",
                },
            ],
            sliders=[
                {
                    "steps": [
                        {
                            "args": [
                                [frame.name],
                                {
                                    "frame": {
                                        "duration": 300,
                                        "redraw": True,
                                    },
                                    "mode": "immediate",
                                    "transition": {"duration": 0},
                                },
                            ],
                            "label": str(frame.name),
                            "method": "animate",
                        }
                        for frame in frames
                    ],
                    "transition": {"duration": 0},
                    "x": 0.1,
                    "y": 0,
                    "yanchor": "top",
                    "xanchor": "left",
                    "currentvalue": {
                        "font": {"size": 20},
                        "prefix": "Epoch:",
                        "visible": True,
                        "xanchor": "right",
                    },
                    "len": 0.9,
                    "pad": {"t": 50, "b": 10},
                }
            ],
        ),
    )
    return fig
