"""Cardiac specialist agent for cardiac findings analysis"""

from typing import List, Dict, Any, Optional, Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
import json

from .base_abcde import ABCDEAgent
from ..orchestrator.state import MultiAgentState, AgentAnalysis, Finding
from ...tools import ChestXRayClassifierTool, ChestXRaySegmentationTool, XRayPhraseGroundingTool

'''
Priority:

High
-- classify_cardiomegaly
-- classify_mediastinum  (mediastinal widening)
-- classify_effusion
-- segment_heart_lungs   (needed for all measurements)
-- calculate_ctr      (Cardiothoracic ratio)
-- calculate_mediastinal_ratio
-- detect_cardiac_devices  (pacemakers, valves, etc.)
-- check_support_devices (ECMO, VAD)

Medium
-- measure_cardiac_area
-- measure_chambers
-- assess_hilar_vessels
-- localize_heart
-- find_calcifications
-- trace_cardiac_device_leads
-- assess_cardiac_position (dextrocardia)
-- evaluate_post_surgical
-- visual_reasoning / cross_validate

Low
-- measure_aortic_knob
'''

class CardiacTaskParameters(BaseModel):
    """Parameters for a cardiac task"""
    focus: Optional[str] = Field(default=None, description="Focus area for the task")
    phrase: Optional[str] = Field(default=None, description="Phrase to ground/localize")
    phrases: Optional[List[str]] = Field(default=None, description="Multiple phrases to ground")
    image_path: Optional[str] = Field(default=None, description="Path to the image")
    
    class Config:
        extra = "forbid"


class CardiacSubtask(BaseModel):
    """A single subtask in the cardiac analysis plan"""
    task_id: str = Field(description="Unique identifier for this subtask (e.g., 'classify_cardiomegaly', 'calculate_ctr', 'detect_devices')")
    description: str = Field(description="What this subtask accomplishes")
    tool: str = Field(description="Which tool to use: chest_xray_classifier, chest_xray_segmentation, or xray_phrase_grounding")
    parameters: CardiacTaskParameters = Field(default_factory=CardiacTaskParameters, description="Parameters for the tool")
    depends_on: List[str] = Field(default_factory=list, description="List of task_ids this task depends on")
    
    class Config:
        extra = "forbid"  # Required for Azure OpenAI structured output


class CardiacPlan(BaseModel):
    """Dynamic plan for cardiac analysis"""
    subtasks: List[CardiacSubtask] = Field(description="Ordered list of subtasks to execute")
    reasoning: str = Field(description="Brief explanation of why this plan was chosen")
    
    class Config:
        extra = "forbid"  # Required for Azure OpenAI structured output


class CardiacAgent(ABCDEAgent):
    """Specialist agent for cardiac findings with LangGraph-style dynamic P&E mode"""
    
    def __init__(
        self,
        llm: BaseLanguageModel,
        classification_tool: ChestXRayClassifierTool,
        segmentation_tool: ChestXRaySegmentationTool,
        grounding_tool: Optional[XRayPhraseGroundingTool] = None,
        vcot_module: Optional[Any] = None
    ):
        # Determine available tools
        tools = [classification_tool, segmentation_tool]
        if grounding_tool:
            tools.append(grounding_tool)
        
        super().__init__(
            agent_name="CardiacAgent",
            llm=llm,
            tools=tools,
            mode="plan_execute",  # Dynamic P&E mode with LLM-driven planning
            vcot_policy="aggressive"  # Aggressive V-CoT for cardiac findings
        )
        self.vcot_module = vcot_module
        self.has_grounding = grounding_tool is not None
        
        # Initialize caches
        self._segmentation_cache = {}
        self._classification_cache = {}
        self._findings_dedup = set()  # Track processed findings to avoid duplicates
        
        # Build tools description once
        self._build_tools_description()
        
        # Define subtask templates for the planner
        self.subtask_templates = self._define_subtask_templates()
        
        # Create structured planner
        self.planner = self._create_planner()
    
    def _build_tools_description(self):
        """Build tools description string once during init"""
        tools_desc = [
            "- chest_xray_classifier: Detects pathologies (cardiomegaly, mediastinal widening, pericardial effusion)",
            "- chest_xray_segmentation: Segments anatomical structures, provides measurements"
        ]
        if self.has_grounding:
            tools_desc.append("- xray_phrase_grounding: Localizes specific phrases/objects in the image")
        
        self.tools_available_desc = "\n".join(tools_desc)
    
    def _define_subtask_templates(self) -> str:
        """Define available subtask templates for cardiac analysis"""
        return """
Available Cardiac Subtask Templates:

1. CLASSIFICATION TASKS:
   - classify_cardiomegaly: Detect enlarged heart using classifier
   - classify_mediastinum: Check for mediastinal widening (per ABCDEF - C: Circulation)
   - classify_effusion: Detect pericardial effusion
   
2. MEASUREMENT TASKS:
   - segment_heart_lungs: Segment cardiac and lung structures
     * For CTR: only needs ["Heart", "Left Lung", "Right Lung"]
     * For mediastinum: needs ["Heart", "Mediastinum", "Aorta", "Left Lung", "Right Lung"]
     * For chamber assessment: needs ["Heart", "Aorta", "Left Lung", "Right Lung"]
   - calculate_ctr: Compute cardiothoracic ratio from segmentation (CTR > 0.5 = cardiomegaly)
   - calculate_mediastinal_ratio: Measure upper mediastinal width ratio (per ABCDEF standard)
   - measure_chambers: Assess individual chamber sizes and position (needs ["Heart", "Aorta"])
   - measure_aortic_knob: Calculate aortic prominence metrics
   
3. LOCALIZATION TASKS:
   - localize_heart: Find cardiac silhouette boundaries
   - detect_cardiac_devices: Locate cardiac pacemakers, valves, stents, or leads
     * Needs extended segmentation: ["Heart", "Left Lung", "Right Lung", "Mediastinum", "Left Clavicle", "Right Clavicle"]
   - find_calcifications: Identify cardiac or valvular calcifications
   - trace_cardiac_device_leads: Follow cardiac pacemaker/ICD lead trajectories
   - assess_hilar_vessels: Evaluate hilar vessel position, size, and clarity (per ABCDEF - C: Circulation)
   
4. SPECIALIZED TASKS:
   - assess_cardiac_position: Check for dextrocardia or malposition (calculates position metrics)
   - evaluate_post_surgical: Look for post-operative changes
   - check_support_devices: Identify VADs, ECMO cannulas

5. VERIFICATION TASKS:
   - visual_reasoning: Apply V-CoT for uncertain findings
   - cross_validate: Compare findings across different tools

Remember: 
- You can create custom subtasks beyond these templates if the query requires it
- The system will automatically optimize organ selection based on the task
- Quantitative measurements provide objective assessment (like CTR for cardiomegaly)"""
    
    def _create_planner(self):
        """Create the LLM-based planner with structured output"""
        
        planner_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a cardiac imaging specialist planning an analysis workflow.
