from typing import Any, Callable, Iterable, cast

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display

from lib_dl.analysis.aggregate import aggregate_mean_std
from lib_dl.analysis.experiment import ExperimentResult, load_results
from lib_dl.analysis.publish.notebook import publish_notebook
from utils.results.plotting import add_std_dev_trace

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as RCSExperimentResult


Result = ExperimentResult[ExperimentConfig, RCSExperimentResult]


def load(
    config_name: str,
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        Result,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    # if notebook == "eval":
    #     output_path = f"experiments/{EXP_NAME}/eval_4939dsk2.html"
    #     notebook_path = f"./experiments/{EXP_NAME}/notebooks/eval.ipynb"
    if notebook == "string_length":
        output_path = f"experiments/{EXP_NAME}/string_length_f942goi3.html"
    elif notebook == "alphabet_size":
        output_path = f"experiments/{EXP_NAME}/alphabet_size_3jd031hi.html"
    elif notebook == "size_change":
        output_path = f"experiments/{EXP_NAME}/size_change_k90a321.html"
    elif notebook == "training_stages":
        output_path = f"experiments/{EXP_NAME}/training_stages_fslk3g93s.html"
    elif notebook == "replacement_strategy":
        output_path = f"experiments/{EXP_NAME}/replacement_strategy_39fx32.html"
    elif notebook == "model_type":
        output_path = f"experiments/{EXP_NAME}/model_type_290fhq2e.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = f"./experiments/{EXP_NAME}/notebooks/{notebook}.ipynb"
    publish_notebook(
        notebook_path,
        output_path,
    )


def plot_overview(
    results: dict[str, list[Result]],
    legend_title: str,
    prefix_display_limit: int = 26,
) -> go.Figure:
    fig = go.Figure()
    for i, (config_name, config_results) in enumerate(results.items()):
        # res = config_results[0].value
        # performances: pd.DataFrame = res.prefix_performance
        # performances = add_plurality_correctness(performances, res.sequence)
        performances = [
            add_plurality_correctness(
                res.value.prefix_performance, res.value.sequence
            )
            for res in config_results
        ]
        plot_prefix_length_correctness(
            performances,
            i,
            f"{config_name}",
            threshold=None,
            figure=fig,
            prefix_display_limit=prefix_display_limit,
        )
    fig.update_layout(
        title="Correct predictions per prefix length",
        xaxis_title="Prefix Length",
        yaxis_title="Accuracy",
        legend_title=legend_title,
        width=900,
        height=700,
    )
    return fig


def show_results(
    results: list[Result],
    show_breakdown: bool = False,
    prefix_display_limit: int = 26,
) -> None:
    assert len(results) == 1
    res = results[0].value
    performances: pd.DataFrame = res.prefix_performance
    performances = add_plurality_correctness(performances, res.sequence)

    print("Memorized string:", res.sequence)

    num_correct = 0
    for token_idx, token_performance in performances.groupby("token_idx"):
        num_correct += int(
            token_performance.loc[(token_idx, token_idx), "correct_samples"]
            >= 1
        )
    frac_full_prefix_correct = num_correct / len(
        performances.index.unique("token_idx")
    )
    print(f"Full prefix correct: {frac_full_prefix_correct:.2%}")

    plurality_as_correct_performance = performances.copy()
    plurality_as_correct_performance["correct_samples"] = performances[
        "plurality_correct"
    ].astype(int)
    show_prefix_performance(
        {
            "Plurality Prediction": plurality_as_correct_performance,
            "Sample Distribution": performances,
        },
        sequence_length=len(res.sequence),
        title="Accuracy per prefix",
    )

    fig = go.Figure()
    for i, threshold in enumerate([None, 1.0, 0.8, 0.6, 0.4, 0.2]):
        if threshold is None:
            name = "Perturbed (Plurality)"
        else:
            name = f"Perturbed (Threshold={threshold})"
        plot_prefix_length_correctness(
            [performances],
            i,
            name,
            threshold=threshold,
            figure=fig,
            prefix_display_limit=prefix_display_limit,
        )
    fig.update_layout(
        title="Correct predictions per prefix length",
        xaxis_title="Prefix Length",
        yaxis_title="Accuracy",
        # width=1000,
        height=600,
    )
    fig.show()

    if show_breakdown:
        plurality_breakdown_fig = plot_plurality_prediction_breakdown(
            performances,
            res.sequence,
        )
        plurality_breakdown_fig.show()


def add_plurality_correctness(
    performances: pd.DataFrame,
    string: str,
) -> pd.DataFrame:
    plurality_correct = [
        string[token_idx] == token_prefix_res["top_1_token"]
        for (token_idx, _), token_prefix_res in performances.iterrows()
    ]
    performances["plurality_correct"] = pd.Series(
        plurality_correct,
        index=performances.index,
    )
    return performances


def show_prefix_performance(
    performance: dict[str, pd.DataFrame],
    sequence_length: int,
    title: str,
) -> None:
    figures = []
    for prefix_performance in performance.values():
        fig = plot_performance_heatmap(prefix_performance, sequence_length)
        figures.append(fig)

    # Step 2: Create a figure with a dropdown to navigate between
    # the context types
    buttons = []
    for prefix_type, fig in zip(performance.keys(), figures):
        button = dict(
            args=[
                {
                    "z": [fig.data[0].z],
                    "title.text": fig.layout.title.text,
                }
            ],
            label=prefix_type,
            method="update",
        )
        buttons.append(button)

    updatemenus = [
        dict(
            active=0,
            buttons=buttons,
            direction="up",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0,
            xanchor="left",
            y=-0.02,
            yanchor="top",
        )
    ]

    layout = go.Layout(
        updatemenus=updatemenus,
        title=title,
    )
    figures[0].update_layout(layout)
    figures[0].show()


def plot_performance_heatmap(
    prefix_performance: pd.DataFrame, sequence_length: int
) -> go.Figure:
    """Plots a heatmap for different tokens showing the performance,
    i.e. the number of correct samples that can be achieved with each token.

    Args:
        prefix_performance: DataFrame with index levels
            - token_idx: The index of the token in the the sequence
            - prefix_length: the number of context tokens retained
            and columns
            - correct_samples: The fraction of tokens correctly predicted
    Returns:
        A figure with a heatmap showing for each token which positions
        are correctly predicted.
    """
    prefix_performance.index.names = ["token_idx", "prefix_length"]
    # Create a blank matrix for the heatmap, initialized with zeros
    # heatmap_matrix = -1 * np.ones((num_tokens, num_tokens))
    heatmap_matrix = -1 * np.ones((sequence_length, sequence_length))

    for token_idx, token_performance in prefix_performance.groupby("token_idx"):
        # We have to round, because the values are averages for the shuffled
        # context types that use multiple samples
        token_idx = cast(int, token_idx)
        # Highlight the target token
        heatmap_matrix[token_idx, token_idx] = 0.5

        for prefix_length, performance in token_performance.groupby(
            "prefix_length"
        ):
            prefix_length = cast(int, prefix_length)
            prefix_start = token_idx - prefix_length
            heatmap_matrix[token_idx, prefix_start] = performance[
                "correct_samples"
            ].iloc[0]

    # Plot heatmap
    figure = go.Figure(
        data=go.Heatmap(
            z=heatmap_matrix,
            hoverongaps=False,
        )
    )

    figure.update_layout(
        # title="Prefix length performance",
        xaxis_title="Context Token",
        yaxis_title="Target Token",
        # width=1100,
        height=1000,
        # yaxis=dict(
        #     tickvals=list(range(heatmap_matrix.shape[0])),
        #     ticktext=[
        #         # With fixed length of 3 digits
        #         f"{token_idx}: {round(token_count):3d}"
        #         for token_idx, token_count in zip(
        #             mean_context_counts.index,
        #             mean_context_counts["min_context_size"],
        #         )
        #     ],
        # ),
    )
    return figure


def plot_prefix_length_correctness(
    prefix_performances: list[pd.DataFrame],
    index: int,
    name: str,
    # None means use the plurality prediction
    threshold: float | None,
    figure: go.Figure = go.Figure(),
    prefix_display_limit: int = 30,
) -> go.Figure:
    """Show for each prefix length what fraction of token positions can be
    correctly predicted"""

    sequence_length = len(
        prefix_performances[0].index.get_level_values("token_idx").unique()
    )
    assert all(
        len(perf.index.get_level_values("token_idx").unique())
        == sequence_length
        for perf in prefix_performances
    )

    max_display_length = min(sequence_length, prefix_display_limit)
    interpolted_performances = [
        _interpolate_performance(prefix_performance, max_display_length)
        for prefix_performance in prefix_performances
    ]
    length_correctness = [
        _compute_length_correctness(interpolated_performance, threshold)
        for interpolated_performance in interpolted_performances
    ]
    performance_means, performance_stds = aggregate_mean_std(
        length_correctness,
        levels_to_preserve=["prefix_length"],
    )
    add_std_dev_trace(
        figure,
        index,
        performance_means,
        performance_stds,
        name=name,
    )
    return figure


def _interpolate_performance(
    prefix_performance: pd.DataFrame,
    max_display_length: int,
) -> pd.DataFrame:
    token_indices = prefix_performance.index.get_level_values(
        "token_idx"
    ).unique()
    interpolated_idx = pd.MultiIndex.from_product(
        [token_indices, range(1, max_display_length)],
        names=["token_idx", "prefix_length"],
    )
    column_dtypes = {
        "correct_samples": float,
        "plurality_correct": bool,
    }
    interpolated_performance = (
        pd.DataFrame(
            index=interpolated_idx,
            columns=list(column_dtypes.keys()),
        )
        .assign(**{col: np.nan for col in column_dtypes.keys()})
        .astype(column_dtypes)
    )
    sub_max_performance = prefix_performance.loc[
        prefix_performance.index.get_level_values("prefix_length")
        < max_display_length
    ]
    interpolated_performance.update(sub_max_performance)

    for token_idx, token_performance in sub_max_performance.groupby(
        "token_idx"
    ):
        token_idx = cast(int, token_idx)
        for prefix_length, performance in token_performance.groupby(
            "prefix_length"
        ):
            prefix_length = cast(int, prefix_length)
            correct_samples_value = performance["correct_samples"].iloc[0]
            plurality_correct_value = performance["plurality_correct"].iloc[0]
            while prefix_length < max_display_length - 1:
                prefix_length += 1
                if not (token_idx, prefix_length) in prefix_performance.index:
                    interpolated_performance.loc[
                        (token_idx, prefix_length), "correct_samples"
                    ] = correct_samples_value
                    interpolated_performance.loc[
                        (token_idx, prefix_length), "plurality_correct"
                    ] = plurality_correct_value
                else:
                    break
    assert interpolated_performance.isna().sum().sum() == 0
    return interpolated_performance


def _compute_length_correctness(
    interpolated_performance: pd.DataFrame,
    threshold: float | None,
) -> pd.DataFrame:
    if threshold is None:
        correct_prefixes = interpolated_performance.loc[
            interpolated_performance["plurality_correct"]
        ]
    else:
        correct_prefixes = interpolated_performance.loc[
            interpolated_performance["correct_samples"] >= threshold
        ]

    correct_prefixes = correct_prefixes.groupby("prefix_length").size()
    # Compute the fraction of correct predictions for each prefix length
    correct_prefixes = (
        correct_prefixes
        / interpolated_performance.groupby("prefix_length").size()
    ).dropna()
    return correct_prefixes


def plot_plurality_prediction_breakdown(
    prefix_performance: pd.DataFrame,
    sequence: str,
) -> go.Figure:
    token_indices = []
    full_prefix_pred_values = []
    full_prefix_correct_values = []
    for token_idx, token_prefix_res in prefix_performance.groupby("token_idx"):
        token_indices.append(token_idx)
        full_prefix_pred = token_prefix_res.loc[
            (token_idx, token_idx), "top_1_token"
        ]
        full_prefix_pred_values.append(full_prefix_pred)
        full_prefix_correct = full_prefix_pred == sequence[int(token_idx)]
        full_prefix_correct_values.append(full_prefix_correct)
    full_prefix_results = pd.DataFrame(
        {
            "pred": full_prefix_pred_values,
            "correct": full_prefix_correct_values,
        },
        index=token_indices,
    )
    full_prefix_results.index.name = "token_idx"

    prefix_sizes = []
    plurality_fp_correct = []
    plurality_only_correct = []
    incorrect_plurality_fp_match = []
    for prefix_size, prefix_res in prefix_performance.groupby("prefix_length"):
        prefix_size = int(prefix_size)
        if prefix_size > 20:
            break
        prefix_sizes.append(prefix_size)

        num_prefix_length_results = len(prefix_res)
        prefix_res.index = prefix_res.index.droplevel("prefix_length")
        plurality_correct = prefix_res["plurality_correct"]
        plurality_fp_correct.append(
            (plurality_correct & full_prefix_results["correct"]).sum()
            / num_prefix_length_results
        )
        plurality_only_correct.append(
            (plurality_correct & ~full_prefix_results["correct"]).sum()
            / num_prefix_length_results
        )

        plurality_preds = prefix_res["top_1_token"]
        incorrect_plurality_fp_match.append(
            (
                ~plurality_correct
                & (
                    plurality_preds
                    == full_prefix_results.loc[prefix_size:, "pred"]
                )
            ).sum()
            / num_prefix_length_results
        )
    prefix_size_performances = pd.DataFrame(
        {
            "Plurality & FP correct": plurality_fp_correct,
            "Plurality only correct": plurality_only_correct,
            "Incorrect, plurality = FP": incorrect_plurality_fp_match,
        },
        index=prefix_sizes,
    )

    fig = px.bar(
        prefix_size_performances,
        # x="prefix_size",
        x=prefix_size_performances.index,
        y=prefix_size_performances.columns,
        # barmode="group",
    )
    fig.update_layout(
        title="Plurality Prediction Breakdown",
        yaxis_title="Fraction of predictions",
        xaxis_title="Prefix size",
        legend_title="Prediction type",
        height=600,
    )
    return fig
