#!/usr/bin/env python3
"""
AgentOrchestrator: Manages agent invocations for RoboPhD system.
Works exclusively with three-artifact agents (agent.md, eval_instructions.md, tools/).
"""

import json
import subprocess
import shutil
import time
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from datetime import datetime
import hashlib

# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent))
from config import MODEL_FALLBACKS, CLAUDE_CLI_MODEL_MAP
# Add grandparent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent))
from utilities.claude_cli import find_claude_cli

class AgentOrchestrator:
    """Orchestrates database analysis agents in RoboPhD system."""
    
    def __init__(self, 
                 base_experiment_dir: Path,
                 analysis_model: str = 'opus-4.1',
                 claude_path: Optional[str] = None,
                 timeout_phase1: int = 900):  # 15 minutes default
        """
        Initialize the orchestrator.
        
        Args:
            base_experiment_dir: Base directory for the experiment
            analysis_model: Model to use for analysis (opus-4.1, sonnet-4, haiku-3.5)
            claude_path: Path to Claude CLI
            timeout_phase1: Timeout for Phase 1 in seconds
        """
        self.base_dir = Path(base_experiment_dir)
        self.analysis_model = analysis_model
        self.claude_path = claude_path or find_claude_cli()
        self.timeout_phase1 = timeout_phase1
        self.performance_log = []
        self.active_agent_name = None
        self.current_agent_id = None
        self.current_database_name = None
    
    def _log_with_context(self, message: str):
        """Log message with agent and database context for consistency."""
        if self.current_agent_id and self.current_database_name:
            from datetime import datetime
            timestamp = datetime.now().strftime("%H:%M:%S")
            print(f"    [{timestamp}] {self.current_agent_id} | {self.current_database_name}: {message}")
        else:
            print(f"  {message}")
        
    def setup_workspace(self,
                       iteration: int,
                       database_name: str,
                       database_path: Path,
                       package_dir: Path,
                       agent_id: str) -> Path:
        """
        Create workspace with specific agent.
        
        Args:
            iteration: Iteration number
            database_name: Name of the database
            database_path: Path to the SQLite database
            package_dir: Path to three-artifact package directory
            agent_id: ID for tracking this agent
            
        Returns:
            Path to the configured workspace
        """
        # Create workspace directory
        workspace = self.base_dir / f"iteration_{iteration:03d}" / f"agent_{agent_id}" / database_name
        workspace.mkdir(parents=True, exist_ok=True)
        
        # Set up three-artifact workspace
        self._setup_three_artifact_workspace(workspace, package_dir, database_path)
        
        # Store agent name from the package
        self.active_agent_name = package_dir.name
        
        # Store context for consistent logging
        self.current_agent_id = agent_id
        self.current_database_name = database_name
        
        # Create symbolic link to database instead of copying (saves disk space)
        db_dest = workspace / "database.sqlite"
        if db_dest.exists():
            db_dest.unlink()
        # Use symlink to avoid copying large database files
        db_dest.symlink_to(database_path.absolute())
        
        # Create required directories
        (workspace / "output").mkdir(exist_ok=True)
        
        return workspace
    
    def _setup_three_artifact_workspace(self, workspace: Path, package_dir: Path, database_path: Path):
        """
        Set up workspace for three-artifact package.
        
        Args:
            workspace: The workspace directory
            package_dir: Directory containing the package
            database_path: Path to the database
        """
        # Copy agent to .claude/agents directory
        agents_dir = workspace / ".claude" / "agents"
        agents_dir.mkdir(parents=True, exist_ok=True)
        
        agent_src = package_dir / "agent.md"
        agent_dest = agents_dir / agent_src.name
        if not agent_src.exists():
            raise FileNotFoundError(f"agent.md not found: {agent_src}")
        shutil.copy2(agent_src, agent_dest)
        
        # Copy eval instructions to workspace
        eval_src = package_dir / "eval_instructions.md"
        eval_dest = workspace / "eval_instructions.md"
        if not eval_src.exists():
            raise FileNotFoundError(f"eval_instructions.md not found: {eval_src}")
        shutil.copy2(eval_src, eval_dest)
        
        # Copy tools if present
        tools_src = package_dir / "tools"
        if tools_src.exists():
            tools_dest = workspace / "tools"
            if tools_dest.exists():
                shutil.rmtree(tools_dest)
            shutil.copytree(tools_src, tools_dest)
        
        # Create tool_output directory
        (workspace / "tool_output").mkdir(exist_ok=True)
    
    def _extract_agent_name(self, workspace: Path) -> str:
        """Extract agent name from YAML frontmatter in agent.md."""
        agent_file = workspace / ".claude" / "agents" / "agent.md"
        if not agent_file.exists():
            return "database analysis"  # fallback name
        
        try:
            content = agent_file.read_text()
            # Look for YAML frontmatter
            if content.startswith('---'):
                yaml_end = content.find('---', 3)
                if yaml_end > 0:
                    yaml_content = content[3:yaml_end].strip()
                    for line in yaml_content.split('\n'):
                        if line.startswith('name:'):
                            name = line.split(':', 1)[1].strip()
                            # Remove quotes if present
                            if name.startswith('"') and name.endswith('"'):
                                name = name[1:-1]
                            elif name.startswith("'") and name.endswith("'"):
                                name = name[1:-1]
                            return name
        except Exception:
            pass
        
        return "database analysis"  # fallback name
    
    def run_phase1(self,
                   workspace: Path,
                   cache_key: Optional[str] = None) -> Tuple[bool, Optional[str]]:
        """
        Phase 1: Agent analyzes database and generates analysis.
        
        Args:
            workspace: Workspace directory
            cache_key: Optional cache key to check for existing output
            
        Returns:
            Tuple of (success, output_content)
        """
        # Check cache first
        output_file = workspace / "output" / "agent_output.txt"
        if cache_key and output_file.exists():
            print(f"  ℹ️  Using cached output for {workspace.name}")
            return True, output_file.read_text()
        
        # Extract agent name from YAML frontmatter
        agent_name = self._extract_agent_name(workspace)
        
        # Three-artifact: agent generates database-specific output with agent name
        prompt = f"""Use your {agent_name} agent to analyze the database at ./database.sqlite. Verify that it has generated output at ./output/agent_output.txt.

In the event that the agent fails to generate the required output, read and follow the instructions in .claude/agents/agent.md to analyze the database. The agent file contains specific instructions for database analysis. Follow those instructions to analyze ./database.sqlite and save your output to ./output/agent_output.txt
"""
        
        # Build command with model
        cli_model = CLAUDE_CLI_MODEL_MAP.get(self.analysis_model, self.analysis_model)
        cmd = [
            self.claude_path,
            "--model", cli_model,
        ]
        
        # Add the prompt
        cmd.extend(["--print", prompt])
        
        # Add permission bypass for automation
        cmd.extend(["--permission-mode", "bypassPermissions"])
        
        # Execute Phase 1 with timeout
        phase1_start = time.time()

        try:
            result = subprocess.run(
                cmd,
                cwd=str(workspace),
                capture_output=True,
                text=True,
                timeout=self.timeout_phase1
            )
            
            phase1_time = time.time() - phase1_start
            
            # Log performance
            self.performance_log.append({
                'phase': 'phase1',
                'agent': self.active_agent_name,
                'database': self.current_database_name,
                'time': phase1_time,
                'success': result.returncode == 0
            })
            
            if result.returncode != 0:
                self._log_with_context(f"⚠️  Phase 1 failed (code {result.returncode})")
                if result.stderr:
                    error_preview = result.stderr[:500]
                    self._log_with_context(f"Error: {error_preview}")
                
                # Check for fallback model
                if self.analysis_model in MODEL_FALLBACKS:
                    fallback = MODEL_FALLBACKS[self.analysis_model]
                    if fallback:
                        self._log_with_context(f"Retrying with fallback model: {fallback}")
                        return self._run_phase1_with_fallback(workspace, fallback, prompt)
                
                return False, None
            
            # Check output file was created
            if not output_file.exists():
                self._log_with_context("⚠️  Phase 1 completed but no output file created")
                return False, None
            
            # Three-artifact: Combine agent output with eval instructions
            agent_output = output_file.read_text()
            eval_instructions_file = workspace / "eval_instructions.md"
            
            if eval_instructions_file.exists():
                eval_instructions = eval_instructions_file.read_text()
                # Combine for the final system prompt
                combined_prompt = f"{agent_output}\n\n---\n\n{eval_instructions}"
                
                # Save combined prompt
                system_prompt_file = workspace / "output" / "system_prompt.txt"
                system_prompt_file.write_text(combined_prompt)
                
                self._log_with_context(f"✅ Phase 1 complete ({phase1_time:.1f}s)")
                return True, combined_prompt
            else:
                # Shouldn't happen in three-artifact mode
                self._log_with_context("⚠️  No eval_instructions.md found")
                return False, None
            
        except subprocess.TimeoutExpired:
            phase1_time = time.time() - phase1_start
            self._log_with_context(f"⏱️  Phase 1 timeout after {phase1_time:.1f}s")
            
            self.performance_log.append({
                'phase': 'phase1',
                'agent': self.active_agent_name,
                'database': self.current_database_name,
                'time': phase1_time,
                'success': False,
                'error': 'timeout'
            })
            
            return False, None
        except Exception as e:
            self._log_with_context(f"❌ Phase 1 error: {e}")
            return False, None
    
    def _run_phase1_with_fallback(self, workspace: Path, fallback_model: str, 
                                  prompt: str) -> Tuple[bool, Optional[str]]:
        """Run phase 1 with fallback model."""
        cli_model = CLAUDE_CLI_MODEL_MAP.get(fallback_model, fallback_model)
        cmd = [
            self.claude_path,
            "--model", cli_model,
            "--print", prompt,
            "--permission-mode", "bypassPermissions"
        ]
        
        try:
            result = subprocess.run(
                cmd,
                cwd=str(workspace),
                capture_output=True,
                text=True,
                timeout=self.timeout_phase1
            )
            
            if result.returncode == 0:
                output_file = workspace / "output" / "agent_output.txt"
                if output_file.exists():
                    agent_output = output_file.read_text()
                    eval_instructions_file = workspace / "eval_instructions.md"
                    
                    if eval_instructions_file.exists():
                        eval_instructions = eval_instructions_file.read_text()
                        combined_prompt = f"{agent_output}\n\n---\n\n{eval_instructions}"
                        
                        system_prompt_file = workspace / "output" / "system_prompt.txt"
                        system_prompt_file.write_text(combined_prompt)
                        
                        self._log_with_context(f"✅ Phase 1 complete with fallback")
                        return True, combined_prompt
            
            return False, None
            
        except (subprocess.TimeoutExpired, Exception):
            return False, None
    
    def validate_agent_output(self, workspace: Path) -> bool:
        """
        Validate that agent produced expected output.
        
        Args:
            workspace: Workspace directory
            
        Returns:
            True if output is valid
        """
        output_file = workspace / "output" / "system_prompt.txt"
        
        if not output_file.exists():
            self._log_with_context("❌ No system prompt generated")
            return False
        
        content = output_file.read_text()
        if len(content) < 100:
            self._log_with_context(f"⚠️  System prompt too short ({len(content)} chars)")
            return False
        
        return True
    
    def get_performance_summary(self) -> Dict:
        """Get summary of performance metrics."""
        if not self.performance_log:
            return {}
        
        total_time = sum(entry['time'] for entry in self.performance_log)
        success_count = sum(1 for entry in self.performance_log if entry['success'])
        
        return {
            'total_runs': len(self.performance_log),
            'successful_runs': success_count,
            'success_rate': success_count / len(self.performance_log) if self.performance_log else 0,
            'total_time': total_time,
            'average_time': total_time / len(self.performance_log) if self.performance_log else 0,
            'logs': self.performance_log
        }