from typing import Any, cast

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display
from scipy.stats import kendalltau, spearmanr
from scipy.stats._stats_py import SignificanceResult

from defs import BASE_FIGURE_DIR
from lib_llm.eval.memorization.dynamics.plots import (
    plot_correctness_over_epochs,
    plot_discrepancy,
)
from lib_project.experiment import ExperimentResult, load_results
from lib_project.notebook import publish_notebook
from lib_project.visualization import with_paper_style
from lib_project.visualization.arrange import arrange_figures_in_grid
from utils import memorization_order as mem_order

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as MDExperimentResult
from .experiment import RandomStringConfig


Result = ExperimentResult[ExperimentConfig, MDExperimentResult]


def show_order(
    results: dict[str, list[Result | pd.DataFrame]],
    title: str,
    show_epochs: list[int] = list(range(0, 40, 4)),
) -> tuple[dict[str, go.Figure], dict[str, go.Figure], pd.DataFrame]:
    memorization_results = {
        config_name: [
            (
                config_res
                if isinstance(config_res, pd.DataFrame)
                else config_res.value.memorization_log
            )
            for config_res in config_results
        ]
        for config_name, config_results in results.items()
    }

    order_heatmaps = {}
    for config_name, config_results in results.items():
        correctness_fig = plot_correctness_over_epochs(
            config_results[0].value.memorization_log,
            epochs_to_show=show_epochs,
        )
        order_heatmaps[config_name] = correctness_fig
    combined_order_fig = arrange_figures_in_grid(
        order_heatmaps,
        n_cols=2,
        size=(1000, 1400),
    )
    combined_order_fig.show()

    figures = {}
    discrepancy_fig = plot_discrepancy(memorization_results, title, seed=4019)
    discrepancy_fig.show()
    figures["discrepancy"] = discrepancy_fig

    order_correlations = compute_order_correlation(memorization_results)
    show_order_correlation(order_correlations)

    return order_heatmaps, figures, order_correlations


def show_order_correlation(
    order_correlations: pd.DataFrame,
) -> None:
    display(
        md(
            """### Rank correlations
The table shows the average Spearman and Kendall Tau
correlation coefficients between the memorization order and the
tokens' position in the string.
"""
        )
    )
    display(order_correlations)


def print_order_correlation(
    order_correlations: pd.DataFrame,
) -> None:
    """Print the order correlation table as a Latex table."""
    latex_table = order_correlations[
        [("initial", "spearman"), ("stable", "spearman")]
    ].T.to_latex(
        caption="Rank correlations between memorization order and token position",
        label="tab:order_correlations",
        float_format="%.3f",
    )
    print(latex_table)


def compute_order_correlation(
    memorization_results: dict[str, list[pd.DataFrame]],
) -> pd.DataFrame:
    memorization_orders = {
        config_name: [
            pd.DataFrame(
                {
                    "initial": mem_order.initial_memorization_order(
                        memorization_log
                    ),
                    "stable": mem_order.stable_memorization_order(
                        memorization_log
                    ),
                }
            )
            for memorization_log in memorization_logs
        ]
        for config_name, memorization_logs in memorization_results.items()
    }

    order_types = ["initial", "stable"]
    config_names = []
    memorization_correlations = {}
    for config_name, memorization_orders in memorization_orders.items():
        config_names.append(config_name)
        for order_type in order_types:
            sparman_corrs = []
            kendall_corrs = []
            for memorization_order in memorization_orders:
                memorized_tokens = memorization_order.loc[
                    memorization_order[order_type] >= 0, order_type
                ]
                increasing_order = np.arange(len(memorized_tokens))
                spearman_corr: Any = spearmanr(
                    increasing_order,
                    memorized_tokens,
                )
                kendall_corr = kendalltau(
                    increasing_order,
                    memorized_tokens,
                )
                sparman_corrs.append(spearman_corr.statistic)
                kendall_corrs.append(kendall_corr.statistic)
            mean_spearman = np.mean(sparman_corrs)
            mean_kendall = np.mean(kendall_corrs)
            memorization_correlations.setdefault(
                (order_type, "spearman"), []
            ).append(mean_spearman)
            memorization_correlations.setdefault(
                (order_type, "kendall"), []
            ).append(mean_kendall)

    columns = pd.MultiIndex.from_product([order_types, ["spearman", "kendall"]])
    correlation_df = pd.DataFrame(
        memorization_correlations,
        columns=columns,
        index=pd.Index(config_names),
    )
    return correlation_df


FIGURE_FOLDER = "memorization_order"
FIGURE_SIZE = (800, 600)


def produce_epoch_correctness_paper_plot(
    figure: go.Figure,
    model: str,
    variation_dimension: str,
    figure_folder: str = FIGURE_FOLDER,
    show_legend: bool = True,
) -> None:
    paper_fig = with_paper_style(
        figure,
        legend_pos=(1, 0) if show_legend else None,
        legend_yanchor="bottom",
        legend_orientation="v",
    )
    paper_fig.update_layout(
        width=FIGURE_SIZE[0],
        height=FIGURE_SIZE[1],
        # yaxis=dict(autorange="reversed"),
        # yaxis=dict(range=[0, 36]),
    )
    paper_fig.show()
    save_path = (
        BASE_FIGURE_DIR
        / f"{figure_folder}/order_{variation_dimension}_{model}.pdf"
    )
    paper_fig.write_image(str(save_path))
    print("Saved figure to", save_path)


def produce_discrepancy_paper_plot(
    figures: dict[str, go.Figure],
    model: str,
    variation_dimension: str,
    figure_folder: str = FIGURE_FOLDER,
    show_legend: bool = True,
) -> None:
    paper_fig = with_paper_style(
        figures["discrepancy"],
        legend_pos=(1, 0) if show_legend else None,
        legend_yanchor="bottom",
        legend_orientation="h",
    )
    paper_fig.update_layout(
        width=FIGURE_SIZE[0],
        height=FIGURE_SIZE[1],
    )
    paper_fig.show()
    save_path = (
        BASE_FIGURE_DIR
        / f"{figure_folder}/discrepancy_{variation_dimension}_{model}.pdf"
    )
    paper_fig.write_image(str(save_path))
    print("Saved figure to", save_path)
