"""
Chest X-ray measurement tools for calculating medical metrics from segmentation results.

This tool separates measurement calculations from segmentation, allowing for:
1. Cleaner separation of concerns
2. Better maintainability  
3. Easier unit testing
4. Reusable measurement logic across agents
"""

from typing import Dict, Any, Optional, List, Tuple
from pydantic import BaseModel, Field
from langchain.tools import BaseTool
import numpy as np


class MeasurementInput(BaseModel):
    """Input for measurement calculations."""
    
    segmentation_metrics: Dict[str, Any] = Field(
        description="Segmentation metrics output from ChestXRaySegmentationTool"
    )
    measurement_type: str = Field(
        description="Type of measurement to calculate: 'ctr', 'mediastinal_ratio', 'tracheal_deviation', 'chamber_position'"
    )
    pixel_spacing_mm: Optional[float] = Field(
        default=0.2, 
        description="Pixel spacing in mm (from DICOM or estimated)"
    )


class MeasurementResult(BaseModel):
    """Result from measurement calculation."""
    
    measurement_name: str = Field(description="Name of the measurement")
    value: float = Field(description="Calculated measurement value")
    interpretation: str = Field(description="Clinical interpretation (normal, enlarged, etc.)")
    confidence: float = Field(description="Confidence in the measurement (0-1)")
    method: str = Field(description="Method used for calculation")
    units: str = Field(description="Units of measurement")
    reference_range: Optional[str] = Field(default=None, description="Normal reference range")
    clinical_significance: Optional[str] = Field(default=None, description="Clinical significance if abnormal")


