from dataclasses import dataclass
from typing import cast

import pandas as pd
from IPython.display import HTML
from IPython.display import Markdown as md
from IPython.display import display
from plotly import graph_objects as go

from .rule_extraction import StringPrefixes


# from IPython import get_ipython


cur_id = 0


def show_prefixes(
    string: list[str],
    prefix_mappings: dict[int, StringPrefixes],
    rules: pd.DataFrame | None = None,
    show_prefixes: bool = True,
    show_mappings: bool = True,
) -> None:
    epoch_output = {}
    for epoch, prefixes in prefix_mappings.items():
        if rules is not None:
            epoch_rules = rules.loc[epoch]
        else:
            epoch_rules = None
        epoch_output[epoch] = (
            f"Epoch {epoch}: Accuracy {prefixes.accuracy:.2%}\n\n"
            + visualize_epoch_prefixes(
                string,
                prefixes,
                epoch_rules,
                show_prefixes=show_prefixes,
                show_mappings=show_mappings,
            )
        )

    global cur_id
    id = cur_id
    cur_id += 1
    # current_cell_index = get_ipython().execution_count

    preamble = (
        """
<style>
    .slider-container {
        width: 100%;
        margin: 20px;
    }

    .slider {
        width: 100%;
        max-width: 400px;
    }

    .correct {
        color: green;
        # font-weight: bold;
    }
    .incorrect {
        color: red;
        # font-weight: bold;
    }
    .partially {
        color: orange;
        # font-weight: bold;
    }
    .no-match {
        color: gray;
        # font-weight: bold;
    }
</style>

<script>
    function setupSlider(id, codeBlocks) {
        var slider = document.getElementById(`epochSlider_${id}`);
        var codeBlock = document.getElementById(`displayBlock_${id}`);

        // Initial code block display
        codeBlock.innerHTML = codeBlocks[slider.value];

        // Update code block when slider value changes
        slider.addEventListener("input", function() {
            codeBlock.innerHTML = codeBlocks[slider.value];
        });
    }
</script>
"""
        # if id == 0
        # else ""
    )

    html_code = """
{{preamble}}

<div class="slider-container">
    <label for="epochSlider_{{id}}">Epoch</label>
    <input type="range" class="slider" min="0" max="{{num_epochs}}" value="{{num_epochs}}" step="1" id="epochSlider_{{id}}" list="epochlist_{{id}}">
    <datalist id="epochlist_{{id}}">
        {{epoch_list}}
    </datalist>
    <pre id="displayBlock_{{id}}"></pre>
</div>

<script>
    setupSlider({{id}}, {{string_blocks}});
</script>
"""
    html_code = (
        html_code.replace("{{", "[%")
        .replace("}}", "%]")
        .replace("{", "[[[")
        .replace("}", "]]]")
        .replace("[%", "{")
        .replace("%]", "}")
        .format(
            id=id,
            preamble=preamble,
            string_blocks="["
            + ",\n".join(
                f"`{epoch_res}`" for epoch_res in epoch_output.values()
            )
            + "]",
            num_epochs=len(epoch_output) - 1,
            epoch_list="\n".join(
                f"<option value='{i}'>{epoch}</option>"
                for i, epoch in enumerate(epoch_output.keys())
            ),
        )
        .replace("[[[", "{")
        .replace("]]]", "}")
    )
    # print(html_code)
    display(HTML(html_code))