Given a query about cardiac findings, create a detailed plan using available tools.

{subtask_templates}

Tools available:
{tools_available}

EFFICIENCY TIP: For CTR calculation, the segmentation tool can be optimized by specifying only the required organs:
- Use parameters: {{"organs": ["Heart", "Left Lung", "Right Lung"]}} instead of segmenting all 14 structures
- This is much faster and provides the same CTR result

MEDICAL BEST PRACTICES:
- For cardiomegaly assessment, ALWAYS include CTR calculation (cardiothoracic ratio > 0.5 indicates cardiomegaly)
- For mediastinal assessment, calculate mediastinal/thoracic ratio (> 0.25 may indicate widening) - per ABCDEF C: Circulation
- For chamber assessment, analyze heart position and aspect ratio
- For hilar vessels, assess position, size, and symmetry - per ABCDEF C: Circulation
- For aortic assessment, measure aortic prominence relative to heart
- Quantitative measurements are the gold standard - always prefer them over classification alone
- The workflow should typically be: classify → segment → calculate measurements

Create a plan that:
1. Addresses the specific query requirements
2. Uses appropriate tools in logical order
3. Includes dependencies between tasks
4. Is efficient (no redundant steps)
5. Can go beyond templates if needed for the query
6. Only uses tools that are actually available
7. Prioritizes quantitative measurements for objective assessment

