import random
import string
import warnings
import numpy as np


class TreeGenerator:
    def __init__(self, depth: int, node_width: int, batch_size: int):
        self.depth = depth
        self.node_width = node_width
        self.batch_size = batch_size
        self.available_chars = string.ascii_uppercase[:26]
        
        # Pre-calculate maximum nodes needed
        max_nodes_needed = sum(node_width ** i for i in range(depth))
        
        # Validate and adjust parameters if needed
        if max_nodes_needed > 26:
            # Sacrifice width first to fit within 26 characters
            adjusted_width = self._find_max_width_for_depth(depth)
            if adjusted_width < 2:
                raise ValueError(f'Cannot create tree with depth {depth} within 26 character limit')
            
            if adjusted_width < node_width:
                warnings.warn(
                    f'Reducing node_width from {node_width} to {adjusted_width} '
                    f'to fit within 26 character limit for depth {depth}'
                )
                self.node_width = adjusted_width
        
        # Pre-generate character assignments for ultra-fast generation
        self._precompute_structures()

    @staticmethod
    def _find_max_width_for_depth(depth: int) -> int:
        for width in range(25, 1, -1):
            max_nodes = sum(width ** i for i in range(depth))
            if max_nodes <= 26:
                return width
        return 1
    
    def _precompute_structures(self):
        # Pre-compute all possible tree structures for given depth and width
        self.level_sizes = [self.node_width ** i for i in range(self.depth)]
        self.total_nodes = sum(self.level_sizes)
        
        # Pre-shuffle characters for randomness
        self.char_pool = list(self.available_chars[:self.total_nodes])
        
    def generate_trees(self) -> list[str]:
        trees = []
        
        for _ in range(self.batch_size):
            # Shuffle characters for this tree to ensure randomness
            chars = self.char_pool.copy()
            random.shuffle(chars)
            
            # Build tree structure using pre-computed layout
            tree_data = self._build_tree_fast(chars)
            tree_string = self._format_tree_string(tree_data)
            trees.append(tree_string)
        
        return trees
    
    def _build_tree_fast(self, chars: list[str]) -> tuple[list[tuple[str, str]], str, str, str]:
        parent_child_pairs = []
        char_idx = 0
        
        # Root is always the first character
        root = chars[char_idx]
        char_idx += 1
        
        # Build level by level for maximum speed
        current_level = [root]
        all_leaves = []
        
        for level in range(1, self.depth):
            next_level = []
            
            for parent in current_level:
                # Determine number of children for this parent
                children_count = min(self.node_width, len(chars) - char_idx)
                
                for _ in range(children_count):
                    if char_idx >= len(chars):
                        break
                    
                    child = chars[char_idx]
                    char_idx += 1
                    
                    parent_child_pairs.append((parent, child))
                    next_level.append(child)
                    
                    # If this is the last level, it's a leaf
                    if level == self.depth - 1:
                        all_leaves.append(child)
            
            current_level = next_level
            if not current_level:
                break
        
        # Select random target leaf and compute route
        if all_leaves:
            target_leaf = random.choice(all_leaves)
            route = self._find_route(parent_child_pairs, root, target_leaf)
        else:
            target_leaf = root
            route = root
        
        return parent_child_pairs, root, target_leaf, route

    @staticmethod
    def _find_route(pairs: list[tuple[str, str]], root: str, target: str) -> str:
        # Build adjacency map for fast route finding
        children_map = {}
        for parent, child in pairs:
            if parent not in children_map:
                children_map[parent] = []
            children_map[parent].append(child)
        
        # DFS to find route from root to target
        def find_path(current: str, path: str) -> str:
            if current == target:
                return path
            
            if current in children_map:
                for child in children_map[current]:
                    result = find_path(child, path + child)
                    if result:
                        return result
            
            return ''
        
        return find_path(root, root)

    @staticmethod
    def _format_tree_string(tree_data: tuple[list[tuple[str, str]], str, str, str]) -> str:
        pairs, root, target, route = tree_data
        
        # Format pairs as comma-separated string
        pair_strings = [f'{parent}{child}' for parent, child in pairs]
        pairs_str = ','.join(pair_strings)
        
        # Combine all components
        if pairs_str:
            return f'{pairs_str}[ROOT]{root}[TARGET]{target}[ROUTE]{route}'
        else:
            return f'[ROOT]{root}[TARGET]{target}[ROUTE]{route}'


