"""Diaphragm specialist agent for diaphragmatic integrity and subdiaphragmatic analysis"""

from typing import List, Dict, Any, Optional, Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
import json

from .base_abcde import ABCDEAgent
from ..orchestrator.state import MultiAgentState, AgentAnalysis, Finding
from ...tools import ChestXRaySegmentationTool, XRayPhraseGroundingTool

# '''
# Priority:
# High
# -- segment_diaphragm
# -- measure_hemidiaphragm_heights
# -- calculate_height_difference
# -- assess_diaphragm_position
# -- detect_pneumoperitoneum (free air = surgical emergency)

# Medium
# -- segment_hemidiaphragms
# -- measure_diaphragm_contours
# -- detect_diaphragm_elevation
# -- measure_diaphragm_angles
# -- compare_diaphragm_contours
# -- segment_subdiaphragmatic_region
# -- measure_free_air_volume
# -- visual_reasoning / validate_measurements

# Low
# -- detect_diaphragm_tenting
# -- localize_diaphragm / localize_free_air
# '''


class DiaphragmTaskParameters(BaseModel):
    """Parameters for a diaphragm task"""
    focus: Optional[str] = Field(default=None, description="Focus area for the task")
    phrase: Optional[str] = Field(default=None, description="Phrase to ground/localize")
    phrases: Optional[List[str]] = Field(default=None, description="Multiple phrases to ground")
    image_path: Optional[str] = Field(default=None, description="Path to the image")
    
    class Config:
        extra = "forbid"


class DiaphragmSubtask(BaseModel):
    """A single subtask in the diaphragm analysis plan"""
    task_id: str = Field(description="Unique identifier for this subtask (e.g., 'segment_diaphragm', 'measure_height_difference', 'detect_free_air')")
    description: str = Field(description="What this subtask accomplishes")
    tool: str = Field(description="Which tool to use: chest_xray_segmentation or xray_phrase_grounding")
    parameters: DiaphragmTaskParameters = Field(default_factory=DiaphragmTaskParameters, description="Parameters for the tool")
    depends_on: List[str] = Field(default_factory=list, description="List of task_ids this task depends on")
    
    class Config:
        extra = "forbid"


class DiaphragmPlan(BaseModel):
    """Dynamic plan for diaphragm analysis"""
    subtasks: List[DiaphragmSubtask] = Field(description="Ordered list of subtasks to execute")
    reasoning: str = Field(description="Brief explanation of why this plan was chosen")
    
    class Config:
        extra = "forbid"


