import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display

from lib_llm.eval.memorization.dynamics.plots import (
    plot_discrepancy,
    plot_sequence_error_distribution,
    plot_sequence_prob_distribution,
    plot_string_position_cum_prob,
    plot_training_accuracy,
    plot_training_cum_prob,
    plot_training_entropy,
    plot_training_loss,
)
from lib_project.experiment import ExperimentResult, load_results
from lib_project.notebook import publish_notebook

from ..substrings import KeepSpec, filter_for_substrings
from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as RSExperimentResult


Result = ExperimentResult[ExperimentConfig, RSExperimentResult]


def load(
    config_name: str | list[str],
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        # Result,
        ExperimentConfig,
        RSExperimentResult,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "entropy":
        output_path = f"experiments/{EXP_NAME}/entropy_93jdw2slfi.html"
    elif notebook == "placement_order":
        output_path = f"experiments/{EXP_NAME}/placement_order_vkj20ks3.html"
    elif notebook == "substring_length":
        output_path = f"experiments/{EXP_NAME}/substring_length_5932xjdfa.html"
    elif notebook == "num_unique_substrings":
        output_path = (
            f"experiments/{EXP_NAME}/num_unique_substrings_993kjds.html"
        )
    elif notebook == "frac_independent_tokens":
        output_path = (
            f"experiments/{EXP_NAME}/frac_independent_tokens_asoghoa3.html"
        )
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = (
        f"./experiments/memorability/{EXP_NAME}/notebooks/{notebook}.ipynb"
    )
    publish_notebook(
        notebook_path,
        output_path,
    )


def show_results(
    results: dict[str, list[Result]],
    title: str,
    show_remaining: bool = False,
) -> None:
    show_std_dev = False
    results = _add_num_string_occurrences(results)

    _show_substring_data(results)

    full_mem_results = _get_memorization_results(results, keep="all")
    display(md("### Full memorization dyanmics"))
    plot_training_loss(
        full_mem_results,
        title,
        show_std_dev=show_std_dev,
    ).show()
    plot_training_accuracy(
        full_mem_results,
        title,
        show_std_dev=show_std_dev,
    ).show()

    first_occurrence_mem_results = _get_memorization_results(
        results,
        keep="first",
    )
    display(md("### First occurrence memorization dyanmics"))
    plot_training_loss(
        first_occurrence_mem_results,
        title,
        show_std_dev=show_std_dev,
    ).show()
    plot_training_accuracy(
        first_occurrence_mem_results,
        title,
        show_std_dev=show_std_dev,
    ).show()

    if show_remaining:
        remaining_occurrence_mem_results = _get_memorization_results(
            results,
            keep="remaining",
        )
        display(md("### Remaining occurrence memorization dyanmics"))
        plot_training_loss(
            remaining_occurrence_mem_results,
            title,
            show_std_dev=show_std_dev,
        ).show()
        plot_training_accuracy(
            remaining_occurrence_mem_results,
            title,
            show_std_dev=show_std_dev,
        ).show()


def _add_num_string_occurrences(
    results: dict[str, list[Result]],
) -> dict[str, list[Result]]:
    num_string_occurrences = {
        config_name: [
            len(config_res.value.substrings[0].positions)
            for config_res in config_results
        ]
        for config_name, config_results in results.items()
    }
    return {
        f"{config_name} ({num_string_occurrences[config_name][0]})": config_res
        for config_name, config_res in results.items()
    }


def _get_memorization_results(
    results: dict[str, list[Result]],
    keep: KeepSpec,
) -> dict[str, list[pd.DataFrame]]:
    mem_results = {
        config_name: [
            filter_for_substrings(
                config_res.value.memorization_log,
                substrings=config_res.value.substrings,
                keep=keep,
            )
            for config_res in config_results
        ]
        for config_name, config_results in results.items()
    }
    return mem_results


def _show_substring_data(
    results: dict[str, list[Result]],
) -> None:
    display(md("### Substrings"))
    config_substrings = {
        config_name: [
            "".join(substr.tokens)
            for substr in config_results[0].value.substrings
        ]
        for config_name, config_results in results.items()
    }

    TOKEANONYMOUS_PER_ROW = 64
    config_rows = {}
    for config_name, substrings in config_substrings.items():
        config_strings = []
        row_length = 0
        row_strings = []
        for substr in substrings:
            row_strings.append(substr)
            row_length += len(substr)
            if row_length >= TOKEANONYMOUS_PER_ROW:
                config_strings.append(",   ".join(row_strings))
                row_length = 0
                row_strings = []
        if row_strings:
            config_strings.append(",   ".join(row_strings))
        config_rows[config_name] = config_strings
    substring_string = "\n".join(
        f"- {config_name}:\n" + "\n".join(f"  - {row}" for row in rows)
        for config_name, rows in config_rows.items()
    )
    display(md(substring_string))
