from typing import cast

import numpy as np
import pandas as pd
import plotly.graph_objects as go

from ..utils import reindex_positionwise


def plot_correctness_over_epochs(
    memorization_log: pd.DataFrame,
    epochs_to_show: list[int],
) -> go.Figure:
    correctness = (
        reindex_positionwise(cast(pd.DataFrame, memorization_log[["correct"]]))[
            "correct"
        ]
        .loc[epochs_to_show]
        .dropna(axis=0)
    )
    correctness_gird = correctness.unstack(level="token_index")

    fig = go.Figure(
        data=go.Heatmap(
            z=correctness_gird,
            x=correctness_gird.columns,
            y=correctness_gird.index,
            colorscale=[(0, "red"), (1, "green")],
            showscale=False,
        ),
    )

    fig.update_layout(
        coloraxis_showscale=False,  # Hide color scale bar
        width=800,
        height=600,
    )
    fig.update_yaxes(
        autorange="reversed",
        # tickmode="array",
        # tickvals=correctness_gird.index,
        # ticktext=correctness_gird.index,
    )

    # Add custom annotations for the legend
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            marker=dict(size=10, color="green"),
            legendgroup="group",
            showlegend=True,
            name="Correct",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=[None],
            y=[None],
            mode="markers",
            marker=dict(size=10, color="red"),
            legendgroup="group",
            showlegend=True,
            name="Incorrect",
        )
    )
    return fig


def plot_sequence_prob_distribution(
    token_distribution: pd.DataFrame,
    sequence: str,
    log_scale: bool,
) -> go.Figure:
    """Creates a heatmap over the token probabilities that is animated
    over the differnt epoch timesteps.
    """
    distribution_columns = list(token_distribution.columns)
    distribution_columns.remove("correct")
    sequence_distribution = token_distribution.loc[
        token_distribution.index.get_level_values("string") == sequence,
        distribution_columns,
    ]
    if not log_scale:
        sequence_distribution = sequence_distribution.apply(np.exp)
    return plot_sequence_distribution(sequence_distribution, "distribution")


def plot_sequence_error_distribution(
    token_distribution: pd.DataFrame,
    sequence: str,
    # export_to: str | None = None,
) -> go.Figure:
    """Plots a heatmap showing which tokens were predicted correctly at
    which position. The heatmap is animated over the training episodes,
    i.e. you can play it or scroll to a specific episode.

    Args:
        token_distribution: The token distribution dataframe
            It should have a multiindex with the levels "epoch", "sequence",
            "character" (i.e. the character at the i-th position in the string),
            and one column for each character in the sequence's
            alphabet, with the log-probability of that character at that
            position in the string.
        sequence: The string memorized by the model (tokenized chracter-wise)
        export_to: If given, a converted dataframe showing whether the
            i-th token in the sequence at each epoch has been predicted
            correctly is exported to this path.
    Returns:
        The plotly figure object of the heatmap.
    """
    prediction_correctness = token_distribution.loc[
        token_distribution.index.get_level_values("string") == sequence,
        ["correct"],
    ]
    # max_prob_cols = sequence_distribution.idxmax(axis=1)
    # target_characters = pd.DataFrame(
    #     {"target": sequence_distribution.index.get_level_values("character")},
    #     index=sequence_distribution.index,
    # )
    # correct_tokens = target_characters["target"] == max_prob_cols
    # correct_df = correct_tokens.to_frame(name="correct").astype(int)

    # if export_to is not None:
    #     correct_df.to_csv(f"{export_to}.csv")
    #     correct_df.to_pickle(f"{export_to}.pkl")

    return plot_sequence_distribution(prediction_correctness, "correctness")


def plot_sequence_distribution(
    distribution: pd.DataFrame,
    description: str,
) -> go.Figure:
    """Creates a heatmap over a token distribution that is animated
    over the differnt epoch timesteps.
    """

    max_val = distribution.max().max()
    min_val = distribution.min().min()
    frame_0 = None
    frames = []
    for epoch, epoch_dist in distribution.groupby("epoch"):
        if cast(int, epoch) % 2 != 0:
            # It's enough to just show every second epoch and save some space
            continue
        epoch_dist = epoch_dist.droplevel("epoch")
        heatmap = go.Heatmap(
            z=epoch_dist.T,
            y=epoch_dist.columns,
            zmin=min_val,
            zmax=max_val,
        )
        if epoch == 0:
            frame_0 = heatmap
        frames.append(
            go.Frame(
                data=[heatmap],
                name=str(epoch),
                layout=go.Layout(
                    title_text=f"Epoch {epoch} {description}",
                    xaxis_title="Token position",
                ),
            ),
        )

    fig = go.Figure(
        data=[frame_0],
        frames=frames,
        layout=go.Layout(
            title=f"Epoch 0 {description}",
            xaxis_title="Token position",
            width=1000,
            height=500,
            updatemenus=[
                {
                    "buttons": [
                        {
                            "args": [
                                None,
                                {
                                    "frame": {
                                        "duration": 500,
                                        "redraw": True,
                                    },
                                    "fromcurrent": True,
                                },
                            ],
                            "label": "Play",
                            "method": "animate",
                        },
                        {
                            "args": [
                                [None],
                                {
                                    "frame": {
                                        "duration": 0,
                                        "redraw": True,
                                    },
                                    "mode": "immediate",
                                    "transition": {"duration": 0},
                                },
                            ],
                            "label": "Pause",
                            "method": "animate",
                        },
                    ],
                    "direction": "left",
                    "pad": {"r": 10, "t": 87},
                    "type": "buttons",
                    "x": 0.1,
                    "xanchor": "right",
                    "y": 0,
                    "yanchor": "top",
                },
            ],
            sliders=[
                {
                    "steps": [
                        {
                            "args": [
                                [frame.name],
                                {
                                    "frame": {
                                        "duration": 300,
                                        "redraw": True,
                                    },
                                    "mode": "immediate",
                                    "transition": {"duration": 0},
                                },
                            ],
                            "label": str(frame.name),
                            "method": "animate",
                        }
                        for frame in frames
                    ],
                    "transition": {"duration": 0},
                    "x": 0.1,
                    "y": 0,
                    "yanchor": "top",
                    "xanchor": "left",
                    "currentvalue": {
                        "font": {"size": 20},
                        "prefix": "Epoch:",
                        "visible": True,
                        "xanchor": "right",
                    },
                    "len": 0.9,
                    "pad": {"t": 50, "b": 10},
                }
            ],
        ),
    )
    return fig
