import random
import json
from typing import List, Dict, Tuple
from collections import defaultdict
import polars as pl
import tqdm


class CFGRule:
    """Represents a CFG production rule."""

    def __init__(self, lhs: int, rhs: List[int]):
        self.lhs = lhs  # Left-hand side symbol
        self.rhs = rhs  # Right-hand side symbols list

    def __repr__(self):
        return f"{self.lhs} -> {' '.join(map(str, self.rhs))}"

    def to_dict(self) -> Dict:
        """Convert rule to dictionary for JSON serialization."""
        return {
            'lhs': self.lhs,
            'rhs': self.rhs
        }

    @classmethod
    def from_dict(cls, data: Dict) -> 'CFGRule':
        """Create rule from dictionary (JSON deserialization)."""
        return cls(data['lhs'], data['rhs'])


class GenerationTracker:
    """Tracks the generation process and boundary information."""

    def __init__(self, levels: int):
        self.levels = levels
        self.reset()

    def reset(self):
        """Reset tracker for new string generation."""
        self.generations = []  # Store symbols at each level
        self.boundaries = []  # Store boundary info for each level
        self.ancestors = []  # Store ancestor info for each level
        self.parent_maps = []  # Store parent-child mappings

    def add_generation(self, symbols: List[int], boundaries: List[int],
                       ancestors: List[int], parent_map: Dict[int, int]):
        """Add information for a generation level."""
        self.generations.append(symbols.copy())
        self.boundaries.append(boundaries.copy())
        self.ancestors.append(ancestors.copy())
        self.parent_maps.append(parent_map.copy())

    def get_final_boundaries(self) -> Dict[str, List[int]]:
        """Get boundary information for the final (terminal) string."""
        if not self.boundaries:
            return {}

        final_length = len(self.generations[-1])
        result = {}

        # For each level, create boundary markers
        for level in range(len(self.boundaries)):
            level_boundaries = [0] * final_length

            if level < len(self.boundaries):
                # Map boundaries from this level to final positions
                boundary_positions = self._map_boundaries_to_final(level)
                for pos in boundary_positions:
                    if pos < final_length:
                        level_boundaries[pos] = 1

            result[f'b{level + 1}'] = level_boundaries

        return result

    def get_final_ancestors(self) -> Dict[str, List[int]]:
        """Get ancestor information for the final (terminal) string."""
        if not self.ancestors:
            return {}

        final_length = len(self.generations[-1])
        result = {}

        # For each level, create ancestor info
        for level in range(len(self.ancestors)):
            level_ancestors = [0] * final_length

            if level < len(self.ancestors):
                # Map ancestors from this level to final positions
                ancestor_info = self._map_ancestors_to_final(level)
                for pos, ancestor in ancestor_info.items():
                    if pos < final_length:
                        level_ancestors[pos] = ancestor

            result[f's{level + 1}'] = level_ancestors

        return result

    def _map_boundaries_to_final(self, level: int) -> List[int]:
        """Map boundary positions from a specific level to final positions."""
        if level >= len(self.boundaries):
            return []

        boundary_positions = []
        for i, is_boundary in enumerate(self.boundaries[level]):
            if is_boundary == 1:
                # For simplification, map proportionally
                final_pos = int(i * len(self.generations[-1]) / len(self.boundaries[level]))
                boundary_positions.append(final_pos)

        return boundary_positions

    def _map_ancestors_to_final(self, level: int) -> Dict[int, int]:
        """Map ancestor information from a specific level to final positions."""
        if level >= len(self.ancestors):
            return {}

        ancestor_map = {}
        for i, ancestor in enumerate(self.ancestors[level]):
            # Map proportionally to final positions
            final_pos = int(i * len(self.generations[-1]) / len(self.ancestors[level]))
            if final_pos < len(self.generations[-1]):
                ancestor_map[final_pos] = ancestor

        return ancestor_map


