"""Math-specific commands for the REPL interface."""

from typing import List, Dict, Any, Optional
import ast
from rich.table import Table
from rich import box
from frame.repl.commands.base import BaseCommand
from frame.repl.interface.protocols import CommandResult, REPLContext
from frame.repl.core.events import Event, EventType
from frame.knowledge_base.entities import Concept, Conjecture
from frame.knowledge_base.knowledge_graph import NodeType, ConstructionStep
from frame.repl.core.exceptions import EntityNotFoundError, InvalidArgumentError
from frame.environments.ground_truth_entities import is_ground_truth_entity, update_entity_implementation, normalize_name

class ApplyRuleCommand(BaseCommand):
    """Apply a production rule to create new concepts/conjectures.
    
    Usage: apply <rule_name> <input_entities...> [param1=value1 param2=value2 ...]
    
    Examples:
        apply map_iterate successor
        apply match multiplication indices_to_match=[0,1]
        apply exists multiplication indices_to_quantify=[0]
    """
    
    @property
    def name(self) -> str:
        """Get the command name."""
        return "apply"
    
    def execute(self, state: REPLContext, *args) -> CommandResult:
        """Execute the rule application command."""
        if not args:
            return CommandResult(False, "Usage: apply <rule_name> <input_entities...> [params...]")
            
        rule_name = args[0]
        rule = state.get_rule_by_name(rule_name)
        if not rule:
            state.emit_event(Event(
                type=EventType.ERROR,
                data={'message': f"Unknown rule: {rule_name}"}
            ))
            return CommandResult(False, f"Unknown rule: {rule_name}")
            
        # Parse input entities and track their IDs
        input_entities = []
        input_entity_ids = []
        current_arg = 1
        while current_arg < len(args) and '=' not in args[current_arg]:
            entity_name = args[current_arg]
            entity = state.get_entity_by_name(entity_name)
            if not entity:
                return CommandResult(False, f"Unknown entity: {entity_name}")
            input_entities.append(entity)
            # Get and store the entity ID
            entity_id = state.get_entity_id(entity_name)
            if entity_id:
                input_entity_ids.append(entity_id)
            current_arg += 1
            
        # Parse parameters
        params = {}
        for arg in args[current_arg:]:
            if '=' not in arg:
                return CommandResult(False, f"Invalid parameter format: {arg}")
            key, value = arg.split('=', 1)
            
            # Try to safely evaluate the value string
            try:
                # Use ast.literal_eval for safe evaluation of literals
                parsed_value = ast.literal_eval(value)
                params[key] = parsed_value
            except (ValueError, SyntaxError):
                # If literal_eval fails, keep as string
                params[key] = value
                
        try:
            # Apply the rule
            result = rule.apply(*input_entities, **params)
            # Get the name of the result if it has one
            result_name = result.name if hasattr(result, 'name') else str(result)
            
            # Create construction step
            from datetime import datetime
            construction_step = ConstructionStep(
                rule=rule,
                input_node_ids=input_entity_ids,
                parameters=params,
                timestamp=datetime.now()
            )
            
            # Add the new entity to the graph with construction information
            if isinstance(result, Concept):
                entity_id = state.graph.add_concept(result, construction_step)
            elif isinstance(result, Conjecture):
                entity_id = state.graph.add_conjecture(result, construction_step)
            else:
                return CommandResult(False, f"Rule produced unknown entity type: {type(result)}")
                
            # Add construction edges
            for input_id in input_entity_ids:
                state.graph.add_construction_edge(input_id, entity_id)
                
            print("Added to graph; checking for ground truth entity")
            # Check if this is a ground truth entity and update it
            if hasattr(result, 'name'):
                normalized_name = normalize_name(result.name)
                if is_ground_truth_entity(normalized_name):
                    old_name = result.name
                    if update_entity_implementation(result, normalized_name, update_implementation=True):
                        print(f"Automatically updated (ground truth) entity name from {old_name} to {result.name}")
                        # Emit a message about the automatic rename
                        state.emit_event(Event(
                            type=EventType.COMMAND_EXECUTED,
                            data={'command': 'auto_rename'},
                            message=f"Automatically renamed '{old_name}' to '{result.name}'"
                        ))
                        result_name = result.name  # Update the result name for the NEW_ENTITY event
                
            # Emit NEW_ENTITY event
            state.emit_event(Event(
                type=EventType.NEW_ENTITY,
                data={'entity': result, 'name': result_name, 'entity_id': entity_id}
            ))
            
            return CommandResult(True, f"Created new entity: {result_name} (ID: {entity_id})")
        except Exception as e:
            return CommandResult(False, f"Error applying rule: {str(e)}")
            
    def get_completions(self, state: REPLContext, partial: str) -> List[str]:
        """Get completions for rule names and entity names."""
        parts = partial.strip().split()
        
        if len(parts) <= 1:
            # Complete rule names
            rules = [rule.name for rule in state.available_rules]
            if not partial:
                return rules
            return [r for r in rules if r.startswith(partial)]
            
        # Complete entity names
        entities = []
        for concept in state.get_concepts():
            if hasattr(concept, 'name'):
                entities.append(concept.name)
        for conjecture in state.get_conjectures():
            if hasattr(conjecture, 'name'):
                entities.append(conjecture.name)
                
        return [e for e in entities if e.startswith(parts[-1])]

