from recognizers.graph_traversal import reachable
from .finite_automaton import (
    FiniteAutomaton,
    FiniteAutomatonContainer,
    FiniteAutomatonTransition
)

def trim_dfa(dfa: FiniteAutomaton) -> FiniteAutomatonContainer:
    # Get the set of states reachable from the start state.
    transitions_by_state_from = {}
    for t in dfa.transitions():
        transitions_by_state_from.setdefault(t.state_from, set()).add(t.state_to)
    new_states = reachable([dfa.initial_state()], lambda q: transitions_by_state_from.get(q, []))
    # Get the set of remaining states that can reach an accept state.
    transitions_by_state_to = {}
    for t in dfa.transitions():
        if t.state_from in new_states and t.state_to in new_states:
            transitions_by_state_to.setdefault(t.state_to, set()).add(t.state_from)
    sorted_new_accept_states = [q for q in dfa.states() if dfa.is_accept_state(q) and q in new_states]
    new_states = reachable(sorted_new_accept_states, lambda q: transitions_by_state_to.get(q, []))
    new_transitions = [t for t in dfa.transitions() if t.state_from in new_states and t.state_to in new_states]
    new_alphabet = { t.symbol for t in new_transitions }
    new_states.add(dfa.initial_state())
    sorted_new_states = [q for q in dfa.states() if q in new_states]
    states_map = { q : i for i, q in enumerate(sorted_new_states) }
    sorted_new_alphabet = [a for a in dfa.alphabet() if a in new_alphabet]
    alphabet_map = { a : i for i, a in enumerate(sorted_new_alphabet) }
    new_dfa = FiniteAutomatonContainer(
        num_states=len(states_map),
        alphabet_size=max(len(alphabet_map), 1),
        initial_state=states_map[dfa.initial_state()]
    )
    for t in new_transitions:
        new_dfa.add_transition(FiniteAutomatonTransition(
            states_map[t.state_from],
            alphabet_map[t.symbol],
            states_map[t.state_to]
        ))
    for f in sorted_new_accept_states:
        new_dfa.add_accept_state(states_map[f])
    return new_dfa