class ChestXRayMeasurementTool(BaseTool):
    """Tool for calculating medical measurements from chest X-ray segmentation results."""
    
    name: str = "chest_xray_measurements"
    description: str = (
        "Calculates medical measurements from chest X-ray segmentation results. "
        "Supports CTR (cardiothoracic ratio), mediastinal ratio, tracheal deviation, "
        "chamber position assessment, and other standardized measurements. "
        "Takes segmentation metrics as input and returns standardized measurements."
    )
    args_schema: type[BaseModel] = MeasurementInput

    def _run(
        self,
        segmentation_metrics: Dict[str, Any],
        measurement_type: str,
        pixel_spacing_mm: float = 0.2,
        **kwargs
    ) -> MeasurementResult:
        """Execute measurement calculation."""
        
        measurement_type = measurement_type.lower()
        
        if measurement_type == "ctr":
            return self._calculate_ctr(segmentation_metrics, pixel_spacing_mm)
        elif measurement_type == "mediastinal_ratio":
            return self._calculate_mediastinal_ratio(segmentation_metrics)
        elif measurement_type == "tracheal_deviation":
            return self._calculate_tracheal_deviation(segmentation_metrics, pixel_spacing_mm)
        elif measurement_type == "chamber_position":
            return self._assess_chamber_position(segmentation_metrics)
        else:
            raise ValueError(f"Unsupported measurement type: {measurement_type}")

    def _calculate_ctr(self, metrics: Dict[str, Any], pixel_spacing: float) -> MeasurementResult:
        """Calculate Cardiothoracic Ratio (CTR)."""
        
        # Check for required organs
        if not all(organ in metrics for organ in ["Heart", "Left Lung", "Right Lung"]):
            raise ValueError("CTR calculation requires Heart, Left Lung, and Right Lung segmentation")
        
        heart_metrics = metrics["Heart"]
        left_lung_metrics = metrics["Left Lung"]
        right_lung_metrics = metrics["Right Lung"]
        
        # Get heart width (cardiac silhouette width)
        heart_width_px = heart_metrics["width"]  # bbox width
        
        # Calculate thoracic width from lung boundaries
        # bbox = (min_y, min_x, max_y, max_x)
        thoracic_min_x = min(left_lung_metrics["bbox"][1], right_lung_metrics["bbox"][1])
        thoracic_max_x = max(left_lung_metrics["bbox"][3], right_lung_metrics["bbox"][3])
        thoracic_width_px = thoracic_max_x - thoracic_min_x
        
        if thoracic_width_px <= 0:
            raise ValueError("Invalid thoracic width calculation")
        
        ctr = heart_width_px / thoracic_width_px
        
        # Clinical interpretation
        interpretation = "enlarged" if ctr > 0.5 else "normal"
        confidence = 0.9 if abs(ctr - 0.5) > 0.05 else 0.7  # Higher confidence for clear cases
        
        clinical_significance = None
        if ctr > 0.5:
            clinical_significance = "Suggests cardiomegaly - may indicate heart failure, valvular disease, or other cardiac pathology"
        
        return MeasurementResult(
            measurement_name="Cardiothoracic Ratio (CTR)",
            value=round(ctr, 3),
            interpretation=interpretation,
            confidence=confidence,
            method="heart_width / thoracic_width",
            units="ratio",
            reference_range="<0.50 (normal)",
            clinical_significance=clinical_significance
        )

    def _calculate_mediastinal_ratio(self, metrics: Dict[str, Any]) -> MeasurementResult:
        """Calculate mediastinal to thoracic width ratio."""
        
        required_organs = ["Mediastinum", "Left Lung", "Right Lung"]
        if not all(organ in metrics for organ in required_organs):
            raise ValueError(f"Mediastinal ratio calculation requires: {required_organs}")
        
        mediastinal_width = metrics["Mediastinum"]["width"]
        
        # Calculate thoracic width from lung boundaries
        left_lung_bbox = metrics["Left Lung"]["bbox"]
        right_lung_bbox = metrics["Right Lung"]["bbox"]
        thoracic_width = right_lung_bbox[3] - left_lung_bbox[1]  # rightmost - leftmost
        
        if thoracic_width <= 0:
            raise ValueError("Invalid thoracic width for mediastinal ratio calculation")
        
        mediastinal_ratio = mediastinal_width / thoracic_width
        
        # Clinical interpretation (typical threshold ~0.25)
        interpretation = "widened" if mediastinal_ratio > 0.25 else "normal"
        confidence = 0.85 if abs(mediastinal_ratio - 0.25) > 0.05 else 0.65
        
        clinical_significance = None
        if mediastinal_ratio > 0.25:
            clinical_significance = "Suggests mediastinal widening - may indicate lymphadenopathy, mass, or vascular pathology"
        
        return MeasurementResult(
            measurement_name="Mediastinal Width Ratio",
            value=round(mediastinal_ratio, 3),
            interpretation=interpretation,
            confidence=confidence,
            method="mediastinal_width / thoracic_width",
            units="ratio",
            reference_range="≤0.25 (normal)",
            clinical_significance=clinical_significance
        )

    def _calculate_tracheal_deviation(self, metrics: Dict[str, Any], pixel_spacing: float) -> MeasurementResult:
        """Calculate tracheal deviation in mm."""
        
        # Look for trachea or use weasand as proxy
        trachea_key = None
        if "Trachea" in metrics:
            trachea_key = "Trachea"
        elif "Weasand" in metrics:
            trachea_key = "Weasand"
        else:
            raise ValueError("Tracheal deviation calculation requires Trachea or Weasand segmentation")
        
        if not all(organ in metrics for organ in ["Left Lung", "Right Lung"]):
            raise ValueError("Tracheal deviation calculation requires Left Lung and Right Lung segmentation")
        
        trachea_metrics = metrics[trachea_key]
        left_lung_bbox = metrics["Left Lung"]["bbox"]
        right_lung_bbox = metrics["Right Lung"]["bbox"]
        
        # Calculate midline (thoracic center)
        thoracic_center_x = (left_lung_bbox[1] + right_lung_bbox[3]) / 2  # average of left and right boundaries
        
        # Get tracheal center
        trachea_center_x = trachea_metrics["centroid"][1]  # x-coordinate
        
        # Calculate deviation in pixels, then convert to mm
        deviation_px = trachea_center_x - thoracic_center_x
        deviation_mm = deviation_px * pixel_spacing
        
        # Clinical interpretation (>5mm is generally considered significant)
        abs_deviation = abs(deviation_mm)
        interpretation = "deviated" if abs_deviation > 5.0 else "midline"
        confidence = 0.8 if abs_deviation > 3.0 else 0.6
        
        clinical_significance = None
        if abs_deviation > 5.0:
            direction = "right" if deviation_mm > 0 else "left"
            clinical_significance = f"Tracheal deviation to the {direction} - may indicate mass effect, pneumothorax, or mediastinal shift"
        
        return MeasurementResult(
            measurement_name="Tracheal Deviation",
            value=round(deviation_mm, 1),
            interpretation=interpretation,
            confidence=confidence,
            method="trachea_center - thoracic_midline",
            units="mm",
            reference_range="±5mm (normal)",
            clinical_significance=clinical_significance
        )

    def _assess_chamber_position(self, metrics: Dict[str, Any]) -> MeasurementResult:
        """Assess cardiac chamber position (normal vs displaced)."""
        
        if "Heart" not in metrics:
            raise ValueError("Chamber position assessment requires Heart segmentation")
        
        heart_metrics = metrics["Heart"]
        relative_pos = heart_metrics["relative_position"]
        
        # Assess horizontal position (left-right)
        left_position = relative_pos["left"]
        
        if left_position > 0.6:  # Heart too far right
            position = "rightward_displaced"
            confidence = min(0.9, (left_position - 0.6) / 0.2)  # Higher confidence for more extreme displacement
            clinical_significance = "Rightward cardiac displacement - may indicate dextrocardia or mediastinal shift"
        elif left_position < 0.3:  # Heart too far left
            position = "leftward_displaced" 
            confidence = min(0.9, (0.3 - left_position) / 0.2)
            clinical_significance = "Leftward cardiac displacement - may indicate pneumothorax, mass effect, or rotation"
        else:
            position = "normal"
            confidence = 0.8
            clinical_significance = None
        
        # Convert position to numeric value for consistency
        position_value = left_position
        
        return MeasurementResult(
            measurement_name="Cardiac Position",
            value=round(position_value, 3),
            interpretation=position,
            confidence=confidence,
            method="heart_centroid_relative_position",
            units="relative position (0-1)",
            reference_range="0.3-0.6 (normal)",
            clinical_significance=clinical_significance
        )