class DiaphragmAgent(ABCDEAgent):
    """Specialist agent for diaphragmatic findings with LangGraph-style flexible P&E mode"""
    
    def __init__(
        self, 
        llm: BaseLanguageModel,
        segmentation_tool: ChestXRaySegmentationTool,
        grounding_tool: Optional[XRayPhraseGroundingTool] = None,
        vcot_module: Optional[Any] = None
    ):
        # DiaphragmAgent primarily uses segmentation + optional grounding
        tools = [segmentation_tool]
        if grounding_tool:
            tools.append(grounding_tool)
            
        super().__init__(
            agent_name="DiaphragmAgent",
            llm=llm,
            tools=tools,
            mode="plan_execute",  # P&E mode with LLM-driven planning
            vcot_policy="borderline_only",  # V-CoT only for borderline findings
            vcot_threshold_low=0.3,
            vcot_threshold_high=0.7
        )
        self.vcot_module = vcot_module
        self.has_grounding = grounding_tool is not None
        
        # Define subtask templates for the planner
        self.subtask_templates = self._define_subtask_templates()
        
        # Create structured planner
        self.planner = self._create_planner()
    
    def _define_subtask_templates(self) -> str:
        """Define available subtask templates for diaphragm analysis"""
        return """
Available Diaphragm Subtask Templates:

1. DIAPHRAGM SEGMENTATION:
   - segment_diaphragm: Segment both hemidiaphragm domes
   - segment_hemidiaphragms: Segment left and right hemidiaphragms separately
   - measure_diaphragm_contours: Analyze diaphragm shape and contour
   
2. HEIGHT AND POSITION ASSESSMENT:
   - measure_hemidiaphragm_heights: Compare right vs left diaphragm height
   - calculate_height_difference: Quantify height difference between hemidiaphragms
   - assess_diaphragm_position: Evaluate overall diaphragm position (elevated/depressed)
   - detect_diaphragm_elevation: Check for unilateral or bilateral elevation
   
3. GEOMETRIC ANALYSIS:
   - measure_diaphragm_angles: Assess diaphragm dome angles
   - compare_diaphragm_contours: Compare left vs right diaphragm contours
   - detect_diaphragm_tenting: Identify tenting or localized elevation
   
4. FREE AIR DETECTION:
   - detect_pneumoperitoneum: Look for free air under diaphragm domes
   - segment_subdiaphragmatic_region: Analyze area below diaphragm
   - measure_free_air_volume: Quantify pneumoperitoneum if present
   
5. LOCALIZATION TASKS:
   - localize_diaphragm: Find and mark diaphragm boundaries
   - localize_free_air: Ground pneumoperitoneum locations
   
6. VERIFICATION TASKS:
   - visual_reasoning: Apply V-CoT for borderline findings only
   - validate_measurements: Cross-check geometric measurements

Remember: 
- DiaphragmAgent focuses on geometric analysis and measurements
- Borderline findings (0.3-0.7 confidence) trigger V-CoT for verification
- Normal right hemidiaphragm is typically 1-3cm higher than left
- NOTE: Costophrenic angles belong to BreathingAgent per ABCDEF B: Breathing"""
    
    def _create_planner(self):
        """Create the LLM-based planner with structured output"""
        
        # Build tools description based on what's available
        tools_desc = [
            "- chest_xray_segmentation: Segments diaphragm structures, provides geometric measurements"
        ]
        if self.has_grounding:
            tools_desc.append("- xray_phrase_grounding: Localizes specific diaphragmatic findings")
        
        tools_available = "\n".join(tools_desc)
        
        planner_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a specialist in diaphragmatic imaging planning geometric analysis workflows.
Given a query about diaphragm findings, create a detailed plan using available tools.

{subtask_templates}

Tools available:
{tools_available}

MEDICAL BEST PRACTICES:
- Normal right hemidiaphragm is 1-3cm higher than left
- Height difference >3cm suggests pathology (paralysis, eventration)
- Free air under diaphragm indicates pneumoperitoneum (surgical emergency)
- Geometric measurements are objective and quantitative
- Always compare bilateral structures for asymmetry
- NOTE: Costophrenic angles belong to BreathingAgent per ABCDEF B: Breathing

Create a plan that:
1. Addresses the specific query requirements
2. Uses appropriate tools in logical order
3. Includes dependencies between tasks
4. Is efficient (no redundant steps)
5. Can go beyond templates if needed for the query
6. Only uses tools that are actually available
7. Prioritizes quantitative geometric measurements

