from recognizers.automata.automaton import Symbol
from recognizers.grammars.grammar import Variable, Rule
from recognizers.grammars.context_free_grammar import ContextFreeGrammarContainer
from recognizers.grammars.trim_cnf_cfg import trim_cnf_cfg

def make_cfg(num_variables, num_terminals, get_rules):
    G = ContextFreeGrammarContainer(
        num_variables=num_variables,
        num_terminals=num_terminals
    )
    for left, right in get_rules(G.variables(), G.terminals()):
        assert G.is_variable(left)
        match len(right):
            case 2:
                assert all(G.is_variable(X) for X in right)
            case 1:
                assert all(G.is_terminal(X) for X in right)
            case 0:
                assert left == G.start_variable()
            case _:
                raise ValueError
        G.add_rule(Rule(Variable(left), tuple(right)))
    return G

def cfg_to_result(G):
    return (
        G.num_variables(),
        G.num_terminals(),
        {(r.left, r.right) for r in G.rules()},
        G.start_variable()
    )

def result(num_variables, num_terminals, get_rules):
    G = make_cfg(num_variables, num_terminals, get_rules)
    G = trim_cnf_cfg(G)
    return cfg_to_result(G)

def expected(num_variables, num_terminals, get_rules):
    return cfg_to_result(make_cfg(num_variables, num_terminals, get_rules))

def test_simple_already_trim():
    def rules(V, T):
        S, A1, A2 = V
        a, b = T
        return [
            (S, [A1, A2]),
            (A1, [a]),
            (A2, [b])
        ]
    assert result(3, 2, rules) == expected(3, 2, rules)

def test_unreachable():
    def rules(V, T):
        S, A1, A2, A3, A4, A5, A6, A7 = V
        a, = T
        return [
            (S, [A1, A2]),
            (A1, [a]),
            (A2, [A5, A7]),
            (A5, [a]),
            (A7, [a]),
            (A4, [A6, A3]),
            (A6, [a]),
            (A3, [A6, A4])
        ]
    def new_rules(V, T):
        S, A1, A2, A5, A7 = V
        a, = T
        return [
            (S, [A1, A2]),
            (A1, [a]),
            (A2, [A5, A7]),
            (A5, [a]),
            (A7, [a])
        ]
    assert result(8, 1, rules) == expected(5, 1, new_rules)

def test_just_epsilon():
    def rules(V, T):
        S, = V
        return [
            (S, ())
        ]
    assert result(1, 0, rules) == expected(1, 0, rules)

def test_just_epsilon_extra_terminals():
    def rules(V, T):
        S, = V
        return [
            (S, ())
        ]
    assert result(1, 10, rules) == expected(1, 0, rules)

def test_no_rules():
    def rules(V, T):
        S, = V
        return []
    assert result(1, 0, rules) == expected(1, 0, rules)

def test_simple_has_empty_variables():
    def rules(V, T):
        S, A1, A2, A3, A4 = V
        a, = T
        return [
            (S, [A1, A2]),
            (S, [A2, A2]),
            (A1, [A3, A4]),
            (A2, [a])
        ]
    def new_rules(V, T):
        S, A2 = V
        a, = T
        return [
            (S, [A2, A2]),
            (A2, [a])
        ]
    assert result(5, 1, rules) == expected(2, 1, new_rules)

def test_has_unreachable_and_empty():
    def rules(V, T):
        S, A1, A5, A2, A6, A3, A7, A4 = V
        a, = T
        return [
            (S, [A1, A2]),
            (S, [A2, A2]),
            (A1, [A3, A4]),
            (A2, [a]),
            (A5, [A6, A7]),
            (A6, [a]),
            (A7, [a])
        ]
    def new_rules(V, T):
        S, A2 = V
        a, = T
        return [
            (S, [A2, A2]),
            (A2, [a])
        ]
    assert result(8, 1, rules) == expected(2, 1, new_rules)

def test_extra_variables():
    def rules(V, T):
        S, _, A1, _, _, A2, *_ = V
        a, = T
        return [
            (S, [A1, A2]),
            (A1, [a]),
            (A2, [a])
        ]
    def new_rules(V, T):
        S, A1, A2 = V
        a, = T
        return [
            (S, [A1, A2]),
            (A1, [a]),
            (A2, [a])
        ]
    assert result(10, 1, rules) == expected(3, 1, new_rules)

def test_extra_variables_and_terminals():
    def rules(V, T):
        S, _, A1, _, _, A2, *_ = V
        _, _, a, _, b, *_ = T
        return [
            (S, [A1, A2]),
            (A1, [a]),
            (A2, [b])
        ]
    def new_rules(V, T):
        S, A1, A2 = V
        a, b = T
        return [
            (S, [A1, A2]),
            (A1, [a]),
            (A2, [b])
        ]
    assert result(13, 17, rules) == expected(3, 2, new_rules)

def test_a_star():
    def rules(V, T):
        S, A, B = V
        a, = T
        return [
            (S, []),
            (S, [a]),
            (S, [A, B]),
            (A, [a]),
            (B, [A, B]),
            (B, [a])
        ]
    assert result(3, 1, rules) == expected(3, 1, rules)

def test_a_star_missing_rule():
    def rules(V, T):
        S, A, B = V
        a, = T
        return [
            (S, []),
            (S, [a]),
            (S, [A, B]),
            (A, [a]),
            (B, [A, B])
        ]
    def new_rules(V, T):
        S, = V
        a, = T
        return [
            (S, []),
            (S, [a])
        ]
    assert result(3, 1, rules) == expected(1, 1, new_rules)

def test_empty_2():
    def rules(V, T):
        S, A1, A2, A3, A4, A5, A6 = V
        a, = T
        return [
            (S, [A1, A2]),
            (S, [A2, A3]),
            (A1, [A3, A4]),
            (A3, [a]),
            (A2, [A5, A6]),
            (A5, [a]),
            (A6, [a])
        ]
    def new_rules(V, T):
        S, A2, A3, A5, A6 = V
        a, = T
        return [
            (S, [A2, A3]),
            (A3, [a]),
            (A2, [A5, A6]),
            (A5, [a]),
            (A6, [a])
        ]
    assert result(7, 1, rules) == expected(5, 1, new_rules)

def test_unreachable_and_empty_2():
    def rules(V, T):
        S, A1, A2, A3, A4, A5, A6 = V
        a, = T
        return [
            (S, [A1, A2]),
            (A1, [a]),
            (A2, [a]),
            (A3, [A2, A4]),
            (A4, [a]),
            (S, [A5, A6]),
            (A6, [a])
        ]
    def new_rules(V, T):
        S, A1, A2 = V
        a, = T
        return [
            (S, [A1, A2]),
            (A1, [a]),
            (A2, [a])
        ]
    assert result(7, 1, rules) == expected(3, 1, new_rules)

def test_empty_and_unreachable_2():
    def rules(V, T):
        S, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11 = V
        a, = T
        return [
            (S, [A1, A2]),
            (S, [A8, A9]),
            (A1, [A3, A4]),
            (A3, [a]),
            (A2, [A5, A6]),
            (A6, [a]),
            (A8, [a]),
            (A9, [a]),
            (A7, [a]),
            (A6, [A10, A11]),
            (A10, [a]),
            (A11, [a])
        ]
    def new_rules(V, T):
        S, A8, A9 = V
        a, = T
        return [
            (S, [A8, A9]),
            (A8, [a]),
            (A9, [a])
        ]
    assert result(12, 1, rules) == expected(3, 1, new_rules)
