import dataclasses
import math
import random

import torch

from recognizers.automata.automaton import Symbol
from recognizers.dataset_generation.weighted_language import (
    WeightedLanguage,
    LengthRestrictedWeightedLanguage,
    String,
    Parse,
    ValidNextSymbolSet,
    EmptyLanguageError
)
from recognizers.grammars.context_free_grammar_recognition import CNFContextFreeGrammar
from .weight_pushing import NormalizedCountingContextFreeGrammar

@dataclasses.dataclass
class CacheEntry:
    label: bool
    edit_distance: int | None = None

class ContextFreeGrammarLanguage(WeightedLanguage):

    def __init__(self,
        grammar: NormalizedCountingContextFreeGrammar,
        alphabet: list[str] | None,
        dtype: torch.dtype,
        device: torch.device
    ):
        super().__init__()
        self.cache = {}
        self.grammar = grammar
        self.cnf_grammar = CNFContextFreeGrammar.from_context_free_grammar(grammar)
        self.dtype = dtype
        self.device = device
        if alphabet is not None:
            self._symbol_to_str = lambda s: alphabet[s]
        else:
            self._symbol_to_str = str

    def alphabet_size(self) -> int:
        return self.grammar.alphabet_size

    def symbol_to_str(self, symbol: Symbol) -> str:
        return self._symbol_to_str(symbol)

    def with_length_range(self,
        length_range: tuple[int, int]
    ) -> LengthRestrictedWeightedLanguage:
        return LengthRestrictedContextFreeGrammarLanguage(self, length_range)

    def sample(self,
        length: int,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool
    ) -> tuple[String, float | None, ValidNextSymbolSet | None]:
        s, log_probability, next_symbols = self.grammar.sample(
            length,
            generator,
            include_log_probability,
            include_next_symbols
        )
        cache_entry = self.cache.get(s)
        if cache_entry is None:
            cache_entry = self.cache[s] = CacheEntry(label=True)
        return s, log_probability, next_symbols

    def is_negative(self,
        s: String,
        include_edit_distance: bool
    ) -> tuple[bool, int | None]:
        cache_entry = self.cache.get(s)
        if cache_entry is None:
            label = self.uncached_label(s)
            cache_entry = self.cache[s] = CacheEntry(label=label)
        if include_edit_distance and cache_entry.edit_distance is None:
            if cache_entry.label:
                cache_entry.edit_distance = 0
            else:
                cache_entry.edit_distance = self.uncached_edit_distance(s, self.dtype, self.device)
        return not cache_entry.label, cache_entry.edit_distance

    def uncached_label(self, s: String) -> bool:
        if s == ():
            return self.grammar.accepts_epsilon()
        else:
            return self.cnf_grammar.recognize(s)

class LengthRestrictedContextFreeGrammarLanguage(LengthRestrictedWeightedLanguage):

    def __init__(self, parent: ContextFreeGrammarLanguage, length_range: tuple[int, int]):
        super().__init__()
        self.parent = parent
        self.min_length, self.max_length = length_range
        if self.max_length > self.parent.grammar.max_length:
            raise ValueError(
                f'the prepared automaton is prepared for sampling strings up '
                f'to length {self.parent.grammar.max_length}, but '
                f'{self.max_length} is required'
            )
        # Figure out which lengths are possible within this length range.
        self.valid_lengths = parent.grammar.valid_lengths(length_range)
        if not self.valid_lengths:
            raise EmptyLanguageError(
                f'no lengths are valid within the length range {length_range}'
            )
        # Precompute the log probability of selecting a length.
        self.log_num_lengths = math.log(len(self.valid_lengths))

    def supports_log_probability(self) -> bool:
        return False

    def supports_next_symbols(self) -> bool:
        return False

    def supports_edit_distance(self) -> bool:
        return False

    def sample(self,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool
    ) -> tuple[String, Parse]:
        length = self.sample_length(generator)
        s, log_probability, next_symbols = self.parent.sample(
            length,
            generator,
            include_log_probability,
            include_next_symbols
        )
        if log_probability is not None:
            # Renormalize the probability according to the length selected.
            log_probability = (
                log_probability
                - self.log_num_lengths
                - self.parent.automaton.total_length_weight(length)
            )
        return s, Parse(log_probability, next_symbols)

    def sample_length(self, generator):
        return generator.choice(self.valid_lengths)

    def is_negative(self,
        s: String,
        include_edit_distance: bool
    ) -> tuple[bool, int | None]:
        if not (self.min_length <= len(s) <= self.max_length):
            raise ValueError
        return self.parent.is_negative(s, include_edit_distance)
