from typing import Any, Callable, cast

import numpy as np
import pandas as pd
import plotly.graph_objects as go

from lib_dl_base.results.aggregate import (
    aggregate_mean_std_dev,
    get_color,
    plot_mean_std_dev,
)


def plot_training_loss(
    results: dict[str, list[pd.DataFrame]],
    legend_title: str,
    guess_baselines: dict[str, float] | None = None,
    show_std_dev: bool = True,
    target_column: str = "loss",
) -> go.Figure:
    mean_results, std_dev_results = compute_mean_std_dev(
        results,
        preprocess=lambda res: cast(pd.DataFrame, res[target_column]),
        compute_std_dev=show_std_dev,
    )
    return plot_scalar_curves(
        mean_results,
        std_dev_results,
        title="Training loss",
        xaxis_title="Epoch",
        yaxis_title="Loss",
        legend_title=legend_title,
        guess_baselines=guess_baselines,
    )


def plot_training_accuracy(
    results: dict[str, list[pd.DataFrame]],
    legend_title: str,
    guess_baselines: dict[str, float] | None = None,
    show_std_dev: bool = True,
    target_column: str = "correct",
) -> go.Figure:
    epoch_accuracy_means, epoch_accuracy_std_devs = compute_mean_std_dev(
        results,
        preprocess=lambda res: cast(pd.DataFrame, res[target_column]),
        compute_std_dev=show_std_dev,
    )
    return plot_scalar_curves(
        epoch_accuracy_means,
        epoch_accuracy_std_devs,
        title="Training accuracy",
        xaxis_title="Epoch",
        yaxis_title="Accuracy",
        legend_title=legend_title,
        guess_baselines=guess_baselines,
    )


def plot_training_cum_prob(
    results: dict[str, list[pd.DataFrame]],
    legend_title: str,
    token_subset: list[str] | None = None,
    show_std_dev: bool = True,
) -> go.Figure:
    cum_prob_means, cum_prob_std_devs = compute_mean_std_dev(
        results,
        preprocess=lambda res: _compute_cum_prob(
            res, token_subset=token_subset
        ),
        compute_std_dev=show_std_dev,
    )
    return plot_scalar_curves(
        cum_prob_means,
        cum_prob_std_devs,
        title="Aggregate Probability over Alphabet Tokens",
        xaxis_title="Epoch",
        yaxis_title="Aggregate Probability",
        legend_title=legend_title,
    )


def _compute_cum_prob(
    df: pd.DataFrame,
    token_subset: list[str] | None = None,
) -> pd.Series:
    """Compute the cumulative probability of each token in the vocabulary
    for each sequence in each epoch"""
    df = df.drop(columns=["correct", "loss"])
    if token_subset is not None:
        df = cast(pd.DataFrame, df[token_subset])
    df_prob = np.exp(df)
    return pd.Series(
        df_prob.sum(axis=1),
        index=df.index,
        dtype=np.float64,
    )


def plot_training_entropy(
    results: dict[str, list[pd.DataFrame]],
    legend_title: str,
    guess_baselines: dict[str, float] | None = None,
    show_std_dev: bool = True,
) -> go.Figure:
    entropy_means, entropy_std_devs = compute_mean_std_dev(
        results,
        preprocess=_compute_entropy,
        compute_std_dev=show_std_dev,
    )
    return plot_scalar_curves(
        entropy_means,
        entropy_std_devs,
        title="Entropy over Alphabet Tokens",
        xaxis_title="Epoch",
        yaxis_title="Entropy",
        legend_title=legend_title,
        guess_baselines=guess_baselines,
    )


