from typing import cast

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display

from lib_llm.eval.memorization.dynamics.plots import (
    plot_string_position_cum_prob,
    plot_training_accuracy,
    plot_training_cum_prob,
    plot_training_entropy,
    plot_training_loss,
)
from lib_project.analysis.aggregate import add_mean_std_dev_trace
from lib_project.experiment import ExperimentResult, load_results
from lib_project.notebook import publish_notebook
from utils import memorization_order as mem_order

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as RSExperimentResult


Result = ExperimentResult[ExperimentConfig, RSExperimentResult]


def load(
    config_name: str | list[str],
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        # Result,
        ExperimentConfig,
        RSExperimentResult,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "num_repetitions":
        output_path = f"experiments/{EXP_NAME}/num_repetitions_4923kl2s.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = (
        f"./experiments/memorability/{EXP_NAME}/notebooks/{notebook}.ipynb"
    )
    publish_notebook(
        notebook_path,
        output_path,
    )


def show_results(
    results: dict[int, list[Result]],
    title: str,
) -> None:
    show_std_dev = False

    memorization_results = {
        str(config_name): [res.value.memorization_log for res in config_results]
        for config_name, config_results in results.items()
    }

    loss_fig = plot_training_loss(
        memorization_results,
        title,
        show_std_dev=show_std_dev,
    )
    loss_fig.show()

    # accuracy_fig = plot_training_accuracy(
    #     memorization_results,
    #     title,
    #     show_std_dev=show_std_dev,
    # )
    # accuracy_fig.show()

    epoch_fig = plot_memorization_epochs(
        results,
        title,
        show_std_dev=show_std_dev,
    )
    epoch_fig.show()


def plot_memorization_epochs(
    results: dict[int, list[Result]],
    title: str,
    show_std_dev: bool = False,
) -> go.Figure:
    means, std_devs = {}, {}
    index_values = []
    for config_name, config_results in results.items():
        seed_results = {}
        for res in config_results:
            for col_name, col_mean in compute_memorization_epochs(res).items():
                seed_results.setdefault(col_name, []).append(col_mean)
        for col_name, col_values in seed_results.items():
            col_mean = np.mean(col_values)
            col_std_dev = np.std(col_values)
            means.setdefault(col_name, []).append(col_mean)
            std_devs.setdefault(col_name, []).append(col_std_dev)
        index_values.append(config_name)
    means = pd.DataFrame(means, index=index_values)
    std_devs = pd.DataFrame(std_devs, index=index_values)

    fig = go.Figure()
    for i, (mem_type, mean) in enumerate(means.items()):
        std_dev = std_devs[mem_type] if show_std_dev else None
        add_mean_std_dev_trace(
            fig,
            i,
            mean,
            std_dev,
            name=cast(str, mem_type),
        )
    fig.update_layout(
        title="Memorization Epochs",
        xaxis_title="Num Clean Repetitions (n)",
        yaxis_title="Memorization Epoch",
        legend_title=title,
    )
    return fig


def compute_memorization_epochs(
    result: Result,
) -> dict[str, float]:
    memorization_log = result.value.memorization_log
    num_single_repetition_tokens = result.config.data.num_tokens
    conflict_positions = [
        token.positions[0] for token in result.value.conflicting_tokens
    ]

    memorization_results = {}
    for memorization_type, compute_epoch in zip(
        ["initial", "stable"],
        [
            mem_order.initial_memorization_order,
            mem_order.stable_memorization_order,
        ],
    ):
        mem_epochs = compute_epoch(memorization_log)

        memorization_results[
            f"overall ({memorization_type} epoch)"
        ] = mem_epochs.mean()
        memorization_results[
            f"First copy ({memorization_type} epoch)"
        ] = mem_epochs.iloc[:num_single_repetition_tokens].mean()
        memorization_results[
            f"Conflicting tokens ({memorization_type} epoch)"
        ] = mem_epochs.iloc[conflict_positions].mean()
    return memorization_results
