from dataclasses import replace

import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display

from lib_llm.eval.memorization.dynamics.plots import (
    plot_discrepancy,
    plot_sequence_error_distribution,
    plot_sequence_prob_distribution,
    plot_string_position_cum_prob,
    plot_training_accuracy,
    plot_training_cum_prob,
    plot_training_entropy,
    plot_training_loss,
)
from lib_project.experiment import ExperimentResult, load_results
from lib_project.notebook import publish_notebook

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as MDExperimentResult


Result = ExperimentResult[ExperimentConfig, MDExperimentResult]


def load(
    config_name: str | list[str],
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        # Result,
        ExperimentConfig,
        MDExperimentResult,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "icl_at_checkpoints":
        output_path = f"experiments/{EXP_NAME}/icl_at_checkpoints_5j3chfu3.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = f"./experiments/{EXP_NAME}/notebooks/{notebook}.ipynb"
    publish_notebook(
        notebook_path,
        output_path,
    )


def split_into_string_results(
    result: Result,
) -> dict[str, list[Result]]:
    memorization_log = result.value.memorization_log
    string_level = memorization_log.index.get_level_values("string")
    return {
        "7 char training string": [
            replace(
                result,
                value=replace(
                    result.value,
                    memorization_log=memorization_log.loc[string_level == 0],
                ),
            ),
        ],
        "7 char alternative string": [
            replace(
                result,
                value=replace(
                    result.value,
                    memorization_log=memorization_log.loc[string_level == 1],
                ),
            ),
        ],
        "2 char string": [
            replace(
                result,
                value=replace(
                    result.value,
                    memorization_log=memorization_log.loc[string_level == 2],
                ),
            ),
        ],
        "26 char string": [
            replace(
                result,
                value=replace(
                    result.value,
                    memorization_log=memorization_log.loc[string_level == 3],
                ),
            ),
        ],
    }


def show_results_overview(
    results: dict[str, list[Result]],
    title: str,
    show_loss: bool = True,
    show_accuracy: bool = True,
    show_cum_prob: bool = True,
    show_entropy: bool = True,
    show_discrepancy: bool = True,
    show_in_context_learning: bool = True,
    icl_epoch: int = 0,
) -> dict[str, go.Figure]:
    figures = {}

    memorization_results = {
        config_name: [
            config_res.value.memorization_log for config_res in config_results
        ]
        for config_name, config_results in results.items()
    }

    if show_loss:
        loss_fig = plot_training_loss(memorization_results, title)
        loss_fig.show()
        figures["training_loss"] = loss_fig

    if show_accuracy:
        accuracy_fig = plot_training_accuracy(memorization_results, title)
        accuracy_fig.show()
        figures["training_accuracy"] = accuracy_fig

    if show_cum_prob:
        cum_prob_fig = plot_training_cum_prob(memorization_results, title)
        cum_prob_fig.show()
        figures["training_cum_prob"] = cum_prob_fig

    if show_entropy:
        entropy_fig = plot_training_entropy(memorization_results, title)
        entropy_fig.show()
        figures["training_entropy"] = entropy_fig

    if show_discrepancy:
        discrepancy_fig = plot_discrepancy(
            memorization_results, title, seed=4019
        )
        discrepancy_fig.show()
        figures["discrepancy"] = discrepancy_fig

    if show_in_context_learning:
        in_context_learning_fig = plot_string_position_cum_prob(
            memorization_results,
            title,
            epoch=icl_epoch,
        )
        in_context_learning_fig.show()
        figures["in_context_learning"] = in_context_learning_fig

    return figures
