"""Flexible workflow with query-driven agent selection and execution modes"""

from typing import Dict, Literal, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from langgraph.graph import StateGraph, END
from langchain_core.messages import SystemMessage

from ..state import MultiAgentState
from .query_analyzer import QueryAnalyzer
from ...agents.base_abcde import ABCDEAgent
from ....tools import DicomProcessorTool


def extract_metadata(state: MultiAgentState) -> MultiAgentState:
    """Extract metadata directly in orchestrator (not a separate agent)"""
    # Initialize lists if not present
    if "messages" not in state:
        state["messages"] = []
    if "completed_agents" not in state:
        state["completed_agents"] = []
    if "visualizations" not in state:
        state["visualizations"] = []
        
    # Extract metadata if DICOM
    if state["image_path"].endswith('.dcm'):
        try:
            dicom_tool = DicomProcessorTool()
            result = dicom_tool.invoke({"dicom_path": state["image_path"]})
            
            # Extract metadata
            state["view_position"] = result.get("ViewPosition", "UNKNOWN")
            pixel_spacing = result.get("PixelSpacing", [1.0, 1.0])
            state["pixel_spacing"] = (float(pixel_spacing[0]), float(pixel_spacing[1]))
            state["dicom_metadata"] = result
            
            # Update image path to the converted PNG
            if "image_path" in result:
                state["image_path"] = result["image_path"]
            
            # Warn about non-frontal views
            if state["view_position"] not in ["PA", "AP"]:
                state["messages"].append(
                    SystemMessage(content=f"⚠️ Non-frontal view: {state['view_position']}. Analysis accuracy may be reduced.")
                )
        except Exception as e:
            print(f"Error extracting DICOM metadata: {e}")
            state["messages"].append(
                SystemMessage(content=f"⚠️ Error processing DICOM file: {e}")
            )
    
    state["current_step"] = "metadata_complete"
    return state


