"""Everything Else specialist agent for bones, soft tissues, and devices"""

from typing import List, Dict, Any, Optional, Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
import json

from .base_abcde import ABCDEAgent
from ..orchestrator.state import MultiAgentState, AgentAnalysis, Finding
from ...tools import ChestXRayClassifierTool, XRayPhraseGroundingTool

# '''
# Priority:
# High
# -- detect_rib_fractures
# -- detect_endotracheal_tube (position)
# -- check_central_lines
# -- assess_chest_tubes
# -- check_nasogastric_tube
# -- detect_foreign_bodies (safety)

# Medium
# -- check_clavicle_fractures
# -- examine_spine
# -- detect_bone_lesions
# -- assess_scapular_bones
# -- check_subcutaneous_emphysema
# -- assess_soft_tissue_masses
# -- examine_neck_structures
# -- detect_surgical_hardware
# -- detect_surgical_clips / sternotomy_wires / other post-surgical changes
# -- localize_fractures / devices / foreign_bodies / surgical_hardware

# Low
# -- examine_ecg_leads (usually obvious on overview image)
# '''

class EverythingElseTaskParameters(BaseModel):
    """Parameters for an everything else task"""
    focus: Optional[str] = Field(default=None, description="Focus area for the task")
    phrase: Optional[str] = Field(default=None, description="Phrase to ground/localize")
    phrases: Optional[List[str]] = Field(default=None, description="Multiple phrases to ground")
    image_path: Optional[str] = Field(default=None, description="Path to the image")
    
    class Config:
        extra = "forbid"


class EverythingElseSubtask(BaseModel):
    """A single subtask in the everything else analysis plan"""
    task_id: str = Field(description="Unique identifier for this subtask (e.g., 'detect_fractures', 'check_devices', 'assess_tubes')")
    description: str = Field(description="What this subtask accomplishes")
    tool: str = Field(description="Which tool to use: chest_xray_classifier or xray_phrase_grounding")
    parameters: EverythingElseTaskParameters = Field(default_factory=EverythingElseTaskParameters, description="Parameters for the tool")
    depends_on: List[str] = Field(default_factory=list, description="List of task_ids this task depends on")
    
    class Config:
        extra = "forbid"


class EverythingElsePlan(BaseModel):
    """Dynamic plan for everything else analysis"""
    subtasks: List[EverythingElseSubtask] = Field(description="Ordered list of subtasks to execute")
    reasoning: str = Field(description="Brief explanation of why this plan was chosen")
    
    class Config:
        extra = "forbid"