class TreeTokenizer:
    def __init__(self, context_len: int):
        self.context_len = context_len
        
        # Build token mapping for ultra-fast lookup
        self.token_to_id = {
            '{PAD}': 0,
            '{EOS}': 1,
            ',': 2,
            '[ROOT]': 3,
            '[TARGET]': 4,
            '[ROUTE]': 5
        }
        
        # Add A-Z mappings
        for i, char in enumerate(string.ascii_uppercase):
            self.token_to_id[char] = 6 + i
        
        # Reverse mapping for debugging
        self.id_to_token = {v: k for k, v in self.token_to_id.items()}
        
        # Pre-compile token patterns for ultra-fast matching
        self._compile_patterns()
    
    def _compile_patterns(self):
        # Sort tokens by length (longest first) for greedy matching
        self.tokens = sorted(self.token_to_id.keys(), key=len, reverse=True)
        # Remove PAD and EOS from matching tokens since they're not in input
        self.tokens = [t for t in self.tokens if t not in ['{PAD}', '{EOS}']]
    
    def encode(self, tree_strings: list[str]) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        batch_size = len(tree_strings)

        # Pre-allocate arrays for maximum speed
        data = np.zeros((batch_size, self.context_len), dtype=np.int32)
        target = np.zeros((batch_size, self.context_len), dtype=np.int32)
        mask = np.zeros((batch_size, self.context_len), dtype=np.int32)
        indices = np.zeros(batch_size, dtype=np.int32)
        
        for batch_idx, tree_string in enumerate(tree_strings):
            # Tokenize single string with ultra-fast greedy matching
            tokens = self._tokenize_single(tree_string)
            
            # Add EOS token
            tokens.append(self.token_to_id['{EOS}'])
            
            # Calculate actual length
            actual_len = min(len(tokens), self.context_len)
            indices[batch_idx] = actual_len
            
            # Fill data array (truncate if necessary)
            data[batch_idx, :actual_len] = tokens[:actual_len]
            
            # Create target array (shifted by 1 for next-token prediction)
            if actual_len > 1:
                target[batch_idx, :actual_len-1] = tokens[1:actual_len]
            
            # Create mask (1 for meaningful tokens, 0 for padding and EOS)
            # EOS should not be included in mask since it's not a valid prediction target
            if actual_len > 1:
                mask[batch_idx, :actual_len-1] = 1
        
        return data, target, mask, indices
    
    def shift_encode(self, tree_strings: list[str], shift: int = 1) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Encode tree strings with multiple shifted targets for non-teacher-forcing training.
        
        Args:
            tree_strings: List of tree expression strings
            shift: Number of shift positions to create (default 1)
            
        Returns:
            data: Same as encode() - tokenized input sequences
            target: Shape (shift, batch_size, context_len) - targets shifted by 1, 2, ..., shift positions
            mask: Shape (shift, batch_size, context_len) - masks for each shifted target
            indices: Same as encode() - actual sequence length for each batch
        """
        batch_size = len(tree_strings)

        # Pre-allocate arrays
        data = np.zeros((batch_size, self.context_len), dtype=np.int32)
        target = np.zeros((shift, batch_size, self.context_len), dtype=np.int32)
        mask = np.zeros((shift, batch_size, self.context_len), dtype=np.int32)
        indices = np.zeros(batch_size, dtype=np.int32)
        
        for batch_idx, tree_string in enumerate(tree_strings):
            # Tokenize single string
            tokens = self._tokenize_single(tree_string)
            
            # Add EOS token
            tokens.append(self.token_to_id['{EOS}'])
            
            # Calculate actual length
            actual_len = min(len(tokens), self.context_len)
            indices[batch_idx] = actual_len
            
            # Fill data array (same as original encode)
            data[batch_idx, :actual_len] = tokens[:actual_len]
            
            # Create multiple shifted targets
            for shift_idx in range(shift):
                shift_amount = shift_idx + 1  # shift by 1, 2, 3, ...
                
                # Create target array shifted by shift_amount
                if actual_len > shift_amount:
                    target_len = actual_len - shift_amount
                    target[shift_idx, batch_idx, :target_len] = tokens[shift_amount:actual_len]
                
                # Create mask - should exclude EOS token for all shifts
                # The mask should only cover positions where valid targets exist
                if actual_len > shift_amount:
                    target_len = actual_len - shift_amount
                    mask[shift_idx, batch_idx, :target_len] = 1
        
        return data, target, mask, indices
    
    def _tokenize_single(self, text: str) -> list[int]:
        tokens = []
        i = 0
        
        while i < len(text):
            matched = False
            
            # Greedy matching - try longest tokens first
            for token in self.tokens:
                if text[i:].startswith(token):
                    tokens.append(self.token_to_id[token])
                    i += len(token)
                    matched = True
                    break
            
            if not matched:
                # Skip unknown characters (shouldn't happen with well-formed input)
                i += 1
        
        return tokens
    
    def decode(self, token_ids: np.ndarray) -> list[str]:
        """Decode token IDs back to strings.
        
        Args:
            token_ids: Array of shape (batch_size, seq_len) containing token IDs
            
        Returns:
            List of decoded strings
        """
        if token_ids.ndim == 1:
            # Single sequence
            token_ids = token_ids.reshape(1, -1)
        
        decoded_strings = []
        
        for batch_idx in range(token_ids.shape[0]):
            tokens = []
            
            for token_id in token_ids[batch_idx]:
                # Stop at EOS token or padding
                if token_id == self.token_to_id['{EOS}'] or token_id == self.token_to_id['{PAD}']:
                    break
                
                # Convert token ID to token string
                if token_id in self.id_to_token:
                    tokens.append(self.id_to_token[token_id])
            
            # Join tokens to form the decoded string
            decoded_string = ''.join(tokens)
            decoded_strings.append(decoded_string)
        
        return decoded_strings