from recognizers.string_sampling.weighted_language import (
    WeightedLanguage,
    LengthRestrictedWeightedLanguage,
    Parse,
    EmptyLanguageError
)


class RandomLanguage(WeightedLanguage):
    """The language of binary strings with first bit 1."""

    def alphabet_size(self):
        return 3

    def with_length_range(self, length):
        return LengthRestrictedRandomLanguage(length)


class LengthRestrictedRandomLanguage(LengthRestrictedWeightedLanguage):

    def __init__(self, length):
        super().__init__()
        self.length= length
        if self.length < 2:
            raise EmptyLanguageError
        self.positives = set()
        self.negatives = set()

    def supports_log_probability(self):
        return False

    def supports_next_symbols(self):
        return False

    def supports_edit_distance(self):
        return False

    def sample(self, generator, include_log_probability, include_next_symbols):
        w = self._sample_string(generator)
        log_probability = None
        if include_next_symbols:
            next_symbols = self._string_to_next_symbols(w)
        else:
            next_symbols = None
        return w, Parse(log_probability, next_symbols)

    def sample_negative(self, generator):
        w = self._sample_negative_string(generator)
        return w, None, None

    def is_negative(self, s, include_edit_distance):
        return (self._parse_string(s) is None, None)

    def _parse_string(self, s):
        if s in self.negatives:
            return None
        return 1

    def _sample_string(self, generator):
        while True:
            n = self.length
            res = tuple([generator.choice([0, 1]) for _ in range(n-1)]+[2])
            if res not in self.negatives:
                self.positives.add(res)
                return res

    def _sample_negative_string(self, generator):
        while True:
            n = self.length
            res = tuple([generator.choice([0, 1]) for _ in range(n-1)]+[2])
            if res not in self.positives:
                self.negatives.add(res)
                return res

    def _string_to_next_symbols(self, w):
        return None