from dataclasses import dataclass
from typing import cast

import pandas as pd


@dataclass
class StringPrefixes:
    string: list[str]
    full_prefix_correct: list[bool]
    min_necessary_prefixes: list[tuple[str, ...]]
    converged_prefixes: list[tuple[str, ...]]

    @property
    def accuracy(self) -> float:
        return sum(self.full_prefix_correct) / len(self.full_prefix_correct)

    @property
    def min_necessary_mappings(self) -> dict[tuple[str, ...], dict[str, int]]:
        return self._get_mappings(self.min_necessary_prefixes)

    @property
    def converged_mappings(self) -> dict[tuple[str, ...], dict[str, int]]:
        return self._get_mappings(self.converged_prefixes)

    def _get_mappings(
        self, prefixes: list[tuple[str, ...]]
    ) -> dict[tuple[str, ...], dict[str, int]]:
        mappings: dict[tuple[str, ...], dict[str, int]] = {}
        for i, prefix in enumerate(prefixes):
            if prefix == ():
                continue
            target_token = self.string[i + 1]
            token_mapping = mappings.setdefault(prefix, {})
            token_mapping[target_token] = token_mapping.get(target_token, 0) + 1
        return mappings


def extract_prefixes(
    string: list[str],
    result: pd.DataFrame,
) -> dict[int, StringPrefixes]:
    epoch_prefixes: dict[int, StringPrefixes] = {}
    for epoch, epoch_res in result.groupby("epoch"):
        epoch = int(cast(int, epoch))
        # Prefixes that correctly predict the next token, where no
        # shorter prefix correctly predicts it.
        min_prefixes = []
        # Prefixes that correctly predict the next token, where all
        # longer prefixes also correctly predict it.
        converged_prefixes = []
        full_prefix_correct = []
        for token_idx, token_res in epoch_res.groupby("token_idx"):
            token_idx = cast(int, token_idx)
            target_token = string[token_idx]

            correct_prefixes = token_res[
                token_res["top_1_token"] == target_token
            ]
            incorrect_prefixes = token_res[
                token_res["top_1_token"] != target_token
            ]
            correct_prefix_lengths = correct_prefixes.index
            incorrect_prefix_lengths = incorrect_prefixes.index
            if len(correct_prefix_lengths) == 0:
                # No prefix correctly predicts the next token.
                min_correct_length = 0
            else:
                min_correct_length = cast(
                    int,
                    correct_prefix_lengths.get_level_values(
                        "prefix_length"
                    ).min(),
                )
            if len(incorrect_prefix_lengths) == 0:
                max_incorrect_length = min_correct_length - 1
            else:
                max_incorrect_length = cast(
                    int,
                    incorrect_prefix_lengths.get_level_values(
                        "prefix_length"
                    ).max(),
                )
            # Because we don't sample all positions correctly, latest
            # incorrect prefix may be shorter than earliest correct prefix.
            converged_length = (
                max_incorrect_length + 1
                if max_incorrect_length + 1 >= min_correct_length
                else min_correct_length
            )
            assert (
                min_correct_length <= converged_length
            ), f"{min_correct_length} > {converged_length}"

            for prefix_len, prefix_list in zip(
                [min_correct_length, converged_length],
                [min_prefixes, converged_prefixes],
            ):
                if prefix_len <= 0 or prefix_len > token_idx:
                    prefix = ()
                else:
                    prefix = tuple(string[token_idx - prefix_len : token_idx])
                prefix_list.append(prefix)

            full_prefix_correct.append(
                token_res.loc[(epoch, token_idx, token_idx), "top_1_token"]
                == target_token
            )
        epoch_prefixes[epoch] = StringPrefixes(
            string,
            full_prefix_correct,
            min_prefixes,
            converged_prefixes,
        )
    return epoch_prefixes


def process_rules(
    rules: pd.DataFrame,
    string: list[str],
    prefix_mappings: dict[int, StringPrefixes],
) -> pd.DataFrame:
    epoch_results = []
    for _, epoch_mappings in prefix_mappings.items():
        epoch_rules = rules.copy()
        epoch_rules = add_rule_counts(epoch_rules, string)
        epoch_rules = filter_out_non_occurring_rules(epoch_rules)
        epoch_rules = add_rule_correctness(epoch_rules, epoch_mappings)
        epoch_results.append(epoch_rules)
    rules = pd.concat(
        epoch_results, names=["epoch", "rule_idx"], keys=prefix_mappings.keys()
    )
    return rules


def add_rule_counts(
    rules: pd.DataFrame,
    string: list[str],
) -> pd.DataFrame:
    for idx, rule in rules.iterrows():
        premise_tokens = tuple(rule["premise_tokens"])
        conclusion_token = str(rule["conclusion_token"])
        rule_substring = list(premise_tokens) + [conclusion_token]
        rule_count = int(
            sum(
                1
                for i in range(len(string) - len(rule_substring) + 1)
                if string[i : i + len(rule_substring)] == rule_substring
            )
        )
        rules.at[idx, "count"] = rule_count
    rules["count"] = rules["count"].astype(int)
    return rules


def filter_out_non_occurring_rules(
    rules: pd.DataFrame,
) -> pd.DataFrame:
    return rules[rules["count"] > 0]


def add_rule_correctness(
    rules: pd.DataFrame,
    prefix_mappings: StringPrefixes,
) -> pd.DataFrame:
    for idx, rule in rules.iterrows():
        premise_tokens = tuple(rule["premise_tokens"])
        conclusion_token = str(rule["conclusion_token"])

        min_necessary_agreement = _check_mapping_agreement(
            premise_tokens,
            conclusion_token,
            prefix_mappings.min_necessary_mappings,
            is_min_necessary=True,
        )
        converged_agreement = _check_mapping_agreement(
            premise_tokens,
            conclusion_token,
            prefix_mappings.converged_mappings,
            is_min_necessary=False,
        )
        if min_necessary_agreement == converged_agreement:
            agreement = min_necessary_agreement
        else:
            if min_necessary_agreement == "no-match":
                agreement = converged_agreement
            elif converged_agreement == "no-match":
                agreement = min_necessary_agreement
            else:
                agreement = "partially"
        rules.at[idx, "correct"] = agreement
    rules["correct"] = rules["correct"].astype("category")
    return rules


def _check_mapping_agreement(
    rule_premise: tuple[str, ...],
    rule_conclusion: str,
    mappings: dict[tuple[str, ...], dict[str, int]],
    is_min_necessary: bool,
) -> str:
    """Check whether all and some applicable mapping prefixes agree
    with this mapping rule, i.e. whether all min necessary
    prefixes, whose premise is a subset or superset of the rule's premise,
    have the same conclusion."""
    all_agree = True
    some_agree = False
    found_matching = False
    for mapping_premise, mapping_targets in mappings.items():
        if (
            is_min_necessary
            and len(rule_premise) >= len(mapping_premise)
            and rule_premise[-len(mapping_premise) :] == mapping_premise
        ) or (
            not is_min_necessary
            and len(rule_premise) <= len(mapping_premise)
            and rule_premise == mapping_premise[-len(rule_premise) :]
        ):
            found_matching = True
            if rule_conclusion in mapping_targets:
                some_agree = True
            else:
                all_agree = False
    if not found_matching:
        return "no-match"
    if all_agree:
        return "correct"
    elif some_agree:
        return "partially"
    else:
        return "incorrect"
