import math

from recognizers.automata.reserved import ReservedSymbol
from recognizers.dataset_generation.weighted_language import (
    WeightedLanguage,
    LengthRestrictedWeightedLanguage,
    Parse
)

class AnBn(WeightedLanguage):

    def __init__(self):
        super().__init__()
        self.num_symbols = 2  # 'a' is 0, 'b' is 1
        self.log_num_symbols = math.log(self.num_symbols)

    def alphabet_size(self):
        return self.num_symbols

    def symbol_to_str(self, symbol):
        if symbol == 0:
            return "a"
        elif symbol == 1:
            return "b"
        else:
            raise ValueError

    def with_length_range(self, length_range):
        return LengthRestrictedAnBn(self, length_range)

    def is_positive(self, s):
        # Check if string is of the form a^n b^n
        if len(s) % 2 != 0:
            return False

        n = len(s) // 2
        return all(symbol == 0 for symbol in s[:n]) and all(symbol == 1 for symbol in s[n:])


class LengthRestrictedAnBn(LengthRestrictedWeightedLanguage):

    def __init__(self, parent, length_range):
        super().__init__()
        self.parent = parent
        self.min_length, self.max_length = length_range
        # For a^n b^n, length = 2n, so n = length/2
        self.min_n = math.ceil(self.min_length / 2)
        self.max_n = math.floor(self.max_length / 2)
        if self.min_n > self.max_n:
            raise ValueError
        n_range_size = self.max_n - self.min_n + 1
        self.n_log_prob = -math.log(n_range_size)

    def supports_log_probability(self):
        return True

    def supports_next_symbols(self):
        return True

    def supports_edit_distance(self):
        return False

    def sample(self, generator, include_log_probability, include_next_symbols):
        n = generator.randint(self.min_n, self.max_n)
        # Create a list of n 0s followed by n 1s
        s = tuple([0] * n + [1] * n)  # Convert to tuple to make it hashable
        
        if include_log_probability:
            log_probability = self._n_to_log_probability(n)
        else:
            log_probability = None
            
        if include_next_symbols:
            next_symbols = self._s_to_next_symbols(s)
        else:
            next_symbols = None
            
        return s, Parse(log_probability, next_symbols)

    def is_negative(self, s, include_edit_distance):
        return (not self._is_positive(s), None)

    def _is_positive(self, s):
        return self.min_length <= len(s) <= self.max_length and self.parent.is_positive(s)

    def _n_to_log_probability(self, n):
        return self.n_log_prob

    def _s_to_next_symbols(self, s):
        result = []
        
        for i in range(len(s) + 1):
            prefix = s[:i]
            a_count = sum(1 for symbol in prefix if symbol == 0)
            b_count = sum(1 for symbol in prefix if symbol == 1)
            
            if b_count == 0:
                # Haven't started the b's yet
                if a_count == 0:
                    # Empty string
                    next_symbol_set = [0, ReservedSymbol.EOS]
                else:
                    # Some a's, can add more a's or start b's
                    next_symbol_set = [0, 1]
            elif a_count > b_count:
                # In the middle of the string, can only add b's now
                next_symbol_set = [1]
            elif a_count == b_count and a_count > 0:
                # Equal number of a's and b's
                next_symbol_set = [ReservedSymbol.EOS]
            else:
                # More b's than a's - invalid state
                next_symbol_set = []
            
            result.append(next_symbol_set)
        
        return result
