from collections import Counter
from typing import Callable, Iterator, Sequence, cast

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from IPython.display import Markdown as md
from IPython.display import display

from experiments.memorization_dynamics.results import (
    produce_accuracy_paper_plot,
    produce_cum_prob_paper_plot,
    produce_entropy_paper_plot,
    produce_kld_paper_plot,
    produce_loss_paper_plot,
    show_dynamics,
)
from lib_project.experiment import ExperimentResult, load_results
from lib_project.notebook import publish_notebook
from lib_project.visualization import with_paper_style
from lib_project.visualization.arrange import arrange_figures_in_grid

from .experiment import EXP_NAME, ExperimentConfig
from .experiment import ExperimentResult as CPMDExperimentResult


Result = ExperimentResult[ExperimentConfig, CPMDExperimentResult]


def load(
    config_name: str | list[str],
    seed_ids: list[int],
) -> list[Result]:
    return load_results(
        EXP_NAME,
        config_name,
        seed_ids,
        # Result,
        ExperimentConfig,
        CPMDExperimentResult,
    )


def publish(
    notebook: str,
) -> None:
    # Use a random postfix to make it harder to guess the file name
    if notebook == "relative_probability":
        output_path = f"experiments/{EXP_NAME}/{notebook}_dk29sdj.html"
    elif notebook == "ngram_length":
        output_path = f"experiments/{EXP_NAME}/{notebook}_3g03hs.html"
    else:
        raise ValueError(f"Unknown notebook: {notebook}")
    notebook_path = (
        f"./experiments/{EXP_NAME}/dynamics_analysis/{notebook}.ipynb"
    )
    publish_notebook(
        notebook_path,
        output_path,
    )


def show_eANONYMOUSrical_entropies(
    results: dict[str, list[Result]],
    condition_length: int | list[int],
    variation_name: str,
    variation_labels: list[str],
) -> None:
    assert len(variation_labels) == len(results), (
        f"Expected {len(variation_labels)} results, "
        f"but got {len(results)} instead."
    )
    if isinstance(condition_length, list):
        assert len(condition_length) == len(variation_labels), (
            f"Expected {len(variation_labels)} condition lengths, "
            f"but got {len(condition_length)} instead."
        )
        conditional_entropy_label = "Conditional"
    else:
        condition_length = [condition_length] * len(variation_labels)
        conditional_entropy_label = f"Conditional (n = {condition_length})"
    unconditional_entropies = []
    conditional_entropies = []
    for config_results, config_cond_length in zip(
        results.values(),
        condition_length,
    ):
        strings = [
            (
                res.value.random_data.raw_tokens
                if hasattr(res.value, "random_data")
                else res.value.data.raw_tokens
            )[0]
            for res in config_results
        ]
        for cond_length, target_arr in zip(
            [0, config_cond_length],
            [unconditional_entropies, conditional_entropies],
        ):
            mean_entropy = np.mean(
                [
                    eANONYMOUSrical_entropy(
                        string,
                        condition_length=cond_length,
                    )
                    for string in strings
                ]
            )
            target_arr.append(mean_entropy)

    entropies = pd.DataFrame(
        {
            "Unconditional": unconditional_entropies,
            conditional_entropy_label: conditional_entropies,
        },
        index=pd.Index(variation_labels, name=variation_name),
    )
    print("EANONYMOUSrical entropies:")
    with pd.option_context("display.float_format", "{:.3f}".format):
        display(entropies)


def eANONYMOUSrical_entropy(
    seq: Sequence,
    condition_length: int = 0,
) -> float:
    """Compute the (conditional) eANONYMOUSrical entropy over a sequence,
    i.e. based on the frequency with which each element appears in the
    collection.

    Args:
        seq (Sequence): The sequence to compute the entropy of.
        condition_length (int, optional): The length of the condition.
            Defaults to 0, i.e. unconditional entropy.
    """
    conditional_entropies = {}
    for i in range(condition_length, len(seq)):
        condition = tuple(seq[i - condition_length : i])
        condition_counter = conditional_entropies.get(condition, Counter())
        condition_counter[seq[i]] += 1
        conditional_entropies[condition] = condition_counter

    total = len(seq) - condition_length
    entropy = 0
    for condition, counts in conditional_entropies.items():
        condition_continuation_probabilities = np.array(
            list(counts.values())
        ) / sum(counts.values())
        condition_entropy = -sum(
            condition_continuation_probabilities
            * np.log2(condition_continuation_probabilities)
        )
        condition_probability = sum(counts.values()) / total
        entropy += condition_probability * condition_entropy
    return entropy

    # counter = Counter(collection)
    # total = sum(counter.values())
    # probabilities = np.array(count / total for count in counter.values())
    # return -sum(probabilities * np.log2(probabilities))


def produce_paper_plots(
    figures: dict[str, dict[str, go.Figure]],
    model: str,
    variation_dimension: str,
    figure_folder: str,
    show_legend: bool = True,
):
    for alphabet_size, alphabet_figures in figures.items():
        produce_accuracy_paper_plot(
            alphabet_figures,
            model,
            variation_dimension=f"{variation_dimension}_{alphabet_size}",
            figure_folder=figure_folder,
            show_legend=show_legend,
        )
