from pathlib import Path
from typing import Any, Callable, Iterable, cast

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display

from lib_dl.analysis.experiment import ExperimentResult, load_results
from lib_dl.analysis.publish.notebook import publish_notebook
from utils.results.training_eval import (
    plot_sequence_error_distribution,
    plot_sequence_token_distribution,
    plot_training_accuracy,
    plot_training_loss,
)

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as MTExperimentResult


Result = ExperimentResult[ExperimentConfig, MTExperimentResult]


def load(
    config_name: str,
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        Result,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "model_type":
        output_path = f"experiments/{EXP_NAME}/model_type_ckls4329.html"
    elif notebook == "alphabet_size":
        output_path = f"experiments/{EXP_NAME}/alphabet_size_3jd031hi.html"
    elif notebook == "num_tokens":
        output_path = f"experiments/{EXP_NAME}/num_tokens_45902f2h.html"
    elif notebook == "num_partitions":
        output_path = f"experiments/{EXP_NAME}/num_partitions_23gh214.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = f"./experiments/{EXP_NAME}/notebooks/{notebook}.ipynb"
    publish_notebook(
        notebook_path,
        output_path,
    )


def set_base_storage_dir(
    results: dict[str, list[Result]], base_storage_dir: Path
) -> dict:
    for config_results in results.values():
        for result in config_results:
            result.value.base_storage_dir = base_storage_dir
    return results


def show_results_overview(
    results: dict[str, list[Result]],
    title: str,
) -> dict[str, go.Figure]:
    figures = {}

    loss_results = {
        config_name: [
            config_res.value.loss_log for config_res in config_results
        ]
        for config_name, config_results in results.items()
    }
    loss_fig = plot_training_loss(loss_results, title)
    loss_fig.show()
    figures["training_loss"] = loss_fig

    accuracy_results = {
        config_name: [
            config_res.value.memorization_log for config_res in config_results
        ]
        for config_name, config_results in results.items()
    }
    accuracy_fig = plot_training_accuracy(accuracy_results, title)
    accuracy_fig.show()
    figures["training_accuracy"] = accuracy_fig

    # Add correctness plot
    return figures


def show_individual_results(
    results: dict[str, Result],
) -> None:
    for result_name, result in results.items():
        display(md(f"### {result_name}"))

        strings = result.value.strings
        assert len(strings) == 1
        string = strings[0]
        print(f"Memorized string: {string}")

        memorization_log = result.value.memorization_log
        print("\n0-1 distribution:")
        plot_sequence_token_distribution(
            memorization_log,
            string,
            log_scale=False,
        ).show()
        print("Log-scale distribution:")
        plot_sequence_token_distribution(
            memorization_log,
            string,
            log_scale=True,
        ).show()

        # Put time on the y-axis, change color scheme
        plot_sequence_error_distribution(memorization_log, string).show()