class CFGGenerator:
    """Context-Free Grammar generator based on the paper's methodology."""

    def __init__(self, levels: int = 7, sizes: List[int] = None, rule_lengths: List[int] = None):
        """
        Initialize CFG generator.

        Args:
            levels: Number of levels in the CFG (L)
            sizes: Number of symbols at each level [NT1, NT2, ..., NTL]
            rule_lengths: Allowed rule lengths (2 or 3 in the paper)
        """
        self.levels = levels
        self.sizes = sizes or [1] + [3] * (levels - 1)  # Default cfg3 structure
        self.rule_lengths = rule_lengths or [2, 3]

        if len(self.sizes) != levels:
            raise ValueError(f"Sizes list must have {levels} elements")

        # Create symbol mappings
        self.symbols_by_level = []
        symbol_id = 1

        for level in range(levels):
            level_symbols = list(range(symbol_id, symbol_id + self.sizes[level]))
            self.symbols_by_level.append(level_symbols)
            symbol_id += self.sizes[level]

        self.root_symbol = self.symbols_by_level[0][0]
        self.terminal_symbols = set(self.symbols_by_level[-1])

        # Initialize empty rules - will be populated by generate_rules() or load_rules()
        self.rules = []
        self.rules_by_symbol = {}

        # Add generation tracker
        self.tracker = GenerationTracker(levels)

    def _generate_rules(self) -> List[CFGRule]:
        """Generate CFG rules for all levels except the last (terminal) level."""
        rules = []

        for level in range(self.levels - 1):  # Skip terminal level
            current_symbols = self.symbols_by_level[level]
            next_symbols = self.symbols_by_level[level + 1]

            for symbol in current_symbols:
                # Generate rules for this symbol
                degree = random.randint(2, 4)  # Number of rules per symbol

                for _ in range(degree):
                    rule_length = random.choice(self.rule_lengths)
                    rhs = random.choices(next_symbols, k=rule_length)
                    rules.append(CFGRule(symbol, rhs))

        return rules

    def _index_rules_by_symbol(self) -> Dict[int, List[CFGRule]]:
        """Create an index of rules by left-hand side symbol."""
        index = defaultdict(list)
        for rule in self.rules:
            index[rule.lhs].append(rule)
        return dict(index)

    def generate_rules(self):
        """Generate new rules and index them."""
        self.rules = self._generate_rules()
        self.rules_by_symbol = self._index_rules_by_symbol()

    def save_rules(self, filename: str):
        """
        Save CFG rules to a JSON file.

        Args:
            filename: Path to the JSON file to save
        """
        cfg_data = {
            'metadata': {
                'levels': self.levels,
                'sizes': self.sizes,
                'rule_lengths': self.rule_lengths,
                'root_symbol': self.root_symbol,
                'terminal_symbols': list(self.terminal_symbols),
                'symbols_by_level': self.symbols_by_level
            },
            'rules': [rule.to_dict() for rule in self.rules]
        }

        with open(filename, 'w') as f:
            json.dump(cfg_data, f, indent=2)

        print(f"CFG rules saved to {filename}")

    def load_rules(self, filename: str):
        """
        Load CFG rules from a JSON file.

        Args:
            filename: Path to the JSON file to load
        """
        try:
            with open(filename, 'r') as f:
                cfg_data = json.load(f)

            # Restore metadata
            metadata = cfg_data['metadata']
            self.levels = metadata['levels']
            self.sizes = metadata['sizes']
            self.rule_lengths = metadata['rule_lengths']
            self.root_symbol = metadata['root_symbol']
            self.terminal_symbols = set(metadata['terminal_symbols'])
            self.symbols_by_level = metadata['symbols_by_level']

            # Restore rules
            self.rules = [CFGRule.from_dict(rule_data) for rule_data in cfg_data['rules']]
            self.rules_by_symbol = self._index_rules_by_symbol()

            print(f"CFG rules loaded from {filename}")
            print(f"Loaded {len(self.rules)} rules")

        except FileNotFoundError:
            print(f"Error: File {filename} not found")
            raise
        except json.JSONDecodeError:
            print(f"Error: Invalid JSON format in {filename}")
            raise
        except KeyError as e:
            print(f"Error: Missing required field {e} in JSON file")
            raise

    @classmethod
    def from_json(cls, filename: str) -> 'CFGGenerator':
        """
        Create a CFGGenerator instance from a saved JSON file.

        Args:
            filename: Path to the JSON file to load

        Returns:
            CFGGenerator instance with loaded rules
        """
        # Create a temporary instance to load the metadata
        temp_instance = cls()
        temp_instance.load_rules(filename)

        # Create the proper instance with loaded parameters
        instance = cls(
            levels=temp_instance.levels,
            sizes=temp_instance.sizes,
            rule_lengths=temp_instance.rule_lengths
        )

        # Copy the loaded rules
        instance.rules = temp_instance.rules
        instance.rules_by_symbol = temp_instance.rules_by_symbol

        return instance

    def get_rule_stats(self) -> Dict:
        """Get statistics about the current rules."""
        stats = {
            'total_rules': len(self.rules),
            'rules_by_length': defaultdict(int),
            'rules_by_symbol': defaultdict(int)
        }

        for rule in self.rules:
            stats['rules_by_length'][len(rule.rhs)] += 1
            stats['rules_by_symbol'][rule.lhs] += 1

        return dict(stats)

    def generate_string_with_tracking(self, max_length: int = 500) -> Tuple[List[int], Dict]:
        """
        Generate a string from the CFG with full tracking information.

        Returns:
            Tuple of (terminal_string, tracking_info)
            tracking_info contains boundaries and ancestors for each level
        """
        if not self.rules:
            raise ValueError("No rules available. Call generate_rules() or load_rules() first.")

        # Reset tracker for new generation
        self.tracker.reset()

        # Start with root symbol
        current_sequence = [self.root_symbol]
        current_boundaries = [1]  # Root is always a boundary
        current_ancestors = [self.root_symbol]
        parent_map = {}

        # Track the initial state (level 0)
        self.tracker.add_generation(current_sequence, current_boundaries, current_ancestors, parent_map)

        # Generate through all levels except terminal
        for level in range(self.levels - 1):
            next_sequence = []
            next_boundaries = []
            next_ancestors = []
            next_parent_map = {}

            for pos, symbol in enumerate(current_sequence):
                if symbol in self.rules_by_symbol:
                    # Choose a random rule for this symbol
                    rule = random.choice(self.rules_by_symbol[symbol])

                    # Add the rule's RHS to the next sequence
                    start_pos = len(next_sequence)

                    for i, child in enumerate(rule.rhs):
                        next_sequence.append(child)
                        next_ancestors.append(symbol)  # Parent symbol as ancestor

                        # Boundary: first symbol of a rule expansion is a boundary
                        next_boundaries.append(1 if i == 0 else 0)

                        # Map child position to parent position
                        next_parent_map[start_pos + i] = pos

                else:
                    # Terminal symbol, keep as is
                    next_sequence.append(symbol)
                    next_boundaries.append(current_boundaries[pos])
                    next_ancestors.append(current_ancestors[pos])
                    next_parent_map[len(next_sequence) - 1] = pos

            # Store this generation's information
            self.tracker.add_generation(next_sequence, next_boundaries, next_ancestors, next_parent_map)

            current_sequence = next_sequence
            current_boundaries = next_boundaries
            current_ancestors = next_ancestors

            # Stop if we've reached terminal symbols or max length
            if all(s in self.terminal_symbols for s in current_sequence):
                break
            if len(current_sequence) > max_length:
                break

        # Prepare tracking information
        tracking_info = {
            'generations': self.tracker.generations.copy(),
            'boundaries': self.tracker.get_final_boundaries(),
            'ancestors': self.tracker.get_final_ancestors(),
            'parent_maps': self.tracker.parent_maps.copy()
        }

        return current_sequence, tracking_info

    def generate_string(self, max_length: int = 500) -> List[int]:
        """
        Generate a string from the CFG (original method for backward compatibility).
        """
        string, _ = self.generate_string_with_tracking(max_length)
        return string

    def generate_dataset(self, num_samples: int = 1000, include_tracking: bool = True) -> List[Dict]:
        """Generate a dataset of CFG strings with optional tracking information."""
        if not self.rules:
            raise ValueError("No rules available. Call generate_rules() or load_rules() first.")

        dataset = []

        pbar = tqdm.tqdm(total=num_samples, desc="Generating dataset")
        for _ in range(num_samples):
            pbar.update(1)

            if include_tracking:
                string, tracking_info = self.generate_string_with_tracking()

                sample = {
                    'text': ' '.join(map(str, string)),
                    'length': len(string),
                }

                # Add boundary information
                boundaries = tracking_info['boundaries']
                for level_name, boundary_list in boundaries.items():
                    sample[level_name] = ' '.join(map(str, boundary_list))

                # Add ancestor information
                ancestors = tracking_info['ancestors']
                for level_name, ancestor_list in ancestors.items():
                    sample[level_name] = ' '.join(map(str, ancestor_list))

            else:
                string = self.generate_string()
                sample = {
                    'text': ' '.join(map(str, string)),
                    'length': len(string),
                }

            dataset.append(sample)

        pbar.close()
        return dataset