IMPORTANT: Do not include tasks that require xray_phrase_grounding if it's not available.
For diagnostic queries, always include measurements and comparisons."""),
            ("human", "Query: {query}\n\nCreate a plan to analyze the diaphragmatic findings.")
        ])
        
        return planner_prompt | self.llm.with_structured_output(DiaphragmPlan)
    
    def get_react_prompt_template(self) -> str:
        """Not used for P&E mode"""
        return ""
    
    def create_plan(self, state: MultiAgentState) -> List[Dict[str, Any]]:
        """Create dynamic plan using LLM with structured output"""
        query = state.get("query", "")
        
        # If in comparison mode, modify query to avoid comparison within agent
        comparison_mode = state.get("comparison_mode")
        if comparison_mode:
            # Simplify query for single image analysis
            if "diaphragm" in query.lower():
                query = "Assess diaphragmatic integrity and position"
            elif "pneumoperitoneum" in query.lower():
                query = "Detect pneumoperitoneum and free air"
            else:
                query = "Perform diaphragmatic assessment with measurements"
        
        try:
            # Build tools description
            tools_desc = [
                "- chest_xray_segmentation: Segments diaphragm structures, provides geometric measurements"
            ]
            if self.has_grounding:
                tools_desc.append("- xray_phrase_grounding: Localizes specific diaphragmatic findings")
            
            tools_available = "\n".join(tools_desc)
            
            # Use LLM to generate structured plan
            plan_response = self.planner.invoke({
                "query": query,
                "subtask_templates": self.subtask_templates,
                "tools_available": tools_available
            })
            
            # Convert structured plan to execution format
            execution_plan = []
            for subtask in plan_response.subtasks:
                # Convert parameters to dict for execution
                params_dict = subtask.parameters.dict(exclude_none=True)
                
                execution_plan.append({
                    "task_id": subtask.task_id,
                    "description": subtask.description,
                    "tool": subtask.tool,
                    "parameters": params_dict,
                    "depends_on": subtask.depends_on
                })
            
            return execution_plan
            
        except Exception as e:
            print(f"Error in LLM planning for diaphragm analysis, falling back: {e}")
            return self._get_fallback_plan(query)
    
    def _get_fallback_plan(self, query: str) -> List[Dict[str, Any]]:
        """Fallback plan if LLM planning fails"""
        return [
            {
                "task_id": "segment_diaphragm",
                "description": "Segment bilateral hemidiaphragms",
                "tool": "chest_xray_segmentation",
                "parameters": {"organs": ["Diaphragm"]},
                "depends_on": []
            },
            {
                "task_id": "measure_heights",
                "description": "Measure hemidiaphragm heights",
                "tool": "chest_xray_segmentation",
                "parameters": {"calculate_diaphragm_metrics": True},
                "depends_on": ["segment_diaphragm"]
            }
        ]
    
    def analyze(self, state: MultiAgentState) -> MultiAgentState:
        """Execute the dynamically generated plan"""
        plan = self.create_plan(state)
        findings = []
        executed_plan = []
        measurements = {}
        task_results = {}
        
        # Execute plan with dependency handling
        for task in plan:
            task_id = task["task_id"]
            
            # Check dependencies
            deps_met = all(dep in task_results for dep in task.get("depends_on", []))
            if not deps_met:
                print(f"Skipping {task_id} - dependencies not met")
                continue
            
            # Execute task based on tool type
            if task["tool"] == "chest_xray_segmentation":
                result = self._execute_segmentation_task(state, task, task_results)
            elif task["tool"] == "xray_phrase_grounding":
                result = self._execute_grounding_task(state, task, task_results)
            elif task["tool"] == "visual_reasoning":
                result = self._execute_vcot_task(state, task, task_results, findings)
            else:
                result = {"success": False, "error": f"Unknown tool: {task['tool']}"}
            
            # Store result and process into findings
            task_results[task_id] = result
            executed_plan.append({
                "task_id": task_id,
                "description": task["description"],
                "result": result
            })
            self._process_task_results(task, result, findings, measurements)
        
        state["diaphragm_analysis"] = AgentAnalysis(
            agent_name="DiaphragmAgent",
            findings=findings,
            plan_executed=executed_plan,
            react_steps=[],  # P&E mode doesn't use react steps
            visual_cot_triggered=any(f.get("visual_cot") for f in findings if isinstance(f, dict)),
            confidence_level=self._assess_confidence(findings),
            needs_human_review=self._needs_review(findings, measurements)
        )
        
        return state
    
    def _execute_segmentation_task(self, state: MultiAgentState, task: Dict, task_results: Dict) -> Dict:
        """Execute segmentation task"""
        try:
            tool = self.tools["chest_xray_segmentation"]
            
            # Get optimal organs for this task
            organs = self._get_optimal_organs_for_task(task["task_id"], task["description"])
            
            params = {
                "image_path": state["image_path"],
                "organs": organs
            }
            
            result = tool.invoke(params)
            return {"success": True, "result": result, "tool": "segmentation"}
        except Exception as e:
            return {"success": False, "error": str(e)}
    
    def _execute_grounding_task(self, state: MultiAgentState, task: Dict, task_results: Dict) -> Dict:
        """Execute grounding task"""
        if not self.has_grounding:
            return {"success": False, "error": "Grounding tool not available"}
        
        try:
            tool = self.tools["xray_phrase_grounding"]
            
            # Extract phrase(s) from parameters
            phrase = task["parameters"].get("phrase")
            phrases = task["parameters"].get("phrases", [])
            
            if not phrase and not phrases:
                # Default phrases based on task
                if "free_air" in task["task_id"].lower() or "pneumoperitoneum" in task["task_id"].lower():
                    phrase = "free air under diaphragm"
                else:
                    phrase = "diaphragm"
            
            params = {
                "image_path": state["image_path"],
                "phrase": phrase
            }
            
            result = tool.invoke(params)
            return {"success": True, "result": result, "tool": "grounding"}
        except Exception as e:
            return {"success": False, "error": str(e)}
    
    def _execute_vcot_task(self, state: MultiAgentState, task: Dict, task_results: Dict, findings: List[Finding]) -> Dict:
        """Execute visual chain-of-thought task"""
        if not self.vcot_module:
            return {"success": False, "error": "V-CoT module not available"}
        
        try:
            # Get ROI from previous grounding results if available
            roi_bbox = None
            for result in task_results.values():
                if result.get("tool") == "grounding" and result.get("success"):
                    grounding_result = result.get("result", {})
                    if grounding_result.get("predictions"):
                        bbox_data = grounding_result["predictions"][0].get("bounding_boxes", {}).get("image_coordinates", [])
                        if bbox_data and bbox_data[0]:
                            roi_bbox = bbox_data[0]
                            break
            
            # Collect measurements
            measurements = {}
            for result in task_results.values():
                if result.get("tool") == "segmentation" and result.get("success"):
                    seg_result = result.get("result", {})
                    if "metrics" in seg_result:
                        measurements.update(seg_result["metrics"])
            
            visual_cot = self.vcot_module.generate(
                image_path=state["image_path"],
                roi_bbox=roi_bbox,
                measurements=measurements,
                task=task["description"],
                confidence=0.5  # Default confidence for V-CoT trigger
            )
            
            return {"success": True, "result": {"visual_cot": visual_cot}, "tool": "vcot"}
        except Exception as e:
            return {"success": False, "error": str(e)}
    
    def _process_task_results(self, task: Dict, result: Dict, findings: List[Finding], measurements: Dict):
        """Process task results into findings"""
        if not result.get("success"):
            return
        
        tool_result = result.get("result", {})
        
        if result.get("tool") == "segmentation":
            # Extract diaphragm-relevant measurements
            if "metrics" in tool_result:
                measurements.update(tool_result["metrics"])
                
                # Create findings for significant measurements
                metrics = tool_result["metrics"]
                
                # Hemidiaphragm height difference assessment
                if "hemidiaphragm_height_difference" in metrics:
                    height_diff = metrics["hemidiaphragm_height_difference"]
                    if abs(height_diff) > 30:  # >3cm difference is abnormal
                        # Confidence based on deviation from normal
                        confidence = min(0.9, abs(height_diff - 20) / 50)  # Scale from normal 2cm
                        
                        # Check if borderline finding needs V-CoT
                        visual_cot = None
                        if self.should_trigger_vcot(confidence, is_borderline=True):
                            visual_cot = self._execute_vcot(state, result, findings)
                        
                        side = "right" if height_diff > 0 else "left"
                        findings.append(Finding(
                            pathology=f"{side}_hemidiaphragm_elevation",
                            confidence=confidence,
                            location=None,
                            measurements={"height_difference_mm": height_diff},
                            visual_cot=visual_cot,
                            evidence=f"Hemidiaphragm height difference: {height_diff:.1f}mm ({side} side elevated)"
                        ))
                

                
                # Free air detection
                if "pneumoperitoneum_volume" in metrics:
                    volume = metrics["pneumoperitoneum_volume"]
                    if volume > 10:  # Any significant free air
                        confidence = min(0.95, volume / 100)  # Scale confidence
                        
                        # Check if borderline finding needs V-CoT
                        visual_cot = None
                        if self.should_trigger_vcot(confidence, is_borderline=True):
                            visual_cot = self._execute_vcot(state, result, findings)
                        
                        findings.append(Finding(
                            pathology="pneumoperitoneum",
                            confidence=confidence,
                            location=None,
                            measurements={"volume_ml": volume},
                            visual_cot=visual_cot,
                            evidence=f"Pneumoperitoneum: {volume:.1f}ml of free air"
                        ))
                
                # Diaphragm position assessment
                if "diaphragm_position_percentile" in metrics:
                    percentile = metrics["diaphragm_position_percentile"]
                    if percentile < 10 or percentile > 90:  # Abnormally high or low
                        position = "elevated" if percentile > 90 else "depressed"
                        confidence = min(0.9, abs(percentile - 50) / 50)
                        
                        # Check if borderline finding needs V-CoT
                        visual_cot = None
                        if self.should_trigger_vcot(confidence, is_borderline=True):
                            visual_cot = self._execute_vcot(state, result, findings)
                        
                        findings.append(Finding(
                            pathology=f"diaphragm_{position}",
                            confidence=confidence,
                            location=None,
                            measurements={"position_percentile": percentile},
                            visual_cot=visual_cot,
                            evidence=f"Diaphragm position: {percentile:.1f} percentile ({position})"
                        ))
    
    def _execute_vcot(self, state: MultiAgentState, result: Dict, findings: List[Finding]) -> Optional[str]:
        """Execute V-CoT analysis for borderline findings"""
        if not self.vcot_module:
            return None
        
        try:
            return self.vcot_module.generate(
                image_path=state["image_path"],
                roi_bbox=None,
                measurements={},
                task="Diaphragmatic analysis",
                confidence=0.5
            )
        except Exception as e:
            print(f"Error in V-CoT: {e}")
            return None
    
    def _needs_review(self, findings: List[Finding], measurements: Dict) -> bool:
        """Determine if human review is needed"""
        # Review needed for significant diaphragmatic abnormalities
        for finding in findings:
            if isinstance(finding, dict):
                pathology = finding.get("pathology", "")
                confidence = self._extract_confidence_value(finding.get("confidence", 0))
                
                # Pneumoperitoneum always needs urgent review
                if "pneumoperitoneum" in pathology:
                    return True
                
                # Significant elevation needs review
                if "elevation" in pathology and confidence > 0.7:
                    return True
        
        # Review for concerning measurements
        height_diff = measurements.get("hemidiaphragm_height_difference", 0)
        if abs(height_diff) > 50:  # >5cm difference
            return True
        
        # Large pneumoperitoneum
        if measurements.get("pneumoperitoneum_volume", 0) > 50:
            return True
        
        return False
    
    def _assess_confidence(self, findings: List[Finding]) -> str:
        """Assess overall confidence level"""
        if not findings:
            return "high"
        
        confidences = [self._extract_confidence_value(f.get("confidence", 0)) for f in findings if isinstance(f, dict)]
        if not confidences:
            return "high"
        
        avg_confidence = sum(confidences) / len(confidences)
        
        if avg_confidence >= 0.7:
            return "high"
        elif avg_confidence >= 0.4:
            return "medium"
        else:
            return "low"
    
    def _get_optimal_organs_for_task(self, task_id: str, task_description: str) -> List[str]:
        """Get optimal organ list for segmentation based on task"""
        task_lower = task_id.lower() + " " + task_description.lower()
        
        # Basic diaphragm structures
        organs = ["Diaphragm"]
        
        # NOTE: Costophrenic angles belong to BreathingAgent per ABCDEF B: Breathing
        
        # Add lungs for reference in position assessment
        if any(keyword in task_lower for keyword in ["position", "elevation", "height"]):
            organs.extend(["Left Lung", "Right Lung"])
        
        # Add abdomen for free air detection
        if any(keyword in task_lower for keyword in ["free_air", "pneumoperitoneum"]):
            organs.extend(["Abdomen"])
        
        return organs
    
    def _extract_confidence_value(self, conf: Any) -> float:
        """Safely extract confidence value from various formats"""
        if isinstance(conf, (int, float)):
            return float(conf)
        elif isinstance(conf, dict):
            return float(conf.get("value", 0))
        else:
            return 0.0 