from typing import Any, Callable

from IPython.display import Markdown as md
from IPython.display import display

from lib_dl.analysis.experiment import ExperimentResult, load_results
from lib_dl.analysis.publish.notebook import publish_notebook
from utils.results.training_eval import (
    compute_training_loss_mean_std,
    plot_sequence_error_distribution,
    plot_sequence_token_distribution,
    plot_training_loss,
)

from .config import (
    DATA_TYPE_ARGS,
    LR_SCHEDULE_VARIATION_ARGS,
    WARMUP_STEPS_VARIATION_ARGS,
    ConfigArgs,
)
from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as MHRExperimentResult


Result = ExperimentResult[ExperimentConfig, MHRExperimentResult]


def load(
    config_name: str,
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        Result,
    )


def publish(
    notebook: str = "pythia",
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "learning_rate":
        output_path = f"experiments/{EXP_NAME}/learning_rate_wio43rif2.html"
    # elif notebook == "training_hyperparams":
    #     output_path = (
    #         f"experiments/{EXP_NAME}/training_hyperparams_398dg21ev8.html"
    #     )
    #     notebook_path = (
    #         f"./experiments/{EXP_NAME}/notebooks/training_hyperparams.ipynb"
    #     )
    elif notebook == "data_params":
        output_path = f"experiments/{EXP_NAME}/data_params_bh3982dk.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = f"./experiments/{EXP_NAME}/notebooks/{notebook}.ipynb"
    publish_notebook(
        notebook_path,
        output_path,
    )


def show_training_param_results(
    model_id: str,
    model_name: str,
) -> None:
    display(md(f"### {model_name}"))

    # display(md(f"#### Grouping by learning rate:"))
    # for learning_rate in [1e-4, 1e-5, 5e-6, 1e-6, 5e-7]:
    #     display(md(f"\n**Learning rate {learning_rate}:**"))
    #     show_constrained_results(
    #         NUM_SEQUENCES_VARIATION_ARGS,
    #         model_id,
    #         "# training sequences",
    #         constraints=(True, learning_rate, None, None, None),
    #     )

    display(md("#### Learning rate schedules"))
    for learning_rate in [1e-4, 1e-5, 1e-6]:
        display(md(f"\n**Schedules for learning rate {learning_rate}:**"))
        show_constrained_results(
            LR_SCHEDULE_VARIATION_ARGS,
            "Learning rate schedule",
            constraints={"model_id": model_id, "learning_rate": learning_rate},
        )

    display(md("#### Warmup steps"))
    for learning_rate in [1e-5, 1e-6]:
        display(md(f"\n**Warmup steps for learning rate {learning_rate}:**"))
        show_constrained_results(
            WARMUP_STEPS_VARIATION_ARGS,
            "Warmup steps",
            constraints={"model_id": model_id, "learning_rate": learning_rate},
            progress_unit="step",
        )


def show_data_param_results(
    model_id: str,
    model_name: str,
) -> None:
    display(md(f"### {model_name}"))

    # display(md(f"#### Sequence lengths:"))
    # show_constrained_results(
    #     SEQUENCE_LENGTH_ARGS,
    #     "Sequence length",
    #     constraints={"model_id": model_id},
    # )

    display(md("#### Data type and tokenization:"))
    show_constrained_results(
        DATA_TYPE_ARGS,
        "Data type and tokenization",
        constraints={"model_id": model_id},
    )


def show_constrained_results(
    source_args: dict[str, ConfigArgs],
    legend_title: str,
    constraints: dict[str, Any],
    config_descriptor: Callable[[ExperimentConfig], str] = lambda c: "",
    show_distributions: bool = True,
) -> None:
    # Filter to get the appropriate configurations to show
    model_config_args = {
        config_name: config_args
        for config_name, config_args in source_args.items()
        if all(
            getattr(config_args, arg_name) == arg_value
            for arg_name, arg_value in constraints.items()
        )
    }
    results = {
        config_name: load(config_name, list(range(1)))
        for config_name in model_config_args.keys()
    }
    # sum_exec_time = np.sum([
    #     res.running_time
    #     for arg_results in results.values()
    #     for res in arg_results
    # ])
    # print(f"\nComputing results took {sum_exec_time:.1f} seconds")

    loss_results = {
        config_name: [
            config_res.value.training_history for config_res in config_results
        ]
        for config_name, config_results in results.items()
    }
    plot_training_loss(
        loss_results,
        legend_title,
    ).show()

    # best_loss_configs = compute_best_training_loss_configs(results)
    # print(
    #     "Configuration with the best training loss at any step:",
    #     best_loss_configs["any_step"],
    # )
    # print(
    #     "Configuration with the best trainign loss after the last epoch:",
    #     best_loss_configs["last_step"],
    # )

    if show_distributions:
        # Plot the character distribution of the sequence
        sequence_idx = 0
        for config_name, config_results in results.items():
            config_res = config_results[0]
            display(md(f"### {config_descriptor(config_res.config)}"))
            distributions = config_res.value.token_distributions
            sequence: str = distributions.index.get_level_values(
                "string"
            ).unique()[sequence_idx]
            print(f"String to memorize: {sequence}")
            error_fig = plot_sequence_error_distribution(
                distributions,
                sequence,
            )
            error_fig.show()
            dist_fig = plot_sequence_token_distribution(
                distributions,
                sequence,
                log_scale=True,
            )
            dist_fig.show()


def compute_best_training_loss_configs(
    results: dict[str, list[Result]]
) -> dict[str, tuple[str, float]]:
    """Compute which configuration achieved the lowest loss at any step
    as well as at the end of training (at the last step)
    """
    mean_results, _ = compute_training_loss_mean_std(results, ["step"])

    best_config: dict[str, tuple[str, float]] = {
        "any_step": ("", float("inf")),
        "last_step": ("", float("inf")),
    }
    for config_name, res in mean_results.items():
        best_loss = res["loss"].min()
        best_loss_last_step = res["loss"].iloc[-1]
        if best_loss < best_config["any_step"][1]:
            # best_step_idx = res["loss"].argmin()
            # best_step = res["step"][best_step_idx]
            best_config["any_step"] = (config_name, best_loss)
        if best_loss_last_step < best_config["last_step"][1]:
            best_config["last_step"] = (config_name, best_loss_last_step)
    return best_config