class CFGValidator:
    """Validates if a sequence complies with CFG rules using dynamic programming."""

    def __init__(self, cfg_generator: CFGGenerator):
        self.cfg = cfg_generator
        self.memo = {}

    def can_derive(self, start: int, end: int, symbol: int, string: List[int]) -> bool:
        """
        Check if symbol can derive string[start:end+1] using dynamic programming.

        Args:
            start: Start index in string
            end: End index in string (inclusive)
            symbol: NT symbol to check
            string: The input string

        Returns:
            True if symbol can derive the substring, False otherwise
        """
        # Memoization key
        key = (start, end, symbol)
        if key in self.memo:
            return self.memo[key]

        # Base case: single terminal symbol
        if start == end:
            if symbol in self.cfg.terminal_symbols:
                result = (symbol == string[start])
            else:
                # Check if any rule can derive this single terminal
                result = False
                if symbol in self.cfg.rules_by_symbol:
                    for rule in self.cfg.rules_by_symbol[symbol]:
                        if len(rule.rhs) == 1 and rule.rhs[0] == string[start]:
                            result = True
                            break
            self.memo[key] = result
            return result

        # If symbol is terminal, yet we have multiple positions, it's impossible
        if symbol in self.cfg.terminal_symbols:
            self.memo[key] = False
            return False

        # Check all possible rules for this symbol
        result = False
        if symbol in self.cfg.rules_by_symbol:
            for rule in self.cfg.rules_by_symbol[symbol]:
                if self._can_derive_with_rule(start, end, rule, string):
                    result = True
                    break

        self.memo[key] = result
        return result

    def _can_derive_with_rule(self, start: int, end: int, rule: CFGRule, string: List[int]) -> bool:
        """Check if a specific rule can derive string[start:end+1]."""
        if len(rule.rhs) == 1:
            # Single symbol rule
            return self.can_derive(start, end, rule.rhs[0], string)

        # Multiple symbols - try all possible splits
        rhs_len = len(rule.rhs)

        # Try all "ways" to split the range among the RHS symbols
        def try_splits(pos: int, rule_index: int) -> bool:
            if rule_index == rhs_len:
                return pos == end + 1

            if rule_index == rhs_len - 1:
                # Last symbol must cover remaining positions
                return self.can_derive(pos, end, rule.rhs[rule_index], string)

            # Try different end positions for current symbol
            for split_end in range(pos, end + 1):
                if self.can_derive(pos, split_end, rule.rhs[rule_index], string):
                    if try_splits(split_end + 1, rule_index + 1):
                        return True
            return False

        return try_splits(start, 0)

    def validate_string(self, string) -> bool:
        """
        Validate if a string belongs to the CFG language L(G).

        Args:
            string: List of terminal symbols

        Returns:
            True if string ∈ L(G), False otherwise
        """
        if not string:
            return False

        if isinstance(string, str):  # A space seperated string
            string = [int(num) for num in string.split()]

        # Clear memoization for fresh validation
        self.memo.clear()

        # Check if root symbol can derive the entire string
        return self.can_derive(0, len(string) - 1, self.cfg.root_symbol, string)


