import pandas as pd
import plotly.graph_objects as go
from IPython.display import Markdown as md

from lib_llm.eval.memorization.dynamics.plots import (
    plot_string_position_cum_prob,
)
from lib_project.experiment import ExperimentResult, load_results
from lib_project.notebook import publish_notebook

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as ICEExperimentResult


Result = ExperimentResult[ExperimentConfig, ICEExperimentResult]


def load(
    config_name: str,
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        ExperimentConfig,
        ICEExperimentResult,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "pretraining_steps":
        output_path = f"experiments/{EXP_NAME}/pretraining_steps_9384fj3.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = f"./experiments/{EXP_NAME}/notebooks/{notebook}.ipynb"
    publish_notebook(
        notebook_path,
        output_path,
    )


def show_results(
    results: list[Result],
    title: str,
) -> dict[str, go.Figure]:
    figures = {}

    # print("results:", results)

    sid_0_res = results[0]
    steps = sid_0_res.value.in_context_performance.index.get_level_values(
        "step"
    ).unique()
    step_results = {
        step: [
            # Add an epoch column
            pd.concat(
                [sid_0_res.value.in_context_performance.loc[step]],
                keys=[0],
                names=["epoch"],
            )
        ]
        for step in steps
    }

    in_context_learning_fig = plot_string_position_cum_prob(
        step_results,
        title,
        epoch=0,
    )
    in_context_learning_fig.show()
    figures["in_context_learning"] = in_context_learning_fig

    # Add correctness plot
    return figures
