from typing import cast

import numpy as np
import pandas as pd
import plotly.graph_objects as go

from lib_dl_base.results.aggregate import (
    add_mean_std_dev_trace,
    aggregate_mean_std_dev,
    plot_mean_std_dev,
)


def plot_epoch_prefix_performance(
    prefix_mappings: list[pd.DataFrame],
    prefix_lengths: list[int],
    *,
    show_full_prefix: bool = True,
    show_std_dev: bool = True,
) -> go.Figure:
    """Plot at each epoch what fraction of tokens can be correctly predicted,
    for different prefix lengths.
    """
    prefix_length_mappings = {}
    for mapping in prefix_mappings:
        prefix_length_filter = mapping.index.get_level_values(
            "prefix_length"
        ).isin(prefix_lengths)
        for prefix_length, prefix_res in mapping.loc[
            prefix_length_filter,
            "plurality_correct",
        ].groupby("prefix_length"):
            prefix_length = cast(int, prefix_length)
            prefix_length_mappings.setdefault(prefix_length, []).append(
                prefix_res
            )

        if show_full_prefix:
            # Add the full_length prefix for every token
            prefix_length_mappings.setdefault("full", []).append(
                get_full_prefix_performance(mapping)
            )

    means, std_devs = {}, {}
    for prefix_length, prefix_res in prefix_length_mappings.items():
        mean, std_dev = aggregate_mean_std_dev(
            prefix_res,
            levels_to_preserve=["epoch"],
        )
        means[prefix_length] = mean
        std_devs[prefix_length] = std_dev if show_std_dev else None

    fig = plot_mean_std_dev(
        means,
        std_devs,
    )
    fig.update_layout(
        xaxis_title="Epoch",
        yaxis_title="Accuracy",
        legend_title="Prefix length",
    )
    return fig


def plot_prefix_length_performance(
    prefix_performances: dict[str, list[pd.DataFrame]],
    legend_title: str,
    epoch: int = -1,
    prefix_lengths: set[int] = {1, 2, 4, 8, 16, 32, 64, 128, 256, 512},
) -> go.Figure:
    """Show for each prefix length what fraction of token positions can be
    correctly predicted

    Args:
        prefix_performances: The dataframes with the performance for
            each prefix length
        index: The index of the plot
        name: The name of the curve to add
        threshold: The threshold to use for correctness. If None, the
            plurality prediction is used.
        figure: The figure to add the curve to.
        prefix_display_limit: The maximum number of prefix lengths to
            display.
    """
    if epoch == -1:
        epoch = cast(
            int,
            next(iter(prefix_performances.values()))[0]
            .index.get_level_values("epoch")
            .max(),
        )

    fig = go.Figure()
    for i, (config_name, config_res) in enumerate(prefix_performances.items()):
        epoch_length_results = [
            prefix_res.loc[epoch, "plurality_correct"]
            for prefix_res in config_res
        ]
        means, std_devs = aggregate_mean_std_dev(
            epoch_length_results,
            levels_to_preserve=["prefix_length"],
        )
        index, mean_values, std_dev_values = [], [], []
        for prefix_length, mean_val, std_dev_val in zip(
            means.index, means.values, std_devs.values
        ):
            if prefix_length not in prefix_lengths:
                continue
            index.append(str(prefix_length))
            mean_values.append(mean_val)
            std_dev_values.append(std_dev_val)

        full_prefix_results = [
            get_full_prefix_performance(prefix_res.loc[[epoch]])
            for prefix_res in config_res
        ]
        index.append("full")
        mean_values.append(np.mean(full_prefix_results).item())
        std_dev_values.append(np.std(full_prefix_results).item())

        add_mean_std_dev_trace(
            fig,
            i,
            mean_values=np.array(mean_values),
            std_dev_values=np.array(std_dev_values),
            x_values=index,
            name=config_name,
        )
    fig.update_layout(
        xaxis_title="Prefix length",
        yaxis_title="Accuracy",
        legend_title=legend_title,
    )
    return fig


def get_full_prefix_performance(
    prefix_performance: pd.DataFrame,
) -> pd.DataFrame:
    prefix_length_filter = prefix_performance.index.get_level_values(
        "prefix_length"
    ) == prefix_performance.index.get_level_values("token_idx")
    full_prefix_performance = (
        prefix_performance.loc[prefix_length_filter, "plurality_correct"]
        .groupby(["epoch"])
        .mean()
    )
    return full_prefix_performance
