import string
from typing import Callable, cast

import pandas as pd
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display

from defs import BASE_FIGURE_DIR
from lib_llm.eval.memorization.dynamics.plots import (
    plot_string_position_cum_prob,
    plot_training_accuracy,
    plot_training_cum_prob,
    plot_training_entropy,
    plot_training_loss,
)
from lib_llm.eval.memorization.prefix_mappings import (
    plot_epoch_prefix_performance,
)
from lib_project.experiment import ExperimentResult, load_results
from lib_project.notebook import publish_notebook
from lib_project.visualization import with_paper_style

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as RTExperimentResult


Result = ExperimentResult[ExperimentConfig, RTExperimentResult]


def load(
    config_name: str | list[str],
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        # Result,
        ExperimentConfig,
        RTExperimentResult,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "same_strings":
        output_path = f"experiments/{EXP_NAME}/same_strings_2gb73hs.html"
    elif notebook == "16-32x_same_strings":
        output_path = f"experiments/{EXP_NAME}/16-23x_same_strings_bqo3raf.html"
    if notebook == "untrained_32x":
        output_path = f"experiments/{EXP_NAME}/untrained_32x_3fq3xf.html"
    elif notebook == "long_training":
        output_path = f"experiments/{EXP_NAME}/long_training_dk39f3kj.html"
    elif notebook == "string_lengths":
        output_path = f"experiments/{EXP_NAME}/string_lengths_43fiu23f.html"
    elif notebook == "alphabet_sizes":
        output_path = f"experiments/{EXP_NAME}/dynamics_analysis/alphabet_sizes_592q1dhu4.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = f"./experiments/{EXP_NAME}/notebooks/{notebook}.ipynb"
    publish_notebook(
        notebook_path,
        output_path,
    )


def show_string_lengths(
    results: list[Result],
    token_amounts: list[int],
    title: str = "String Length",
    show_std_dev: bool = True,
    show_cum_prob: bool = False,
    show_entropy: bool = False,
) -> tuple[dict[str, go.Figure], dict[str, go.Figure]]:
    # string_names = [str(length) for length in token_amounts]
    string_names = [str(i + 1) for i in range(len(token_amounts))]

    sequential_figs = show_sequential_results(
        results,
        title,
        string_names,
        show_std_dev=show_std_dev,
        show_cum_prob=show_cum_prob,
        show_entropy=show_entropy,
    )
    parallel_figs = show_parallel_results(
        results,
        title,
        string_names,
        show_std_dev=show_std_dev,
    )
    return sequential_figs, parallel_figs


def show_alphabet_sizes(
    results: list[Result],
    alphabet_sizes: list[int],
    show_std_dev: bool = True,
):
    title = "Alphabet Size"
    string_names = [str(size) for size in alphabet_sizes]
    token_subsets = [
        list(string.ascii_lowercase[:size]) for size in alphabet_sizes
    ]

    _ = show_sequential_results(
        results,
        title,
        string_names,
        cum_prob_token_subsets=token_subsets,
        show_std_dev=show_std_dev,
    )
    _ = show_parallel_results(
        results,
        title,
        string_names,
        show_std_dev=show_std_dev,
    )


FIGURE_FOLDER = "repeated_memorization"
FIGURE_SIZE = (800, 600)


def produce_accuracy_paper_plot(
    figures: dict[str, go.Figure],
    model: str,
    variation_dimension: str,
    pretrained: bool,
    filter_runs: list[int] | None = None,
    figure_folder: str = FIGURE_FOLDER,
    size: tuple[int, int] = FIGURE_SIZE,
) -> None:
    accuracy_fig = figures["training_accuracy"]
    filtered_fig = go.Figure()
    if filter_runs is not None:
        run_filter = set(str(run) for run in filter_runs)
    else:
        run_filter = None
    # Used to add the std-dev traces
    n_prev_added = 0
    for trace in accuracy_fig.data:
        if run_filter is None or trace.name in run_filter:
            n_prev_added = 2
        else:
            n_prev_added -= 1
        if n_prev_added >= 0:
            filtered_fig.add_trace(trace)
    paper_fig = with_paper_style(
        filtered_fig,
        legend_pos=(1, 0),
        legend_yanchor="bottom",
        legend_orientation="h",
    )
    # Too many traces, disable the legend
    paper_fig.update_layout(
        showlegend=filter_runs is not None,
        width=size[0],
        height=size[1],
        xaxis_title="Epoch",
        yaxis_title="Accuracy",
    )
    paper_fig.update_yaxes(range=[-0.05, 1.05])
    paper_fig.show()
    save_path = (
        BASE_FIGURE_DIR
        / figure_folder
        / ("pretrained" if pretrained else "untrained")
        / f"accuracy_{variation_dimension}_{model}.pdf"
    )
    paper_fig.write_image(str(save_path))
    print("Saved figure to", save_path)


def show_sequential_results(
    results: list[Result],
    title: str,
    string_names: list[str] | None = None,
    cum_prob_token_subsets: list[list[str]] | None = None,
    show_std_dev: bool = True,
    show_cum_prob: bool = True,
    show_entropy: bool = True,
) -> dict[str, go.Figure]:
    display(
        md(
            """### Iterative memorization dynamics
We show the memorization dynamics for each string as the model memorizes
them one after the other, as well as how the subsequent memorization of
other strings affects the them.
"""
        )
    )
    memorization_results = reindex_by_string(
        results,
        combination_func=combine_iteration_mem_logs,
        string_names=string_names,
        token_subsets=cum_prob_token_subsets,
    )
    figures = plot_dynamics(
        memorization_results,
        title=title,
        show_std_dev=show_std_dev,
        show_cum_prob=show_cum_prob,
        show_entropy=show_entropy,
    )
    return figures


def show_parallel_results(
    results: list[Result],
    title: str,
    string_names: list[str] | None = None,
    show_std_dev: bool = True,
) -> dict[str, go.Figure]:
    display(
        md(
            """### Initial memorization comparison
We compare the memorization dynamics for each string when the model first
memorizes that string, to compare memorization speeds.
I.e. for the 0th string, we show epoch 0 - 50,
for the 1st string epoch 50 - 100, etc.
"""
        )
    )
    memorization_results = reindex_by_string(
        results,
        combination_func=filter_iteration_mem_logs,
        string_names=string_names,
    )
    return plot_dynamics(
        memorization_results,
        title=title,
        show_cum_prob=False,
        show_entropy=False,
        show_std_dev=show_std_dev,
    )


def show_prefix_performance(
    results: list[Result],
    title: str = "Prefix Performance",
) -> None:
    display(md(f"### {title}"))

    prefix_lengths = [1, 2, 4, 8, 16, 32, 64, 128]
    iteration_prefix_mappings = {}
    for result in results:
        for i, iteration_result in enumerate(result.value.iteration_results):
            if iteration_result.prefix_mappings is None:
                continue
            iteration_prefix_mappings.setdefault(i, []).append(
                iteration_result.prefix_mappings
            )

    for iteration, prefix_mappings in iteration_prefix_mappings.items():
        display(md(f"#### String {iteration + 1}"))
        fig = plot_epoch_prefix_performance(
            prefix_mappings,
            prefix_lengths,
            show_full_prefix=True,
            show_std_dev=False,
        )
        fig.show()


def reindex_by_string(
    results: list[Result],
    combination_func: Callable[[Result], pd.DataFrame],
    string_names: list[str] | None = None,
    token_subsets: list[list[str]] | None = None,
) -> dict[str, list[pd.DataFrame]]:
    # Map from string index to string memorization results
    memorization_results = {}
    for result in results:
        combined_seed_mem_log = combination_func(result)
        for string_idx, string_mem_log in combined_seed_mem_log.groupby(
            level="string"
        ):
            string_idx = cast(int, string_idx)
            if string_names is not None:
                string_name = string_names[string_idx]
                # string_name = string_names[string_idx] + f" ({string_idx + 1})"
            else:
                string_name = str(string_idx)
            if token_subsets is not None:
                column_subset = ["loss", "correct", *token_subsets[string_idx]]
            else:
                column_subset = string_mem_log.columns
            string_logs = memorization_results.get(string_name, [])
            string_logs.append(string_mem_log[column_subset])
            memorization_results[string_name] = string_logs
    return memorization_results


def plot_dynamics(
    memorization_results: dict[str, list[pd.DataFrame]],
    title: str,
    show_cum_prob: bool = True,
    show_entropy: bool = True,
    show_std_dev: bool = True,
) -> dict[str, go.Figure]:
    figures = {}

    loss_fig = plot_training_loss(
        memorization_results,
        title,
        show_std_dev=show_std_dev,
    )
    loss_fig.show()
    figures["training_loss"] = loss_fig

    accuracy_fig = plot_training_accuracy(
        memorization_results,
        title,
        show_std_dev=show_std_dev,
    )
    accuracy_fig.show()
    figures["training_accuracy"] = accuracy_fig

    if show_cum_prob:
        cum_prob_fig = plot_training_cum_prob(
            memorization_results,
            title,
            show_std_dev=show_std_dev,
        )
        cum_prob_fig.show()
        figures["training_cum_prob"] = cum_prob_fig

    if show_entropy:
        entropy_fig = plot_training_entropy(
            memorization_results,
            title,
            show_std_dev=show_std_dev,
        )
        entropy_fig.show()
        figures["training_entropy"] = entropy_fig

    return figures


def combine_iteration_mem_logs(result: Result) -> pd.DataFrame:
    """Sequentially combine the memorization logs from each iteration."""
    mem_results = []
    epoch_offset = 0
    for iteration_res in result.value.iteration_results:
        mem_res = iteration_res.memorization_log
        epoch_values = mem_res.index.get_level_values("epoch")
        remaining_levels = mem_res.index.names[1:]
        incremented_idx = pd.MultiIndex.from_arrays(
            [
                epoch_values + epoch_offset,
                *[
                    mem_res.index.get_level_values(level)
                    for level in remaining_levels
                ],
            ],
            names=["epoch", *remaining_levels],
        )
        mem_res = mem_res.copy()
        mem_res.index = incremented_idx
        mem_results.append(mem_res)

        max_epoch = cast(int, epoch_values.max())
        epoch_offset += max_epoch
    return pd.concat(mem_results)


def filter_iteration_mem_logs(
    result: Result,
) -> pd.DataFrame:
    """Filter the memorization logs from each iteration for the string
    memorized at that iteration."""
    mem_results = []
    for i, iteration_res in enumerate(result.value.iteration_results):
        mem_res = iteration_res.memorization_log
        string_indices = mem_res.index.get_level_values("string")
        filtered_mem_res = mem_res.loc[string_indices == i]
        mem_results.append(filtered_mem_res)
    return pd.concat(mem_results)
