from typing import Callable, Iterator, cast

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display

from experiments.memorization_dynamics.results import (
    plot_training_loss,
    produce_accuracy_paper_plot,
    produce_cum_prob_paper_plot,
    produce_entropy_paper_plot,
    produce_kld_paper_plot,
    produce_loss_paper_plot,
    show_dynamics,
)
from lib_project.experiment import ExperimentResult, load_results
from lib_project.notebook import publish_notebook
from lib_project.visualization import with_paper_style
from lib_project.visualization.arrange import arrange_figures_in_grid

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as PMDExperimentResult


Result = ExperimentResult[ExperimentConfig, PMDExperimentResult]


def load(
    config_name: str | list[str],
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        # Result,
        ExperimentConfig,
        PMDExperimentResult,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "dynamics_analysis/batch_size":
        output_path = f"experiments/{EXP_NAME}/{notebook}_djqzkd3.html"
    elif notebook == "dynamics_analysis/context_size":
        output_path = f"experiments/{EXP_NAME}/{notebook}_502dkdk.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = f"./experiments/{EXP_NAME}/{notebook}.ipynb"
    publish_notebook(
        notebook_path,
        output_path,
    )


def convert_to_discrete_training_steps(
    results: dict[str, list[Result]],
) -> dict[str, list[Result]]:
    for config_results in results.values():
        for res in config_results:
            num_epochs = res.config.training.args.num_train_epochs
            memorization_log = res.value.memorization_log
            epoch_values = (
                memorization_log.index.get_level_values("epoch") * num_epochs
            ).astype(int)
            converted_idx = pd.MultiIndex.from_arrays(
                (
                    epoch_values,
                    memorization_log.index.get_level_values("string"),
                    memorization_log.index.get_level_values("character"),
                ),
                names=["epoch", "string", "character"],
            )
            memorization_log.index = converted_idx
    return results


ALPHABET_SIZES = [2, 7, 26]
# ALPHABET_SIZES = [26]


def show_alphabet_pretraining_results(
    load_results: Callable[[str, int, list[int]], dict[str, list[Result]]],
    show_results: Callable[[dict[str, list[Result]]], dict[str, go.Figure]],
    paper_plots: Callable[[dict, str, str], None] | None,
    model: str,
    seed_ids: list[int],
    alphabet_sizes: list[int] = ALPHABET_SIZES,
    pretrained: bool = True,
    untrained: bool = True,
) -> None:
    print("alphabet_sizes", alphabet_sizes)
    for alphabet_size in alphabet_sizes:
        display(md(f"### Alphabet Size $l = {alphabet_size}$"))
        for model_id, pt_type in _get_model_infos(
            model,
            pretrained,
            untrained,
        ):
            display(md(f"#### {pt_type} Model"))
            results = load_results(
                model_id,
                alphabet_size,
                seed_ids,
            )

            if pt_type == "Pretrained":
                # Only show the first 100 epochs, since a few results
                # have more epochs and we want to keep the plot clean
                for _, config_results in results.items():
                    for res in config_results:
                        mem_log = res.value.memorization_log
                        epochs = mem_log.index.get_level_values("epoch")
                        res.value.memorization_log = mem_log.loc[epochs <= 100]

                        training_log = res.value.training_log
                        training_steps = training_log["step"]
                        res.value.training_log = training_log.loc[
                            training_steps <= 100
                        ]

            dynamics_figures = show_results(results)
            display(md("Performance on wikitext test set"))
            rw_eval_figurs = show_testset_performance(results)

            if paper_plots is not None:
                paper_plots(
                    dynamics_figures | rw_eval_figurs,
                    f"a-{alphabet_size}",
                    model_id,
                )


def _get_model_infos(
    model: str,
    pretrained: bool,
    untrained: bool,
) -> Iterator[tuple[str, str]]:
    if pretrained:
        yield model, "Pretrained"
    if untrained:
        yield f"{model}_u", "Untrained"


def show_testset_performance(
    results: dict[str, list[Result]],
) -> dict[str, go.Figure]:
    figures = {}
    eval_results = {
        config_name: [
            (
                cast(
                    pd.DataFrame,
                    config_res.value.training_log[["eval_loss", "step"]],
                )
                .rename(
                    columns={
                        "eval_loss": "loss",
                        "step": "epoch",
                    }
                )
                .set_index("epoch")
                .dropna()
            )
            for config_res in config_results
        ]
        for config_name, config_results in results.items()
    }

    loss_fig = plot_training_loss(
        eval_results,
        "Eval Loss",
    )
    loss_fig.update_layout(
        title="Eval Loss",
    )
    figures["rw_test_loss"] = loss_fig

    # figures["perplexity"] = plot_perplexity(eval_results)

    combined_fig = arrange_figures_in_grid(
        figures,
        n_cols=2,
        size=(1000, 450),
    )
    combined_fig.show()

    return figures


def plot_perplexity(
    results: dict[str, list[pd.DataFrame]],
) -> go.Figure:
    perplexity_results = {
        config_name: [
            (config_res.assign(loss=np.exp(config_res["loss"])))
            for config_res in config_results
        ]
        for config_name, config_results in results.items()
    }

    perplexity_fig = plot_training_loss(
        perplexity_results,
        "Perplexity",
    )
    perplexity_fig.update_layout(
        title="Perplexity",
    )
    return perplexity_fig


def produce_paper_plots(
    figures: dict,
    model: str,
    variation_dimension: str,
    figure_folder: str,
    show_legend: bool = True,
):
    produce_accuracy_paper_plot(
        figures,
        model,
        variation_dimension=variation_dimension,
        figure_folder=figure_folder,
        show_legend=show_legend,
    )
    # produce_loss_paper_plot(
    # figures,
    # model,
    # variation_dimension=variation_dimension,
    # figure_folder=figure_folder,
    # show_legend=show_legend,
    # )
    # produce_cum_prob_paper_plot(
    # figures,
    # model,
    # variation_dimension=variation_dimension,
    # figure_folder=figure_folder,
    # show_legend=show_legend,
    # )
    # produce_entropy_paper_plot(
    #     figures,
    #     model,
    #     variation_dimension=variation_dimension,
    #     figure_folder=figure_folder,
    #     show_legend=show_legend,
    # )
    produce_rw_test_loss_paper_plot(
        figures,
        model,
        variation_dimension=variation_dimension,
        figure_folder=figure_folder,
        show_legend=show_legend,
    )


def produce_rw_test_loss_paper_plot(
    figures: dict[str, go.Figure],
    model: str,
    variation_dimension: str,
    figure_folder: str,
    show_legend: bool = True,
) -> None:
    converted_figures = {
        "training_loss": figures["rw_test_loss"],
    }
    produce_loss_paper_plot(
        converted_figures,
        model,
        f"rw_test_{variation_dimension}",
        figure_folder,
        show_legend,
    )
