import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display

from lib_llm.eval.memorization.dynamics.plots import (
    plot_discrepancy,
    plot_sequence_error_distribution,
    plot_sequence_prob_distribution,
    plot_string_position_cum_prob,
    plot_training_accuracy,
    plot_training_cum_prob,
    plot_training_entropy,
    plot_training_loss,
)
from lib_llm.eval.memorization.dynamics.utils import (
    get_string_tokens,
    reindex_positionwise,
)
from lib_project.experiment import ExperimentResult, load_results
from lib_project.notebook import publish_notebook
from utils import memorization_order as mem_order

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as RSExperimentResult


Result = ExperimentResult[ExperimentConfig, RSExperimentResult]


def load(
    config_name: str | list[str],
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        # Result,
        ExperimentConfig,
        RSExperimentResult,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "frequency":
        output_path = f"experiments/{EXP_NAME}/frequency_k32f9fs9s.html"
    elif notebook == "prefix_length":
        output_path = f"experiments/{EXP_NAME}/prefix_length_t2i23fa.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = (
        f"./experiments/memorability/{EXP_NAME}/notebooks/{notebook}.ipynb"
    )
    publish_notebook(
        notebook_path,
        output_path,
    )


def show_results(
    results: list[Result],
    order_type: str,
    order_dimension: str,
) -> None:
    assert len(results) == 1
    result = results[0]
    mem_log = result.value.memorization_log

    display(md(f"### {order_type.capitalize()} memorization epoch"))

    if order_type == "initial":
        order = mem_order.initial_memorization_order(mem_log)
    elif order_type == "stable":
        order = mem_order.stable_memorization_order(mem_log)
    else:
        raise ValueError(f"Unknown order type: {order_type}")
    # display(order)

    token_positions = _get_annotated_token_positions(
        result,
        annotation_type=order_dimension,
    )

    string_order_fig = plot_string_memorization_order(order, token_positions)
    string_order_fig.show()

    display(md("#### Memorization across all positions"))
    token_order_fig = plot_token_memorization_order(
        order,
        token_positions,
    )
    token_order_fig.show()

    # display(md("#### Memorization across initial positions"))
    # initial_positions = {
    #     token: positions[:1] for token, positions in token_positions.items()
    #     if not token.startswith("Other")
    # }
    # initial_order_fig = plot_token_memorization_order(
    #     order,
    #     initial_positions,
    # )
    # initial_order_fig.show()


def plot_string_memorization_order(
    memorization_order: pd.Series,
    token_positions: dict[str, list[int]],
) -> go.Figure:
    df = memorization_order.reset_index().reset_index()
    df.columns = ["position", "character", memorization_order.name]

    df["token_type"] = "Other"
    for label, positions in token_positions.items():
        for pos in positions:
            df.at[pos, "token_type"] = label

    fig = px.scatter(
        df,
        x="position",
        y=memorization_order.name,
        color="token_type",
    )
    return fig


def plot_token_memorization_order(
    memorization_order: pd.Series,
    token_positions: dict[str, list[int]],
) -> go.Figure:
    token_distribution_data = []
    uncovered_positions = set(range(len(memorization_order)))
    for token, positions in token_positions.items():
        mem_epochs = [memorization_order.iloc[pos] for pos in positions]
        token_distribution_data.extend(
            [(token, mem_epoch) for mem_epoch in mem_epochs]
        )
        uncovered_positions = uncovered_positions - set(positions)

    token_distribution_df = pd.DataFrame(
        token_distribution_data, columns=["token", memorization_order.name]
    )

    fig = px.box(token_distribution_df, x="token", y=memorization_order.name)
    return fig


def _get_annotated_token_positions(
    result: Result,
    annotation_type: str,
) -> dict[str, list[int]]:
    """Get a mapping from token to the positions it appears in.
    The mapping is annotated with additional information, e.g. the
    number of positions the token appears in or its prefix length.
    """
    token_positions = {}
    uncovered_positions = set(range(result.config.data.num_tokens))
    for token_prefix_pair in result.value.token_prefix_pairs:
        tokens = [token_prefix_pair.token_1]
        if token_prefix_pair.token_2 is not None:
            tokens.append(token_prefix_pair.token_2)
        for token in tokens:
            annotation = _get_annotation(
                annotation_type,
                num_positions=len(token.positions),
                prefix_length=len(token_prefix_pair.prefix_tokens),
            )
            token_positions[f"{token.tokens[0]}{annotation}"] = token.positions
            uncovered_positions = uncovered_positions - set(token.positions)

    other_token_annotation = _get_annotation(
        annotation_type, len(uncovered_positions), None
    )
    token_positions[f"Other{other_token_annotation}"] = list(
        uncovered_positions
    )

    return token_positions


def _get_annotation(
    annotation_type: str,
    num_positions: int,
    prefix_length: int | None,
) -> str:
    if annotation_type == "frequency":
        return f" ({num_positions})"
    elif annotation_type == "prefix_length":
        if prefix_length is None:
            return ""
        else:
            return f" ({prefix_length})"
    else:
        raise ValueError(f"Unknown additional info type: {annotation_type}")


# def show_dynamics(
#     memorization_results: dict[str, list[pd.DataFrame]],
#     title: str,
#     # only_loss: bool = False,
# ) -> dict[str, go.Figure]:
#     figures = {}

#     loss_fig = plot_training_loss(memorization_results, title)
#     loss_fig.show()
#     figures["training_loss"] = loss_fig
#     # if only_loss:
#     #     return figures

#     accuracy_fig = plot_training_accuracy(memorization_results, title)
#     accuracy_fig.show()
#     figures["training_accuracy"] = accuracy_fig

#     # cum_prob_fig = plot_training_cum_prob(memorization_results, title)
#     # cum_prob_fig.show()
#     # figures["training_cum_prob"] = cum_prob_fig

#     # entropy_fig = plot_training_entropy(memorization_results, title)
#     # entropy_fig.show()
#     # figures["training_entropy"] = entropy_fig

#     return figures
