from copy import deepcopy
from typing import cast

import pandas as pd
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display

from defs import BASE_FIGURE_DIR
from lib_llm.eval.memorization.dynamics.plots import (
    compute_mean_std_dev,
    plot_scalar_curves,
    plot_training_accuracy,
)
from lib_llm.eval.memorization.dynamics.utils import reindex_positionwise
from lib_project.experiment import ExperimentResult, load_results
from lib_project.visualization import with_paper_style
from lib_project.visualization.arrange import arrange_figures_in_grid

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as MDExperimentResult


Result = ExperimentResult[ExperimentConfig, MDExperimentResult]


def show_models(
    results: dict[str, list[Result]],
    title: str,
) -> dict[str, go.Figure]:
    marker_threshold = 0.9

    memorization_results = {
        config_name: [
            config_res.value.memorization_log for config_res in config_results
        ]
        for config_name, config_results in results.items()
    }
    # print("memory results", memorization_results)
    figures = {}

    accuracy_fig = plot_training_accuracy(
        memorization_results,
        title,
    )
    accuracy_fig = add_threshold_markers(
        accuracy_fig,
        marker_threshold,
    )
    figures["accuracy"] = accuracy_fig

    recall_fig = plot_consecutive_window_recall_comparison(
        memorization_results,
        title,
    )
    recall_fig = add_threshold_markers(
        recall_fig,
        marker_threshold,
    )
    figures["contiguous_recall"] = recall_fig

    combined_fig = combine_accuracy_recall_figures(
        accuracy_fig,
        recall_fig,
        legend_title=title,
    )
    figures["combined"] = combined_fig

    combined_fig = arrange_figures_in_grid(
        figures,
        n_cols=2,
        size=(1000, 800),
    )
    combined_fig.show()

    return figures


def combine_accuracy_recall_figures(
    accuracy_fig: go.Figure,
    recall_fig: go.Figure,
    legend_title: str,
) -> go.Figure:
    combined_fig = go.Figure()
    combined_fig.add_traces(accuracy_fig.data)
    for trace in accuracy_fig.data:
        trace_copy = deepcopy(trace)
    for trace in recall_fig.data:
        trace_copy = deepcopy(trace)
        trace_copy.line.dash = "dash"
        trace_copy.showlegend = False
        combined_fig.add_trace(trace_copy)

    for metric_type, dash_type in zip(
        ["Accuracy", "Contig. Recall"],
        ["solid", "dash"],
    ):
        combined_fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode="lines",
                name=metric_type,
                line=dict(
                    dash=dash_type,
                    color="black",
                    width=3,
                ),
            )
        )

    combined_fig.update_layout(
        title="Accuracy and recall",
        xaxis_title="Epoch",
        yaxis_title="Accuracy / Contig. Recall",
        legend_title=legend_title,
    )
    return combined_fig


def plot_consecutive_window_recall_comparison(
    results: dict[str, list[pd.DataFrame]],
    legend_title: str,
) -> go.Figure:
    window_size = 50
    recall_results = {
        config_name: [
            compute_extractability(
                res,
                window_size=window_size,
            )
            for res in config_results
        ]
        for config_name, config_results in results.items()
    }
    cum_prob_means, cum_prob_std_devs = compute_mean_std_dev(
        recall_results,
        preprocess=lambda res: res["recalled"],
    )
    return plot_scalar_curves(
        cum_prob_means,
        cum_prob_std_devs,
        title=f"{window_size}-token window recall",
        xaxis_title="Epoch",
        yaxis_title="Contiguous Reacall",
        legend_title=legend_title,
    )


def compute_extractability(
    memorization_log: pd.DataFrame, window_size: int = 50
) -> pd.DataFrame:
    pos_memorization_log = reindex_positionwise(memorization_log)
    post_window_positions = (
        pos_memorization_log.index.get_level_values("token_index")
        >= window_size
    )
    rolling_sums = (
        pos_memorization_log["correct"]
        .groupby(level="epoch")
        .rolling(window=window_size, min_periods=window_size)
        .sum()
        .reset_index(level=0, drop=True)
    )
    # Check if the rolling sum is 11 (all 1's)
    recalled = (
        (rolling_sums == window_size)
        .astype(int)
        .dropna()
        .loc[post_window_positions]
    )
    return pd.DataFrame({"recalled": recalled}, index=recalled.index)


def add_threshold_markers(
    figure: go.Figure,
    threshold: float,
) -> go.Figure:
    trace_idx = -1
    for trace in figure.data:
        trace = cast(go.Scatter, trace)
        if (
            trace.name is not None
        ):  # and trace.mode in ['lines', 'lines+markers']:
            trace_idx += 1
            # print(trace.name, trace_idx)
            # Find the x position where y first exceeds the threshold
            x_position = None
            for x, y in zip(trace.x, trace.y):
                if y >= threshold:
                    x_position = x
                    break

            trace_color = trace.line.color
            if x_position is not None:
                figure.add_vline(
                    x=x_position,
                    line=dict(color=trace_color, width=2, dash="dash"),
                )
                figure.add_annotation(
                    x=x_position,
                    y=1.05,
                    text=f"<b>{int(x_position)}</b>",
                    showarrow=False,
                    xshift=15 if trace_idx == 1 else -15,
                    font=dict(
                        color=trace_color,
                        size=26,
                    ),
                )
    return figure


FIGURE_FOLDER = "memorability"
FIGURE_SIZE = (800, 600)


def produce_accuracy_recall_paper_plots(
    figures: dict[str, go.Figure],
    figure_folder: str = FIGURE_FOLDER,
    show_legend: bool = True,
) -> None:
    for figure_type, fig in figures.items():
        paper_fig = with_paper_style(
            fig,
            legend_pos=(1, 0) if show_legend else None,
            legend_yanchor="bottom",
            legend_xanchor="right",
            legend_orientation="v",
        )
        paper_fig.update_yaxes(range=[-0.05, 1.1])
        paper_fig.update_layout(
            width=FIGURE_SIZE[0],
            height=FIGURE_SIZE[1],
        )
        paper_fig.show()
        save_path = BASE_FIGURE_DIR / f"{figure_folder}/{figure_type}.pdf"
        paper_fig.write_image(str(save_path))
        print("Saved figure to", save_path)