# Example usage
if __name__ == "__main__":
    # Create a cfg3f-style CFG
    cfg = CFGGenerator.from_json('cfg3f_rules.json')

    # Generate dataset with tracking
    print("\nGenerating dataset with tracking...")
    train_dataset_0 = cfg.generate_dataset(num_samples=50000, include_tracking=True)
    train_dataset_1 = cfg.generate_dataset(num_samples=50000, include_tracking=True)
    train_dataset_2 = cfg.generate_dataset(num_samples=50000, include_tracking=True)
    train_dataset_3 = cfg.generate_dataset(num_samples=50000, include_tracking=True)
    train_dataset_4 = cfg.generate_dataset(num_samples=50000, include_tracking=True)

    train_dataset_5 = cfg.generate_dataset(num_samples=50000, include_tracking=True)
    train_dataset_6 = cfg.generate_dataset(num_samples=50000, include_tracking=True)
    train_dataset_7 = cfg.generate_dataset(num_samples=50000, include_tracking=True)
    train_dataset_8 = cfg.generate_dataset(num_samples=50000, include_tracking=True)
    train_dataset_9 = cfg.generate_dataset(num_samples=50000, include_tracking=True)

    test_dataset_0 = cfg.generate_dataset(num_samples=50000, include_tracking=True)
    test_dataset_1 = cfg.generate_dataset(num_samples=50000, include_tracking=True)

    # Convert to DataFrame and save
    train_dataset_0 = pl.DataFrame(train_dataset_0)
    train_dataset_1 = pl.DataFrame(train_dataset_1)
    train_dataset_2 = pl.DataFrame(train_dataset_2)
    train_dataset_3 = pl.DataFrame(train_dataset_3)
    train_dataset_4 = pl.DataFrame(train_dataset_4)

    train_dataset_5 = pl.DataFrame(train_dataset_5)
    train_dataset_6 = pl.DataFrame(train_dataset_6)
    train_dataset_7 = pl.DataFrame(train_dataset_7)
    train_dataset_8 = pl.DataFrame(train_dataset_8)
    train_dataset_9 = pl.DataFrame(train_dataset_9)

    test_dataset_0 = pl.DataFrame(test_dataset_0)
    test_dataset_1 = pl.DataFrame(test_dataset_1)

    # Save to parquet
    train_dataset_0.write_parquet('../data/cfg_train_0.parquet')
    train_dataset_1.write_parquet('../data/cfg_train_1.parquet')
    train_dataset_2.write_parquet('../data/cfg_train_2.parquet')
    train_dataset_3.write_parquet('../data/cfg_train_3.parquet')
    train_dataset_4.write_parquet('../data/cfg_train_4.parquet')

    train_dataset_5.write_parquet('../data/cfg_train_5.parquet')
    train_dataset_6.write_parquet('../data/cfg_train_6.parquet')
    train_dataset_7.write_parquet('../data/cfg_train_7.parquet')
    train_dataset_8.write_parquet('../data/cfg_train_8.parquet')

    test_dataset_0.write_parquet('../data/cfg_test_0.parquet')
    test_dataset_1.write_parquet('../data/cfg_test_1.parquet')