def _compute_entropy(
    df: pd.DataFrame,
) -> pd.Series:
    """Compute the entropy of each sequence in each epoch"""
    df = df.drop(columns=["correct", "loss"])
    df_prob = cast(pd.DataFrame, np.exp(df))
    df_norm = df_prob.div(df_prob.sum(axis=1), axis=0)
    # Filter out zero values to avoid log2(0)
    non_zero_df = df_norm[df_norm > 0]
    return pd.Series(
        -(non_zero_df * np.log2(non_zero_df)).sum(axis=1),
        index=df.index,
        dtype=np.float64,
    )


def plot_training_kld(
    results: dict[str, list[pd.DataFrame]],
    legend_title: str,
    character_probabilities: dict[str, np.ndarray],
) -> go.Figure:
    kld_results = {
        config_name: [
            _compute_kld(res, character_probabilities[config_name])
            for res in config_results
        ]
        for config_name, config_results in results.items()
    }
    kld_means, kld_std_devs = compute_mean_std_dev(
        kld_results,
        preprocess=lambda res: cast(pd.Series, res["kld"]),
    )
    return plot_scalar_curves(
        kld_means,
        kld_std_devs,
        title="KLD from True Token Distribution",
        xaxis_title="Epoch",
        yaxis_title="KLD",
        legend_title=legend_title,
    )


def _compute_kld(
    df: pd.DataFrame,
    character_probabilities: np.ndarray,
) -> pd.DataFrame:
    df = df.drop(columns=["correct", "loss"])
    df_prob = cast(pd.DataFrame, np.exp(df))
    epsilon = 1e-9
    df_log = np.log2((df_prob + epsilon) / (character_probabilities + epsilon))
    kld = -(character_probabilities * df_log).sum(axis=1)
    return pd.DataFrame(
        {"kld": kld},
        index=df.index,
        dtype=np.float32,
    )


def plot_string_position_cum_prob(
    results: dict[str, list[pd.DataFrame]],
    legend_title: str,
    epoch: int,
    figure_offset: tuple[go.Figure, int] | None = None,
    token_subset: list[str] | None = None,
    show_std_dev: bool = True,
) -> go.Figure:
    """Plot the cumulative probability over the tokens in the vocabulary
    for each position in the string at a certain epoch"""
    cum_prob_means, cum_prob_std_devs = compute_mean_std_dev(
        results,
        preprocess=lambda res: _compute_string_position_cum_prob(
            res, epoch, token_subset=token_subset
        ),
        target_level="character",
        compute_std_dev=show_std_dev,
    )
    return plot_scalar_curves(
        cum_prob_means,
        cum_prob_std_devs,
        title=f"Aggregate Probability over Alphabet Tokens at Epoch {epoch}",
        xaxis_title="Position in String",
        yaxis_title="Aggregate Probability",
        legend_title=legend_title,
        figure_offset=figure_offset,
    )


def _compute_string_position_cum_prob(
    df: pd.DataFrame,
    epoch: int,
    window_size: int = 50,
    token_subset: list[str] | None = None,
) -> pd.Series:
    """Compute the cumulative probability of each token in the vocabulary
    for each position in the string at a certain epoch"""
    epoch_df = df.loc[epoch]
    new_idx = pd.MultiIndex.from_arrays(
        [
            epoch_df.index.get_level_values("string"),
            range(len(epoch_df)),
        ],
        names=epoch_df.index.names,
    )
    epoch_df.index = new_idx
    return cast(
        pd.Series,
        _compute_cum_prob(
            epoch_df,
            token_subset=token_subset,
        )
        .rolling(window_size)
        .mean(),
    )


def plot_discrepancy(
    results: dict[str, list[pd.DataFrame]],
    legend_title: str,
    seed: int,
    show_std_dev: bool = True,
) -> go.Figure:
    discrepancy_means, discrepancy_std_devs = compute_mean_std_dev(
        results,
        preprocess=lambda res: _compute_discrepancy(res, seed),
        compute_std_dev=show_std_dev,
    )
    return plot_scalar_curves(
        discrepancy_means,
        discrepancy_std_devs,
        title="Discrepancy Score",
        xaxis_title="Epoch",
        yaxis_title="Discrepancy Score",
        legend_title=legend_title,
    )


