"""LangGraph-style orchestrator for multi-agent coordination"""

from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from langchain_core.language_models import BaseLanguageModel
from langgraph.graph import StateGraph, END
from concurrent.futures import ThreadPoolExecutor, as_completed

from .state import MultiAgentState, AgentAnalysis
from ..agents.base_abcde import ABCDEAgent
from .semantic_query_analyzer import SemanticQueryAnalyzer, SemanticQueryAnalysis


@dataclass
class OrchestratorConfig:
    """Configuration for the orchestrator"""
    parallel_execution: bool = True
    max_parallel_agents: int = 3
    agent_timeout: int = 30
    priority_threshold: float = 0.3


@dataclass
class OrchestratorResult:
    """Final result from orchestration"""
    query: str
    query_analysis: SemanticQueryAnalysis
    synthesis_result: Any  # SynthesisResult from synthesis agent
    agent_results: Dict[str, AgentAnalysis]
    execution_time: float
    activated_agents: List[str]


class LangGraphOrchestrator:
    """Orchestrator that coordinates the multi-agent workflow"""
    
    def __init__(
        self,
        llm: BaseLanguageModel,
        agents: Dict[str, ABCDEAgent],
        config: Optional[OrchestratorConfig] = None
    ):
        self.llm = llm
        self.agents = agents
        self.config = config or OrchestratorConfig()
        
        # Initialize query analyzer
        self.query_analyzer = SemanticQueryAnalyzer(llm)
        
        # Build workflow
        self.workflow = self._build_workflow()
    
    def _build_workflow(self) -> StateGraph:
        """Build the LangGraph workflow"""
        
        workflow = StateGraph(MultiAgentState)
        
        # Add nodes
        workflow.add_node("analyze_query", self._analyze_query_node)
        workflow.add_node("execute_agents", self._execute_agents_node)
        workflow.add_node("synthesize", self._synthesize_node)
        
        # Add edges
        workflow.add_edge("analyze_query", "execute_agents")
        workflow.add_edge("execute_agents", "synthesize")
        workflow.add_edge("synthesize", END)
        
        # Set entry point
        workflow.set_entry_point("analyze_query")
        
        return workflow.compile()
    
    def _analyze_query_node(self, state: MultiAgentState) -> MultiAgentState:
        """Analyze the query to determine agent activation and priorities"""
        print(f"\n{'='*60}")
        print(f"ORCHESTRATOR: Analyzing query...")
        print(f"Query: {state['query']}")
        
        import time
        start_time = time.time()
        
        # Use semantic analyzer
        analysis = self.query_analyzer.analyze(state['query'])
        
        # Update state with analysis
        state['query_analysis'] = analysis
        state['agent_tasks'] = analysis.agent_requirements
        
        # Determine active agents based on priority threshold
        active_agents = self.query_analyzer.get_active_agents(
            analysis, 
            threshold=self.config.priority_threshold
        )
        
        # Filter to only agents that exist
        state['active_agents'] = [
            agent for agent in active_agents 
            if agent in self.agents
        ]
        
        # Store agent execution modes for later use
        state['agent_execution_modes'] = {
            agent: self.query_analyzer.get_agent_execution_mode(analysis, agent)
            for agent in state['active_agents']
        }
        
        state['need_comparison'] = analysis.requires_comparison
        
        elapsed = time.time() - start_time
        print(f"\nQuery Analysis ({elapsed:.2f}s):")
        print(f"  Intent: {analysis.query_intent.value}")
        print(f"  Active Agents: {state['active_agents']}")
        print(f"  Analysis Focus: {analysis.analysis_focus}")
        
        return state
    
    def _execute_agents_node(self, state: MultiAgentState) -> MultiAgentState:
        """Execute active agents based on execution mode"""
        print(f"\n{'='*60}")
        print(f"ORCHESTRATOR: Executing agents...")
        
        import time
        start_time = time.time()
        
        active_agents = state.get('active_agents', [])
        
        if not active_agents:
            print("No agents to execute")
            state['agent_results'] = {}
            return state
        
        # Check if comparison is needed
        needs_comparison = state.get('need_comparison', False) and state.get('prior_image_path')
        
        if needs_comparison:
            print(f"Comparison mode: Processing both current and prior images")
            # For comparison, process both images
            agent_results = self._execute_comparison(state, active_agents)
        else:
            # Normal execution
            if self.config.parallel_execution:
                print(f"Executing {len(active_agents)} agents in parallel...")
                agent_results = self._execute_parallel(state, active_agents)
            else:
                print(f"Executing {len(active_agents)} agents sequentially...")
                agent_results = self._execute_sequential(state, active_agents)
        
        # Update state with results 
        for name, result in agent_results.items():
            if result:
                # Also store individual analysis directly in state for LangGraph compatibility
                state[f"{name}_analysis"] = result
        
        state['agent_results'] = agent_results
        state['completed_agents'] = list(agent_results.keys())
        
        elapsed = time.time() - start_time
        print(f"\nAgent execution completed in {elapsed:.2f}s")
        
        return state
    
    def _execute_sequential(self, state: MultiAgentState, agents_to_run: List[str]) -> Dict[str, AgentAnalysis]:
        """Execute agents sequentially"""
        import time
        agent_results = {}
        
        for agent_name in agents_to_run:
            print(f"Running {agent_name}...")
            start_time = time.time()
            
            if agent_name in self.agents:
                try:
                    # Update state with current agent
                    state['current_agent'] = agent_name
                    
                    # Get specific tasks for this agent
                    if state.get('query_analysis'):
                        state['agent_tasks'] = self.query_analyzer.get_agent_tasks(
                            state['query_analysis'], 
                            agent_name
                        )
                    
                    # Get execution mode for this agent
                    execution_mode = state.get('agent_execution_modes', {}).get(agent_name, 'function_calling')
                    state['execution_mode'] = execution_mode
                    print(f"  Using execution mode: {execution_mode}")
                    
                    # Set agent mode if it supports dynamic mode setting
                    if hasattr(self.agents[agent_name], 'set_mode'):
                        self.agents[agent_name].set_mode(execution_mode)
                    
                    # Run agent
                    updated_state = self.agents[agent_name].analyze(state)
                    
                    # Extract agent's analysis
                    analysis_key = f"{agent_name}_analysis"
                    if analysis_key in updated_state:
                        agent_results[agent_name] = updated_state[analysis_key]
                    
                    elapsed = time.time() - start_time
                    print(f"✓ Completed {agent_name} in {elapsed:.2f}s")
                    
                except Exception as e:
                    elapsed = time.time() - start_time
                    print(f"✗ Error in {agent_name} after {elapsed:.2f}s: {e}")
                    
        return agent_results
    
    def _execute_comparison(self, state: MultiAgentState, agents_to_run: List[str]) -> Dict[str, AgentAnalysis]:
        """Execute agents for comparison - process both current and prior images"""
        import time
        agent_results = {}
        
        print(f"Running comparison analysis on {len(agents_to_run)} agents")
        
        for agent_name in agents_to_run:
            if agent_name not in self.agents:
                continue
                
            print(f"\n  Processing {agent_name} for comparison...")
            start_time = time.time()
            
            try:
                # Create states for current and prior images
                current_state = state.copy()
                current_state['current_agent'] = agent_name
                current_state['comparison_mode'] = 'current'
                
                prior_state = state.copy()
                prior_state['image_path'] = state['prior_image_path']  # Use prior image
                prior_state['current_agent'] = agent_name
                prior_state['comparison_mode'] = 'prior'
                
                # Get specific tasks
                if state.get('query_analysis'):
                    tasks = self.query_analyzer.get_agent_tasks(state['query_analysis'], agent_name)
                    current_state['agent_tasks'] = tasks
                    prior_state['agent_tasks'] = tasks
                
                # Get execution mode
                execution_mode = state.get('agent_execution_modes', {}).get(agent_name, 'plan_execute')
                current_state['execution_mode'] = execution_mode
                prior_state['execution_mode'] = execution_mode
                
                print(f"    Analyzing current image...")
                current_analysis = self.agents[agent_name].analyze(current_state)
                current_result = current_analysis.get(f"{agent_name}_analysis")
                
                print(f"    Analyzing prior image...")
                prior_analysis = self.agents[agent_name].analyze(prior_state)
                prior_result = prior_analysis.get(f"{agent_name}_analysis")
                
                # Combine results for comparison
                comparison_focus = []
                if state.get('query_analysis'):
                    # query_analysis is a SemanticQueryAnalysis object, not a dict
                    comparison_focus = getattr(state['query_analysis'], 'comparison_focus', []) or []
                
                comparison_result = self._combine_comparison_results(
                    agent_name, current_result, prior_result, 
                    comparison_focus
                )
                
                agent_results[agent_name] = comparison_result
                
                elapsed = time.time() - start_time
                print(f"  ✓ {agent_name} comparison completed in {elapsed:.2f}s")
                
            except Exception as e:
                print(f"  ✗ Error in {agent_name} comparison: {e}")
                # Still provide individual results even if comparison fails
                if current_result:
                    # Provide current analysis as fallback
                    agent_results[agent_name] = current_result
                elif prior_result:
                    # Or at least the prior if current failed
                    agent_results[agent_name] = prior_result
        
        return agent_results
    
    def _combine_comparison_results(
        self, 
        agent_name: str, 
        current_result: AgentAnalysis, 
        prior_result: AgentAnalysis,
        comparison_focus: List[str]
    ) -> AgentAnalysis:
        """Combine current and prior analyses for comparison - general implementation"""
        
        # Extract findings from both analyses with error handling
        current_findings = []
        prior_findings = []
        
        # Safely extract current findings
        if current_result:
            if hasattr(current_result, 'findings'):
                current_findings = current_result.findings
            elif isinstance(current_result, dict) and 'findings' in current_result:
                current_findings = current_result['findings']
        
        # Safely extract prior findings
        if prior_result:
            if hasattr(prior_result, 'findings'):
                prior_findings = prior_result.findings
            elif isinstance(prior_result, dict) and 'findings' in prior_result:
                prior_findings = prior_result['findings']
        
        # Create structured comparison data for the synthesizer
        # The synthesizer's LLM will interpret what these changes mean clinically
        comparison_data = {
            'current_findings': current_findings,
            'prior_findings': prior_findings,
            'comparison_focus': comparison_focus
        }
        
        # Also create a simple findings list that includes both current and prior info
        all_findings = []
        
        # Add current findings with "current_" prefix
        for finding in current_findings:
            enhanced_finding = finding.copy() if isinstance(finding, dict) else finding
            enhanced_finding['temporal_context'] = 'current'
            enhanced_finding['comparison_data'] = {
                'is_current': True,
                'image_type': 'current'
            }
            all_findings.append(enhanced_finding)
        
        # Add prior findings with "prior_" prefix  
        for finding in prior_findings:
            enhanced_finding = finding.copy() if isinstance(finding, dict) else finding
            enhanced_finding['temporal_context'] = 'prior'
            enhanced_finding['comparison_data'] = {
                'is_current': False,
                'image_type': 'prior'
            }
            all_findings.append(enhanced_finding)
        
        # Safely extract plan_executed and react_steps
        plan_executed = []
        react_steps = []
        
        if current_result:
            if hasattr(current_result, 'plan_executed'):
                plan_executed = current_result.plan_executed
            elif isinstance(current_result, dict) and 'plan_executed' in current_result:
                plan_executed = current_result['plan_executed']
                
            if hasattr(current_result, 'react_steps'):
                react_steps = current_result.react_steps
            elif isinstance(current_result, dict) and 'react_steps' in current_result:
                react_steps = current_result['react_steps']
        
        # Create combined analysis with all temporal data for the synthesizer
        return {
            'agent_name': agent_name,
            'findings': all_findings,  # All findings with temporal context
            'comparison_data': comparison_data,  # Structured comparison data
            'plan_executed': plan_executed,
            'react_steps': react_steps,
            'visual_cot_triggered': False,
            'confidence_level': "medium",
            'needs_human_review': True,  # Comparisons often need review
            'comparison_performed': True,
            'current_analysis': current_result,
            'prior_analysis': prior_result
        }
    
    def _execute_parallel(self, state: MultiAgentState, agents_to_run: List[str]) -> Dict[str, AgentAnalysis]:
        """Execute agents in parallel with timing"""
        import time
        from concurrent.futures import ThreadPoolExecutor, as_completed
        
        agent_results = {}
        
        def run_agent(agent_name: str, agent_state: MultiAgentState):
            """Run a single agent with timing"""
            start_time = time.time()
            try:
                # Update state with current agent
                agent_state['current_agent'] = agent_name
                
                # Get specific tasks
                if agent_state.get('query_analysis'):
                    agent_state['agent_tasks'] = self.query_analyzer.get_agent_tasks(
                        agent_state['query_analysis'], 
                        agent_name
                    )
                
                # Get execution mode for this agent
                execution_mode = agent_state.get('agent_execution_modes', {}).get(agent_name, 'function_calling')
                agent_state['execution_mode'] = execution_mode
                print(f"  {agent_name} using execution mode: {execution_mode}")
                
                # Set agent mode if it supports dynamic mode setting
                if hasattr(self.agents[agent_name], 'set_mode'):
                    self.agents[agent_name].set_mode(execution_mode)
                
                updated_state = self.agents[agent_name].analyze(agent_state)
                
                elapsed = time.time() - start_time
                print(f"✓ {agent_name} completed in {elapsed:.2f}s")
                
                # Extract analysis
                analysis_key = f"{agent_name}_analysis"
                return agent_name, updated_state.get(analysis_key)
                
            except Exception as e:
                elapsed = time.time() - start_time
                print(f"✗ {agent_name} failed after {elapsed:.2f}s: {e}")
                return agent_name, None
        
        # Execute agents in parallel
        with ThreadPoolExecutor(max_workers=min(len(agents_to_run), self.config.max_parallel_agents)) as executor:
            # Submit all agents
            futures = {
                executor.submit(run_agent, agent_name, state.copy()): agent_name
                for agent_name in agents_to_run
                if agent_name in self.agents
            }
            
            # Collect results with timeout
            for future in as_completed(futures, timeout=self.config.agent_timeout):
                agent_name, result = future.result()
                if result:
                    agent_results[agent_name] = result
        
        return agent_results
    
    def _synthesize_node(self, state: MultiAgentState) -> MultiAgentState:
        """Synthesis node - always runs after agents complete"""
        print(f"\n{'='*60}")
        print(f"ORCHESTRATOR: Running synthesis agent...")
        
        import time
        start_time = time.time()
        
        # The state should already have individual agent analyses stored from execute_agents_node
        # Verify they are available
        completed_agents = state.get('completed_agents', [])
        
        # Ensure agent analyses are available in state
        for agent_name in completed_agents:
            analysis_key = f"{agent_name}_analysis"
            if analysis_key not in state:
                print(f"  WARNING: {analysis_key} not found in state - synthesis may not have access to findings")
        
        # Run synthesis
        if 'synthesis' in self.agents:
            updated_state = self.agents['synthesis'].analyze(state)
            state = updated_state
        else:
            print("Warning: No synthesis agent available")
            state['final_report'] = "Analysis complete but no synthesis available."
        
        elapsed = time.time() - start_time
        print(f"\nSynthesis completed in {elapsed:.2f}s")
        
        return state
    
    def execute(
        self, 
        query: str, 
        image_path: str,
        prior_image_path: Optional[str] = None
    ) -> OrchestratorResult:
        """Execute the complete orchestration workflow"""
        import time
        start_time = time.time()
        
        # Initialize state
        initial_state = MultiAgentState(
            image_path=image_path,
            query=query,
            prior_image_path=prior_image_path,
            messages=[],
            current_step="started",
            completed_agents=[],
            active_agents=[],
            need_comparison=False,
            execution_mode="parallel" if self.config.parallel_execution else "sequential"
        )
        
        # Run workflow
        final_state = self.workflow.invoke(initial_state)
        
        # Extract results
        agent_results = {}
        for agent_name in final_state.get('completed_agents', []):
            analysis_key = f"{agent_name}_analysis"
            if analysis_key in final_state:
                agent_results[agent_name] = final_state[analysis_key]
        
        execution_time = time.time() - start_time
        
        return OrchestratorResult(
            query=query,
            query_analysis=final_state['query_analysis'],
            synthesis_result=final_state.get('synthesis_result'),
            agent_results=agent_results,
            execution_time=execution_time,
            activated_agents=final_state.get('active_agents', [])
        ) 