IMPORTANT: Do not include tasks that require xray_phrase_grounding if it's not available.
For diagnostic queries (e.g., "Is there X?"), always include both classification AND relevant measurements."""),
            ("human", "Query: {query}\n\nCreate a plan to analyze the cardiac findings.")
        ])
        
        return planner_prompt | self.llm.with_structured_output(CardiacPlan)
    
    def get_react_prompt_template(self) -> str:
        """Not used for P&E mode"""
        return ""
    
    def create_plan(self, state: MultiAgentState) -> List[Dict[str, Any]]:
        """Create dynamic plan using LLM with structured output"""
        query = state.get("query", "")
        
        # If in comparison mode, modify query to avoid comparison within agent
        comparison_mode = state.get("comparison_mode")
        if comparison_mode:
            # Simplify query for single image analysis
            if "cardiomegaly" in query.lower():
                query = "Assess if cardiomegaly is present and calculate CTR"
            elif "heart" in query.lower():
                query = "Assess heart size and calculate CTR"
            else:
                query = "Perform cardiac assessment with CTR calculation"
        
        try:
            # Use LLM to generate structured plan
            plan_response = self.planner.invoke({
                "query": query,
                "subtask_templates": self.subtask_templates,
                "tools_available": self.tools_available_desc
            })
            
            # Convert structured plan to execution format
            execution_plan = []
            for subtask in plan_response.subtasks:
                # Convert parameters to dict for execution
                params_dict = subtask.parameters.dict(exclude_none=True)
                
                execution_plan.append({
                    "task_id": subtask.task_id,
                    "description": subtask.description,
                    "tool": subtask.tool,
                    "parameters": params_dict,
                    "depends_on": subtask.depends_on
                })
            
            # Store reasoning for transparency
            state["cardiac_planning_reasoning"] = plan_response.reasoning
            
            return execution_plan
            
        except Exception as e:
            print(f"Error in LLM planning, falling back to basic plan: {e}")
            # Fallback to a basic comprehensive plan
            return self._get_fallback_plan(query)
    
    def _get_fallback_plan(self, query: str) -> List[Dict[str, Any]]:
        """Minimal fallback plan if LLM planning fails"""
        plan = [
            {
                "task_id": "classify_cardiac",
                "description": "General cardiac pathology detection",
                "tool": "chest_xray_classifier",
                "parameters": {"focus": "cardiac"},
                "depends_on": []
            },
            {
                "task_id": "segment_heart_lungs",
                "description": "Segment heart and lungs for CTR measurement",
                "tool": "chest_xray_segmentation",
                "parameters": {"organs": ["Heart", "Left Lung", "Right Lung"]},
                "depends_on": []
            },
            {
                "task_id": "calculate_ctr",
                "description": "Calculate cardiothoracic ratio",
                "tool": "chest_xray_segmentation",
                "parameters": {},
                "depends_on": ["segment_heart_lungs"]
            }
        ]
        
        # Only add grounding task if tool is available
        if self.has_grounding:
            plan.append({
                "task_id": "localize_heart",
                "description": "Find cardiac silhouette",
                "tool": "xray_phrase_grounding",
                "parameters": {"phrase": "heart cardiac silhouette"},
                "depends_on": []
            })
        
        return plan
    
    def analyze(self, state: MultiAgentState) -> MultiAgentState:
        """Execute the dynamically generated plan"""
        plan = self.create_plan(state)
        print(f"CardiacAgent: Executing plan with {len(plan)} tasks")
        for t in plan:
            print(f"  - {t['task_id']}: {t['description']} (tool: {t['tool']})")
        
        findings = []
        executed_plan = []
        measurements = {}
        task_results = {}
        
        # Reset deduplication set for this analysis
        self._findings_dedup.clear()
        
        # Execute plan with dependency handling
        for task in plan:
            task_id = task["task_id"]
            
            # Check dependencies
            deps_met = all(dep in task_results for dep in task.get("depends_on", []))
            if not deps_met:
                print(f"Skipping {task_id} - dependencies not met")
                continue
            
            print(f"\nExecuting task: {task_id}")
            
            # Execute task based on tool type
            if task["tool"] == "chest_xray_classifier":
                result = self._execute_classification_task(state, task, task_results)
            elif task["tool"] == "chest_xray_segmentation":
                result = self._execute_segmentation_task(state, task, task_results)
            elif task["tool"] == "xray_phrase_grounding":
                result = self._execute_grounding_task(state, task, task_results)
            elif task["tool"] == "visual_reasoning":
                result = self._execute_vcot_task(state, task, task_results, findings)
            else:
                result = {"success": False, "error": f"Unknown tool: {task['tool']}"}
            
            print(f"Task {task_id} result: success={result.get('success', False)}")
            
            # Store result
            task_results[task_id] = result
            executed_plan.append({
                "task_id": task_id,
                "description": task["description"],
                "result": result
            })
            
            # Process results into findings (only if task succeeded)
            if result.get("success", False):
                self._process_task_results(task, result, findings, measurements)
            else:
                print(f"  Skipping finding processing for failed task: {result.get('error', 'Unknown error')}")
        
        print(f"\nCardiacAgent: Generated {len(findings)} findings")
        for f in findings:
            print(f"  - {f['pathology']}: {f['confidence']:.2f} - {f['evidence']}")
        
        # Update state with cardiac analysis
        try:
            # Try to assess confidence and review needs
            confidence_level = self._assess_confidence(findings)
            needs_review = self._needs_review(findings, measurements)
            visual_cot_triggered = any(f.get("visual_cot") is not None for f in findings)
        except Exception as e:
            print(f"Warning: Error in confidence assessment: {e}")
            # Fallback values if assessment fails
            confidence_level = "medium"
            needs_review = True
            visual_cot_triggered = False
        
        # Ensure findings is always a list
        if not isinstance(findings, list):
            findings = []
        
        state["cardiac_analysis"] = AgentAnalysis(
            agent_name="CardiacAgent",
            findings=findings,
            plan_executed=executed_plan,
            react_steps=[],  # Empty for P&E mode
            visual_cot_triggered=visual_cot_triggered,
            confidence_level=confidence_level,
            needs_human_review=needs_review
        )
        
        return state

    def _execute_classification_task(self, state: MultiAgentState, task: Dict, task_results: Dict) -> Dict:
        """Execute a classification task with caching"""
        try:
            parameters = task.get("parameters", {})
            # Add image path if not specified
            if "image_path" not in parameters:
                parameters["image_path"] = state["image_path"]
            
            # Check cache first - include image path in key
            img_path = parameters.get('image_path', state.get('image_path', ''))
            cache_key = f"classification_{img_path}_{parameters.get('focus', 'default')}"
            if cache_key in self._classification_cache:
                print(f"  Using cached classification for {parameters.get('focus', 'default')}")
                return self._classification_cache[cache_key]
            
            print(f"  Calling classifier with params: {parameters}")
            result = self.tools["chest_xray_classifier"].invoke(parameters)
            
            # Handle tuple return from classification tool
            if isinstance(result, tuple):
                classification, metadata = result
            else:
                classification = result
                metadata = {}
            
            # Convert numpy types to Python types for JSON serialization
            if isinstance(classification, dict):
                classification_clean = {}
                for k, v in classification.items():
                    if hasattr(v, 'item'):  # numpy scalar
                        classification_clean[k] = float(v.item())
                    else:
                        classification_clean[k] = float(v) if isinstance(v, (int, float)) else v
                classification = classification_clean
            
            print(f"  Classification results: {classification}")
            result_dict = {"success": True, "classification": classification, "metadata": metadata}
            
            # Cache the result
            self._classification_cache[cache_key] = result_dict
            return result_dict
        except Exception as e:
            print(f"Error in {task['task_id']}: {e}")
            import traceback
            traceback.print_exc()
            return {"success": False, "error": str(e)}
    
    def _execute_segmentation_task(self, state: MultiAgentState, task: Dict, task_results: Dict) -> Dict:
        """Execute a segmentation task with caching"""
        try:
            parameters = task.get("parameters", {})
            # Add image path if not specified
            if "image_path" not in parameters:
                parameters["image_path"] = state["image_path"]
            
            # Dynamically optimize organs based on task
            if "organs" not in parameters:
                parameters["organs"] = self._get_optimal_organs_for_task(task["task_id"], task["description"])
                print(f"  Optimizing segmentation for {task['task_id']} - only segmenting: {parameters['organs']}")
            
            # Create cache key based on image path and organs
            img_path = parameters.get('image_path', state.get('image_path', ''))
            organs_sorted = sorted(parameters["organs"])
            cache_key = f"segmentation_{img_path}_{','.join(organs_sorted)}"
            
            # Check cache first
            if cache_key in self._segmentation_cache:
                print(f"  Using cached segmentation for {task['task_id']}")
                cached_result = self._segmentation_cache[cache_key]
                
                # For CTR calculation, extract just the CTR value
                if "calculate_ctr" in task["task_id"]:
                    if "segmentation" in cached_result and "metrics" in cached_result["segmentation"]:
                        ctr_data = cached_result["segmentation"]["metrics"].get("CardiothoracicRatio", {})
                        if ctr_data and isinstance(ctr_data, dict):
                            ctr = ctr_data.get("value")
                            if ctr is not None:
                                return {"success": True, "ctr": ctr}
                
                return cached_result
            
            # Check if this is a CTR calculation task that can use previous results
            if task["task_id"] == "calculate_ctr" and "segment_heart_lungs" in task_results:
                # Use previous segmentation result
                seg_result = task_results["segment_heart_lungs"]
                print(f"  CTR calculation using previous segmentation")
                
                # Check if measurements are available from previous task
                if seg_result.get("success") and "measurements" in seg_result:
                    measurements = seg_result["measurements"]
                    if "ctr" in measurements:
                        ctr_data = measurements["ctr"]
                        # Handle both dict and float formats
                        if isinstance(ctr_data, dict):
                            ctr_value = ctr_data.get("value", ctr_data)
                            print(f"  CTR value: {ctr_value}")
                            return {
                                "success": True, 
                                "ctr": ctr_value,
                                "interpretation": ctr_data.get("interpretation", "enlarged" if ctr_value > 0.5 else "normal"),
                                "confidence": ctr_data.get("confidence", 0.9)
                            }
                        else:
                            ctr_value = float(ctr_data)
                            print(f"  CTR value: {ctr_value}")
                            return {
                                "success": True, 
                                "ctr": ctr_value,
                                "interpretation": "enlarged" if ctr_value > 0.5 else "normal",
                                "confidence": 0.9
                            }
                
                # Fallback to old metrics structure
                if seg_result.get("segmentation") and "metrics" in seg_result["segmentation"]:
                    print(f"  Metrics keys: {list(seg_result['segmentation']['metrics'].keys())}")
                    ctr_data = seg_result["segmentation"]["metrics"].get("CardiothoracicRatio")
                    if ctr_data and isinstance(ctr_data, dict):
                        ctr = ctr_data.get("value")
                        print(f"  CTR value: {ctr}")
                        return {"success": True, "ctr": ctr, "segmentation": seg_result["segmentation"]}
                    else:
                        print(f"  CTR data not found or invalid: {ctr_data}")
                return {"success": False, "error": "CTR calculation failed"}
            
            # Regular segmentation
            print(f"  Calling segmentation with params: {parameters}")
            result = self.tools["chest_xray_segmentation"].invoke(parameters)
            
            # Handle tuple return from segmentation tool
            if isinstance(result, tuple):
                segmentation_data, metadata = result
                print(f"  Segmentation returned tuple, extracting first element")
            else:
                segmentation_data = result
                metadata = {}
            
            print(f"  Segmentation result keys: {list(segmentation_data.keys()) if isinstance(segmentation_data, dict) else type(segmentation_data)}")
            
            # Extract metrics and calculate relevant measurements
            metrics = segmentation_data.get('metrics', {}) if isinstance(segmentation_data, dict) else {}
            print(f"  Metrics keys: {list(metrics.keys())}")
            
            # Calculate measurements based on task
            measurements = {}
            
            # CTR calculation
            if "CardiothoracicRatio" in metrics:
                ctr_value = metrics["CardiothoracicRatio"].get("value")
                if ctr_value is not None:
                    # Store CTR as a simple float value for consistency
                    measurements["ctr"] = float(ctr_value)
                    print(f"  CTR value: {ctr_value}")
            
            # Mediastinal ratio calculation (per ABCDEF C: Circulation)
            if task["task_id"] in ["check_mediastinum", "segment_mediastinum", "calculate_mediastinal_ratio"] or "mediastin" in task["task_id"].lower():
                med_ratio = self._calculate_mediastinal_ratio(metrics)
                if med_ratio is not None:
                    measurements["mediastinal_ratio"] = {
                        "value": med_ratio,
                        "interpretation": "widened" if med_ratio > 0.25 else "normal",
                        "confidence": 0.85
                    }
                    print(f"  Mediastinal ratio: {med_ratio:.3f}")
            
            
            # Chamber position assessment
            if task["task_id"] in ["assess_cardiac_position", "measure_chambers"]:
                chamber_assessment = self._assess_chamber_position(metrics)
                if chamber_assessment:
                    measurements["chamber_assessment"] = chamber_assessment
                    print(f"  Chamber position: {chamber_assessment.get('position', 'unknown')}")
            
            # Aortic measurements
            if "Aorta" in metrics:
                aortic_metrics = self._calculate_aortic_prominence(metrics)
                if aortic_metrics:
                    measurements["aortic_prominence"] = aortic_metrics
                    print(f"  Aortic width: {aortic_metrics['width_pixels']} pixels")
            
            # Store measurements in results
            segmentation_data['measurements'] = measurements
            segmentation_data['metrics'] = metrics
            
            result_dict = {"success": True, "segmentation": segmentation_data, "metadata": metadata, "measurements": measurements}
            
            # Cache the result
            self._segmentation_cache[cache_key] = result_dict
            
            return result_dict
        except Exception as e:
            print(f"Error in {task['task_id']}: {e}")
            import traceback
            traceback.print_exc()
            return {"success": False, "error": str(e)}
    
    def _execute_grounding_task(self, state: MultiAgentState, task: Dict, task_results: Dict) -> Dict:
        """Execute a grounding task"""
        if not self.has_grounding:
            print(f"Skipping {task['task_id']}: grounding tool not available")
            return {"success": False, "error": "Grounding tool not available"}
            
        try:
            parameters = task.get("parameters", {})
            # Add image path if not specified
            if "image_path" not in parameters:
                parameters["image_path"] = state["image_path"]
            
            # Handle multiple phrases (for device detection)
            if "phrases" in parameters:
                all_results = {}
                for phrase in parameters["phrases"]:
                    grounding = self.tools["xray_phrase_grounding"].invoke({
                        "image_path": parameters["image_path"],
                        "phrase": phrase
                    })
                    if grounding and "predictions" in grounding:
                        predictions = grounding["predictions"]
                        if predictions and len(predictions) > 0:
                            bboxes = predictions[0].get("bounding_boxes", {})
                            if "image_coordinates" in bboxes and len(bboxes["image_coordinates"]) > 0:
                                all_results[phrase] = {
                                    "bbox": tuple(bboxes["image_coordinates"][0]),
                                    "confidence": predictions[0].get("confidence", 0.5)
                                }
                return {"success": True, "grounding_results": all_results}
            
            # Single phrase grounding
            result = self.tools["xray_phrase_grounding"].invoke(parameters)
            bbox = None
            confidence = 0.5
            
            if result and "predictions" in result:
                predictions = result["predictions"]
                if predictions and len(predictions) > 0:
                    bboxes = predictions[0].get("bounding_boxes", {})
                    if "image_coordinates" in bboxes and len(bboxes["image_coordinates"]) > 0:
                        bbox = tuple(bboxes["image_coordinates"][0])
                        confidence = predictions[0].get("confidence", 0.5)
            
            return {"success": True, "bbox": bbox, "confidence": confidence}
        except Exception as e:
            print(f"Error in {task['task_id']}: {e}")
            return {"success": False, "error": str(e)}
    
    def _execute_vcot_task(self, state: MultiAgentState, task: Dict, task_results: Dict, findings: List[Finding]) -> Dict:
        """Execute a V-CoT task"""
        try:
            if not self.vcot_module:
                return {"success": False, "error": "V-CoT module not available"}
            
            # Get relevant measurements and bbox from previous tasks
            measurements = {}
            bbox = None
            
            # Extract CTR if available
            if "calculate_ctr" in task_results:
                ctr_result = task_results["calculate_ctr"]
                if ctr_result.get("success") and "ctr" in ctr_result:
                    ctr_data = ctr_result["ctr"]
                    # Extract float value from dict or use directly if float
                    if isinstance(ctr_data, dict):
                        measurements["ctr"] = float(ctr_data.get("value", ctr_data))
                    else:
                        measurements["ctr"] = float(ctr_data)
            
            # Extract heart bbox if available
            if "localize_heart" in task_results:
                heart_result = task_results["localize_heart"]
                if heart_result.get("success") and heart_result.get("bbox"):
                    bbox = heart_result["bbox"]
            
            # Determine task focus
            focus = task.get("parameters", {}).get("focus", "cardiac findings")
            
            # Import required classes for new interface
            from ..reasoning.visual_cot import VisualEvidence, VisualEvidenceType
            
            # Create evidence list from task results and findings
            evidence_list = []
            
            # Add classification evidence
            for finding in findings:
                if isinstance(finding, dict) and finding.get("pathology") == "cardiomegaly":
                    evidence_list.append(VisualEvidence(
                        tool_name="chest_xray_classifier",
                        evidence_type=VisualEvidenceType.CLASSIFICATION,
                        confidence=self._extract_confidence_value(finding.get("confidence", 0.5)),
                        description="Cardiomegaly detected",
                        data={"pathology": "cardiomegaly"}
                    ))
            
            # Add measurement evidence
            if measurements.get("ctr"):
                ctr_value = measurements["ctr"]
                evidence_list.append(VisualEvidence(
                    tool_name="chest_xray_segmentation",
                    evidence_type=VisualEvidenceType.MEASUREMENT,
                    confidence=0.9 if abs(ctr_value - 0.5) > 0.1 else 0.6,
                    description=f"CTR = {ctr_value:.3f}",
                    data={"ctr": ctr_value, "enlarged": ctr_value > 0.5}
                ))
            
            # Determine target based on focus
            target = "heart and cardiac silhouette"
            if "mediastin" in focus.lower():
                target = "mediastinum"
            
            # Call new V-CoT interface
            vcot_result = self.vcot_module.generate(
                image_path=state["image_path"],
                task=f"Visual assessment: {focus}",
                target=target,
                evidence_list=evidence_list,
                roi_bbox=bbox,
                measurements=measurements
            )
            
            # Extract text from result (new interface returns VCoTResult object)
            visual_cot = vcot_result.visual_reasoning if hasattr(vcot_result, 'visual_reasoning') else str(vcot_result)
            
            return {"success": True, "visual_cot": visual_cot}
        except Exception as e:
            print(f"Error in V-CoT: {e}")
            return {"success": False, "error": str(e)}
    
    def _process_task_results(self, task: Dict, result: Dict, findings: List[Finding], measurements: Dict):
        """Process task results into findings and measurements with deduplication"""
        if not result.get("success"):
            return
        
        task_id = task["task_id"]
        
        # Process classification results
        if "classification" in result:
            classification = result["classification"]
            cardiac_pathologies = ["Cardiomegaly", "Enlarged Cardiomediastinum", "Pericardial Effusion"]
            
            for pathology in cardiac_pathologies:
                if pathology in classification:
                    conf = classification[pathology]
                    if conf > 0.3:
                        # Create unique key for deduplication
                        finding_key = f"{pathology.lower().replace(' ', '_')}_{task_id}"
                        if finding_key not in self._findings_dedup:
                            self._findings_dedup.add(finding_key)
                        findings.append({
                            "pathology": pathology.lower().replace(" ", "_"),
                            "confidence": float(conf),  # Ensure float
                            "location": None,
                            "measurements": None,
                            "visual_cot": None,
                            "evidence": f"{pathology} detected by {task_id}"
                        })
        
        # Process CTR measurement
        elif "ctr" in result:
            ctr_data = result["ctr"]
            # Handle both dict and float formats for CTR
            if isinstance(ctr_data, dict):
                ctr_value = ctr_data.get("value", ctr_data)
                ctr_confidence = float(ctr_data.get("confidence", 0.9 if abs(ctr_value - 0.5) > 0.1 else 0.6))
            else:
                ctr_value = float(ctr_data)
                ctr_confidence = float(0.9 if abs(ctr_value - 0.5) > 0.1 else 0.6)
            
            measurements["ctr"] = ctr_value  # Store as float for consistency
            is_enlarged = ctr_value > 0.5
            
            findings.append({
                "pathology": "ctr_measurement",
                "confidence": ctr_confidence,  # Now guaranteed to be float
                "location": None,
                "measurements": {"ctr": ctr_value},
                "visual_cot": None,
                "evidence": f"CTR = {ctr_value:.3f} {'(enlarged)' if is_enlarged else '(normal)'}"
            })
        
        # Process other measurements
        elif "measurements" in result:
            task_measurements = result["measurements"]
            
            # Mediastinal ratio (per ABCDEF C: Circulation)
            if "mediastinal_ratio" in task_measurements:
                med_data = task_measurements["mediastinal_ratio"]
                # Ensure confidence is a float, not a dict
                med_confidence = med_data.get("confidence", 0.85) if isinstance(med_data, dict) else 0.85
                if isinstance(med_confidence, dict):
                    med_confidence = med_confidence.get("value", 0.85)
                findings.append({
                    "pathology": "mediastinal_assessment",
                    "confidence": float(med_confidence),  # Ensure it's a float
                    "location": None,
                    "measurements": {"mediastinal_ratio": med_data["value"]},
                    "visual_cot": None,
                    "evidence": f"Mediastinal ratio = {med_data['value']:.3f} ({med_data['interpretation']}) - per ABCDEF C: Circulation"
                })
            
            
            # Chamber position
            if "chamber_assessment" in task_measurements:
                chamber_data = task_measurements["chamber_assessment"]
                if chamber_data.get("position") != "normal":
                    findings.append({
                        "pathology": f"cardiac_position_{chamber_data['position']}",
                        "confidence": float(0.85),  # Ensure float
                        "location": None,
                        "measurements": chamber_data,
                        "visual_cot": None,
                        "evidence": f"Cardiac position: {chamber_data['position']}"
                    })
            
            # Aortic prominence
            if "aortic_prominence" in task_measurements:
                aortic_data = task_measurements["aortic_prominence"]
                if aortic_data.get("aorta_heart_width_ratio", 0) > 0.5:  # Threshold TBD
                    findings.append({
                        "pathology": "aortic_prominence",
                        "confidence": float(0.75),  # Ensure float
                        "location": None,
                        "measurements": aortic_data,
                        "visual_cot": None,
                        "evidence": f"Aortic prominence: width ratio = {aortic_data.get('aorta_heart_width_ratio', 0):.3f}"
                    })
            
            # Update global measurements
            measurements.update(task_measurements)
        
        # Process grounding results
        elif "grounding_results" in result:
            # Multiple devices/objects found
            for phrase, grounding_data in result["grounding_results"].items():
                if grounding_data and grounding_data.get("confidence", 0) > 0.3:
                    # Ensure confidence is a float
                    grounding_confidence = grounding_data.get("confidence", 0.5)
                    if isinstance(grounding_confidence, dict):
                        grounding_confidence = grounding_confidence.get("value", 0.5)
                    findings.append({
                        "pathology": f"cardiac_finding_{phrase.replace(' ', '_')}",
                        "confidence": float(grounding_confidence),  # Ensure float
                        "location": grounding_data.get("bbox"),
                        "measurements": None,
                        "visual_cot": None,
                        "evidence": f"Detected: {phrase}"
                    })
        
        # Process single bbox result
        elif "bbox" in result and result["bbox"]:
            # Store bbox for later use (e.g., heart location)
            if task_id == "localize_heart":
                measurements["heart_bbox"] = result["bbox"]
        
        # Process V-CoT results
        elif "visual_cot" in result:
            # Update existing findings with V-CoT
            vcot_text = result["visual_cot"]
            for finding in findings:
                if finding["pathology"] in ["cardiomegaly", "ctr_measurement", "enlarged_cardiomediastinum"]:
                    finding["visual_cot"] = vcot_text
    
    def _should_trigger_vcot_for_findings(self, findings: List[Finding], result: Dict) -> bool:
        """Determine if V-CoT should be triggered based on task results"""
        # Aggressive triggering for cardiac findings
        has_cardiomegaly = any(
            f.get("pathology") == "cardiomegaly" and self._extract_confidence_value(f.get("confidence", 0)) > 0.2 
            for f in findings if isinstance(f, dict)
        )
        has_borderline_ctr = result.get("ctr", 0) > 0.45
        has_low_conf_findings = any(
            0.3 < self._extract_confidence_value(f.get("confidence", 0)) < 0.7 
            for f in findings if isinstance(f, dict)
        )
        
        return has_cardiomegaly or has_borderline_ctr or has_low_conf_findings
    
    def _execute_vcot(self, state: MultiAgentState, result: Dict, findings: List[Finding]) -> Optional[str]:
        """Execute V-CoT step using new BiasResistantVisualCoT interface"""
        try:
            # Import required classes for new interface
            from ..reasoning.visual_cot import VisualEvidence, VisualEvidenceType
            
            # Use heart bbox if available
            heart_bbox = result.get("heart_bbox")
            
            # Create evidence list from findings
            evidence_list = []
            
            # Add cardiomegaly evidence
            for finding in findings:
                if isinstance(finding, dict) and finding.get("pathology") == "cardiomegaly":
                    evidence_list.append(VisualEvidence(
                        tool_name="chest_xray_classifier",
                        evidence_type=VisualEvidenceType.CLASSIFICATION,
                        confidence=self._extract_confidence_value(finding.get("confidence", 0.5)),
                        description="Cardiomegaly detected",
                        data={"pathology": "cardiomegaly"}
                    ))
            
            # Add CTR measurement evidence
            ctr_value = result.get("ctr")
            if ctr_value:
                if isinstance(ctr_value, dict):
                    ctr_value = ctr_value.get("value", ctr_value)
                evidence_list.append(VisualEvidence(
                    tool_name="chest_xray_segmentation",
                    evidence_type=VisualEvidenceType.MEASUREMENT,
                    confidence=0.9 if abs(float(ctr_value) - 0.5) > 0.1 else 0.6,
                    description=f"CTR = {float(ctr_value):.3f}",
                    data={"ctr": float(ctr_value), "enlarged": float(ctr_value) > 0.5}
                ))
            
            # Call new V-CoT interface
            vcot_result = self.vcot_module.generate(
                image_path=state["image_path"],
                task="Assess cardiac size and CTR measurement accuracy",
                target="heart and cardiac silhouette",
                evidence_list=evidence_list,
                roi_bbox=heart_bbox,
                measurements={k: v for k, v in result.items() if k in ["ctr"]}
            )
            
            # Extract text from result
            visual_cot = vcot_result.visual_reasoning if hasattr(vcot_result, 'visual_reasoning') else str(vcot_result)
            return visual_cot
        except Exception as e:
            print(f"Error in V-CoT: {e}")
            return None
    
    def _needs_review(self, findings: List[Finding], measurements: Dict) -> bool:
        """Determine if human review is needed"""
        # Borderline CTR
        ctr = measurements.get("ctr", 0)
        if ctr and 0.48 < ctr < 0.52:
            return True
        
        # Low confidence findings - safely extract confidence values
        for f in findings:
            if isinstance(f, dict):
                conf = f.get("confidence", 0)
                # Handle case where confidence might be stored as dict
                if isinstance(conf, dict):
                    conf = conf.get("value", 0) if "value" in conf else 0
                # Ensure conf is numeric before comparison
                try:
                    conf = float(conf)
                    if 0.3 < conf < 0.7:
                        return True
                except (TypeError, ValueError):
                    pass  # Skip if conversion fails
        
        # Conflicting findings
        has_normal_ctr = ctr and ctr <= 0.5
        has_cardiomegaly_finding = any(
            f.get("pathology") == "cardiomegaly" and self._extract_confidence_value(f.get("confidence", 0)) > 0.5 
            for f in findings if isinstance(f, dict)
        )
        if has_normal_ctr and has_cardiomegaly_finding:
            return True
        
        return False
    
    def _assess_confidence(self, findings: List[Finding]) -> str:
        """Assess overall confidence level based on findings"""
        if not findings:
            return "low"
        
        # Calculate average confidence (handle dict and Finding TypedDict properly)
        confidences = []
        for f in findings:
            if isinstance(f, dict) and 'confidence' in f:
                conf = f.get('confidence', 0.5)
                confidences.append(self._extract_confidence_value(conf))
        
        if not confidences:
            return "low"
        
        avg_confidence = sum(confidences) / len(confidences)
        
        # Check for V-CoT usage
        has_vcot = any(
            isinstance(f, dict) and f.get("visual_cot") is not None 
            for f in findings
        )
        
        # Determine confidence level
        if avg_confidence >= 0.8 and not has_vcot:
            return "high"
        elif avg_confidence >= 0.6:
            return "medium"
        else:
            return "low"

    def _get_optimal_organs_for_task(self, task_id: str, task_description: str) -> List[str]:
        """Determine optimal organs to segment based on the task"""
        
        # Define organ sets for different cardiac tasks
        organ_sets = {
            # CTR calculation
            "ctr": ["Heart", "Left Lung", "Right Lung"],
            
            # Mediastinal assessment (per ABCDEF C: Circulation)
            "mediastinum": ["Heart", "Mediastinum", "Aorta", "Left Lung", "Right Lung"],
            
            # Pericardial effusion
            "effusion": ["Heart", "Mediastinum", "Left Lung", "Right Lung"],
            
            # Chamber assessment
            "chambers": ["Heart", "Aorta"],
            
            # Full cardiac assessment
            "comprehensive": ["Heart", "Left Lung", "Right Lung", "Mediastinum", "Aorta"],
            
            # Cardiac device assessment
            "devices": ["Heart", "Left Lung", "Right Lung", "Mediastinum", "Left Clavicle", "Right Clavicle"]
        }
        
        # Determine which organ set to use based on task
        task_lower = task_id.lower() + " " + task_description.lower()
        
        if "ctr" in task_lower or "cardiothoracic" in task_lower or "cardiomegaly" in task_lower:
            return organ_sets["ctr"]
        elif "mediastin" in task_lower or "widening" in task_lower:
            return organ_sets["mediastinum"]
        elif "effusion" in task_lower or "pericardial" in task_lower:
            return organ_sets["effusion"]
        elif "chamber" in task_lower or "atrium" in task_lower or "ventricle" in task_lower:
            return organ_sets["chambers"]
        elif "device" in task_lower or "pacemaker" in task_lower or "lead" in task_lower:
            return organ_sets["devices"]
        elif "comprehensive" in task_lower or "full" in task_lower:
            return organ_sets["comprehensive"]
        else:
            # Default to CTR organs for general cardiac tasks
            return organ_sets["ctr"]

    def _calculate_mediastinal_ratio(self, metrics: Dict[str, Any]) -> Optional[float]:
        """Calculate mediastinal/thoracic ratio from segmentation metrics (per ABCDEF C: Circulation)"""
        if "Mediastinum" in metrics and "Left Lung" in metrics and "Right Lung" in metrics:
            # Get mediastinal width
            mediastinal_width = metrics["Mediastinum"]["width"]
            
            # Get thoracic width from lung boundaries
            left_lung_bbox = metrics["Left Lung"]["bbox"]
            right_lung_bbox = metrics["Right Lung"]["bbox"]
            thoracic_width = right_lung_bbox[3] - left_lung_bbox[1]  # rightmost - leftmost
            
            if thoracic_width > 0:
                return mediastinal_width / thoracic_width
        return None
    
    
    def _extract_confidence_value(self, conf: Any) -> float:
        """Safely extract confidence value from various formats"""
        if isinstance(conf, (int, float)):
            return float(conf)
        elif isinstance(conf, dict):
            return float(conf.get("value", 0))
        else:
            return 0.0
    
    def _assess_chamber_position(self, metrics: Dict[str, Any]) -> Dict[str, Any]:
        """Assess cardiac chamber position and size indicators"""
        assessment = {}
        
        if "Heart" in metrics:
            heart_centroid = metrics["Heart"]["centroid"]
            heart_bbox = metrics["Heart"]["bbox"]
            
            # Check if heart is displaced (dextrocardia, etc.)
            relative_pos = metrics["Heart"]["relative_position"]
            if relative_pos["left"] > 0.6:  # Heart centroid too far right
                assessment["position"] = "rightward_displaced"
            elif relative_pos["left"] < 0.3:  # Too far left
                assessment["position"] = "leftward_displaced"
            else:
                assessment["position"] = "normal"
            
            # Assess heart height/width ratio for chamber enlargement
            assessment["height_width_ratio"] = metrics["Heart"]["aspect_ratio"]
            
        return assessment
    
    def _calculate_aortic_prominence(self, metrics: Dict[str, Any]) -> Optional[Dict[str, float]]:
        """Calculate aortic knob prominence metrics"""
        if "Aorta" in metrics:
            aorta_width = metrics["Aorta"]["width"]
            aorta_area = metrics["Aorta"]["area_pixels"]
            
            result = {
                "width_pixels": aorta_width,
                "area_pixels": aorta_area,
            }
            
            # If we have heart metrics, calculate relative size
            if "Heart" in metrics:
                heart_width = metrics["Heart"]["width"]
                result["aorta_heart_width_ratio"] = aorta_width / heart_width if heart_width > 0 else None
                
            return result
        return None 