from typing import Any, Iterable

import pandas as pd
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display
from plotly.subplots import make_subplots

from lib_dl.analysis.aggregate import aggregate_mean_std
from lib_dl.analysis.experiment import ExperimentResult, load_results
from lib_dl.analysis.publish.notebook import publish_notebook

from .config import PREFIX_LENGTH_ARGS, ConfigArgs
from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as CSExperimentResult


Result = ExperimentResult[ExperimentConfig, CSExperimentResult]


def load(
    config_name: str,
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        Result,
    )


def publish(
    notebook: str = "pythia",
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "eval":
        output_path = f"experiments/{EXP_NAME}/eval_t93jdtg83.html"
        notebook_path = f"./experiments/{EXP_NAME}/notebooks/eval.ipynb"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    publish_notebook(
        notebook_path,
        output_path,
    )


def show_constrained_results(
    source_args: dict[str, ConfigArgs],
    constraints: dict[str, Any],
) -> None:
    # Filter to get the appropriate configurations to show
    model_config_args = {
        config_name: config_args
        for config_name, config_args in source_args.items()
        if all(
            getattr(config_args, arg_name) == arg_value
            for arg_name, arg_value in constraints.items()
        )
    }
    results = {
        config_name: load(config_name, list(range(1)))
        for config_name in model_config_args.keys()
    }

    context_accuracy_fig = plot_context_results(
        [config_results[0] for config_results in results.values()],
        "correct",
        "Accuracy",
    )
    context_accuracy_fig.show()
    context_kld_fig = plot_context_results(
        [config_results[0] for config_results in results.values()],
        "kld",
        "KLD from full prefix",
    )
    context_kld_fig.show()


def plot_context_results(
    config_results: Iterable[Result],
    col_name: str,
    metric_name: str,
) -> go.Figure:
    fig = go.Figure()

    combined_config_resuls = pd.concat(
        {
            config_res.config.context_probing.prefix_length: config_res.value.context_results
            for config_res in config_results
        },
        names=["prefix_length", "context_type", "sequence"],
    )
    for context_type, context_results in combined_config_resuls.groupby(
        "context_type"
    ):
        context_results = (
            context_results.groupby("prefix_length").mean().reset_index()
        )
        fig.add_trace(
            go.Scatter(
                x=context_results["prefix_length"],
                y=context_results[col_name],
                name=context_type,
                mode="lines+markers",
            ),
        )
    fig.update_layout(
        xaxis_title="Prefix length",
        yaxis_title=metric_name,
        legend_title="Context type",
        title=f"{metric_name}",
        height=500,
    )
    fig.update_xaxes(type="category")
    return fig