class InspectCommand(BaseCommand):
    """Command to inspect entities in the knowledge graph."""
    
    MAX_EXAMPLES = 5  # Maximum number of examples/nonexamples to show
    
    def validate_args(self, args: tuple, kwargs: dict) -> None:
        """Validate inspect command arguments."""
        # Check if entity_name is provided either as positional or named arg
        if not args and 'entity_name' not in kwargs:
            raise InvalidArgumentError(
                "Entity name or ID is required",
                self.name,
                "entity_name"
            )
    
    def execute(self, context: REPLContext, *args, **kwargs) -> CommandResult:
        """Execute the inspect command.
        
        Args:
            context: The REPL context
            *args: Positional arguments (entity_name can be passed here)
            **kwargs: Named arguments (entity_name if passed as named arg)
            
        Returns:
            CommandResult with inspection information
        """
        # Get entity_name from either positional args or kwargs
        entity_name = args[0] if args else kwargs.get('entity_name')
        if not entity_name:
            raise InvalidArgumentError("Entity name or ID is required", self.name, "entity_name")
        
        # Get the entity from the knowledge graph
        entity = context.get_entity_by_name(entity_name)
        if not entity:
            raise EntityNotFoundError(entity_name, self.name)
        
        # Create a table for the output
        table = Table(box=box.ROUNDED)
        table.add_column("Property", style="cyan")
        table.add_column("Value", style="green")
        
        # Add basic information
        table.add_row("Name", entity.name if hasattr(entity, 'name') else str(entity))
        table.add_row("Type", entity.__class__.__name__)
        entity_id = context.get_entity_id(entity_name)
        if entity_id:
            table.add_row("ID", str(entity_id))
            
        # Add example structure information if available
        if hasattr(entity, 'examples') and hasattr(entity.examples, 'example_structure'):
            structure = entity.examples.example_structure
            table.add_row("Concept Type", str(structure.concept_type))
            table.add_row("Component Types", ", ".join(str(t) for t in structure.component_types))
            if structure.input_arity is not None:
                table.add_row("Input Arity", str(structure.input_arity))
        
        # Add examples if available (limited to MAX_EXAMPLES)
        if hasattr(entity, 'examples'):
            examples = list(entity.examples.get_examples())  # Convert set to list
            if examples:
                # Sort examples for consistent display
                examples.sort(key=str)
                formatted_examples = [self._format_example(e) for e in examples[:self.MAX_EXAMPLES]]
                if len(examples) > self.MAX_EXAMPLES:
                    formatted_examples.append(f"... ({len(examples) - self.MAX_EXAMPLES} more)")
                table.add_row("Examples", "\n".join(formatted_examples))
        
            # Add nonexamples if available (limited to MAX_EXAMPLES)
            nonexamples = list(entity.examples.get_nonexamples())  # Convert set to list
            if nonexamples:
                # Sort nonexamples for consistent display
                nonexamples.sort(key=str)
                formatted_nonexamples = [self._format_example(e) for e in nonexamples[:self.MAX_EXAMPLES]]
                if len(nonexamples) > self.MAX_EXAMPLES:
                    formatted_nonexamples.append(f"... ({len(nonexamples) - self.MAX_EXAMPLES} more)")
                table.add_row("Nonexamples", "\n".join(formatted_nonexamples))
        
        return CommandResult(
            success=True,
            message=None,  # No message needed
            data=table
        )
    
    def get_completions(self, context: REPLContext, partial: str) -> List[str]:
        """Get completions for entity names.
        
        Args:
            context: The REPL context
            partial: Partial input to complete
            
        Returns:
            List of possible completions
        """
        completions = []
        
        # Add entity names
        for concept in context.get_concepts():
            if hasattr(concept, 'name') and concept.name.startswith(partial):
                completions.append(concept.name)
        for conjecture in context.get_conjectures():
            if hasattr(conjecture, 'name') and conjecture.name.startswith(partial):
                completions.append(conjecture.name)
        
        return completions
        
    def _format_example(self, example) -> str:
        """Format an example in a human-readable way."""
        if hasattr(example, 'value'):
            return str(example.value)
        return str(example)

class ComputeCommand(BaseCommand):
    """Compute the result of a concept with given arguments.
    
    Usage: compute <concept_name> <arg1> <arg2> ...
    
    Examples:
        compute successor 5
        compute addition 3 4
    """
    
    @property
    def name(self) -> str:
        """Get the command name."""
        return "compute"
    
    def execute(self, state: REPLContext, *args) -> CommandResult:
        """Execute the compute command."""
        if not args:
            return CommandResult(False, "Usage: compute <concept_name> <arg1> <arg2> ...")
            
        concept_name = args[0]
        concept = state.get_entity_by_name(concept_name)
        
        if not concept:
            state.emit_event(Event(
                type=EventType.ERROR,
                data={'message': f"Unknown concept: {concept_name}"}
            ))
            return CommandResult(False, f"Unknown concept: {concept_name}")
            
        if not concept.has_computational_implementation():
            return CommandResult(False, f"No computational implementation available for {concept_name}")
            
        try:
            # Convert arguments to appropriate types
            compute_args = []
            for arg in args[1:]:
                try:
                    # Try to convert to int
                    compute_args.append(int(arg))
                except ValueError:
                    # If not an int, keep as string
                    compute_args.append(arg)
            
            result = concept.compute(*compute_args)
            return CommandResult(True, f"Result: {result}")
            
        except Exception as e:
            return CommandResult(False, f"Error computing result: {str(e)}")
            
    def get_completions(self, state: REPLContext, partial: str) -> List[str]:
        """Get completions for concept names."""
        parts = partial.strip().split()
        
        if len(parts) <= 1:
            # Complete concept names
            concepts = []
            for concept in state.get_concepts():
                if hasattr(concept, 'name') and concept.has_computational_implementation():
                    concepts.append(concept.name)
                    
            if not partial:
                return concepts
            return [c for c in concepts if c.startswith(partial)]
            
        return []  # No completions for arguments 