from typing import Any, Callable, Iterable

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import torch
from IPython.display import HTML
from IPython.display import Markdown as md
from IPython.display import display
from plotly.subplots import make_subplots

from lib_dl.analysis.aggregate import aggregate_mean_std
from lib_dl.analysis.experiment import ExperimentResult, load_results
from lib_dl.analysis.publish.notebook import publish_notebook

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as RCSExperimentResult


Result = ExperimentResult[ExperimentConfig, RCSExperimentResult]


def load(
    config_name: str,
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        Result,
    )


def publish(
    notebook: str = "pythia",
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "eval":
        output_path = f"experiments/{EXP_NAME}/eval_39fd32d.html"
        notebook_path = f"./experiments/{EXP_NAME}/notebooks/eval.ipynb"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    publish_notebook(
        notebook_path,
        output_path,
    )


def show_sequence_results(
    results: list[Result],
    show_strings: bool = True,
    show_required_tokens: bool = True,
    show_attention_maps: bool = False,
) -> None:
    assert len(results) == 1
    res = results[0].value

    min_context_sizes = res.min_token_context_sizes.reindex(
        [
            "preceding_masked",
            "preceding_shuffled",
            "attention_0_masked",
            "attention_0_shuffled",
            "attention_8_masked",
            "attention_8_shuffled",
            "attention_15_masked",
            "attention_15_shuffled",
        ],
        level="context_type",
    )

    if show_strings:
        print("Memorized strings:")
        display(md("\n".join(f"- {s}" for s in res.sequences)))

    display(md("### Minimum context sizes for correct prediction"))
    display(
        compute_mean_min_tokens_for_correct_prediction(
            min_context_sizes,
        )
    )

    if show_required_tokens:
        display(md("### Minimum contexts for correct prediction"))
        show_correct_prediction_tokens(
            min_context_sizes,
            res.attentions,
        )
    if show_attention_maps:
        display(md("### Attention maps"))
        show_multilayer_attention_maps(res.attentions)


def compute_mean_min_tokens_for_correct_prediction(
    min_context_results: pd.DataFrame,
) -> pd.DataFrame:
    # Compute the mean of the minimum number of tokens needed
    # for correct prediction
    mean_min_tokens_for_correct_prediction = min_context_results.groupby(
        "context_type", sort=False
    ).agg(["mean", "std"])
    mean_min_tokens_for_correct_prediction.columns = [
        "min_context_size (mean)",
        "min_context_size (std)",
    ]
    return mean_min_tokens_for_correct_prediction


def show_multilayer_attention_maps(maps: np.ndarray) -> None:
    # Make sure that there's only one sequence
    assert maps.shape[1] == 1
    num_layers = maps.shape[0]
    for layer_idx in [0, round(num_layers / 2), num_layers - 1]:
        display(md(f"**Layer {layer_idx} attention**"))
        fig = plot_attention_maps(maps[layer_idx, 0])
        fig.show()


def plot_attention_maps(
    attention: np.ndarray,
) -> go.Figure:
    mean_attention = attention.mean(axis=0)
    mean_attention_map = _create_attention_map(mean_attention)
    buttons = [
        {
            "label": "Mean attention",
            "method": "update",
            "args": [
                {"z": [mean_attention]},
                {"title": "Mean attention"},
            ],
        }
    ] + [
        {
            "label": f"Attention head {i}",
            "method": "update",
            "args": [
                {"z": [head_map]},
                {"title": f"Attention head {i}"},
            ],
        }
        for i, head_map in enumerate(attention)
    ]
    layout = go.Layout(
        updatemenus=[
            {
                "buttons": buttons,
                "direction": "down",
                "showactive": True,
                "x": 0.8,
                "xanchor": "left",
                "y": 1.04,
                "yanchor": "top",
            }
        ],
        title="Mean attention",
        height=900,
        width=1000,
        xaxis_title="Sequence",
        yaxis_title="Sequence",
    )
    fig = go.Figure(data=[mean_attention_map], layout=layout)
    return fig


def _create_attention_map(
    head_attention: np.ndarray,
) -> go.Heatmap:
    map = go.Heatmap(
        z=head_attention,
    )
    return map


def show_correct_prediction_tokens(
    min_context_counts: pd.DataFrame,
    attention: np.ndarray,
) -> None:
    assert (
        len(min_context_counts.index.get_level_values("sequence").unique()) == 1
    )
    sequence = str(min_context_counts.index.get_level_values("sequence")[0])
    sequence_attention = attention[:, 0, :]

    figures = []
    context_types = []
    for context_type, context_counts in min_context_counts.groupby(
        "context_type", sort=False
    ):
        if not context_type.endswith("_shuffled"):
            continue
        context_types.append(context_type)
        # display(md(f"**Context type: {context_type}**"))
        context_counts.reset_index(
            ["sequence", "context_type"], drop=True, inplace=True
        )
        context_counts.columns = ["min_context_size"]
        fig = plot_correct_prediction_context_type(
            str(context_type),
            len(sequence),
            context_counts,
            sequence_attention,
        )
        # fig.show()
        figures.append(fig)

    # Step 2: Create a figure with a slider to navigate between the context types
    steps = []
    for i, fig in enumerate(figures):
        step = dict(
            args=[
                {
                    "z": [fig.data[0].z],
                    "title.text": fig.layout.title.text,
                }
            ],
            label=context_types[i],
            method="update",
        )
        steps.append(step)

    sliders = [
        dict(
            active=0,
            yanchor="top",
            xanchor="left",
            currentvalue={
                "font": {"size": 20},
                "prefix": "Context type: ",
                "visible": True,
                "xanchor": "right",
            },
            pad={"b": 10, "t": 50},
            len=0.9,
            x=0.05,
            y=0,
            steps=steps,
        )
    ]

    layout = go.Layout(
        sliders=sliders,
        title="Minimum context sizes for correct prediction",
    )
    figures[0].update_layout(layout)
    figures[0].show()


def plot_correct_prediction_context_type(
    context_type: str,
    sequence_len: int,
    min_context_counts: pd.DataFrame,
    attention: np.ndarray,
) -> go.Figure:
    """Plots a heatmap for different tokens showing which tokens in the
    context are needed to correctly predict respective token.

    Args:
        attention: Attention maps of shape
            (num_layers, num_heads, num_tokens, num_tokens)
        min_context_counts: DataFrame with index levels
            - token_idx: The index of the token in the the sequence
            - sample_idx: for the shuffled context, i.e. where non-context
                tokens in the in the prompt are replaced with random tokens
                we sample multiple replacements. This is the sample index.
            and columns
            - min_context_size: The number of tokens needed to correctly
                predict the token
    Returns:
        A figure with a heatmap showing for each token which positions
        are required, which are not needed as well as highlighting the
        target token in a differnt color.
    """
    if context_type.startswith("attention"):
        layer_idx = int(context_type.split("_")[1])
        # Attention is averaged over all heads
        scores = attention[layer_idx].mean(axis=0)
    else:
        scores = [np.arange(attention.shape[-1])] * attention.shape[-1]

    mean_context_counts = (
        min_context_counts.groupby("token_idx")
        .mean()
        .sort_values(by="token_idx")
    )

    # Create a blank matrix for the heatmap, initialized with zeros
    heatmap_matrix = np.zeros((len(mean_context_counts), sequence_len))

    # Fill heatmap_matrix using values from min_context_counts DataFrame
    for i, (token_idx, row) in enumerate(mean_context_counts.iterrows()):
        # We have to round, because the values are averages for the shuffled
        # context types that use multiple samples
        token_idx = int(token_idx)
        min_context_size = round(row["min_context_size"])
        importance_order = np.argsort(scores[token_idx - 1][:token_idx])[::-1]
        heatmap_matrix[i, importance_order[:min_context_size]] = 1
        # Highlight the target token
        heatmap_matrix[i, token_idx] = 0.5

    # Plot heatmap
    figure = go.Figure(
        data=go.Heatmap(
            z=heatmap_matrix,
            # colorscale="Viridis",
            hoverongaps=False,
        )
    )

    mean_std_context_counts = min_context_counts.aggregate(["mean", "std"])
    count_mean = mean_std_context_counts.loc["mean", "min_context_size"]
    count_std = mean_std_context_counts.loc["std", "min_context_size"]

    figure.update_layout(
        title=(
            f"{context_type}: Min Tokens Required for Correct Prediction "
            f"(mean: {count_mean:.1f}, std: {count_std:.1f})"
        ),
        xaxis_title="Context Token",
        yaxis_title="Target Token",
        width=1100,
        height=min(sequence_len * 8, 1000),
        yaxis=dict(
            tickvals=list(range(heatmap_matrix.shape[0])),
            ticktext=[
                # With fixed length of 3 digits
                f"{token_idx}: {round(token_count):3d}"
                for token_idx, token_count in zip(
                    mean_context_counts.index,
                    mean_context_counts["min_context_size"],
                )
            ],
        ),
    )
    return figure
