"""Mathematical Discovery REPL.

This module provides an interactive shell for mathematical discovery.
"""

import cmd
import os
import traceback
import sys
from typing import List, Dict, Any, Optional
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.syntax import Syntax
from rich.prompt import Prompt
from rich import box
from rich.text import Text
from frame.repl.core.engine import MathREPLCore
from frame.repl.core.state import StateSnapshot
from frame.repl.interface.protocols import CommandResult
from frame.repl.core.events import EventType
from frame.knowledge_base.entities import Concept, Conjecture
from frame.knowledge_base.knowledge_graph import KnowledgeGraph, NodeType
from frame.productions.base import ProductionRule
from frame.repl.core.exceptions import EntityNotFoundError, InvalidArgumentError

class RichCmd(cmd.Cmd):
    """Base class for Rich-enabled command interpreters."""
    
    def __init__(self):
        super().__init__()
        self.console = Console()
        
    def write(self, text):
        """Write text to output."""
        self.console.print(text)
        
    def print_error(self, text):
        """Print error message."""
        self.console.print(f"[red]{text}[/red]")
        
    def print_success(self, text):
        """Print success message."""
        self.console.print(f"[green]{text}[/green]")

class MathREPLShell(RichCmd):
    """Interactive shell for mathematical discovery."""
    
    def __init__(self, initial_graph: Optional[KnowledgeGraph] = None):
        """Initialize the REPL shell."""
        super().__init__()
        self.repl = MathREPLCore(initial_graph = initial_graph)
        
        # Set up event handlers
        self.repl.subscribe(EventType.NEW_ENTITY, self._handle_new_entity)
        self.repl.subscribe(EventType.ERROR, self._handle_error)
        
        # Register built-in commands
        from frame.repl.commands.basic import (
            HelpCommand, ListCommand, RenameCommand,
            ClearCommand, VisualizeCommand, ExitCommand,
            SaveCommand, RemoveCommand
        )
        from frame.repl.commands.math import ApplyRuleCommand, InspectCommand, ComputeCommand
        
        self.repl.registry.register(HelpCommand())
        self.repl.registry.register(ListCommand())
        self.repl.registry.register(RenameCommand())
        self.repl.registry.register(ApplyRuleCommand())
        self.repl.registry.register(ClearCommand())
        self.repl.registry.register(VisualizeCommand())
        self.repl.registry.register(ExitCommand())
        self.repl.registry.register(InspectCommand())
        self.repl.registry.register(SaveCommand())
        self.repl.registry.register(ComputeCommand())
        self.repl.registry.register(RemoveCommand())

        # Cache commonly used commands
        self._inspect_cmd = self.repl.registry.get_command('inspect')
        self._compute_cmd = self.repl.registry.get_command('compute')
        if not self._inspect_cmd:
            raise RuntimeError("Failed to initialize inspect command")
        
        # Print welcome message
        self.console.print("\nWelcome to the Mathematical Discovery REPL.", style="bold blue")
        self.console.print("Type 'help' for a list of commands.\n")
        
        # Set up the prompt
        self._update_prompt()

    def _update_prompt(self):
        """Update the prompt with rich formatting."""
        self.prompt = str(Text("math> ", style="bold green"))

    def _handle_new_entity(self, event):
        """Handle creation of new entities."""
        name = event.data.get('name', str(event.data.get('entity', 'Unknown')))
        entity_id = event.data.get('entity_id', 'Unknown')
        
    def _handle_error(self, event):
        """Handle errors during command execution."""
        self.print_error(event.message)

    def _group_rules_by_category(self, rules: List[ProductionRule]) -> Dict[str, List[ProductionRule]]:
        """Group rules into concepts and conjectures."""
        categories = {
            "Concept Rules": [],
            "Conjecture Rules": []
        }
        
        for rule in rules:
            # Check if the rule's module path contains 'conjectures'
            if 'conjectures' in rule.__module__:
                categories["Conjecture Rules"].append(rule)
            else:
                categories["Concept Rules"].append(rule)
                
        return {k: v for k, v in categories.items() if v}

    def _get_rule_description(self, rule: ProductionRule) -> str:
        """Get a formatted description for a rule."""
        if not rule.__doc__:
            return "No description available"
        
        # Split docstring into lines and clean up
        lines = [line.strip() for line in rule.__doc__.split('\n')]
        # Get first non-empty line as main description
        desc = next((line for line in lines if line), "No description available")
        
        # Add parameter information if available
        param_info = []
        if hasattr(rule, 'get_param_info'):
            params = rule.get_param_info()
            if params:
                param_str = []
                for name, description in params.items():
                    if 'index' in name and 'indices' not in name:
                        param_str.append(f"{name} (int)")
                    elif 'indices' in name:
                        param_str.append(f"{name} (list of ints)")
                    else:
                        param_str.append(name)
                if param_str:
                    param_info.append(f"Parameters: {', '.join(param_str)}")
        
        # Combine description with parameter info
        if param_info:
            desc = f"{desc} [{'; '.join(param_info)}]"
            
        return desc

    def _get_rule_requirements(self, rule: ProductionRule) -> str:
        """Get detailed input requirements for a rule."""
        requirements = []
        
        # Get input types
        if hasattr(rule, 'get_input_types'):
            input_types = rule.get_input_types()
            if input_types:
                if isinstance(input_types[0], list):
                    # Multiple input specifications
                    requirements.append("Can be used in multiple ways:")
                    for i, spec in enumerate(input_types, 1):
                        inputs = []
                        for entity_type, concept_type in spec:
                            type_str = entity_type.__name__
                            if concept_type:
                                if isinstance(concept_type, list):
                                    type_str += f" of type {' or '.join(str(t) for t in concept_type)}"
                                else:
                                    type_str += f" of type {concept_type}"
                            inputs.append(type_str)
                        requirements.append(f"{i}. Takes {', '.join(inputs)}")
                else:
                    # Single input specification
                    inputs = []
                    for entity_type, concept_type in input_types:
                        type_str = entity_type.__name__
                        if concept_type:
                            if isinstance(concept_type, list):
                                type_str += f" of type {' or '.join(str(t) for t in concept_type)}"
                            else:
                                type_str += f" of type {concept_type}"
                        inputs.append(type_str)
                    requirements.append(f"Takes {', '.join(inputs)}")
        
        # Get parameter info
        if hasattr(rule, 'get_param_info'):
            param_info = rule.get_param_info()
            if param_info:
                requirements.append("\nRequired parameters:")
                for param_name, param_desc in param_info.items():
                    requirements.append(f"- {param_name}: {param_desc}")
                    # Add example values
                    if 'index' in param_name:
                        requirements.append("  Example: 0")
                    elif 'indices' in param_name:
                        requirements.append("  Example: [0, 1]")
        
        return "\n".join(requirements)

    def _format_example(self, example) -> str:
        """Format an example in a human-readable way."""
        if hasattr(example, 'value'):
            return str(example.value)
        return str(example)

    def do_clear(self, arg: str) -> None:
        """Clear the screen."""
        os.system('cls' if os.name == 'nt' else 'clear')
        # Print the welcome message again
        self.console.print("\nWelcome to the Mathematical Discovery REPL.", style="bold blue")
        self.console.print("Type 'help' for a list of commands.\n")
        return False  # Explicitly return False to prevent None from being printed

    def do_inspect(self, arg: str) -> bool:
        """Inspect an entity in the knowledge graph.
        
        Usage: inspect <entity_name_or_id>
        """
        # Parse arguments
        args = arg.split()
        if not args:
            self.console.print("[red]Error: Entity name or ID is required[/red]")
            return False
            
        # Get entity name (first argument)
        entity_name = args[0]
        
        try:
            # Get the current REPL state
            state = self.repl.get_state()
            
            # Execute the command with named arguments and context
            result = self._inspect_cmd.safe_execute(context=state, entity_name=entity_name)
            
            if not result.success:
                # Handle different types of errors
                if isinstance(result.error, EntityNotFoundError):
                    self.console.print(f"[red]Error: Entity '{result.error.entity_name}' not found[/red]")
                elif isinstance(result.error, InvalidArgumentError):
                    self.console.print(f"[red]Error: {result.error.message}[/red]")
                    if result.error.arg_name:
                        self.console.print(f"[yellow]Hint: Check the '{result.error.arg_name}' argument[/yellow]")
                else:
                    self.console.print(f"[red]Error: {result.message}[/red]")
                return False
            
            # Display the result
            if result.data:
                self.console.print(result.data)
            if result.message:  # Only print message if it's not None
                self.console.print(f"[green]{result.message}[/green]")
            return False
            
        except Exception as e:
            self.console.print(f"[red]An unexpected error occurred in command 'inspect': {str(e)}[/red]")
            return False

    def do_rename(self, arg: str) -> None:
        """Rename an entity.
        
        Usage: rename <old_name_or_id> <new_name>
        
        You can use either the entity's name or its ID (e.g., 'concept_0', 'conjecture_1')
        """
        result = self.repl.execute_command(f"rename {arg}")
        if result.success:
            self.console.print(f"[green]{result.message}[/green]")
        else:
            self.console.print(f"[red]{result.message}[/red]")

    def do_help(self, arg: str) -> None:
        """Get help on commands or list available rules.
        
        Usage:
            help           - Show available commands
            help COMMAND  - Show help for a specific command
        """
        if not arg:
            table = Table(title="Available Commands", box=box.ROUNDED)
            table.add_column("Command", style="cyan")
            table.add_column("Description", style="yellow")
            
            commands = {
                "help": "Get help on commands or list available rules",
                "list": "List available concepts, rules, or conjectures",
                "apply": "Apply a production rule to create new concepts/conjectures",
                "inspect": "Show detailed information about an entity",
                "compute": "Test computational implementation with args",
                "rename": "Rename an entity",
                "remove": "Remove an entity",
                "visualize": "Create a visualization of the current knowledge graph",
                "clear": "Clear the screen",
                "save": "Save knowledge graph to file in data/graphs; use --load arg of frame.repl to restore",
                "exit": "Exit the REPL",
            }
            
            for cmd, desc in commands.items():
                table.add_row(cmd, desc)
                
            self.console.print(table)
        else:
            try:
                func = getattr(self, 'do_' + arg)
                if func.__doc__:
                    self.console.print(Panel(func.__doc__, title=f"Help: {arg}", border_style="blue"))
                else:
                    self.console.print(f"[red]No help available for '{arg}'[/red]")
            except AttributeError:
                self.console.print(f"[red]No such command: '{arg}'[/red]")

    def do_list(self, arg: str) -> None:
        """List available concepts, rules, or conjectures.
        
        Usage:
            list concepts    - List all available concepts
            list rules      - List all available rules with descriptions
            list conjectures - List all available conjectures
        """
        if not arg:
            self.console.print("[yellow]Usage: list <concepts|rules|conjectures>[/yellow]")
            return
            
        state = self.repl.get_state()
            
        if arg.strip() == "concepts":
            concepts = state.get_concepts()
            if concepts:
                table = Table(title="Available Concepts", box=box.ROUNDED)
                table.add_column("ID", style="magenta", justify="right")
                table.add_column("Name", style="cyan")
                table.add_column("Type", style="yellow")
                
                # Create a list of tuples (concept_id, concept) for sorting
                concept_pairs = []
                for concept in concepts:
                    if hasattr(concept, 'name'):
                        concept_id = state.get_entity_id(concept.name)
                        if concept_id:
                            concept_pairs.append((concept_id, concept))
                
                # Sort by concept ID (numeric part after 'concept_')
                concept_pairs.sort(key=lambda x: int(x[0].split('_')[1]) if x[0] and x[0].startswith('concept_') else float('inf'))
                
                # Show last 50 concepts
                shown_concepts = concept_pairs[-50:]
                
                for concept_id, concept in shown_concepts:
                    concept_type = concept.examples.example_structure.concept_type if hasattr(concept, 'examples') else "unknown type"
                    table.add_row(concept_id, concept.name, str(concept_type))
                
                # Add a note about total count if some concepts were omitted
                total_concepts = len(concepts)
                if total_concepts > 50:
                    table.caption = f"Showing 50 of {total_concepts} concepts"
                    table.caption_style = "italic yellow"
                
                self.console.print(table)
            else:
                self.console.print("[yellow]No concepts available[/yellow]")
                
        elif arg.strip() == "rules":
            rules = state.available_rules
            if rules:
                # Group rules by category
                categorized_rules = self._group_rules_by_category(rules)
                
                # Print rules by category
                for category, category_rules in categorized_rules.items():
                    table = Table(title=category, box=box.ROUNDED)
                    table.add_column("Rule", style="cyan")
                    table.add_column("Description", style="yellow", max_width=60)
                    
                    for rule in category_rules:
                        desc = self._get_rule_description(rule)
                        table.add_row(rule.name, desc)
                    
                    self.console.print(table)
                    self.console.print()  # Add spacing between categories
            else:
                self.console.print("[yellow]No rules available[/yellow]")
                
        elif arg.strip() == "conjectures":
            conjectures = state.get_conjectures()
            if conjectures:
                table = Table(title="Available Conjectures", box=box.ROUNDED)
                table.add_column("ID", style="magenta", justify="right")
                table.add_column("Name", style="cyan")
                
                # Create a list of tuples (conjecture_id, conjecture) for sorting
                conjecture_pairs = []
                for conjecture in conjectures:
                    if hasattr(conjecture, 'name'):
                        conjecture_id = state.get_entity_id(conjecture.name)
                        if conjecture_id:
                            conjecture_pairs.append((conjecture_id, conjecture))
                
                # Sort by conjecture ID (numeric part after 'conjecture_')
                conjecture_pairs.sort(key=lambda x: int(x[0].split('_')[1]) if x[0] and x[0].startswith('conjecture_') else float('inf'))
                
                # Show first 50 conjectures
                shown_conjectures = conjecture_pairs[:50]
                
                for conjecture_id, conjecture in shown_conjectures:
                    table.add_row(conjecture_id, conjecture.name)
                
                # Add a note about total count if some conjectures were omitted
                total_conjectures = len(conjectures)
                if total_conjectures > 50:
                    table.caption = f"Showing 50 of {total_conjectures} conjectures"
                    table.caption_style = "italic yellow"
                
                self.console.print(table)
            else:
                self.console.print("[yellow]No conjectures available[/yellow]")
        else:
            self.console.print(f"[red]Invalid argument: '{arg}'. Use 'concepts', 'rules', or 'conjectures'[/red]")

    def do_apply(self, arg: str) -> None:
        """Apply a production rule to create new concepts/conjectures.
        
        Usage:
        Interactive mode: apply
        Partial mode: apply <rule_name>  # Will prompt for remaining inputs
        Non-interactive mode: apply <rule_name> <input1> <input2> ... [param1=value1] [param2=value2] ...
        
        Examples:
        apply map_iteration successor
        apply match iterate_(iterate_(successor)) indices_to_match=[0,1]
        apply compose tau tau output_to_input_map={0:0}
        apply match  # Will prompt for inputs and parameters
        """
        state = self.repl.get_state()
        rules = state.available_rules
        
        # If no arguments provided, enter interactive mode
        if not arg:
            self._apply_interactive_mode(state, rules)
            return
            
        # Split the command into parts
        parts = arg.split()
        if len(parts) < 1:
            self.console.print("[red]Usage: apply <rule_name> [input1 input2 ...] [param1=value1 param2=value2 ...][/red]")
            return
        
        # Get rule name
        rule_name = parts[0]
        if rule_name == "iter":
            rule_name = "map_iteration" # add map_iteration shortcut
        selected_rule = next((r for r in rules if r.name == rule_name), None)
        if not selected_rule:
            self.console.print(f"[red]Invalid rule: '{rule_name}'[/red]")
            return
            
        # If only rule name provided, enter interactive mode for remaining inputs
        if len(parts) == 1:
            self._apply_interactive_mode(state, rules, selected_rule)
            return
            
        # Non-interactive mode: parse arguments directly
        try:
            # Get input entities
            inputs = parts[1:]
            input_entities = []
            params = {}
            
            # Separate inputs from parameters
            while inputs and '=' in inputs[-1]:
                param_str = inputs.pop()
                name, value = param_str.split('=', 1)
                
                # Convert value to appropriate type
                try:
                    if value.startswith('[') and value.endswith(']'):
                        # Handle list parameters with brackets
                        value = [int(x.strip()) for x in value.strip('[]').split(',')]
                    elif value.startswith('{') and value.endswith('}'):
                        # Handle dictionary parameters
                        # Remove curly braces and whitespace
                        value = value.strip('{}').strip()
                        # Split by comma for multiple mappings
                        pairs = value.split(',')
                        value_dict = {}
                        for pair in pairs:
                            k, v = pair.split(':')
                            value_dict[int(k.strip())] = int(v.strip())
                        value = value_dict
                    elif ',' in value:
                        # Handle list parameters without brackets
                        value = [int(x.strip()) for x in value.split(',')]
                    elif value.isdigit():
                        # Handle integer parameters
                        value = int(value)
                    elif '->' in value:
                        # Handle arrow notation for dictionaries (e.g., "0->0")
                        src, dst = value.split('->')
                        value = {int(src.strip()): int(dst.strip())}
                    # Otherwise keep as string
                except ValueError as e:
                    self.console.print(f"[red]Invalid parameter value: {str(e)}[/red]")
                    return
                
                params[name] = value
            
            # Get the actual entity objects
            for name_or_id in inputs:
                entity = state.get_entity_by_name(name_or_id)
                if entity is None:
                    self.console.print(f"[red]Entity not found: {name_or_id}[/red]")
                    return
                input_entities.append(entity)
            
            # Check if the rule can be applied with these inputs and parameters
            try:
                valid_params = selected_rule.get_valid_parameterizations(*input_entities)
                if not valid_params:
                    self.console.print("[red]Rule cannot be applied to these entities: No valid parameterizations found[/red]")
                    return
                
                # Check if the provided parameters are valid
                if rule_name != "forall": #Note(_; 5/4): This rule doesn't return all valid params atm
                    if params and not any(all(k in valid_param and valid_param[k] == v for k, v in params.items()) for valid_param in valid_params):
                        self.console.print("[red]Invalid parameters for this rule[/red]")
                        return
                
                # Apply the rule
                cmd_parts = ["apply", rule_name] + inputs
                for name, value in params.items():
                    if isinstance(value, list):
                        # Format list parameters with brackets and no spaces
                        cmd_parts.append(f"{name}=[{','.join(str(x) for x in value)}]")
                    elif isinstance(value, dict):
                        # Format dictionary parameters
                        dict_str = "{" + ",".join(f"{k}:{v}" for k, v in value.items()) + "}"
                        cmd_parts.append(f"{name}={dict_str}")
                    else:
                        cmd_parts.append(f"{name}={value}")
                
                result = self.repl.execute_command(" ".join(str(part) for part in cmd_parts))
                if hasattr(result, 'message'):
                    self.console.print(f"[green]{result.message}[/green]")
            except Exception as e:
                self.console.print(f"[red]Error applying rule: {e}[/red]")
                
        except Exception as e:
            self.console.print(f"[red]Error parsing command: {e}[/red]")
            self.console.print("[red]Usage: apply <rule_name> [input1 input2 ...] [param1=value1 param2=value2 ...][/red]")

    def _apply_interactive_mode(self, state, rules, selected_rule=None):
        """Handle interactive mode for apply command."""
        if not selected_rule:
            # Step 1: List available rules in a compact format
            categorized_rules = self._group_rules_by_category(rules)
            
            # Create a table for rules
            table = Table(title="Available Rules", box=box.ROUNDED)
            table.add_column("Category", style="blue")
            table.add_column("Rules", style="cyan")
            
            for category, category_rules in categorized_rules.items():
                rule_names = ", ".join(rule.name for rule in category_rules)
                table.add_row(category, rule_names)
            
            self.console.print(table)
                
            # Step 2: Get rule selection
            rule_name = input("\nEnter rule name: ").strip()
            selected_rule = next((r for r in rules if r.name == rule_name), None)
            if not selected_rule:
                self.console.print(f"[red]Invalid rule: '{rule_name}'[/red]")
                return
                
        # Print detailed rule requirements
        self.console.print(Panel(
            self._get_rule_description(selected_rule),
            title=f"Rule: {selected_rule.name}",
            border_style="blue"
        ))
        self.console.print(Panel(
            self._get_rule_requirements(selected_rule),
            title="Requirements",
            border_style="blue"
        ))
            
        # Step 3: List available entities with their types and IDs
        table = Table(title="Available Entities", box=box.ROUNDED)
        table.add_column("ID", style="magenta", justify="right")
        table.add_column("Name", style="cyan")
        table.add_column("Type", style="yellow")
        
        concepts = state.get_concepts()
        conjectures = state.get_conjectures()
        
        for concept in concepts:
            if hasattr(concept, 'name'):
                concept_id = state.get_entity_id(concept.name)
                concept_type = concept.examples.example_structure.concept_type if hasattr(concept, 'examples') else "unknown type"
                table.add_row(concept_id if concept_id else "N/A", concept.name, f"concept of type {concept_type}")
        for conjecture in conjectures:
            if hasattr(conjecture, 'name'):
                conjecture_id = state.get_entity_id(conjecture.name)
                table.add_row(conjecture_id if conjecture_id else "N/A", conjecture.name, "conjecture")
                
        self.console.print(table)
            
        # Step 4: Get input entities
        self.console.print("\n[blue]Enter input entities (space-separated). You can use either names or IDs (e.g., 'concept_0')[/blue]")
        inputs = input("Input entities: ").strip().split()
        if not inputs:
            self.console.print("[red]No inputs provided[/red]")
            return
        
        # Get the actual entity objects
        input_entities = []
        for name_or_id in inputs:
            entity = state.get_entity_by_name(name_or_id)
            if entity is None:
                self.console.print(f"[red]Entity not found: {name_or_id}[/red]")
                return
            input_entities.append(entity)

        # Step 5: Check valid parameterizations for these inputs
        try:
            valid_params = selected_rule.get_valid_parameterizations(*input_entities)
            if not valid_params:
                self.console.print("[red]Rule cannot be applied to these entities: No valid parameterizations found[/red]")
                return
                
            # If there are parameterizations beyond empty dict, show them
            if any(params for params in valid_params):
                self.console.print("\n[blue]This rule requires additional parameters for these inputs.[/blue]")
                
                # Show up to 3 example parameterizations in a table
                table = Table(title="Example Parameterizations", box=box.ROUNDED)
                table.add_column("Option", style="cyan")
                table.add_column("Parameters", style="yellow")
                
                for i, params in enumerate(valid_params[:3], 1):
                    if params:  # Only show non-empty parameterizations
                        table.add_row(str(i), str(params))
                self.console.print(table)
                        
                # Get parameters from user
                params = {}
                for param_name in valid_params[0].keys():
                    # Show example value based on parameter name
                    example = ""
                    if param_name == 'output_to_input_map' or param_name == "indices_to_map":
                        example = " (example: {0: 0} or 0->0)"
                        param_type = 'dict'
                    elif 'index' in param_name and 'indices' not in param_name:
                        example = " (example: 0)"
                        param_type = 'int'
                    elif 'indices' in param_name:
                        example = " (example: [0, 1] or 0,1)"
                        param_type = 'list'
                    else:
                        param_type = 'str'
                    
                    value = input(f"Enter {param_name}{example}: ").strip()
                    
                    # Convert to appropriate type
                    try:
                        if param_type == 'int':
                            value = int(value)
                        elif param_type == 'list':
                            # Handle both [0,1] and 0,1 formats
                            # First remove any brackets and whitespace
                            value = value.strip('[]').strip()
                            # Split by comma and convert each part to int
                            value = [int(x.strip()) for x in value.split(',')]
                        elif param_type == 'dict':
                            # Handle both {0: 0} and 0->0 formats
                            if '->' in value:
                                # Handle 0->0 format
                                src, dst = value.split('->')
                                value = {int(src.strip()): int(dst.strip())}
                            else:
                                # Handle {0: 0} format
                                # Remove curly braces and whitespace
                                value = value.strip('{}').strip()
                                # Split by comma for multiple mappings
                                pairs = value.split(',')
                                value = {}
                                for pair in pairs:
                                    k, v = pair.split(':')
                                    value[int(k.strip())] = int(v.strip())
                    except ValueError as e:
                        self.console.print(f"[red]Invalid value for {param_name}: {str(e)}[/red]")
                        return
                        
                    params[param_name] = value
            else:
                params = {}
                
            # Step 6: Apply the rule
            try:
                cmd_parts = ["apply", selected_rule.name] + inputs
                for name, value in params.items():
                    if isinstance(value, list):
                        # Format list parameters with brackets and no spaces
                        cmd_parts.append(f"{name}=[{','.join(str(x) for x in value)}]")
                    elif isinstance(value, dict):
                        # Format dictionary parameters
                        dict_str = "{" + ",".join(f"{k}:{v}" for k, v in value.items()) + "}"
                        cmd_parts.append(f"{name}={dict_str}")
                    else:
                        cmd_parts.append(f"{name}={value}")
                
                result = self.repl.execute_command(" ".join(str(part) for part in cmd_parts))
                if hasattr(result, 'message'):
                    self.console.print(f"[green]{result.message}[/green]")
            except Exception as e:
                self.console.print(f"[red]Error applying rule: {e}[/red]")
                
        except Exception as e:
            self.console.print(f"[red]Error checking parameterizations: {e}[/red]")

    def do_visualize(self, arg: str) -> None:
        """Visualize the current knowledge graph.
        
        Usage: visualize [output_path]
        If no output path is provided, saves to 'data/visualizations/knowledge_graph_TIMESTAMP'
        """
        import os
        from datetime import datetime
        
        # Create base directory if it doesn't exist
        base_dir = os.path.join("data", "visualizations")
        os.makedirs(base_dir, exist_ok=True)
        
        # Generate default filename with timestamp if no arg provided
        if not arg:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_path = os.path.join(base_dir, f"knowledge_graph_{timestamp}")
        else:
            output_path = arg
            
        try:
            # Get current graph state and generate visualization
            state = self.repl.get_state()
            state.graph.visualize_construction_tree(output_path)
            self.console.print(f"[green]Logged visualization to: {output_path}.pdf[/green]")
        except Exception as e:
            self.print_error(f"Failed to create visualization: {e}")
            # Print stack trace for debugging
            import traceback
            self.print_error(traceback.format_exc())

    def do_save(self, arg: str) -> None:
        """Save the current REPL state to a file.
        Usage: save [filename]
        If no filename is provided, defaults to 'repl_state.dill'
        """
        result = self.repl.execute_command(f"save {arg}")
        if result.success:
            self.console.print(f"[green]{result.message}[/green]")
        else:
            self.console.print(f"[red]{result.message}[/red]")

    def do_remove(self, arg: str) -> None:
        """Remove a concept or conjecture from the knowledge graph.
        
        Usage: remove <entity_name_or_id>
        
        You can use either the entity's name or its ID (e.g., 'concept_0', 'conjecture_1')
        """
        result = self.repl.execute_command(f"remove {arg}")
        if result.success:
            self.console.print(f"[green]{result.message}[/green]")
        else:
            self.console.print(f"[red]{result.message}[/red]")

    def do_exit(self, arg: str) -> bool:
        """Exit the REPL."""
        print("Goodbye!")
        return True
        
    def do_EOF(self, arg: str) -> bool:
        """Handle Ctrl+D gracefully."""
        print("\nGoodbye!")
        return True
        
    def emptyline(self) -> None:
        """Don't repeat the last command on empty line."""
        pass
        
    def default(self, line: str) -> None:
        """Handle unknown commands."""
        print(f"Unknown command: '{line}'. Type 'help' for available commands.")

    def completedefault(self, text: str, line: str, begidx: int, endidx: int) -> List[str]:
        """Handle argument completion for commands.
        
        This method is called by cmd.Cmd when the command name doesn't have a
        complete_* method defined.
        """
        # Split the line into command and args
        parts = line[:endidx].split()
        if not parts:
            return []
            
        cmd_name = parts[0]
        command = self.repl.registry.get_command(cmd_name)
        if not command:
            return []
            
        # Get completions from the command
        state = self.repl.get_state()
        partial = text if text else ""
        return command.get_completions(state, partial)

    def completenames(self, text: str, *ignored) -> List[str]:
        """Complete command names."""
        state = self.repl.get_state()
        return self.repl.registry.get_completions(state, text)

    def on_entity_created(self, name: str, entity_id: str) -> None:
        """Handle entity creation event."""
        # No need to print here since the command result will already show the message
        pass

    def do_compute(self, arg: str) -> bool:
        """Compute the result of a concept with given arguments.
        
        Usage: compute <concept_name> <arg1> <arg2> ...
        
        Examples:
            compute successor 5
            compute addition 3 4
        """
        # Parse arguments
        args = arg.split()
        if not args:
            self.console.print("[red]Error: Concept name and arguments are required[/red]")
            return False
            
        # Get the current REPL state
        state = self.repl.get_state()
        
        # Execute the command
        result = self._compute_cmd.safe_execute(state, *args)
        
        if not result.success:
            self.console.print(f"[red]Error: {result.message}[/red]")
            return False
            
        # Display the result
        if result.message:
            self.console.print(f"[green]{result.message}[/green]")
        return False 