from typing import TypeVar

import torch

from .automaton import Semiring
from .pushdown_automaton import WeightedPushdownAutomaton
from .fixed_point_iteration import fixed_point_iteration

T = TypeVar('T')

def top_down_pushdown_automaton_allsum(
    M: WeightedPushdownAutomaton[T],
    dtype: torch.dtype,
    device: torch.device
) -> T:
    semiring = M.semiring()
    num_states = M.num_states()
    stack_alphabet_size = M.stack_alphabet_size()
    # Coalesce the transition weights over all scanned strings.
    push_weights = semiring.zeros(
        (num_states, stack_alphabet_size, stack_alphabet_size, stack_alphabet_size, num_states),
        dtype,
        device
    )
    replace_weights = semiring.zeros(
        (num_states, stack_alphabet_size, stack_alphabet_size, num_states),
        dtype,
        device
    )
    pop_weights = semiring.zeros(
        (num_states, stack_alphabet_size, num_states),
        dtype,
        device
    )
    for t, weight in M.transition_weights():
        q = t.state_from
        r = t.state_to
        X = t.popped_symbol
        match t.pushed_symbols:
            case ():
                semiring.add_in_place(
                    semiring.transform_tensors(pop_weights, lambda x: x[q, X, r]),
                    weight
                )
            case (Y,):
                semiring.add_in_place(
                    semiring.transform_tensors(replace_weights, lambda x: x[q, X, Y, r]),
                    weight
                )
            case (Y, Z):
                semiring.add_in_place(
                    semiring.transform_tensors(push_weights, lambda x: x[q, X, Y, Z, r]),
                    weight
                )
            case _:
                raise ValueError('the pushdown automaton is not in top-down normal form')

    def func(x: T) -> T:
        return top_down_pushdown_automaton_allsum_step(
            semiring,
            x,
            push_weights,
            replace_weights,
            pop_weights
        )

    # Initialize the item weights to zero.
    zero = semiring.zeros(
        (num_states, stack_alphabet_size, num_states),
        dtype,
        device
    )
    # Run fixed point iteration to solve the system of nonlinear equations.
    return fixed_point_iteration(func, semiring.equal, zero)

def top_down_pushdown_automaton_allsum_step(
    semiring: Semiring[T],
    item_weights: T,
    push_weights: T,
    replace_weights: T,
    pop_weights: T
):
    # This is a top-down analogue of the bottom-up allsum formula given in
    # Section 6 of https://arxiv.org/pdf/2210.06884
    replace_term = semiring.sum(
        semiring.multiply(
            # pXYr1
            semiring.transform_tensors(
                replace_weights,
                lambda x: x.unsqueeze(4)
            ),
            # 11Yrq
            semiring.transform_tensors(
                item_weights,
                lambda x: x.transpose(0, 1)[None, None]
            )
        ),
        dims=(2, 3)
    )
    push_factor_1 = semiring.sum(
        semiring.multiply(
            # pXYZr1
            semiring.transform_tensors(
                push_weights,
                lambda x: x.unsqueeze(5)
            ),
            # 111Zrs
            semiring.transform_tensors(
                item_weights,
                lambda x: x.transpose(0, 1)[None, None, None]
            )
        ),
        dims=(3, 4)
    )
    push_term = semiring.sum(
        semiring.multiply(
            # pXYs1
            semiring.transform_tensors(
                push_factor_1,
                lambda x: x.unsqueeze(4)
            ),
            # 11Ysq
            semiring.transform_tensors(
                item_weights,
                lambda x: x.transpose(0, 1)[None, None]
            )
        ),
        dims=(2, 3)
    )
    return semiring.add(semiring.add(pop_weights, replace_term), push_term)
