import dataclasses
from collections.abc import Iterable
from typing import TypeVar

import torch

from .finite_automaton import WeightedFiniteAutomaton
from .lehmann import lehmann

T = TypeVar('T')

def inner(
    M: WeightedFiniteAutomaton[T],
    dtype: torch.dtype,
    device: torch.device
) -> T:
    semiring = M.semiring()
    num_states = M.num_states()
    A = semiring.zeros(size=(num_states, num_states), dtype=dtype, device=device)
    for t, weight in M.transition_weights():
        # Coalesce weights over different symbols.
        semiring.add_in_place(
            semiring.transform_tensors(A, lambda x: x[t.state_from, t.state_to]),
            weight
        )
    lehmann(A, semiring)
    return A

def backward_from_inner(
    M: WeightedFiniteAutomaton[T],
    A: T,
    dtype: torch.dtype,
    device: torch.device
) -> T:
    semiring = M.semiring()
    num_states = M.num_states()
    accept_weights = semiring.zeros((num_states,), dtype, device)
    for f, w in M.accept_weights():
        semiring.set_index(accept_weights, f, w)
    return semiring.sum(
        semiring.multiply(
            A,
            semiring.transform_tensors(accept_weights, lambda x: x[None])
        ),
        dims=(1,)
    )

def backward(
    M: WeightedFiniteAutomaton[T],
    dtype: torch.dtype,
    device: torch.device
) -> T:
    return backward_from_inner(M, inner(M, dtype, device), dtype, device)
