import torch
from torch.testing import assert_close

from recognizers.automata.automaton import Symbol
from recognizers.automata.pushdown_automaton import (
    StackSymbol,
    PushdownAutomatonTransition,
    WeightedPushdownAutomatonContainer
)
from recognizers.automata.pushdown_automaton_allsum import (
    top_down_pushdown_automaton_allsum
)
from recognizers.automata.counting_semiring import CountingSemiring

def test_single_pop() -> None:
    semiring = CountingSemiring(4)
    dtype = torch.float32
    device = torch.device('cpu')
    M = WeightedPushdownAutomatonContainer[torch.Tensor](
        semiring=semiring,
        num_states=2,
        stack_alphabet_size=2
    )
    q1, q2 = M.states()
    a, = map(Symbol, range(1))
    A, B = map(StackSymbol, range(M.stack_alphabet_size()))

    def w(values):
        return torch.tensor(values, dtype=dtype, device=device)

    M.set_transition_weight(PushdownAutomatonTransition(q1, a, q2, A, ()), w([0, 1, 0, 0]))
    M.set_accept_state(q2)

    item_weights = top_down_pushdown_automaton_allsum(M, dtype, device)
    assert item_weights.size() == (2, 2, 2, 4)
    expected_item_weights = semiring.zeros((2, 2, 2), dtype, device)
    expected_item_weights[q1, A, q2] = w([0, 1, 0, 0])
    assert_close(item_weights, expected_item_weights)

def test_simple_chain() -> None:
    semiring = CountingSemiring(5)
    dtype = torch.float32
    device = torch.device('cpu')
    M = WeightedPushdownAutomatonContainer[torch.Tensor](
        semiring=semiring,
        num_states=5,
        stack_alphabet_size=4
    )
    q1, q2, q3, q4, q5 = M.states()
    a, b, c, d = map(Symbol, range(4))
    A, B, C, D = map(StackSymbol, range(M.stack_alphabet_size()))

    def w(values):
        return torch.tensor(values, dtype=dtype, device=device)

    M.set_transition_weight(PushdownAutomatonTransition(q1, a, q2, A, (B, C)), w([0, 1, 0, 0, 0]))
    M.set_transition_weight(PushdownAutomatonTransition(q2, b, q3, C, (D,)), w([0, 1, 0, 0, 0]))
    M.set_transition_weight(PushdownAutomatonTransition(q3, c, q4, D, ()), w([0, 1, 0, 0, 0]))
    M.set_transition_weight(PushdownAutomatonTransition(q4, d, q5, B, ()), w([0, 1, 0, 0, 0]))
    M.set_accept_state(q5)

    item_weights = top_down_pushdown_automaton_allsum(M, dtype, device)
    assert item_weights.size() == (5, 4, 5, 5)
    expected_item_weights = semiring.zeros((5, 4, 5), dtype, device)
    expected_item_weights[q1, A, q5] = w([0, 0, 0, 0, 1])
    expected_item_weights[q2, C, q4] = w([0, 0, 1, 0, 0])
    expected_item_weights[q3, D, q4] = w([0, 1, 0, 0, 0])
    expected_item_weights[q4, B, q5] = w([0, 1, 0, 0, 0])
    assert_close(item_weights, expected_item_weights)

def test_cycle() -> None:
    semiring = CountingSemiring(8)
    dtype = torch.float32
    device = torch.device('cpu')
    M = WeightedPushdownAutomatonContainer[torch.Tensor](
        semiring=semiring,
        num_states=2,
        stack_alphabet_size=1
    )
    q1, q2 = M.states()
    a, b = map(Symbol, range(2))
    A, = map(StackSymbol, range(M.stack_alphabet_size()))

    def w(values):
        return torch.tensor(values, dtype=dtype, device=device)

    M.set_transition_weight(PushdownAutomatonTransition(q1, a, q1, A, (A, A)), w([0, 0.1, 0, 0, 0, 0, 0, 0]))
    M.set_transition_weight(PushdownAutomatonTransition(q1, b, q2, A, ()), w([0, 0.9, 0, 0, 0, 0, 0, 0]))
    M.set_transition_weight(PushdownAutomatonTransition(q2, b, q2, A, ()), w([0, 0.1, 0, 0, 0, 0, 0, 0]))
    M.set_accept_state(q2)

    item_weights = top_down_pushdown_automaton_allsum(M, dtype, device)
    assert item_weights.size() == (2, 1, 2, 8)
    expected_item_weights = semiring.zeros((2, 1, 2), dtype, device)
    expected_item_weights[q1, A, q2] = w([0, 0.9, 0, 0.009, 0, 0.00009, 0, 0.0000009])
    expected_item_weights[q2, A, q2] = w([0, 0.1, 0, 0, 0, 0, 0, 0])
    assert_close(item_weights, expected_item_weights)
