from recognizers.graph_traversal import reachable
from recognizers.utils import group_by

from .grammar import Variable, Rule
from .context_free_grammar import (
    ContextFreeGrammar,
    ContextFreeGrammarContainer
)

def trim_cnf_cfg(cfg: ContextFreeGrammar) -> ContextFreeGrammarContainer:
    new_variables = non_empty_variables(cfg)
    # Get the set of variables reachable from the start variable.
    edges = {}
    for rule in cfg.rules():
        if len(rule.right) == 2 and all(A in new_variables for A in rule.right):
            edges.setdefault(rule.left, set()).update(rule.right)
    new_variables = reachable([cfg.start_variable()], lambda A: edges.get(A, []))
    new_rules = [
        rule
        for rule in cfg.rules()
        if (
            len(rule.right) == 0 or
            (len(rule.right) == 1 and rule.left in new_variables) or
            (len(rule.right) == 2 and all(A in new_variables for A in (rule.left, *rule.right)))
        )
    ]
    new_terminals = { rule.right[0] for rule in new_rules if len(rule.right) == 1 }
    sorted_new_variables = [A for A in cfg.variables() if A in new_variables]
    sorted_new_terminals = [a for a in cfg.terminals() if a in new_terminals]
    new_cfg = ContextFreeGrammarContainer(
        num_variables=len(sorted_new_variables),
        num_terminals=len(sorted_new_terminals)
    )
    vt_map = {}
    vt_map.update(zip(sorted_new_variables, new_cfg.variables()))
    vt_map.update(zip(sorted_new_terminals, new_cfg.terminals()))
    new_cfg.set_start_variable(vt_map[cfg.start_variable()])
    for rule in new_rules:
        new_cfg.add_rule(Rule(
            vt_map[rule.left],
            tuple(vt_map[X] for X in rule.right)
        ))
    return new_cfg

def non_empty_variables(cfg: ContextFreeGrammar) -> set[Variable]:
    S = cfg.start_variable()
    binary_rules = [rule for rule in cfg.rules() if len(rule.right) == 2 and rule.left != S]
    preterminals = { rule.left for rule in cfg.rules() if len(rule.right) == 1 and rule.left != S }
    curr_variables = preterminals.copy()
    while True:
        next_variables = preterminals.copy()
        for rule in binary_rules:
            if all(A in curr_variables for A in rule.right):
                next_variables.add(rule.left)
        if len(next_variables) == len(curr_variables):
            return curr_variables
        curr_variables = next_variables
