import dataclasses
import math
import random

from recognizers.automata.automaton import Symbol
from recognizers.dataset_generation.weighted_language import (
    WeightedLanguage,
    LengthRestrictedWeightedLanguage,
    Parse,
    String,
    ValidNextSymbolSet,
    EmptyLanguageError
)

@dataclasses.dataclass
class CacheEntry:
    label: bool
    edit_distance: int | None = None

class WeightPushedLanguage(WeightedLanguage):

    def __init__(self, alphabet: list[str] | None):
        super().__init__()
        self.cache = {}
        if alphabet is not None:
            self._symbol_to_str = lambda s: alphabet[s]
        else:
            self._symbol_to_str = str

    def symbol_to_str(self, symbol: Symbol) -> str:
        return self._symbol_to_str(symbol)

    def with_length_range(self,
        length_range: tuple[int, int]
    ) -> LengthRestrictedWeightedLanguage:
        return LengthRestrictedWeightPushedLanguage(self, length_range)

    def sample(self,
        length: int,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool
    ) -> tuple[String, float | None, ValidNextSymbolSet | None]:
        s, log_probability, next_symbols = self.uncached_sample(
            length,
            generator,
            include_log_probability,
            include_next_symbols
        )
        cache_entry = self.cache.get(s)
        if cache_entry is None:
            cache_entry = self.cache[s] = CacheEntry(label=True)
        return s, log_probability, next_symbols

    def is_negative(self,
        s: String,
        include_edit_distance: bool
    ) -> tuple[bool, int | None]:
        cache_entry = self.cache.get(s)
        if cache_entry is None:
            label = self.uncached_label(s)
            cache_entry = self.cache[s] = CacheEntry(label=label)
        if include_edit_distance and cache_entry.edit_distance is None:
            if cache_entry.label:
                cache_entry.edit_distance = 0
            else:
                cache_entry.edit_distance = self.uncached_edit_distance(s, self.dtype, self.device)
        return not cache_entry.label, cache_entry.edit_distance

    def uncached_sample(self,
        length: int,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool
    ) -> tuple[String, float | None, ValidNextSymbolSet | None]:
        raise NotImplementedError

    def uncached_label(self, s: String) -> bool:
        raise NotImplementedError

    def uncached_edit_distance(self, s: String) -> int:
        raise NotImplementedError

    def max_length(self) -> int:
        raise NotImplementedError

    def valid_lengths(self, length_range: tuple[int, int]) -> list[int]:
        raise NotImplementedError

    def supports_log_probability(self) -> bool:
        raise NotImplementedError

    def supports_next_symbols(self) -> bool:
        raise NotImplementedError

    def supports_edit_distance(self) -> bool:
        raise NotImplementedError

    def total_length_log_probability(self, length: int) -> float:
        raise NotImplementedError

class LengthRestrictedWeightPushedLanguage(LengthRestrictedWeightedLanguage):

    def __init__(self, parent: WeightPushedLanguage, length_range: tuple[int, int]):
        super().__init__()
        self.parent = parent
        self.min_length, self.max_length = length_range
        if self.max_length > parent.max_length():
            raise ValueError(
                f'this language is prepared for sampling strings up to length '
                f'{self.parent.automaton.max_length}, but {self.max_length} '
                f'is required'
            )
        # Figure out which lengths are possible within this length range.
        self.valid_lengths = parent.valid_lengths(length_range)
        if not self.valid_lengths:
            raise EmptyLanguageError(
                f'no lengths are valid within the length range {length_range}'
            )
        # Precompute the log probability of selecting a length.
        self.log_num_lengths = math.log(len(self.valid_lengths))

    def supports_log_probability(self) -> bool:
        return self.parent.supports_log_probability()

    def supports_next_symbols(self) -> bool:
        return self.parent.supports_next_symbols()

    def supports_edit_distance(self) -> bool:
        return self.parent.supports_edit_distance()

    def sample(self,
        generator: random.Random,
        include_log_probability: bool,
        include_next_symbols: bool
    ) -> tuple[String, Parse]:
        length = self.sample_length(generator)
        s, log_probability, next_symbols = self.parent.sample(
            length,
            generator,
            include_log_probability,
            include_next_symbols
        )
        if log_probability is not None:
            # Renormalize the probability according to the length selected.
            log_probability = (
                log_probability
                - self.log_num_lengths
                - self.parent.total_length_log_probability(length)
            )
        return s, Parse(log_probability, next_symbols)

    def sample_length(self, generator):
        return generator.choice(self.valid_lengths)

    def is_negative(self,
        s: String,
        include_edit_distance: bool
    ) -> tuple[bool, int | None]:
        if not (self.min_length <= len(s) <= self.max_length):
            raise ValueError
        return self.parent.is_negative(s, include_edit_distance)

    def compute_sensitivity(self,
                    generator: random.Random,
                    num_samples: int = 100000) -> float:
        if self.parent.automaton.alphabet_size == 2:
            num_flips = 0
            for _ in range(num_samples):
                s, _ = self.sample(generator, False, False)
                l = len(s)
                if l > 0:
                    i = generator.choice(range(l))
                    s_prime = tuple(list(s[:i]) + [0 if s[i]== 1 else 1] + list(s[i+1:]))
                    negative, _ = self.is_negative(s_prime, False)
                    if negative:
                        num_flips += 1
            self.sensitivity = num_flips / num_samples


