import os
from typing import Any, Callable

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display
from plotly.subplots import make_subplots

from lib_dl.analysis.aggregate import aggregate_mean_std
from lib_dl.analysis.experiment import ExperimentResult, load_results
from lib_dl.analysis.publish.notebook import publish_notebook

from .config import (
    DATA_TYPE_ARGS,
    LR_SCHEDULE_VARIATION_ARGS,
    NUM_SEQUENCES_VARIATION_ARGS,
    SEQUENCE_LENGTH_ARGS,
    WARMUP_STEPS_VARIATION_ARGS,
    ConfigArgs,
)
from .experiment import EXP_NAME, SDDDExperimentConfig, SDDDExperimentResult


SDDDResult = ExperimentResult[SDDDExperimentConfig, SDDDExperimentResult]


def load(
    config_name: str,
    seed_ids: list[int],
) -> list[SDDDResult]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        SDDDResult,
    )


def publish(
    notebook: str = "pythia",
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "pythia":
        output_path = f"experiments/{EXP_NAME}/pythia_dkjwl4209f.html"
        notebook_path = f"./experiments/{EXP_NAME}/notebooks/pythia.ipynb"
    elif notebook == "training_hyperparams":
        output_path = (
            f"experiments/{EXP_NAME}/training_hyperparams_398dg21ev8.html"
        )
        notebook_path = (
            f"./experiments/{EXP_NAME}/notebooks/training_hyperparams.ipynb"
        )
    elif notebook == "data_params":
        output_path = f"experiments/{EXP_NAME}/data_params_bh3982dk.html"
        notebook_path = f"./experiments/{EXP_NAME}/notebooks/data_params.ipynb"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    publish_notebook(
        notebook_path,
        output_path,
    )


def show_training_param_results(
    model_id: str,
    model_name: str,
) -> None:
    display(md(f"### {model_name}"))

    display(md("#### Grouping by the number of training sequences:"))
    for num_sequences in [1, 16, 256]:
        display(md(f"\n**{num_sequences} sequences:**"))
        show_constrained_results(
            NUM_SEQUENCES_VARIATION_ARGS,
            "Learning rate",
            constraints={"model_id": model_id, "num_sequences": num_sequences},
        )

    # display(md(f"#### Grouping by learning rate:"))
    # for learning_rate in [1e-4, 1e-5, 5e-6, 1e-6, 5e-7]:
    #     display(md(f"\n**Learning rate {learning_rate}:**"))
    #     show_constrained_results(
    #         NUM_SEQUENCES_VARIATION_ARGS,
    #         model_id,
    #         "# training sequences",
    #         constraints=(True, learning_rate, None, None, None),
    #     )

    display(md("#### Learning rate schedules"))
    for learning_rate in [1e-4, 1e-5, 1e-6]:
        display(md(f"\n**Schedules for learning rate {learning_rate}:**"))
        show_constrained_results(
            LR_SCHEDULE_VARIATION_ARGS,
            "Learning rate schedule",
            constraints={"model_id": model_id, "learning_rate": learning_rate},
        )

    display(md("#### Warmup steps"))
    for learning_rate in [1e-5, 1e-6]:
        display(md(f"\n**Warmup steps for learning rate {learning_rate}:**"))
        show_constrained_results(
            WARMUP_STEPS_VARIATION_ARGS,
            "Warmup steps",
            constraints={"model_id": model_id, "learning_rate": learning_rate},
            progress_unit="step",
        )


def show_data_param_results(
    model_id: str,
    model_name: str,
) -> None:
    display(md(f"### {model_name}"))

    # display(md(f"#### Sequence lengths:"))
    # show_constrained_results(
    #     SEQUENCE_LENGTH_ARGS,
    #     "Sequence length",
    #     constraints={"model_id": model_id},
    # )

    display(md("#### Data type and tokenization:"))
    show_constrained_results(
        DATA_TYPE_ARGS,
        "Data type and tokenization",
        constraints={"model_id": model_id},
    )


def show_constrained_results(
    source_args: dict[str, ConfigArgs],
    legend_title: str,
    constraints: dict[str, Any],
    config_descriptor: Callable[[SDDDExperimentConfig], str] = lambda c: "",
    progress_unit: str = "epoch",
) -> None:
    # Filter to get the appropriate configurations to show
    model_config_args = {
        config_name: config_args
        for config_name, config_args in source_args.items()
        if all(
            getattr(config_args, arg_name) == arg_value
            for arg_name, arg_value in constraints.items()
        )
    }
    results = {
        config_name: load(config_name, list(range(1)))
        for config_name in model_config_args.keys()
    }
    # sum_exec_time = np.sum([
    #     res.running_time
    #     for arg_results in results.values()
    #     for res in arg_results
    # ])
    # print(f"\nComputing results took {sum_exec_time:.1f} seconds")

    plot_training_loss(
        results, legend_title, progress_unit=progress_unit
    ).show()

    # best_loss_configs = compute_best_training_loss_configs(results)
    # print(
    #     "Configuration with the best training loss at any step:",
    #     best_loss_configs["any_step"],
    # )
    # print(
    #     "Configuration with the best trainign loss after the last epoch:",
    #     best_loss_configs["last_step"],
    # )

    # Plot the character distribution of the sequence
    sequence_idx = 0
    for config_name, config_results in results.items():
        config_res = config_results[0]
        display(md(f"### {config_descriptor(config_res.config)}"))
        distributions = config_res.value.token_distributions
        sequence: str = distributions.index.get_level_values(
            "sequence"
        ).unique()[sequence_idx]
        print(f"String to memorize: {sequence}")
        error_fig = plot_sequence_error_distribution(
            distributions,
            sequence,
        )
        error_fig.show()
        dist_fig = plot_sequence_token_distribution(
            distributions,
            sequence,
        )
        dist_fig.show()


def plot_training_loss(
    results: dict[str, list[SDDDResult]],
    legend_title: str,
    progress_unit: str,
) -> go.Figure:
    mean_results, std_results = compute_training_loss_mean_std(
        results, [progress_unit]
    )

    name_prefix = os.path.commonprefix(list(mean_results.keys()))
    mean_results = {
        res_name[len(name_prefix) :]: res
        for res_name, res in mean_results.items()
    }
    std_results = {
        res_name[len(name_prefix) :]: res
        for res_name, res in std_results.items()
    }

    fig = go.Figure()
    for res_name, mean_res in mean_results.items():
        std_res = std_results[res_name]
        fig.add_trace(
            go.Scatter(
                x=mean_res[progress_unit],
                y=mean_res["loss"],
                error_y=dict(
                    type="data",
                    array=std_res["loss"],
                    visible=True,
                ),
                name=res_name,
            )
        )
    fig.update_layout(
        title="Training loss",
        xaxis_title=progress_unit,
        yaxis_title="Loss",
        # legend_title_text="Learning rate",
        legend_title_text=legend_title,
        width=800,
        height=600,
    )
    return fig


def compute_best_training_loss_configs(
    results: dict[str, list[SDDDResult]]
) -> dict[str, tuple[str, float]]:
    """Compute which configuration achieved the lowest loss at any step
    as well as at the end of training (at the last step)
    """
    mean_results, _ = compute_training_loss_mean_std(results, ["step"])

    best_config: dict[str, tuple[str, float]] = {
        "any_step": ("", float("inf")),
        "last_step": ("", float("inf")),
    }
    for config_name, res in mean_results.items():
        best_loss = res["loss"].min()
        best_loss_last_step = res["loss"].iloc[-1]
        if best_loss < best_config["any_step"][1]:
            # best_step_idx = res["loss"].argmin()
            # best_step = res["step"][best_step_idx]
            best_config["any_step"] = (config_name, best_loss)
        if best_loss_last_step < best_config["last_step"][1]:
            best_config["last_step"] = (config_name, best_loss_last_step)
    return best_config


def compute_training_loss_mean_std(
    results: dict[str, list[SDDDResult]],
    additional_columns: list[str],
) -> tuple[dict[str, pd.DataFrame], dict[str, pd.DataFrame]]:
    mean_results = {}
    std_results = {}
    for res_name, res_list in results.items():
        training_histories = [
            (
                res.value.training_history[["loss", *additional_columns]]
                .rename_axis("iter", axis=0)
                .dropna(axis=0)
            )
            for res in res_list
            # if hasattr(res.value, "training_history")
        ]
        # if len(training_histories) == 0:
        #     continue
        mean_results[res_name], std_results[res_name] = aggregate_mean_std(
            training_histories,
            levels_to_preserve=["iter"],
        )
    return mean_results, std_results


def plot_sequence_token_distribution(
    token_distribution: pd.DataFrame,
    sequence: str,
) -> go.Figure:
    """Creates a heatmap over the token probabilities that is animated
    over the differnt epoch timesteps.
    """
    sequence_distribution = token_distribution.loc[
        token_distribution.index.get_level_values("sequence") == sequence
    ]
    return plot_sequence_distribution(sequence_distribution, "distribution")


def plot_sequence_error_distribution(
    token_distribution: pd.DataFrame,
    sequence: str,
) -> go.Figure:
    sequence_distribution = token_distribution.loc[
        token_distribution.index.get_level_values("sequence") == sequence
    ]
    max_prob_cols = sequence_distribution.idxmax(axis=1)
    target_characters = pd.DataFrame(
        {"target": sequence_distribution.index.get_level_values("character")},
        index=sequence_distribution.index,
    )
    correct_tokens = target_characters["target"] == max_prob_cols
    correct_df = correct_tokens.to_frame(name="correct").astype(int)
    return plot_sequence_distribution(correct_df, "correctness")


def plot_sequence_distribution(
    distribution: pd.DataFrame,
    description: str,
) -> go.Figure:
    """Creates a heatmap over a token distribution that is animated
    over the differnt epoch timesteps.
    """

    max_val = distribution.max().max()
    min_val = distribution.min().min()
    frame_0 = None
    frames = []
    for epoch, epoch_dist in distribution.groupby("epoch"):
        epoch_dist = epoch_dist.droplevel("epoch")
        heatmap = go.Heatmap(
            z=epoch_dist.T,
            y=epoch_dist.columns,
            zmin=min_val,
            zmax=max_val,
        )
        if epoch == 0:
            frame_0 = heatmap
        frames.append(
            go.Frame(
                data=[heatmap],
                name=str(epoch),
                layout=go.Layout(
                    title_text=f"Epoch {epoch} {description}",
                    xaxis_title="Token position",
                ),
            ),
        )

    fig = go.Figure(
        data=[frame_0],
        frames=frames,
        layout=go.Layout(
            title=f"Epoch 0 {description}",
            xaxis_title="Token position",
            width=1000,
            height=500,
            updatemenus=[
                {
                    "buttons": [
                        {
                            "args": [
                                None,
                                {
                                    "frame": {
                                        "duration": 500,
                                        "redraw": True,
                                    },
                                    "fromcurrent": True,
                                },
                            ],
                            "label": "Play",
                            "method": "animate",
                        },
                        {
                            "args": [
                                [None],
                                {
                                    "frame": {
                                        "duration": 0,
                                        "redraw": True,
                                    },
                                    "mode": "immediate",
                                    "transition": {"duration": 0},
                                },
                            ],
                            "label": "Pause",
                            "method": "animate",
                        },
                    ],
                    "direction": "left",
                    "pad": {"r": 10, "t": 87},
                    "type": "buttons",
                    "x": 0.1,
                    "xanchor": "right",
                    "y": 0,
                    "yanchor": "top",
                },
            ],
            sliders=[
                {
                    "steps": [
                        {
                            "args": [
                                [frame.name],
                                {
                                    "frame": {
                                        "duration": 300,
                                        "redraw": True,
                                    },
                                    "mode": "immediate",
                                    "transition": {"duration": 0},
                                },
                            ],
                            "label": str(frame.name),
                            "method": "animate",
                        }
                        for frame in frames
                    ],
                    "transition": {"duration": 0},
                    "x": 0.1,
                    "y": 0,
                    "yanchor": "top",
                    "xanchor": "left",
                    "currentvalue": {
                        "font": {"size": 20},
                        "prefix": "Epoch:",
                        "visible": True,
                        "xanchor": "right",
                    },
                    "len": 0.9,
                    "pad": {"t": 50, "b": 10},
                }
            ],
        ),
    )
    return fig