class EverythingElseAgent(ABCDEAgent):
    """Specialist agent for bones, soft tissues, and devices with LangGraph-style flexible P&E mode"""
    
    def __init__(
        self, 
        llm: BaseLanguageModel,
        classification_tool: ChestXRayClassifierTool,
        grounding_tool: Optional[XRayPhraseGroundingTool] = None,
        vcot_module: Optional[Any] = None
    ):
        # EverythingElseAgent uses classification + optional grounding
        tools = [classification_tool]
        if grounding_tool:
            tools.append(grounding_tool)
            
        super().__init__(
            agent_name="EverythingElseAgent",
            llm=llm,
            tools=tools,
            mode="plan_execute",  # P&E mode with LLM-driven planning
            vcot_policy="never"  # No V-CoT for binary findings
        )
        self.vcot_module = vcot_module
        self.has_grounding = grounding_tool is not None
        
        # Define subtask templates for the planner
        self.subtask_templates = self._define_subtask_templates()
        
        # Create structured planner
        self.planner = self._create_planner()
    
    def _define_subtask_templates(self) -> str:
        """Define available subtask templates for everything else analysis"""
        return """
Available Everything Else Subtask Templates:

1. BONE ASSESSMENT:
   - detect_rib_fractures: Identify rib fractures or abnormalities
   - check_clavicle_fractures: Assess clavicular integrity
   - examine_spine: Check for spinal abnormalities or fractures
   - assess_scapular_bones: Evaluate scapular fractures or lesions
   - detect_bone_lesions: Identify bone metastases or destructive lesions
   
2. SOFT TISSUE EVALUATION:
   - check_subcutaneous_emphysema: Detect air in soft tissues
   - assess_soft_tissue_masses: Identify soft tissue abnormalities
   - examine_neck_structures: Check for neck soft tissue abnormalities
   - detect_foreign_bodies: Identify radiopaque foreign objects
   
3. DEVICE AND HARDWARE ASSESSMENT:
   - detect_endotracheal_tube: Locate and assess ETT position
   - check_central_lines: Identify central venous catheters
   - assess_chest_tubes: Locate chest drainage tubes
   - detect_surgical_hardware: Identify surgical clips, wires, prosthetics
   - check_nasogastric_tube: Assess NG tube position
   - examine_ecg_leads: Identify ECG monitoring leads
   
4. POST-SURGICAL CHANGES:
   - detect_surgical_clips: Identify vascular or tissue clips
   - check_sternotomy_wires: Look for median sternotomy hardware
   - assess_surgical_changes: Identify post-operative anatomical changes
   
5. LOCALIZATION TASKS:
   - localize_fractures: Ground specific fracture locations
   - localize_devices: Ground medical device positions
   - localize_foreign_bodies: Ground foreign object locations
   - localize_surgical_hardware: Ground surgical implant positions

Remember: 
- EverythingElseAgent handles binary present/absent findings
- No V-CoT needed - findings are typically clear-cut
- Focus on systematic detection and accurate localization
- Device position assessment is critical for patient safety"""
    
    def _create_planner(self):
        """Create the LLM-based planner with structured output"""
        
        # Build tools description based on what's available
        tools_desc = [
            "- chest_xray_classifier: Detects pathologies (fractures, devices, foreign bodies)"
        ]
        if self.has_grounding:
            tools_desc.append("- xray_phrase_grounding: Localizes specific findings or devices")
        
        tools_available = "\n".join(tools_desc)
        
        planner_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a specialist in detecting bones, soft tissues, and medical devices planning systematic detection workflows.
Given a query about these findings, create a detailed plan using available tools.

{subtask_templates}

Tools available:
{tools_available}

MEDICAL BEST PRACTICES:
- Systematic approach: bones → soft tissues → devices → foreign bodies
- Device position is critical for patient safety (ETT, central lines)
- Fractures require careful systematic evaluation
- Foreign bodies can be subtle but important
- Post-surgical changes should be documented
- Binary findings: present vs absent (no uncertainty)

Create a plan that:
1. Addresses the specific query requirements
2. Uses appropriate tools in logical order
3. Includes dependencies between tasks
4. Is efficient (no redundant steps)
5. Can go beyond templates if needed for the query
6. Only uses tools that are actually available
7. Prioritizes patient safety for device positioning

