from typing import Any, Callable

from IPython.display import Markdown as md
from IPython.display import display

from lib_llm.eval.memorization.dynamics.plots import plot_training_loss
from lib_project.experiment import ExperimentResult, load_results
from lib_project.notebook import publish_notebook

from .config import ConfigArgs
from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as MHRExperimentResult


Result = ExperimentResult[ExperimentConfig, MHRExperimentResult]


def load(
    config_name: str,
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        ExperimentConfig,
        MHRExperimentResult,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "learning_rate":
        output_path = f"experiments/{EXP_NAME}/learning_rate_wio43rif2.html"
    elif notebook == "untrained_learning_rate":
        output_path = (
            f"experiments/{EXP_NAME}/untrained_learning_rate_4810d.html"
        )
    # elif notebook == "training_hyperparams":
    #     output_path = (
    #         f"experiments/{EXP_NAME}/training_hyperparams_398dg21ev8.html"
    #     )
    #     notebook_path = (
    #         f"./experiments/{EXP_NAME}/notebooks/training_hyperparams.ipynb"
    #     )
    elif notebook == "data_params":
        output_path = f"experiments/{EXP_NAME}/data_params_bh3982dk.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = f"./experiments/{EXP_NAME}/notebooks/{notebook}.ipynb"
    publish_notebook(
        notebook_path,
        output_path,
    )


def show_results(
    results: dict[str, list[Result]],
    title: str,
) -> None:
    memorization_results = {
        config_name: [
            config_res.value.memorization_log for config_res in config_results
        ]
        for config_name, config_results in results.items()
    }

    plot_training_loss(
        memorization_results,
        title,
    ).show()


# def show_constrained_results(
#     source_args: dict[str, ConfigArgs],
#     legend_title: str,
#     constraints: dict[str, Any],
#     config_descriptor: Callable[[ExperimentConfig], str] = lambda c: "",
#     show_distributions: bool = True,
# ) -> None:
#     # Filter to get the appropriate configurations to show
#     model_config_args = {
#         config_name: config_args
#         for config_name, config_args in source_args.items()
#         if all(
#             getattr(config_args, arg_name) == arg_value
#             for arg_name, arg_value in constraints.items()
#         )
#     }
#     results = {
#         config_name: load(config_name, list(range(1)))
#         for config_name in model_config_args.keys()
#     }
#     # sum_exec_time = np.sum([
#     #     res.running_time
#     #     for arg_results in results.values()
#     #     for res in arg_results
#     # ])
#     # print(f"\nComputing results took {sum_exec_time:.1f} seconds")

#     loss_results = {
#         config_name: [
#             config_res.value.memorization_log for config_res in config_results
#         ]
#         for config_name, config_results in results.items()
#     }
#     plot_training_loss(
#         loss_results,
#         legend_title,
#     ).show()

#     # best_loss_configs = compute_best_training_loss_configs(results)
#     # print(
#     #     "Configuration with the best training loss at any step:",
#     #     best_loss_configs["any_step"],
#     # )
#     # print(
#     #     "Configuration with the best trainign loss after the last epoch:",
#     #     best_loss_configs["last_step"],
#     # )

#     if show_distributions:
#         # Plot the character distribution of the sequence
#         sequence_idx = 0
#         for config_name, config_results in results.items():
#             config_res = config_results[0]
#             display(md(f"### {config_descriptor(config_res.config)}"))
#             distributions = config_res.value.token_distributions
#             sequence: str = distributions.index.get_level_values(
#                 "string"
#             ).unique()[sequence_idx]
#             print(f"String to memorize: {sequence}")
#             error_fig = plot_sequence_error_distribution(
#                 distributions,
#                 sequence,
#             )
#             error_fig.show()
#             dist_fig = plot_sequence_token_distribution(
#                 distributions,
#                 sequence,
#                 log_scale=True,
#             )
#             dist_fig.show()


# def compute_best_training_loss_configs(
#     results: dict[str, list[Result]]
# ) -> dict[str, tuple[str, float]]:
#     """Compute which configuration achieved the lowest loss at any step
#     as well as at the end of training (at the last step)
#     """
#     mean_results, _ = compute_training_loss_mean_std(results, ["step"])

#     best_config: dict[str, tuple[str, float]] = {
#         "any_step": ("", float("inf")),
#         "last_step": ("", float("inf")),
#     }
#     for config_name, res in mean_results.items():
#         best_loss = res["loss"].min()
#         best_loss_last_step = res["loss"].iloc[-1]
#         if best_loss < best_config["any_step"][1]:
#             # best_step_idx = res["loss"].argmin()
#             # best_step = res["step"][best_step_idx]
#             best_config["any_step"] = (config_name, best_loss)
#         if best_loss_last_step < best_config["last_step"][1]:
#             best_config["last_step"] = (config_name, best_loss_last_step)
#     return best_config
