import numpy as np
import pandas as pd
import plotly.express as px
from plotly import graph_objects as go
from plotly.subplots import make_subplots

from lib_llm.eval.memorization.dynamics.utils import (
    get_max_epoch,
    get_string_tokens,
    reindex_positionwise,
)


def stable_memorization_order(result: pd.DataFrame) -> pd.Series:
    """
    Compute for each token in the string when it is stably memorized,
    i.e. when it is memorized first without being forgotten again.
    """
    tokens = get_string_tokens(result)
    max_epoch = get_max_epoch(result)
    reindexed_result = reindex_positionwise(result)
    reindex = reindexed_result.index

    stable_epochs = [-1]
    for token_idx in range(1, len(tokens)):
        token_res = reindexed_result.loc[
            reindex.get_level_values("token_index") == token_idx
        ].reset_index(level=["epoch"])
        incorrect_epochs = token_res.loc[token_res["correct"] == 0, "epoch"]
        # print("incorrect_epochs", incorrect_epochs)
        if incorrect_epochs.empty:
            first_stable_epoch = 0
        else:
            largest_incorrect_epoch = int(incorrect_epochs.max())
            first_stable_epoch = (
                largest_incorrect_epoch + 1
                if (largest_incorrect_epoch + 1) <= max_epoch
                else -1
            )
        stable_epochs.append(first_stable_epoch)
    return pd.Series(
        stable_epochs,
        index=tokens,
        name="stable_mem_epoch",
        dtype=int,
    )


def initial_memorization_order(result: pd.DataFrame) -> pd.Series:
    """Compute for each token in the string when it is first memorized."""
    tokens = get_string_tokens(result)
    reindexed_result = reindex_positionwise(result)
    reindex = reindexed_result.index

    initial_epochs = [-1]
    for token_idx in range(1, len(tokens)):
        token_res = result.loc[
            reindex.get_level_values("token_index") == token_idx
        ].reset_index(level=["epoch"])
        correct_epochs = token_res.loc[token_res["correct"] == 1, "epoch"]
        if correct_epochs.empty:
            first_correct_epoch = -1
        else:
            first_correct_epoch = int(correct_epochs.min())
        initial_epochs.append(first_correct_epoch)
    return pd.Series(
        initial_epochs,
        index=tokens,
        name="initial_mem_epoch",
        dtype=int,
    )


def prefix_agreement(
    result: pd.DataFrame,
    prefix_length: int,
) -> pd.DataFrame:
    """
    Compute for each token in the string how many prefixes of length
    `prefix_length` exist that agree with the token (i.e. same prefixes
    followed by the same token), and how many disagree (i.e. same prefix
    followed by different tokens).
    """
    sequence = tuple(get_string_tokens(result))
    prefix_positions = {}
    for i in range(prefix_length, len(sequence)):
        prefix = sequence[i - prefix_length : i]
        prefix_positions.setdefault(prefix, []).append(i)

    prefix_agreement = np.zeros(len(sequence), dtype=int)
    prefix_disagreement = np.zeros(len(sequence), dtype=int)
    for i, token in enumerate(sequence[prefix_length:], prefix_length):
        token_prefix = sequence[i - prefix_length : i]
        token_prefix_positions = prefix_positions[token_prefix]
        agreement = 0
        disagreement = 0
        for j in token_prefix_positions:
            if j == i:
                continue
            if sequence[j] == token:
                agreement += 1
            else:
                disagreement += 1
        prefix_agreement[i] = agreement
        prefix_disagreement[i] = disagreement
    return pd.DataFrame(
        {
            "agreement": prefix_agreement[prefix_length:],
            "disagreement": prefix_disagreement[prefix_length:],
        },
        index=pd.Index(sequence[prefix_length:], name="token"),
    )


def plot_order_agreement_disagreement(
    order: pd.Series,
    agreement: pd.DataFrame,
) -> go.Figure:
    """Plot two side-by-side subplots as pointclouds,
    with memorization order on the y-axis and
    agreement/disagreement on the x-axis.
    """
    # Only keep the tokens for which there is agreement/disagreement
    # data, i.e. after the first prefix_length tokens
    order = order.iloc[-len(agreement) :]

    agreement_correlations = {}
    for agreement_type in ["agreement", "disagreement"]:
        if agreement[agreement_type].unique().size == 1:
            agreement_correlation = np.nan
        else:
            agreement_correlation = np.corrcoef(
                order,
                agreement[agreement_type],
            )[0, 1]
        agreement_correlations[agreement_type] = agreement_correlation

    fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=[
            f"Agreement (corr {agreement_correlations['agreement']:.3f})",
            f"Disagreement (corr {agreement_correlations['disagreement']:.3f})",
        ],
    )

    for i, (agreement_type, color) in enumerate(
        zip(
            ["agreement", "disagreement"],
            ["blue", "orange"],
        )
    ):
        agreement_data = agreement[agreement_type]
        # Violin Plot for Agreement
        # fig.add_trace(
        #     go.Violin(
        #         x=agreement_data,
        #         y=order,
        #         name=agreement_type,
        #         box_visible=True,
        #         line_color=color,
        #     ),
        #     row=1,
        #     col=i + 1,
        # )

        # Box and Beeswarm Plot for Agreement
        fig.add_trace(
            px.box(
                x=agreement_data,
                y=order,
                # points="all",
                points="outliers",
                color_discrete_sequence=[color],
                boxmode="overlay",
                title=agreement_type,
            ).data[0],
            row=1,
            col=i + 1,
        )

        # Add annotation for each violin on the x-axis in Agreement
        unique_agreement_data = agreement_data.unique()
        for x_val in unique_agreement_data:
            total_points = len(agreement[agreement_data == x_val])
            fig.add_annotation(
                text=str(total_points),
                x=x_val,
                y=-1.5,
                showarrow=False,
                font=dict(size=10),
                row=1,
                col=i + 1,
            )

    fig.update_layout(
        yaxis_title="Memorization epoch",
        showlegend=False,
    )
    fig.update_xaxes(
        title_text="No. of same prefixes with the same next token",
        row=1,
        col=1,
    )
    fig.update_xaxes(
        title_text="No. of same prefixes with a different next token",
        row=1,
        col=2,
    )
    return fig
