import math

from recognizers.automata.reserved import ReservedSymbol
from recognizers.dataset_generation.weighted_language import (
    LengthRestrictedWeightedLanguage,
    Parse,
    WeightedLanguage
)

class DyckK(WeightedLanguage):
    """
    Implements the Dyck language with k types of parentheses.
    The Dyck language consists of balanced strings of parentheses.
    """

    def __init__(self, k=1):
        """
        Initialize the Dyck language with k types of parentheses.
        
        Args:
            k: Number of types of parentheses. Default is 1.
        """
        super().__init__()
        self.k = k
        self.num_symbols = 2 * k
        self.log_num_symbols = math.log(self.num_symbols)

    def alphabet_size(self):
        return self.num_symbols

    def symbol_to_str(self, symbol):
        if symbol < self.num_symbols:
            if symbol < self.k:  # Opening parentheses
                return "(" + str(symbol) + ")"
            else:  # Closing parentheses
                return ")" + str(symbol - self.k) + "("
        else:
            raise ValueError(f"Symbol {symbol} is outside the alphabet range")

    def with_length_range(self, length_range):
        return LengthRestrictedDyckK(self, length_range)

    def is_positive(self, s):
        """
        Check if a string belongs to the Dyck language.
        
        Args:
            s: A sequence of symbols representing parentheses.
            
        Returns:
            bool: True if the string is balanced, False otherwise.
        """
        stack = []
        
        for symbol in s:
            if symbol < self.k:  # Opening parenthesis
                stack.append(symbol)
            else:  # Closing parenthesis
                matching_open = symbol - self.k
                if not stack or stack[-1] != matching_open:
                    return False
                stack.pop()
                
        return len(stack) == 0


class LengthRestrictedDyckK(LengthRestrictedWeightedLanguage):
    def __init__(self, parent, length_range):
        super().__init__()
        self.parent = parent
        self.min_length, self.max_length = length_range
        
        # Dyck strings must have even length
        self.min_length = self.min_length if self.min_length % 2 == 0 else self.min_length + 1
        self.max_length = self.max_length if self.max_length % 2 == 0 else self.max_length - 1
        
        if self.min_length > self.max_length:
            raise ValueError("Empty length range after ensuring even lengths")
        
        length_range_size = (self.max_length - self.min_length) // 2 + 1
        self.length_log_prob = -math.log(length_range_size)

    def supports_log_probability(self):
        return True

    def supports_next_symbols(self):
        return True

    def supports_edit_distance(self):
        return False

    def sample(self, generator, include_log_probability, include_next_symbols):
        s = self._sample_balanced_string(generator)
        
        if include_log_probability:
            # This is a simplified approximation
            log_probability = self._length_to_log_probability(len(s))
        else:
            log_probability = None
            
        if include_next_symbols:
            next_symbols = self._s_to_next_symbols(s)
        else:
            next_symbols = None
            
        return s, Parse(log_probability, next_symbols)

    def is_negative(self, s, include_edit_distance):
        return (not self._is_positive(s), None)

    def _is_positive(self, s):
        return self.min_length <= len(s) <= self.max_length and self.parent.is_positive(s)

    def _sample_balanced_string(self, generator):
        # Choose an even length within the range
        length = 2 * generator.randint(self.min_length // 2, self.max_length // 2)
        
        # Generate a balanced string using a recursive approach
        balanced_list = self._generate_balanced(generator, length, self.parent.k)
        # Convert the list to a tuple to make it hashable
        return tuple(balanced_list)

    def _generate_balanced(self, generator, length, k):
        if length == 0:
            return []
        
        result = []
        stack = []
        
        for _ in range(length):
            # If stack is empty or we have space to add more open parentheses
            if not stack or (len(result) + len(stack) < length and generator.random() > 0.5):
                # Choose a random type of open parenthesis
                open_paren = generator.randrange(k)
                result.append(open_paren)
                stack.append(open_paren)
            else:
                # Close the most recent parenthesis
                open_paren = stack.pop()
                result.append(open_paren + k)  # Corresponding close parenthesis
                
        return result

    def _length_to_log_probability(self, length):
        # Simplified approximation
        return self.length_log_prob

    def _s_to_next_symbols(self, s):
        result = []
        for i in range(len(s) + 1):
            prefix = s[:i]
            next_symbol_set = []
            
            stack = []
            valid_prefix = True
            
            # Check if the prefix is valid
            for symbol in prefix:
                if symbol < self.parent.k:  # Opening parenthesis
                    stack.append(symbol)
                else:  # Closing parenthesis
                    matching_open = symbol - self.parent.k
                    if not stack or stack[-1] != matching_open:
                        valid_prefix = False
                        break
                    stack.pop()
            
            if valid_prefix:
                # If we've reached max length, only allow closing parentheses
                if i == self.max_length:
                    if stack:
                        for open_paren in set(stack):
                            next_symbol_set.append(open_paren + self.parent.k)
                    else:
                        next_symbol_set.append(ReservedSymbol.EOS)
                else:
                    # We can add any opening parenthesis if we have space
                    if i + len(stack) < self.max_length:
                        next_symbol_set.extend(range(self.parent.k))
                    
                    # We can add closing parenthesis if we have matching opening ones
                    if stack:
                        for open_paren in set(stack):
                            next_symbol_set.append(open_paren + self.parent.k)
                    
                    # If we have a complete balanced string, EOS is possible
                    if not stack and i >= self.min_length:
                        next_symbol_set.append(ReservedSymbol.EOS)
            
            result.append(next_symbol_set)
            
        return result