def _compute_discrepancy(
    df: pd.DataFrame,
    seed: int,
) -> pd.Series:
    """Compute the discrepancy score of each sequence in each epoch"""
    rng = np.random.default_rng(seed)
    epoch_random_center = []
    for epoch in range(50):
        epoch_correctness = df["correct"].xs(epoch, level="epoch")

        # randomize the correct position for 10 times, then compute
        # the center of the correct recollection
        random_correctness = rng.permutation(epoch_correctness)
        # get a window of 10 tokens, randomly select length//window_size
        # windows, count  [correct in window size, actual]
        # - [correct in window size, random], average for 10 samples
        window_size = 20
        length = len(epoch_correctness)
        num_windows = length // window_size
        # get the correct number for epoch_correctness and random_correctness
        window_list = []
        for _ in range(num_windows):
            # randomly select one position between
            # [window_size, length-window_size]
            start = np.random.randint(window_size, length - window_size)
            correct_num = np.sum(
                epoch_correctness[
                    start * window_size : (start + 1) * window_size
                ]
            )
            random_correct_num = np.sum(
                random_correctness[
                    start * window_size : (start + 1) * window_size
                ]
            )
            window_list.append(correct_num - random_correct_num)
        # get the average of the window_list
        epoch_random_center.append(np.mean(window_list))
    return pd.Series(
        epoch_random_center,
        index=pd.Index(
            list(range(50)),
            name="epoch",
        ),
        dtype=np.float32,
    )


def compute_mean_std_dev(
    results: dict[str, list[pd.DataFrame]],
    preprocess: (
        Callable[[pd.DataFrame], pd.DataFrame]
        | Callable[[pd.DataFrame], pd.Series]
    ),
    target_level: str = "epoch",
    compute_std_dev: bool = True,
) -> tuple[dict[str, pd.DataFrame], dict[str, pd.DataFrame] | None]:
    """Compute mean and standard deviation of results.

    Args:
        results (dict[str, list[pd.DataFrame]]): Results from the experiments.

    Returns:
        tuple[dict[str, pd.DataFrame], dict[str, pd.DataFrame]]:
        Mean and standard deviation of results.
    """
    mean_results = {}
    std_dev_results = {}
    for config_name, seed_results in results.items():
        seed_results: list[Any] = [
            # The first row has nan values, because there is no probability for
            # the first token in the string, so we drop it.
            preprocess(df).dropna(axis=0)
            for df in seed_results
        ]
        mean, std = aggregate_mean_std_dev(
            seed_results, levels_to_preserve=[target_level]
        )
        mean_results[config_name] = mean
        std_dev_results[config_name] = std
    if compute_std_dev:
        return mean_results, std_dev_results
    else:
        return mean_results, None


def plot_scalar_curves(
    result_means: dict[str, pd.DataFrame],
    result_std_devs: dict[str, pd.DataFrame] | None,
    title: str,
    xaxis_title: str,
    yaxis_title: str,
    legend_title: str,
    guess_baselines: dict[str, float] | None = None,
    figure_offset: tuple[go.Figure, int] | None = None,
) -> go.Figure:
    fig = plot_mean_std_dev(
        result_means,
        result_std_devs,
        figure_offset=figure_offset,
    )
    if guess_baselines is not None:
        # Add dashed lines indicating random guess baselines
        for i, result_name in enumerate(result_means.keys()):
            baseline = guess_baselines[result_name]
            color = get_color(i)
            fig.add_hline(
                y=baseline,
                line=dict(
                    color=color,
                    width=2,
                    dash="dash",
                ),
            )

    fig.update_layout(
        title=title,
        xaxis_title=xaxis_title,
        yaxis_title=yaxis_title,
        legend_title=legend_title,
        width=800,
        height=600,
    )
    return fig
