import math
from typing import cast

import pandas as pd
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display

from defs import BASE_FIGURE_DIR
from lib_llm.eval.memorization.dynamics.plots import (
    plot_discrepancy,
    plot_sequence_error_distribution,
    plot_sequence_prob_distribution,
    plot_string_position_cum_prob,
    plot_training_accuracy,
    plot_training_cum_prob,
    plot_training_entropy,
    plot_training_kld,
    plot_training_loss,
)
from lib_project.experiment import ExperimentResult, load_results
from lib_project.notebook import publish_notebook
from lib_project.visualization import with_paper_style
from lib_project.visualization.arrange import arrange_figures_in_grid

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as MDExperimentResult
from .experiment import RandomStringConfig


Result = ExperimentResult[ExperimentConfig, MDExperimentResult]


def load(
    config_name: str | list[str],
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        # Result,
        ExperimentConfig,
        MDExperimentResult,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "dynamics_analysis/alphabet_size":
        output_path = f"experiments/{EXP_NAME}/{notebook}_3jd031hi.html"
    elif notebook == "dynamics_analysis/entropy_levels":
        output_path = f"experiments/{EXP_NAME}/{notebook}_592djow4.html"
    elif notebook == "dynamics_analysis/string_length":
        output_path = f"experiments/{EXP_NAME}/{notebook}_f9sdgs23f.html"
    elif notebook == "dynamics_analysis/partitions":
        output_path = f"experiments/{EXP_NAME}/{notebook}_g39djsa.html"
    elif notebook == "dynamics_analysis/non_latin_alphabet":
        output_path = f"experiments/{EXP_NAME}/{notebook}_420fjlaw.html"
    elif notebook == "dynamics_analysis/non_pretrained_models":
        output_path = f"experiments/{EXP_NAME}/{notebook}_dsyg31.html"
    elif notebook == "dynamics_analysis/carlini_comparison":
        output_path = f"experiments/{EXP_NAME}/{notebook}_gabe248.html"

    # Order analysis
    elif notebook == "order_analysis/memorization_order":
        output_path = f"experiments/{EXP_NAME}/{notebook}_29f2dDs.html"
    elif notebook == "order_analysis/untrained_memorization_order":
        output_path = f"experiments/{EXP_NAME}/{notebook}_52udjas.html"
    elif notebook == "order_analysis/prefix_order_rel":
        output_path = f"experiments/{EXP_NAME}/{notebook}_593kddfu.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = f"./experiments/{EXP_NAME}/{notebook}.ipynb"
    publish_notebook(
        notebook_path,
        output_path,
    )


def show_dynamics(
    results: dict[str, list[Result | pd.DataFrame]],
    title: str,
    show_loss: bool = True,
    show_accuracy: bool = True,
    show_cum_prob: bool = True,
    show_entropy: bool = True,
    show_kld: bool = True,
    show_in_context_learning: bool = True,
    data_configs: dict[str, RandomStringConfig] | None = None,
) -> dict[str, go.Figure]:
    figures = {}

    memorization_results = {
        config_name: [
            (
                config_res
                if isinstance(config_res, pd.DataFrame)
                else config_res.value.memorization_log
            )
            for config_res in config_results
        ]
        for config_name, config_results in results.items()
    }
    if data_configs is None:
        data_configs = {
            config_name: cast(
                RandomStringConfig,
                (
                    config_results[0].config.random_data
                    if hasattr(config_results[0].config, "random_data")
                    else getattr(config_results[0].config, "data")
                ),
            )
            for config_name, config_results in results.items()
        }

    if show_loss:
        guess_baselines = {
            config_name: data_config.guess_ce_loss
            for config_name, data_config in data_configs.items()
        }
        loss_fig = plot_training_loss(
            memorization_results,
            title,
            guess_baselines=guess_baselines,
        )
        figures["training_loss"] = loss_fig

    if show_accuracy:
        guess_baselines = {
            config_name: data_config.guess_accuracy
            for config_name, data_config in data_configs.items()
        }
        accuracy_fig = plot_training_accuracy(
            memorization_results,
            title,
            guess_baselines=guess_baselines,
        )
        figures["training_accuracy"] = accuracy_fig

    if show_cum_prob:
        cum_prob_fig = plot_training_cum_prob(memorization_results, title)
        figures["training_cum_prob"] = cum_prob_fig

    if show_entropy:
        guess_baselines = {
            config_name: data_config.uniform_entropy
            for config_name, data_config in data_configs.items()
        }
        entropy_fig = plot_training_entropy(
            memorization_results,
            title,
            guess_baselines=guess_baselines,
        )
        figures["training_entropy"] = entropy_fig

    if show_kld:
        character_probabilities = {
            config_name: data_config.character_probabilities
            for config_name, data_config in data_configs.items()
        }
        kld_fig = plot_training_kld(
            memorization_results,
            title,
            character_probabilities=character_probabilities,
        )
        figures["training_kld"] = kld_fig

    if show_in_context_learning:
        in_context_learning_fig = plot_string_position_cum_prob(
            memorization_results,
            title,
            epoch=0,
        )
        figures["in_context_learning"] = in_context_learning_fig

    height = int(math.ceil(len(figures) / 2)) * 450
    combined_fig = arrange_figures_in_grid(
        figures,
        n_cols=2,
        size=(1000, height),
    )
    combined_fig.show()

    # Add correctness plot
    return figures


FIGURE_FOLDER = "memorability"
FIGURE_SIZE = (800, 600)


def produce_loss_paper_plot(
    figures: dict[str, go.Figure],
    model: str,
    variation_dimension: str,
    figure_folder: str = FIGURE_FOLDER,
    show_legend: bool = True,
) -> None:
    paper_fig = with_paper_style(
        figures["training_loss"],
        legend_pos=(1, 1) if show_legend else None,
        legend_yanchor="top",
        legend_xanchor="right",
        legend_orientation="h",
    )
    paper_fig.update_yaxes(minallowed=-0.1)
    paper_fig.update_layout(
        width=FIGURE_SIZE[0],
        height=FIGURE_SIZE[1],
        yaxis_range=[-0.05, None],
    )
    paper_fig.show()
    save_path = (
        BASE_FIGURE_DIR
        / f"{figure_folder}/loss_{variation_dimension}_{model}.pdf"
    )
    paper_fig.write_image(str(save_path))
    print("Saved figure to", save_path)


def produce_accuracy_paper_plot(
    figures: dict[str, go.Figure],
    model: str,
    variation_dimension: str,
    figure_folder: str = FIGURE_FOLDER,
    show_legend: bool = True,
) -> None:
    paper_fig = with_paper_style(
        figures["training_accuracy"],
        legend_pos=(1, 0) if show_legend else None,
        legend_yanchor="bottom",
        legend_xanchor="right",
        legend_orientation="h",
    )
    paper_fig.update_yaxes(range=[-0.05, 1.05])
    paper_fig.update_layout(
        width=FIGURE_SIZE[0],
        height=FIGURE_SIZE[1],
    )
    paper_fig.show()
    save_path = (
        BASE_FIGURE_DIR
        / f"{figure_folder}/accuracy_{variation_dimension}_{model}.pdf"
    )
    paper_fig.write_image(str(save_path))
    print("Saved figure to", save_path)


def produce_cum_prob_paper_plot(
    figures: dict[str, go.Figure],
    model: str,
    variation_dimension: str,
    figure_folder: str = FIGURE_FOLDER,
    show_legend: bool = True,
) -> None:
    paper_fig = with_paper_style(
        figures["training_cum_prob"],
        legend_pos=(1, 0) if show_legend else None,
        legend_yanchor="bottom",
    )
    paper_fig.update_yaxes(range=[-0.05, 1.05])
    paper_fig.update_layout(
        width=FIGURE_SIZE[0],
        height=FIGURE_SIZE[1],
    )
    paper_fig.show()
    save_path = (
        BASE_FIGURE_DIR
        / f"{figure_folder}/cum-prob_{variation_dimension}_{model}.pdf"
    )
    paper_fig.write_image(str(save_path))
    print("Saved figure to", save_path)


def produce_entropy_paper_plot(
    figures: dict[str, go.Figure],
    model: str,
    variation_dimension: str,
    figure_folder: str = FIGURE_FOLDER,
    show_legend: bool = True,
) -> None:
    paper_fig = with_paper_style(
        figures["training_entropy"],
        legend_pos=(1, 1) if show_legend else None,
        legend_yanchor="top",
        legend_xanchor="right",
    )
    paper_fig.update_yaxes(minallowed=-0.1)
    paper_fig.update_layout(
        width=FIGURE_SIZE[0],
        height=FIGURE_SIZE[1],
    )
    paper_fig.show()
    save_path = (
        BASE_FIGURE_DIR
        / f"{figure_folder}/entropy_{variation_dimension}_{model}.pdf"
    )
    paper_fig.write_image(str(save_path))
    print("Saved figure to", save_path)


def produce_kld_paper_plot(
    figures: dict[str, go.Figure],
    model: str,
    variation_dimension: str,
    figure_folder: str = FIGURE_FOLDER,
    show_legend: bool = True,
) -> None:
    paper_fig = with_paper_style(
        figures["training_kld"],
        legend_pos=(1, 0) if show_legend else None,
        legend_yanchor="bottom",
    )
    paper_fig.update_yaxes(minallowed=-0.1)
    paper_fig.update_layout(
        width=FIGURE_SIZE[0],
        height=FIGURE_SIZE[1],
    )
    paper_fig.show()
    save_path = (
        BASE_FIGURE_DIR
        / f"{figure_folder}/kld_{variation_dimension}_{model}.pdf"
    )
    paper_fig.write_image(str(save_path))
    print("Saved figure to", save_path)


def produce_icl_paper_plot(
    figures: dict[str, go.Figure],
    model: str,
    variation_dimension: str,
    figure_folder: str = FIGURE_FOLDER,
    show_legend: bool = True,
) -> None:
    paper_fig = with_paper_style(
        figures["in_context_learning"],
        legend_pos=(0, 1) if show_legend else None,
        legend_yanchor="top",
        legend_xanchor="left",
        legend_orientation="h",
    )
    paper_fig.update_yaxes(range=[-0.05, 1.05])
    paper_fig.update_layout(
        width=FIGURE_SIZE[0],
        height=FIGURE_SIZE[1],
    )
    paper_fig.show()
    save_path = (
        BASE_FIGURE_DIR
        / f"{figure_folder}/icl_{variation_dimension}_{model}.pdf"
    )
    paper_fig.write_image(str(save_path))
    print("Saved figure to", save_path)


def show_individual_results(
    results: dict[str, Result],
) -> None:
    for result_name, result in results.items():
        display(md(f"### {result_name}"))

        strings = result.value.strings
        assert len(strings) == 1
        string = strings[0]
        print(f"Memorized string: {string}")

        memorization_log = result.value.memorization_log
        print("\n0-1 distribution:")
        plot_sequence_prob_distribution(
            memorization_log,
            string,
            log_scale=False,
        ).show()
        print("Log-scale distribution:")
        plot_sequence_prob_distribution(
            memorization_log,
            string,
            log_scale=True,
        ).show()

        # Put time on the y-axis, change color scheme
        plot_sequence_error_distribution(memorization_log, string).show()
