from recognizers.string_sampling.weighted_language import (
    WeightedLanguage,
    LengthRestrictedWeightedLanguage,
    Parse,
    EmptyLanguageError
)


class KSparseMajority(WeightedLanguage):
    """The language of binary strings with more 1s than 0s among k fixed positions."""

    def alphabet_size(self):
        return 3

    def with_length_range(self, length, k, generator):
        return LengthRestrictedKSparseMajority(length, k, generator)


class LengthRestrictedKSparseMajority(LengthRestrictedWeightedLanguage):

    def __init__(self, length, k, generator):
        super().__init__()
        self.length = length
        assert self.length-1 >= k, "String length must be at least k"
        self.k = k
        self.idx = list(generator.sample(range(self.length-1), k))
        if self.length < 2:
            raise EmptyLanguageError

    def supports_log_probability(self):
        return False

    def supports_next_symbols(self):
        return False

    def supports_edit_distance(self):
        return False

    def sample(self, generator, include_log_probability, include_next_symbols):
        w = self._sample_string(generator)
        log_probability = None
        if include_next_symbols:
            next_symbols = self._string_to_next_symbols(w)
        else:
            next_symbols = None
        return w, Parse(log_probability, next_symbols)

    def sample_negative(self, generator):
        w = self._sample_negative_string(generator)
        return w, None, None

    def is_negative(self, s, include_edit_distance):
        return (self._parse_string(s) is None, None)

    def _parse_string(self, s):
        filtered_s = tuple(s[i] for i in self.idx)
        n_1s = sum(filtered_s)
        n_0s = len(filtered_s) - n_1s
        if n_1s > n_0s:
            return n_1s, n_0s
        return None

    def _sample_string(self, generator):
        n = self.length-1
        final_result = [generator.choice([0,1]) for _ in range(n)] + [2]
        k = self.k
        c_1 = generator.randint(k // 2 + 1, k)
        c_0 = k - c_1
        result = [0] * c_0 + [1] * c_1
        generator.shuffle(result)
        for i, bit in enumerate(result):
            final_result[self.idx[i]] = bit
        return tuple(final_result)

    def _sample_negative_string(self, generator):
        n = self.length-1
        final_result = [generator.choice([0,1]) for _ in range(n)] + [2]
        k = self.k
        c_1 = generator.randint(0, k // 2)
        c_0 = k - c_1
        result = [0] * c_0 + [1] * c_1
        generator.shuffle(result)
        for i, bit in enumerate(result):
            final_result[self.idx[i]] = bit
        return tuple(final_result)

    def _string_to_next_symbols(self, w):
        return None