import math
import random

import torch

from rayuela.base.semiring import Tropical as RayuelaTropical
from rayuela.base.symbol import Sym as RayuelaSym
from rayuela.base.state import State as RayuelaState
from rayuela.fsa.fsa import FSA

from recognizers.automata.automaton import Symbol
from recognizers.string_sampling.finite_automaton.weight_pushing import (
    WeightPushedFiniteAutomaton
)
from recognizers.dataset_generation.weighted_language import (
    LengthRestrictedWeightedLanguage,
    String,
    Parse,
    ValidNextSymbolSet,
    ValidNextSymbolList,
    EmptyLanguageError
)
from recognizers.dataset_generation.weight_pushed_language import (
    WeightPushedLanguage,
    LengthRestrictedWeightPushedLanguage
)
from recognizers.string_sampling.finite_automaton.edit_distance import (
    compute_edit_distance
)

def to_tropical_rayuela_fsa(automaton):
    result = FSA(R=RayuelaTropical)
    result.set_I(RayuelaState(automaton.initial_state))
    for (q, a), r in automaton.transitions.items():
        result.set_arc(RayuelaState(q), RayuelaSym(a), RayuelaState(r))
    for q in automaton.accept_states:
        result.set_F(RayuelaState(q))
    return result

class FiniteAutomatonLanguage(WeightPushedLanguage):

    def __init__(self,
        automaton: WeightPushedFiniteAutomaton,
        alphabet: list[str] | None,
        dtype: torch.dtype,
        device: torch.device
    ):
        super().__init__(alphabet)
        self.automaton = automaton
        self.tropical_rayuela_fsa = to_tropical_rayuela_fsa(automaton)
        self.dtype = dtype
        self.device = device

    def alphabet_size(self) -> int:
        return self.automaton.alphabet_size

    def uncached_sample(self,
        length: int,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool
    ) -> tuple[String, float | None, ValidNextSymbolSet | None]:
        return self.automaton.sample(
            length,
            generator,
            include_log_probability,
            include_next_symbols
        )

    def uncached_label(self, s: String) -> bool:
        return self.automaton.accepts(s)

    def uncached_edit_distance(self, s: String) -> int:
        return compute_edit_distance(self.tropical_rayuela_fsa, s, self.dtype, self.device)

    def max_length(self) -> int:
        return self.automaton.max_length

    def valid_lengths(self, length_range: tuple[int, int]) -> list[int]:
        return self.automaton.valid_lengths(length_range)

    def supports_log_probability(self) -> bool:
        return True

    def supports_next_symbols(self) -> bool:
        return True

    def supports_edit_distance(self) -> bool:
        return True

    def sample(self,
        length: int,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool
    ) -> tuple[String, float, ValidNextSymbolList]:
        return self.automaton.sample(
            length,
            generator,
            include_log_probability,
            include_next_symbols
        )

class LengthRestrictedFiniteAutomatonLanguage(LengthRestrictedWeightPushedLanguage):

    def sample(self,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool
    ) -> tuple[String, Parse]:
        length = self.sample_length(generator)
        s, log_probability, next_symbols = self.parent.sample(
            length,
            generator,
            include_log_probability,
            include_next_symbols
        )
        if log_probability is not None:
            # Renormalize the probability according to the length selected.
            log_probability = (
                log_probability
                - self.log_num_lengths
                - self.parent.automaton.total_length_weight(length)
            )
        return s, Parse(log_probability, next_symbols)

    def sample_length(self, generator):
        return generator.choice(self.valid_lengths)

    def is_negative(self,
        s: String,
        include_edit_distance: bool
    ) -> tuple[bool, int | None]:
        if not (self.min_length <= len(s) <= self.max_length):
            raise ValueError
        return self.parent.is_negative(s, include_edit_distance)
