from typing import TypeVar

import torch

from recognizers.automata.automaton import Semiring
from recognizers.grammars.context_free_grammar import WeightedContextFreeGrammar

from recognizers.automata.fixed_point_iteration import fixed_point_iteration

T = TypeVar('T')

def context_free_grammar_allsum(
    G: WeightedContextFreeGrammar[T],
    dtype: torch.dtype,
    device: torch.device
) -> T:
    semiring = G.semiring()
    num_nonterminals = G.num_nonterminals()
    # Coalesce the production weights over all generated strings.
    lexical_weights = semiring.zeros(
        (num_nonterminals, ),
        dtype,
        device
    )
    binary_weights = semiring.zeros(
        (num_nonterminals, num_nonterminals, num_nonterminals),
        dtype,
        device
    )
    for p, weight in G.production_weights():
        X = p.left_hand_side
        match p.right_hand_side:
            case ():
                semiring.add_in_place(
                    semiring.transform_tensors(lexical_weights, lambda x: x[G.variable_index(X),]),
                    weight
                )
            case (a,):
                semiring.add_in_place(
                    semiring.transform_tensors(lexical_weights, lambda x: x[G.variable_index(X),]),
                    weight
                )
            case (Y, Z):
                semiring.add_in_place(
                    semiring.transform_tensors(binary_weights, lambda x: x[G.variable_index(X), G.variable_index(Y), G.variable_index(Z)]),
                    weight
                )
            case _:
                raise ValueError('the grammar is not in Chomsky normal form')

    def func(x: T) -> T:
        return context_free_grammar_allsum_step(
            semiring,
            x,
            lexical_weights,
            binary_weights
        )

    # Initialize the item weights to zero.
    zero = semiring.zeros(
        (num_nonterminals,),
        dtype,
        device
    )
    # Run fixed point iteration to solve the system of nonlinear equations.
    return fixed_point_iteration(func, semiring.equal, zero)

def context_free_grammar_allsum_step(
    semiring: Semiring[T],
    item_weights: T,
    lexical_weights: T,
    binary_weights: T
):
    binary_factor_1 = semiring.sum(
        semiring.multiply(
            # XYZ
            binary_weights,
            # 11Z
            semiring.transform_tensors(
                item_weights,
                lambda x: x[None, None]
            )
        ),
        dims=(2,)
    )
    binary_term = semiring.sum(
        semiring.multiply(
            # XY
            binary_factor_1,
            # 1Y
            semiring.transform_tensors(
                item_weights,
                lambda x: x[None]
            )
        ),
        dims=(1,)
    )
    return semiring.add(lexical_weights, binary_term)
