import math
from typing import Callable, Iterator, cast

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display

from experiments.memorization_dynamics.results import (
    plot_training_accuracy,
    plot_training_loss,
    produce_accuracy_paper_plot,
    produce_cum_prob_paper_plot,
    produce_entropy_paper_plot,
    produce_kld_paper_plot,
    produce_loss_paper_plot,
    show_dynamics,
)
from lib_project.experiment import ExperimentResult, load_results
from lib_project.notebook import publish_notebook
from lib_project.visualization import with_paper_style
from lib_project.visualization.arrange import arrange_figures_in_grid

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as PMDExperimentResult


Result = ExperimentResult[ExperimentConfig, PMDExperimentResult]


def load(
    config_name: str | list[str],
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        # Result,
        ExperimentConfig,
        PMDExperimentResult,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "alphabet_size":
        output_path = f"experiments/{EXP_NAME}/{notebook}_592sflie.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = (
        f"./experiments/{EXP_NAME}/dynamics_analysis/{notebook}.ipynb"
    )
    publish_notebook(
        notebook_path,
        output_path,
    )


def convert_to_discrete_training_steps(
    results: dict[str, list[Result]],
) -> dict[str, list[Result]]:
    for config_results in results.values():
        for res in config_results:
            num_epochs = res.config.training.args.num_train_epochs
            memorization_log = res.value.memorization_log
            epoch_values = (
                memorization_log.index.get_level_values("epoch") * num_epochs
            ).astype(int)
            converted_idx = pd.MultiIndex.from_arrays(
                (
                    epoch_values,
                    memorization_log.index.get_level_values("string"),
                ),
                names=[
                    "epoch",
                    "string",
                ],
            )
            memorization_log.index = converted_idx
    return results


def show_name_dynamics(
    results: dict[str, list[Result]],
    title: str,
) -> None:
    figures = {}
    memorization_results = {
        config_name: [
            (
                config_res
                if isinstance(config_res, pd.DataFrame)
                else config_res.value.memorization_log
            )
            for config_res in config_results
        ]
        for config_name, config_results in results.items()
    }
    data_configs = {
        config_name: config_results[0].config.data
        for config_name, config_results in results.items()
    }

    guess_baselines = {
        config_name: data_config.guess_ce_loss
        for config_name, data_config in data_configs.items()
    }
    loss_fig = plot_training_loss(
        memorization_results,
        title,
        guess_baselines=guess_baselines,
        target_column="name_loss",
    )
    figures["name_loss"] = loss_fig

    guess_baselines = {
        config_name: data_config.guess_accuracy
        for config_name, data_config in data_configs.items()
    }
    accuracy_fig = plot_training_accuracy(
        memorization_results,
        title,
        guess_baselines=guess_baselines,
        target_column="name_correct",
    )
    figures["name_accuracy"] = accuracy_fig

    height = int(math.ceil(len(figures) / 2)) * 450
    combined_fig = arrange_figures_in_grid(
        figures,
        n_cols=2,
        size=(1000, height),
    )
    combined_fig.show()


def produce_paper_plots(
    figures: dict,
    model: str,
    variation_dimension: str,
    figure_folder: str,
    show_legend: bool = True,
):
    produce_accuracy_paper_plot(
        figures,
        model,
        variation_dimension=variation_dimension,
        figure_folder=figure_folder,
        show_legend=show_legend,
    )
    produce_loss_paper_plot(
        figures,
        model,
        variation_dimension=variation_dimension,
        figure_folder=figure_folder,
        show_legend=show_legend,
    )
    produce_cum_prob_paper_plot(
        figures,
        model,
        variation_dimension=variation_dimension,
        figure_folder=figure_folder,
        show_legend=show_legend,
    )
    produce_entropy_paper_plot(
        figures,
        model,
        variation_dimension=variation_dimension,
        figure_folder=figure_folder,
        show_legend=show_legend,
    )
    produce_rw_test_loss_paper_plot(
        figures,
        model,
        variation_dimension=variation_dimension,
        figure_folder=figure_folder,
        show_legend=show_legend,
    )


def produce_rw_test_loss_paper_plot(
    figures: dict[str, go.Figure],
    model: str,
    variation_dimension: str,
    figure_folder: str,
    show_legend: bool = True,
) -> None:
    converted_figures = {
        "training_loss": figures["rw_test_loss"],
    }
    produce_loss_paper_plot(
        converted_figures,
        model,
        f"rw_test_{variation_dimension}",
        figure_folder,
        show_legend,
    )
