from typing import cast

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display

from lib_llm.eval.memorization.dynamics.plots import (
    compute_mean_std_dev,
    plot_discrepancy,
    plot_scalar_curves,
    plot_sequence_error_distribution,
    plot_sequence_prob_distribution,
    plot_string_position_cum_prob,
    plot_training_accuracy,
    plot_training_cum_prob,
    plot_training_entropy,
    plot_training_loss,
)
from lib_llm.eval.memorization.dynamics.utils import (
    get_max_epoch,
    reindex_positionwise,
)
from lib_project.experiment import ExperimentResult, load_results
from lib_project.notebook import publish_notebook

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as RSExperimentResult


Result = ExperimentResult[ExperimentConfig, RSExperimentResult]


def load(
    config_name: str | list[str],
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        # Result,
        ExperimentConfig,
        RSExperimentResult,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "entropy":
        output_path = f"experiments/{EXP_NAME}/entropy_93jdw2slfi.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = (
        f"./experiments/memorability/{EXP_NAME}/notebooks/{notebook}.ipynb"
    )
    publish_notebook(
        notebook_path,
        output_path,
    )


def show_results(
    results: dict[str, list[Result]],
    title: str,
) -> None:
    show_std_dev = False

    memorization_results = {
        config_name: [res.value.memorization_log for res in config_results]
        for config_name, config_results in results.items()
    }

    loss_fig = plot_training_loss(
        memorization_results,
        title,
        show_std_dev=show_std_dev,
    )
    loss_fig.show()

    # accuracy_fig = plot_training_accuracy(
    #     memorization_results,
    #     title,
    #     show_std_dev=show_std_dev,
    # )
    # accuracy_fig.show()

    icl_fig = plot_icl_metric(
        memorization_results,
        title,
        show_std_dev=show_std_dev,
    )
    icl_fig.show()


def plot_icl_metric(
    mem_results: dict[str, list[pd.DataFrame]],
    title: str,
    show_std_dev: bool = False,
) -> go.Figure:
    icl_results = {
        config_name: [icl_metric(res) for res in config_results]
        for config_name, config_results in mem_results.items()
    }
    mean_results, std_dev_results = compute_mean_std_dev(
        icl_results,
        preprocess=lambda res: cast(pd.Series, res["difference"]),
        compute_std_dev=show_std_dev,
    )
    return plot_scalar_curves(
        mean_results,
        std_dev_results,
        title="ICL: Loss on first half - Loss on second half of the string",
        xaxis_title="Epoch",
        yaxis_title="Loss Difference",
        legend_title=title,
    )


def icl_metric(
    mem_res: pd.DataFrame,
) -> pd.DataFrame:
    pos_mem_res = reindex_positionwise(mem_res)
    string_positions = pos_mem_res.index.get_level_values("token_index")
    first_half = string_positions <= cast(int, string_positions.max()) / 2
    second_half = string_positions > cast(int, string_positions.max()) / 2
    first_half_loss = mem_res.loc[first_half, "loss"].groupby("epoch").mean()
    second_half_loss = mem_res.loc[second_half, "loss"].groupby("epoch").mean()
    icl_metrics = pd.DataFrame(
        {
            "first_half": first_half_loss,
            "second_half": second_half_loss,
            "difference": first_half_loss - second_half_loss,
        }
    )
    return icl_metrics