def create_flexible_workflow(
    agents: Dict[str, ABCDEAgent],
    query_analyzer: QueryAnalyzer,
    execution_mode: Literal["parallel", "sequential"] = "parallel"
) -> StateGraph:
    """Create workflow with query-driven agent selection and execution modes"""
    
    workflow = StateGraph(MultiAgentState)
    
    # Add query analysis node
    def analyze_and_plan(state: MultiAgentState) -> MultiAgentState:
        """Analyze query and determine execution plan"""
        # Extract metadata first
        state = extract_metadata(state)
        
        # Analyze query
        agent_needs = query_analyzer.analyze_query(state["query"])
        
        # Determine active agents
        active_agents = []
        if agent_needs.get("need_full", False):
            active_agents = ["airway", "breathing", "cardiac", "diaphragm", "everything"]
        else:
            if agent_needs.get("need_airway", False): 
                active_agents.append("airway")
            if agent_needs.get("need_breathing", False): 
                active_agents.append("breathing")
            if agent_needs.get("need_cardiac", False): 
                active_agents.append("cardiac")
            if agent_needs.get("need_diaphragm", False): 
                active_agents.append("diaphragm")
            if agent_needs.get("need_everything", False): 
                active_agents.append("everything")
        
        state["active_agents"] = active_agents
        state["need_comparison"] = agent_needs.get("need_comparison", False)
        state["execution_mode"] = execution_mode
        
        print(f"Query analysis result: {agent_needs}")
        print(f"Active agents: {active_agents}")
        
        return state
    
    workflow.add_node("analyze_query", analyze_and_plan)
    
    # Add flexible ABCDE execution node
    def execute_abcde(state: MultiAgentState) -> MultiAgentState:
        """Execute ABCDE agents based on mode and selection"""
        active_agents = state.get("active_agents", [])
        
        if not active_agents:
            print("No agents selected for execution")
            return state
        
        if execution_mode == "parallel":
            print(f"Executing {len(active_agents)} agents in parallel mode")
            # Parallel execution with rate limit protection
            with ThreadPoolExecutor(max_workers=3) as executor:  # Limit concurrent calls
                futures = {}
                for agent_name in active_agents:
                    if agent_name in agents:
                        # Each agent processes original image independently
                        state_copy = state.copy()
                        future = executor.submit(
                            agents[agent_name].analyze,
                            state_copy
                        )
                        futures[future] = agent_name
                
                # Collect results
                for future in as_completed(futures):
                    agent_name = futures[future]
                    try:
                        agent_state = future.result()
                        # Merge only the specific agent's analysis
                        analysis_key = f"{agent_name}_analysis"
                        if analysis_key in agent_state:
                            state[analysis_key] = agent_state[analysis_key]
                            state["completed_agents"].append(agent_name)
                            print(f"Completed {agent_name} analysis")
                    except Exception as e:
                        print(f"Error in {agent_name}: {e}")
                        state[f"{agent_name}_analysis"] = None
        
        else:  # sequential
            print(f"Executing {len(active_agents)} agents in sequential mode")
            # Sequential execution - no rate limit issues
            for agent_name in active_agents:
                if agent_name in agents:
                    try:
                        print(f"Running {agent_name}...")
                        # Each agent still processes original image independently
                        state_copy = state.copy()
                        agent_state = agents[agent_name].analyze(state_copy)
                        analysis_key = f"{agent_name}_analysis"
                        if analysis_key in agent_state:
                            state[analysis_key] = agent_state[analysis_key]
                            state["completed_agents"].append(agent_name)
                            print(f"Completed {agent_name} analysis")
                    except Exception as e:
                        print(f"Error in {agent_name}: {e}")
                        state[f"{agent_name}_analysis"] = None
        
        return state
    
    workflow.add_node("execute_abcde", execute_abcde)
    
    # Add comparison node (conditional) - placeholder for now
    def comparison_placeholder(state: MultiAgentState) -> MultiAgentState:
        """Placeholder for comparison agent"""
        print("Comparison agent not yet implemented")
        state["comparison_results"] = None
        return state
    
    workflow.add_node("comparison", comparison_placeholder)
    
    # Add synthesis node (always runs) - placeholder for now
    def synthesis_placeholder(state: MultiAgentState) -> MultiAgentState:
        """Placeholder for synthesis/report agent"""
        print("Synthesis agent not yet implemented")
        
        # Create a simple report from available analyses
        report_parts = []
        report_parts.append("# Chest X-Ray Analysis Report\n")
        
        # Add findings from each completed agent
        for agent_name in state.get("completed_agents", []):
            analysis = state.get(f"{agent_name}_analysis")
            if analysis:
                report_parts.append(f"\n## {agent_name.capitalize()} Analysis\n")
                
                if analysis.get("findings"):
                    for finding in analysis["findings"]:
                        report_parts.append(f"- **{finding['pathology']}**: ")
                        report_parts.append(f"Confidence {finding['confidence']:.2f}")
                        if finding.get("measurements"):
                            report_parts.append(f", Measurements: {finding['measurements']}")
                        report_parts.append(f"\n  Evidence: {finding['evidence']}\n")
                else:
                    report_parts.append("No significant findings.\n")
                    
                if analysis.get("needs_human_review"):
                    report_parts.append("⚠️ **Human review recommended**\n")
        
        state["final_report"] = "".join(report_parts)
        return state
    
    workflow.add_node("synthesis_report", synthesis_placeholder)
    
    # Define edges
    workflow.add_edge("analyze_query", "execute_abcde")
    
    # Conditional edge for comparison
    workflow.add_conditional_edges(
        "execute_abcde",
        lambda s: "comparison" if s.get("need_comparison") and s.get("prior_image_path") else "synthesis_report",
        {
            "comparison": "comparison",
            "synthesis_report": "synthesis_report"
        }
    )
    
    workflow.add_edge("comparison", "synthesis_report")
    workflow.add_edge("synthesis_report", END)
    
    workflow.set_entry_point("analyze_query")
    
    return workflow.compile() 