import dataclasses
import itertools

from recognizers.automata.automaton import (
    Symbol,
    State
)
from recognizers.automata.pushdown_automaton import (
    StackSymbol
)
from recognizers.automata.reserved import ReservedSymbol
from recognizers.string_sampling.weighted_language import String
from recognizers.automata.pushdown_automaton import PushdownAutomaton

@dataclasses.dataclass
class PreprocessedPushdownAutomaton:

    initial_state: State
    initial_stack_symbol: StackSymbol
    accept_state: State
    pop_transitions: dict[Symbol, list[tuple[State, StackSymbol, State]]]
    replace_transitions: dict[tuple[Symbol, StackSymbol, State], list[tuple[State, StackSymbol]]]
    scanning_push_transitions: dict[tuple[Symbol, StackSymbol, State], list[tuple[State, StackSymbol, StackSymbol]]]
    non_scanning_push_transitions: dict[tuple[StackSymbol, State], list[tuple[State, StackSymbol, StackSymbol]]]

    @staticmethod
    def from_pushdown_automaton(automaton: PushdownAutomaton) -> 'PreprocessedPushdownAutomaton':
        pop_transitions = {}
        replace_transitions = {}
        scanning_push_transitions = {}
        non_scanning_push_transitions = {}
        for transition, _ in automaton.transition_weights:
            def fail():
                raise ValueError(f'pushdown automaton is not in top-down normal form: {transition}')
            p = transition.state_from
            a = transition.symbol
            q = transition.state_to
            X = transition.popped_symbol
            match transition.pushed_symbols:
                case ():
                    if a != ReservedSymbol.EPSILON:
                        pop_transitions.setdefault(a, []).append((p, X, q))
                case (Y,):
                    if a == ReservedSymbol.EPSILON:
                        fail()
                    replace_transitions.setdefault((a, Y, q), []).append((p, X))
                case (Y, Z):
                    if a == ReservedSymbol.EPSILON:
                        non_scanning_push_transitions.setdefault((Y, q), []).append((p, X, Z))
                    else:
                        scanning_push_transitions.setdefault((a, Y, q), []).append((p, X, Z))
                case _:
                    fail()
        return PreprocessedPushdownAutomaton(
            initial_state=automaton.initial_state,
            initial_stack_symbol=automaton.initial_stack_symbol,
            accept_state=automaton.accept_state,
            pop_transitions=pop_transitions,
            replace_transitions=replace_transitions,
            scanning_push_transitions=scanning_push_transitions,
            non_scanning_push_transitions=non_scanning_push_transitions
        )

    def recognize(self, string: String) -> bool:
        n = len(string)

        items = {}
        def add_item(i, p, X, j, q):
            items.setdefault((i, j), set()).add((p, X, q))
        def get_items(i, j):
            return items.get((i, j)) or set()
        def has_item(i, p, X, j, q):
            return (p, X, q) in get_items(i, j)

        helper_items = {}
        def add_helper_item(i, p, Z, X, k, s):
            helper_items.setdefault((i, Z, k, s), set()).add((p, X))
        def get_helper_items(i, Z, k, s):
            return helper_items.get((i, Z, k, s)) or set()

        for j in range(1, n+1):
            w_j = string[j-1]
            for p, X, q in self.pop_transitions.get(w_j) or []:
                add_item(j-1, p, X, j, q)
            if j >= 2:
                k = j-1
                for l in range(1, j):
                    i = k - l
                    for r, Y, s in get_items(i+1, k):
                        for p, X, Z in itertools.chain(
                            self.scanning_push_transitions.get((w_j, Y, r)) or [],
                            self.non_scanning_push_transitions.get((Y, r)) or []
                        ):
                            add_helper_item(i, p, Z, X, k, s)
            for l in range(2, j+1):
                i = j - l
                for r, Y, q in get_items(i+1, j):
                    for p, X in self.replace_transitions.get((w_j, Y, r)) or []:
                        add_item(i, p, X, j, q)
                for k in range(i+1, j):
                    for s, Z, q in get_items(k, j):
                        for p, X in get_helper_items(i, Z, k, s):
                            add_item(i, p, X, j, q)
        accept = has_item(
            0,
            self.initial_state,
            self.initial_stack_symbol,
            n,
            self.accept_state
        )
        return accept