IMPORTANT: Do not include tasks that require xray_phrase_grounding if it's not available.
For device queries, always include localization if grounding is available."""),
            ("human", "Query: {query}\n\nCreate a plan to analyze bones, soft tissues, devices, and other findings.")
        ])
        
        return planner_prompt | self.llm.with_structured_output(EverythingElsePlan)
    
    def get_react_prompt_template(self) -> str:
        """Not used for P&E mode"""
        return ""
    
    def create_plan(self, state: MultiAgentState) -> List[Dict[str, Any]]:
        """Create dynamic plan using LLM with structured output"""
        query = state.get("query", "")
        
        # If in comparison mode, modify query to avoid comparison within agent
        comparison_mode = state.get("comparison_mode")
        if comparison_mode:
            # Simplify query for single image analysis
            if "fracture" in query.lower():
                query = "Detect bone fractures and abnormalities"
            elif "device" in query.lower() or "tube" in query.lower():
                query = "Assess medical devices and hardware position"
            else:
                query = "Perform comprehensive bone, soft tissue, and device assessment"
        
        try:
            # Build tools description
            tools_desc = [
                "- chest_xray_classifier: Detects pathologies (fractures, devices, foreign bodies)"
            ]
            if self.has_grounding:
                tools_desc.append("- xray_phrase_grounding: Localizes specific findings or devices")
            
            tools_available = "\n".join(tools_desc)
            
            # Use LLM to generate structured plan
            plan_response = self.planner.invoke({
                "query": query,
                "subtask_templates": self.subtask_templates,
                "tools_available": tools_available
            })
            
            # Convert structured plan to execution format
            execution_plan = []
            for subtask in plan_response.subtasks:
                # Convert parameters to dict for execution
                params_dict = subtask.parameters.dict(exclude_none=True)
                
                execution_plan.append({
                    "task_id": subtask.task_id,
                    "description": subtask.description,
                    "tool": subtask.tool,
                    "parameters": params_dict,
                    "depends_on": subtask.depends_on
                })
            
            return execution_plan
            
        except Exception as e:
            print(f"Error in LLM planning for everything else analysis, falling back: {e}")
            return self._get_fallback_plan(query)
    
    def _get_fallback_plan(self, query: str) -> List[Dict[str, Any]]:
        """Fallback plan if LLM planning fails"""
        return [
            {
                "task_id": "classify_everything_else",
                "description": "Check for fractures, devices, and other findings",
                "tool": "chest_xray_classifier",
                "parameters": {"focus": "comprehensive"},
                "depends_on": []
            }
        ]
    
    def analyze(self, state: MultiAgentState) -> MultiAgentState:
        """Execute the dynamically generated plan"""
        plan = self.create_plan(state)
        findings = []
        executed_plan = []
        measurements = {}
        task_results = {}
        
        # Execute plan with dependency handling
        for task in plan:
            task_id = task["task_id"]
            
            # Check dependencies
            deps_met = all(dep in task_results for dep in task.get("depends_on", []))
            if not deps_met:
                print(f"Skipping {task_id} - dependencies not met")
                continue
            
            # Execute task based on tool type
            if task["tool"] == "chest_xray_classifier":
                result = self._execute_classification_task(state, task, task_results)
            elif task["tool"] == "xray_phrase_grounding":
                result = self._execute_grounding_task(state, task, task_results)
            else:
                result = {"success": False, "error": f"Unknown tool: {task['tool']}"}
            
            # Store result and process into findings
            task_results[task_id] = result
            executed_plan.append({
                "task_id": task_id,
                "description": task["description"],
                "result": result
            })
            self._process_task_results(task, result, findings, measurements)
        
        state["everything_analysis"] = AgentAnalysis(
            agent_name="EverythingElseAgent",
            findings=findings,
            plan_executed=executed_plan,
            react_steps=[],  # P&E mode doesn't use react steps
            visual_cot_triggered=False,  # Never triggers V-CoT
            confidence_level=self._assess_confidence(findings),
            needs_human_review=self._needs_review(findings)
        )
        
        return state
    
    def _execute_classification_task(self, state: MultiAgentState, task: Dict, task_results: Dict) -> Dict:
        """Execute classification task"""
        try:
            tool = self.tools["chest_xray_classifier"]
            params = {"image_path": state["image_path"]}
            
            result = tool.invoke(params)
            return {"success": True, "result": result, "tool": "classification"}
        except Exception as e:
            return {"success": False, "error": str(e)}
    
    def _execute_grounding_task(self, state: MultiAgentState, task: Dict, task_results: Dict) -> Dict:
        """Execute grounding task"""
        if not self.has_grounding:
            return {"success": False, "error": "Grounding tool not available"}
        
        try:
            tool = self.tools["xray_phrase_grounding"]
            
            # Extract phrase(s) from parameters
            phrase = task["parameters"].get("phrase")
            phrases = task["parameters"].get("phrases", [])
            
            if not phrase and not phrases:
                # Default phrases based on task
                task_lower = task["task_id"].lower()
                if "tube" in task_lower:
                    phrase = "endotracheal tube" if "endotracheal" in task_lower else "chest tube"
                elif "line" in task_lower:
                    phrase = "central line"
                elif "fracture" in task_lower:
                    phrase = "fracture"
                elif "device" in task_lower or "pacemaker" in task_lower:
                    phrase = "medical device"
                else:
                    phrase = "foreign body"
            
            params = {
                "image_path": state["image_path"],
                "phrase": phrase
            }
            
            result = tool.invoke(params)
            return {"success": True, "result": result, "tool": "grounding"}
        except Exception as e:
            return {"success": False, "error": str(e)}
    
    def _process_task_results(self, task: Dict, result: Dict, findings: List[Finding], measurements: Dict):
        """Process task results into findings"""
        if not result.get("success"):
            return
        
        tool_result = result.get("result", {})
        
        if result.get("tool") == "classification":
            # Process classification results for everything-else pathologies
            everything_else_pathologies = [
                # Bones
                "Fracture", "Bone Lesion", "Scoliosis",
                # Devices/Hardware  
                "Support Devices", "Prosthetic Material", "Surgical Material",
                # Other
                "Foreign Object", "Subcutaneous Emphysema"
            ]
            
            for pathology in everything_else_pathologies:
                confidence = tool_result.get(pathology, 0)
                if confidence > 0.5:  # Higher threshold for binary findings
                    
                    # No V-CoT for binary findings (policy: never)
                    findings.append(Finding(
                        pathology=pathology.lower().replace(" ", "_"),
                        confidence=confidence,
                        location=None,  # Will be filled by grounding if available
                        measurements=None,
                        visual_cot=None,  # Never trigger V-CoT
                        evidence=f"Classification confidence: {confidence:.3f}"
                    ))
        
        elif result.get("tool") == "grounding":
            # Process grounding results to add location information
            if tool_result.get("predictions"):
                prediction = tool_result["predictions"][0]
                bbox_data = prediction.get("bounding_boxes", {}).get("image_coordinates", [])
                
                if bbox_data and bbox_data[0]:
                    bbox = bbox_data[0]
                    phrase = task["parameters"].get("phrase", "finding")
                    
                    # Try to match this location to existing findings
                    matched = False
                    for finding in findings:
                        if isinstance(finding, dict) and finding.get("pathology"):
                            pathology = finding["pathology"].replace("_", " ")
                            if any(word in phrase.lower() for word in pathology.split()) or \
                               any(word in pathology for word in phrase.lower().split()):
                                finding["location"] = bbox
                                finding["evidence"] += f" | Located at: {bbox}"
                                matched = True
                                break
                    
                    # If no existing finding matched, create new finding
                    if not matched:
                        pathology_name = phrase.lower().replace(" ", "_")
                        findings.append(Finding(
                            pathology=pathology_name,
                            confidence=0.8,  # High confidence for successfully grounded items
                            location=bbox,
                            measurements=None,
                            visual_cot=None,  # Never trigger V-CoT
                            evidence=f"Grounding result: {phrase} located at {bbox}"
                        ))
    
    def _needs_review(self, findings: List[Finding]) -> bool:
        """Determine if human review is needed"""
        # Review needed for safety-critical device findings
        device_pathologies = ["endotracheal_tube", "central_line", "chest_tube"]
        fracture_pathologies = ["fracture", "bone_lesion"]
        
        for finding in findings:
            if isinstance(finding, dict):
                pathology = finding.get("pathology", "")
                confidence = self._extract_confidence_value(finding.get("confidence", 0))
                
                # High confidence device findings need position verification
                if any(device in pathology for device in device_pathologies) and confidence > 0.7:
                    return True
                
                # High confidence fractures need review
                if any(fracture in pathology for fracture in fracture_pathologies) and confidence > 0.8:
                    return True
                
                # Foreign objects always need review
                if "foreign" in pathology:
                    return True
        
        return False
    
    def _assess_confidence(self, findings: List[Finding]) -> str:
        """Assess overall confidence level"""
        if not findings:
            return "high"  # High confidence in negative findings
        
        confidences = [self._extract_confidence_value(f.get("confidence", 0)) for f in findings if isinstance(f, dict)]
        if not confidences:
            return "high"
        
        avg_confidence = sum(confidences) / len(confidences)
        
        if avg_confidence >= 0.8:  # Higher threshold for binary findings
            return "high"
        elif avg_confidence >= 0.6:
            return "medium"
        else:
            return "low"
    
    def _extract_confidence_value(self, conf: Any) -> float:
        """Safely extract confidence value from various formats"""
        if isinstance(conf, (int, float)):
            return float(conf)
        elif isinstance(conf, dict):
            return float(conf.get("value", 0))
        else:
            return 0.0 