import math

import torch

from recognizers.grammars.grammar import (
    Production,
    Nonterminal
)
from recognizers.grammars.context_free_grammar import (
    ContextFreeGrammar,
    WeightedContextFreeGrammar,
    WeightedContextFreeGrammarContainer
)
from recognizers.automata.reserved import ReservedSymbol
from recognizers.automata.log_counting_semiring import LogCountingSemiring

def get_production_weight(
    production: Production,
    log_probability: float,
    weight_size: int,
    dtype: torch.dtype,
    device: torch.device
) -> torch.Tensor:
    weight = torch.full((weight_size,), -math.inf, dtype=dtype, device=device)
    rhs = production.right_hand_side
    # TODO Use [] instead of EPSILON for the right side of epsilon rules.
    if len(rhs) == 1 and rhs[0] != ReservedSymbol.EPSILON:
        index = 1
    else:
        index = 0
    weight[index] = log_probability
    return weight

def lift_cnf_cfg_weights(
    M: ContextFreeGrammar,
    max_count: int,
    dtype: torch.dtype,
    device: torch.device
) -> WeightedContextFreeGrammar[torch.Tensor]:
    weight_size = max_count + 1
    result = WeightedContextFreeGrammarContainer[torch.Tensor](
        num_variables=M.num_variables(),
        num_terminals=M.num_terminals(),
        semiring=LogCountingSemiring(weight_size)
    )
    result.set_start_variable(M.start_variable())
    grouped_transitions: dict[tuple[Nonterminal], list[Production]] = {}
    for p in M.productions():
        key = p.left_hand_side
        if key not in grouped_transitions:
            grouped_transitions[key] = []
        grouped_transitions[key].append(p)
    for X, productions in grouped_transitions.items():
        num_actions = len(productions)
        if num_actions > 0:
            log_prob = -math.log(num_actions)
            for p in productions:
                result.set_production_weight(
                    p,
                    get_production_weight(p, log_prob, weight_size, dtype, device)
                )
    return result
