import dataclasses

from recognizers.automata.automaton import Symbol
from recognizers.automata.reserved import ReservedSymbol
from recognizers.dataset_generation.weighted_language import String
from recognizers.grammars.grammar import Nonterminal
from recognizers.grammars.context_free_grammar import ContextFreeGrammar

@dataclasses.dataclass
class CNFContextFreeGrammar:

    start_symbol: Nonterminal
    lexical_productions: dict[Symbol, list[Nonterminal]]
    binary_productions: dict[tuple[Nonterminal, Nonterminal], list[Nonterminal]]

    @staticmethod
    def from_context_free_grammar(grammar: ContextFreeGrammar) -> 'CNFContextFreeGrammar':
        lexical_productions = {}
        binary_productions = {}
        for production, _ in grammar.production_weights:
            def fail():
                raise ValueError(f'grammar is not in Chomsky normal form: {production}')
            X = production.left_hand_side
            match production.right_hand_side:
                case ():
                    pass
                case (a,):
                    lexical_productions.setdefault(a, set()).add(X)
                case (Y, Z):
                    binary_productions.setdefault((Y, Z), set()).add(X)
                case _:
                    fail()
        return CNFContextFreeGrammar(
            start_symbol=grammar.start_symbol,
            lexical_productions=lexical_productions,
            binary_productions=binary_productions
        )

    def recognize(self, string: String) -> bool:
        """CKY algorithm."""
        n = len(string)

        items = {}
        def add_item(i, X, j):
            items.setdefault((i, j), set()).add(X)
        def get_items(i, j):
            return items.get((i, j)) or set()
        def has_item(i, X, j):
            return X in get_items(i, j)

        for j in range(0, n):
            w_j = string[j]
            for X in self.lexical_productions.get(w_j):
                add_item(j, X, j + 1)

        for j in range(0, n + 1):
            for span in range(2, j + 1):
                i = j - span
                for k in range(i + 1, j):
                    for (Y, Z) in self.binary_productions:
                        if Y in get_items(i, k) and Z in get_items(k, j):
                            for X in self.binary_productions.get((Y, Z)):
                                add_item(i, X, j)
        accept = has_item(
            0,
            self.start_symbol,
            n
        )
        return accept
