from typing import cast

import numpy as np
import pandas as pd
import plotly.graph_objects as go


def plot_multi_performance_heatmaps(
    performance: dict[str, pd.DataFrame],
    sequence_length: int,
    title: str,
) -> go.Figure:
    """Plots multiple types of prefix performance heatmaps with a dropdown
    to navigate between them.
    """
    figures = []
    for prefix_performance in performance.values():
        fig = plot_performance_heatmap(prefix_performance, sequence_length)
        figures.append(fig)

    # Step 2: Create a figure with a dropdown to navigate between
    # the context types
    buttons = []
    for prefix_type, fig in zip(performance.keys(), figures):
        button = dict(
            args=[
                {
                    "z": [fig.data[0].z],
                    "title.text": fig.layout.title.text,
                }
            ],
            label=prefix_type,
            method="update",
        )
        buttons.append(button)

    updatemenus = [
        dict(
            active=0,
            buttons=buttons,
            direction="up",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0,
            xanchor="left",
            y=-0.02,
            yanchor="top",
        )
    ]

    layout = go.Layout(
        updatemenus=updatemenus,
        title=title,
    )
    figures[0].update_layout(layout)
    # figures[0].show()
    return figures[0]


def plot_performance_heatmap(
    prefix_performance: pd.DataFrame, sequence_length: int
) -> go.Figure:
    """Plots a heatmap for different tokens showing the performance,
    i.e. the number of correct samples that can be achieved with each token.

    Args:
        prefix_performance: DataFrame with index levels
            - token_idx: The index of the token in the the sequence
            - prefix_length: the number of context tokens retained
            and columns
            - correct_samples: The fraction of tokens correctly predicted
    Returns:
        A figure with a heatmap showing for each token which positions
        are correctly predicted.
    """
    prefix_performance.index.names = ["token_idx", "prefix_length"]
    # Create a blank matrix for the heatmap, initialized with zeros
    # heatmap_matrix = -1 * np.ones((num_tokens, num_tokens))
    heatmap_matrix = -1 * np.ones((sequence_length, sequence_length))

    for token_idx, token_performance in prefix_performance.groupby("token_idx"):
        # We have to round, because the values are averages for the shuffled
        # context types that use multiple samples
        token_idx = cast(int, token_idx)
        # Highlight the target token
        heatmap_matrix[token_idx, token_idx] = 0.5

        for prefix_length, performance in token_performance.groupby(
            "prefix_length"
        ):
            prefix_length = cast(int, prefix_length)
            prefix_start = token_idx - prefix_length
            heatmap_matrix[token_idx, prefix_start] = performance[
                "correct_samples"
            ].iloc[0]

    # Plot heatmap
    figure = go.Figure(
        data=go.Heatmap(
            z=heatmap_matrix,
            hoverongaps=False,
        )
    )

    figure.update_layout(
        # title="Prefix length performance",
        xaxis_title="Context Token",
        yaxis_title="Target Token",
        # width=1100,
        height=1000,
        # yaxis=dict(
        #     tickvals=list(range(heatmap_matrix.shape[0])),
        #     ticktext=[
        #         # With fixed length of 3 digits
        #         f"{token_idx}: {round(token_count):3d}"
        #         for token_idx, token_count in zip(
        #             mean_context_counts.index,
        #             mean_context_counts["min_context_size"],
        #         )
        #     ],
        # ),
    )
    return figure