def visualize_epoch_prefixes(
    string: list[str],
    prefixes: StringPrefixes,
    rules: pd.DataFrame | None = None,
    show_prefixes: bool = True,
    show_mappings: bool = True,
) -> str:
    position_correctness = [
        (
            "correct"
            if (
                min_necessary_prefix != ()
                and min_necessary_prefix == converged_prefix
            )
            else (
                "partially"
                if (min_necessary_prefix != () or converged_prefix != ())
                else "incorrect"
            )
        )
        for min_necessary_prefix, converged_prefix in zip(
            [()] + prefixes.min_necessary_prefixes,
            [()] + prefixes.converged_prefixes,
        )
    ]
    output = " " + " ".join(
        [
            _get_highlighted_token(token, token_correctness)
            for token, token_correctness in zip(string, position_correctness)
        ]
    )

    if show_prefixes:
        output += "\nPrefixes:"
        for i, (
            min_necessary_prefix,
            converged_prefix,
            target_correctness,
        ) in enumerate(
            zip(
                prefixes.min_necessary_prefixes,
                prefixes.converged_prefixes,
                position_correctness[1:],
            )
        ):
            target_token = string[i + 1]
            if len(converged_prefix) == 0:
                # The full prefix is not correct, i.e. there is no
                # converged prefix.
                converged_only_prefix = ()
            else:
                assert (
                    converged_prefix[-len(min_necessary_prefix) :]
                    == min_necessary_prefix
                ), (
                    f"{converged_prefix[-len(min_necessary_prefix):]} != {min_necessary_prefix}"
                    + f"converged prefix: {converged_prefix}, min neccesary prefix: {min_necessary_prefix}"
                )
                converged_only_prefix = converged_prefix[
                    : -len(min_necessary_prefix)
                ]

            prefix_length = max(
                0, 2 * max(len(min_necessary_prefix), len(converged_prefix)) - 1
            )
            if len(converged_only_prefix) == 0:
                # There is an additional "-" in front
                prefix_length += 1
            print_start = 2 * (i + 1) - prefix_length

            output += (
                "\n"
                + " " * print_start
                + (
                    " ".join(converged_only_prefix)
                    if converged_only_prefix
                    else ""
                )
                + ("-" if len(converged_prefix) == 0 else "+")
                + " ".join(min_necessary_prefix)
                + ":"
                + _get_highlighted_token(target_token, target_correctness)
            )

    if show_mappings:
        for mapping_name, mappings in zip(
            ["Minimum Necessary", "Converged"],
            [prefixes.min_necessary_mappings, prefixes.converged_mappings],
        ):
            output += f"\n\n{mapping_name} Prefix Mappings:"
            for prefix, target_tokens in mappings.items():
                output += (
                    "\n"
                    + "".join(prefix)
                    + ": "
                    + ", ".join(
                        f"{target_token} ({count}x)"
                        for target_token, count in target_tokens.items()
                    )
                )

    if rules is not None:
        output += "\n\nRules:"
        for i, (_, rule) in enumerate(rules.iterrows()):
            premise_tokens = tuple(rule["premise_tokens"])
            conclusion_token = str(rule["conclusion_token"])
            agreement = rule["correct"]
            rule_count = rule["count"]
            if i % 4 == 0:
                output += "\n"
            else:
                output += "  "
            output += f"<span class='{agreement}'>"
            output += "".join(premise_tokens)
            output += f": {conclusion_token} ({rule_count:2d}x)</span>"
    return output


def _get_highlighted_token(
    token: str,
    correct: str,
) -> str:
    return f"<span class='{correct}'>{token}</span>"


def plot_rule_correctness(
    rules: pd.DataFrame,
) -> go.Figure:
    """For each epoch, plot the distribution of correctness of the rules,
    i.e. how many rules are correct, incorrect, somewhat correct, or not
    applicable.
    """
    correctness_counts = rules.groupby("epoch")["correct"].value_counts()
    correctness_counts = correctness_counts.unstack(fill_value=0)
    correctness_counts = correctness_counts.reindex(
        columns=["correct", "partially", "incorrect", "no-match"],
        fill_value=0,
    )
    correctness_counts = correctness_counts.sort_index(ascending=False)

    fig = go.Figure()
    for correctness in correctness_counts.columns:
        fig.add_trace(
            go.Bar(
                x=correctness_counts.index,
                y=correctness_counts[correctness],
                name=correctness,
                marker=dict(
                    color=(
                        "green"
                        if correctness == "correct"
                        else "red"
                        if correctness == "incorrect"
                        else "orange"
                        if correctness == "partially"
                        else "gray"
                    )
                ),
            )
        )

    fig.update_layout(
        barmode="stack",
        title="Correctness of Rules per Epoch",
        xaxis_title="Epoch",
        yaxis_title="Number of Rules",
        xaxis=dict(
            type="category",
            autorange="reversed",
        ),
    )
    return fig
