from typing import cast

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import Markdown as md

from defs import BASE_FIGURE_DIR
from experiments.memorization_dynamics.results import Result as MDResult
from lib_llm.eval.memorization.dynamics.plots import (
    plot_training_accuracy,
    plot_training_loss,
)
from lib_llm.eval.memorization.prefix_mappings import (
    plot_epoch_prefix_performance,
    plot_prefix_length_performance,
)
from lib_project.analysis.aggregate import COLOR_SEQUENCE
from lib_project.experiment import ExperimentResult, load_results
from lib_project.notebook import publish_notebook
from lib_project.visualization import with_paper_style
from lib_project.visualization.arrange import arrange_figures_in_grid

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as PMExperimentResult


Result = ExperimentResult[ExperimentConfig, PMExperimentResult]


def load(
    config_name: str,
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        ExperimentConfig,
        PMExperimentResult,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "prefix_length/alphabet_size":
        output_path = f"experiments/{EXP_NAME}/alphabet_size_fj39g402.html"
    if notebook == "prefix_length/untrained_alphabet_size":
        output_path = f"experiments/{EXP_NAME}/alphabet_size_a3h3af8.html"
    elif notebook == "prefix_length/entropy_level":
        output_path = f"experiments/{EXP_NAME}/entropy_level_5902fs.html"
    elif notebook == "prefix_length/size_change":
        output_path = f"experiments/{EXP_NAME}/size_change_5932fops.html"
    elif notebook == "prefix_length/replacement_strategy":
        output_path = f"experiments/{EXP_NAME}/replacement_strategy_29gyws.html"
    elif notebook == "prefix_length/conditional_probabilities":
        output_path = (
            f"experiments/{EXP_NAME}/conditional_probability_592sa.html"
        )

    elif notebook == "prefix_identity/random_strings":
        output_path = f"experiments/{EXP_NAME}/{notebook}_shw24ld0.html"
    elif notebook == "prefix_identity/deterministic_rule_strings":
        output_path = f"experiments/{EXP_NAME}/{notebook}_43jd92j.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = f"./experiments/{EXP_NAME}/{notebook}.ipynb"
    publish_notebook(
        notebook_path,
        output_path,
    )


def show_epoch_prefix_lengths(
    results: dict[str, list[Result]],
) -> dict[str, go.Figure]:
    prefix_lengths = [1, 2, 4, 8, 16, 32, 64, 128, 256]

    figures = {}
    for config_name, config_results in results.items():
        prefix_mappings = [res.value.prefix_mappings for res in config_results]
        fig = plot_epoch_prefix_performance(
            prefix_mappings,
            prefix_lengths,
            show_full_prefix=True,
            show_std_dev=True,
        )
        fig.update_layout(
            title=f"Prefix Lengths: {config_name}",
            yaxis_title="Accuracy",
            xaxis_title="Epoch",
            width=800,
            height=600,
        )
        figures[config_name] = fig

    combined_fig = arrange_figures_in_grid(
        figures,
        n_cols=2,
        size=(900, 1200),
    )
    combined_fig.show()

    return figures


def show_prefix_length_performance(
    results: dict[str, list[Result]],
    legend_title: str,
    epoch: int = -1,
) -> go.Figure:
    prefix_performances = {
        config_name: [res.value.prefix_mappings for res in config_results]
        for config_name, config_results in results.items()
    }
    prefix_performance_fig = plot_prefix_length_performance(
        prefix_performances,
        legend_title,
        epoch=epoch,
    )
    prefix_performance_fig.show()

    # for config_name, prefix_performance in prefix_performances.items():
    #     print("Config:", config_name)
    #     fig = plot_epoch_prefix_performance(
    #         prefix_performance,
    #         prefix_lengths=[1, 2, 4, 8, 16, 32, 64, 128],
    #         show_full_prefix=True,
    #         show_std_dev=True,
    #     )
    #     fig.update_layout(
    #         title=f"Prefix Lengths: {config_name}",
    #         yaxis_title="Accuracy",
    #         xaxis_title="Epoch",
    #         width=800,
    #         height=600,
    #     )
    #     fig.show()

    return prefix_performance_fig


def show_memorization_dynamics(
    results: dict[str, list[MDResult]],
    title: str,
) -> None:
    memorization_results = {
        config_name: [
            config_res.value.memorization_log for config_res in config_results
        ]
        for config_name, config_results in results.items()
    }

    figures = {
        "Training loss": plot_training_loss(memorization_results, title),
        "Training accuracy": plot_training_accuracy(
            memorization_results, title
        ),
    }
    combined_fig = arrange_figures_in_grid(
        figures,
        n_cols=2,
        size=(900, 400),
    )
    combined_fig.show()


FIGURE_FOLDER = "prefix_mappings"


def produce_epoch_prefix_paper_plot(
    figure: go.Figure,
    model: str,
    variation_dimension: str,
    figure_folder: str = FIGURE_FOLDER,
    show_legend: bool = True,
) -> None:
    paper_fig = with_paper_style(
        figure,
        legend_pos=(0.98, 1),
        legend_xanchor="right",
        legend_yanchor="top",
        legend_orientation="v",
    )
    paper_fig.update_yaxes(range=[-0.05, 1.05])
    if not show_legend:
        paper_fig.update_layout(showlegend=False)

    paper_fig.show()
    save_path = (
        BASE_FIGURE_DIR
        / f"{figure_folder}/epochs_{variation_dimension}_{model}.pdf"
    )
    paper_fig.write_image(str(save_path))
    print("Saved figure to", save_path)


def produce_prefix_length_paper_plot(
    figure: go.Figure,
    model: str,
    variation_dimension: str,
    figure_folder: str = FIGURE_FOLDER,
) -> None:
    paper_fig = with_paper_style(
        figure,
        legend_pos=(1, 0),
        legend_xanchor="right",
        legend_yanchor="bottom",
        legend_orientation="h",
    )
    paper_fig.update_layout(
        width=800,
        height=600,
    )
    paper_fig.update_yaxes(range=[-0.05, 1.05])

    paper_fig.show()
    save_path = (
        BASE_FIGURE_DIR / f"{figure_folder}/{variation_dimension}_{model}.pdf"
    )
    paper_fig.write_image(str(save_path))
    print("Saved figure to", save_path)
