"""Visual Chain-of-Thought module with bias-resistant visual reasoning"""

from typing import Optional, Dict, Tuple, List, Any
import base64
from pathlib import Path
import uuid
import cv2
import numpy as np
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import ChatPromptTemplate
from dataclasses import dataclass
from enum import Enum


class VisualEvidenceType(Enum):
    """Types of visual evidence that can be analyzed"""
    SEGMENTATION_MASK = "segmentation_mask"
    BOUNDING_BOX = "bounding_box"
    CLASSIFICATION_RESULT = "classification_result"
    MEASUREMENT = "measurement"


@dataclass
class VisualEvidence:
    """Container for visual evidence from tools"""
    evidence_type: VisualEvidenceType
    tool_name: str
    confidence: float
    data: Any  # Can be bbox, mask, classification score, etc.
    description: str  # Human-readable description of what the tool claims
    visualization_path: Optional[str] = None  # Path to existing visualization if available


@dataclass
class VCoTResult:
    """Structured result from V-CoT analysis"""
    visual_observations: str
    evidence_validation: str
    tool_agreement: str  # "agrees", "disagrees", "uncertain"
    final_confidence: float  # Adjusted confidence after visual verification
    direct_visual_assessment: Optional[str] = None  # Direct visual answer to clinical question
    alternative_assessment: Optional[str] = None  # What to use if tool result is rejected
    requires_human_review: bool = False
    images_analyzed: List[str] = None
    full_reasoning: str = ""  # Complete reasoning for transparency


class BiasResistantVisualCoT:
    """
    Enhanced Visual Chain-of-Thought module that prioritizes visual evidence over prior knowledge.
    
    Implements a two-step process:
    1. Pure visual analysis without tool influence
    2. Tool evidence validation against visual observations
    
    This approach minimizes the risk of LLM relying on training knowledge 
    instead of actual visual evidence in the image.
    """
    
    def __init__(self, llm: BaseLanguageModel, temp_dir: str = "temp"):
        self.llm = llm
        self.temp_dir = Path(temp_dir)
        self.temp_dir.mkdir(exist_ok=True)
    
        # Create structured prompts for different reasoning steps
        self._create_reasoning_prompts()
    
    def _create_reasoning_prompts(self):
        """Create structured prompts for bias-resistant reasoning"""
        
        # Step 1: Pure visual analysis without tool influence
        self.visual_analysis_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a radiologist analyzing a chest X-ray image. Your task is to make observations based PURELY on visual evidence in the image and provide a DIRECT ANSWER to the clinical question.

CRITICAL INSTRUCTIONS:
1. DO NOT use any prior medical knowledge about "normal" positions or typical findings
2. DO NOT make assumptions about what "should" be there
3. Base your analysis ONLY on what you can actually see in the image
4. Provide a direct clinical answer to the specific question asked
5. Describe visual features using objective radiographic terms (density, borders, shapes, positions, etc.)
6. If you cannot clearly see a structure, say so explicitly

Your goal is to provide an independent clinical assessment that can be compared with automated tool results."""),
            
            ("human", """Clinical Question: {task}
Target anatomical structure: {target}

Please analyze this chest X-ray image and provide:

1. DIRECT CLINICAL ANSWER: Answer the specific clinical question based on what you observe
2. VISUAL OBSERVATIONS: Detailed visual description of what you see
3. ASSESSMENT CONFIDENCE: How confident are you in your assessment (0-100%)

For example, if asked "Is the trachea midline?":
- DIRECT ANSWER: "The trachea appears deviated to the right by approximately 1-2cm" OR "The trachea appears midline"
- VISUAL OBSERVATIONS: "I can see a tubular radiolucent structure representing the trachea. It appears to be positioned [describe position relative to chest midline]..."
- CONFIDENCE: 85%

Remember: Provide your independent clinical assessment based solely on visual evidence.""")
        ])
        
        # Step 2: Tool evidence validation with structured output
        self.evidence_validation_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are comparing your independent visual assessment with automated tool results. Your task is to determine if the tool's conclusion matches your clinical assessment.

CRITICAL INSTRUCTIONS:
1. Compare your direct clinical answer with the tool's conclusion
2. Determine if they agree, disagree, or if there's uncertainty
3. Provide a final confidence level based on your visual assessment
4. If there's disagreement, explain what the correct answer should be based on visual evidence

You must provide a structured assessment with:
- TOOL_AGREEMENT: "agrees", "disagrees", or "uncertain"
- FINAL_CONFIDENCE: A number between 0.0 and 1.0 based on your visual assessment
- DIRECT_VISUAL_ASSESSMENT: Your direct clinical answer (repeated for clarity)
- ALTERNATIVE_ASSESSMENT: If disagreement, what the correct clinical answer should be
- REQUIRES_REVIEW: Only if image quality prevents reliable assessment"""),
            
            ("human", """Your independent visual assessment:
{visual_observations}

Tool Results:
{tool_evidence}

Step 2: Compare Your Assessment with Tool Results
- Does the tool's conclusion match your direct clinical answer?
- Are there any discrepancies between your assessment and the tool's findings?
- What might explain any differences (tool limitations, measurement errors, etc.)?
- Based on your visual evidence, what is the correct clinical answer?

Provide your response in this format:
TOOL_AGREEMENT: [agrees/disagrees/uncertain]
FINAL_CONFIDENCE: [0.0-1.0]
DIRECT_VISUAL_ASSESSMENT: [Your direct clinical answer to the question]
ALTERNATIVE_ASSESSMENT: [If disagreement, what the correct answer should be]
REQUIRES_REVIEW: [yes/no - only if image quality prevents assessment]

Example for tracheal position:
TOOL_AGREEMENT: disagrees
FINAL_CONFIDENCE: 0.8
DIRECT_VISUAL_ASSESSMENT: The trachea appears deviated to the right by approximately 1-2cm from midline
ALTERNATIVE_ASSESSMENT: Right tracheal deviation is present, contrary to tool's assessment of midline position
REQUIRES_REVIEW: no
""")
        ])
    
    def _create_roi_from_existing_viz(self, existing_viz_path: str, bbox: Tuple[int, int, int, int]) -> str:
        """Create ROI from existing visualization instead of original image"""
        try:
            img = cv2.imread(existing_viz_path)
            if img is None:
                return existing_viz_path  # Fall back to existing viz
                
            x1, y1, x2, y2 = bbox
            
            # Add padding for context
            pad = 50
            x1_pad = max(0, x1 - pad)
            y1_pad = max(0, y1 - pad)
            x2_pad = min(img.shape[1], x2 + pad)
            y2_pad = min(img.shape[0], y2 + pad)
            
            # Crop with highlighting
            roi = img[y1_pad:y2_pad, x1_pad:x2_pad].copy()
            
            # Draw additional rectangle around ROI
            cv2.rectangle(roi, (x1-x1_pad, y1-y1_pad), (x2-x1_pad, y2-y1_pad), (0, 255, 0), 3)
            
            # Save ROI
            roi_path = self.temp_dir / f"roi_from_viz_{uuid.uuid4().hex[:8]}.png"
            cv2.imwrite(str(roi_path), roi)
            return str(roi_path)
        except Exception as e:
            print(f"Warning: Could not create ROI from existing viz: {e}")
            return existing_viz_path  # Fall back to existing viz
    
    def _format_tool_evidence(self, evidence_list: List[VisualEvidence]) -> str:
        """Format tool evidence for prompt"""
        if not evidence_list:
            return "No tool evidence provided."
        
        formatted_evidence = []
        for evidence in evidence_list:
            formatted_evidence.append(f"""
Tool: {evidence.tool_name}
Type: {evidence.evidence_type.value}
Confidence: {evidence.confidence:.2f}
Claim: {evidence.description}
Data: {evidence.data}
Existing Visualization: {evidence.visualization_path or "None"}
""")
        
        return "\n".join(formatted_evidence)
    
    def generate(
        self,
        image_path: str,
        task: str,
        target: str,
        evidence_list: List[VisualEvidence],
        roi_bbox: Optional[Tuple[int, int, int, int]] = None,
        measurements: Optional[Dict[str, float]] = None
    ) -> VCoTResult:
        """
        Generate bias-resistant visual reasoning with structured output.
        
        Args:
            image_path: Path to the chest X-ray image
            task: Description of the analysis task
            target: Target anatomical structure or finding
            evidence_list: List of visual evidence from tools
            roi_bbox: Optional ROI bounding box
            measurements: Optional measurements
            
        Returns:
            VCoTResult containing structured analysis results
        """
        
        try:
            # CRITICAL: Use original image for V-CoT analysis, not highlighted segmentation
            # V-CoT should provide independent visual assessment, not validate highlighted areas
            images_to_analyze = [image_path]  # Start with original image
            
            # Only add ROI if we have a specific bounding box to focus on
            # Do NOT add existing segmentation visualizations as they bias the analysis
            if roi_bbox:
                try:
                    # Create ROI from original image, not from existing visualization
                    roi_path = self._create_roi_image(image_path, roi_bbox)
                    images_to_analyze.append(roi_path)
                except Exception as e:
                    print(f"Warning: Could not create ROI image: {e}")
            
            print(f"INFO: V-CoT analyzing original image independently (not using highlighted segmentation)")
            
            # Step 1: Pure visual analysis of original image
            visual_observations = self._perform_visual_analysis(
                images_to_analyze, task, target
            )
            
            # Step 2: Tool evidence validation
            evidence_validation = self._validate_tool_evidence(
                visual_observations, evidence_list
            )
            
            # Parse structured validation result
            parsed_validation = self._parse_validation_result(evidence_validation)
            
            # Create structured result
            result = VCoTResult(
                visual_observations=visual_observations,
                evidence_validation=evidence_validation,
                tool_agreement=parsed_validation.get("tool_agreement", "uncertain"),
                final_confidence=parsed_validation.get("final_confidence", 0.5),
                direct_visual_assessment=parsed_validation.get("direct_visual_assessment"),
                alternative_assessment=parsed_validation.get("alternative_assessment"),
                requires_human_review=parsed_validation.get("requires_review", False),
                images_analyzed=images_to_analyze,
                full_reasoning=self._create_full_reasoning(
                    visual_observations, evidence_validation, measurements
                )
            )
            
            return result
            
        except Exception as e:
            print(f"Error in bias-resistant V-CoT: {e}")
            return VCoTResult(
                visual_observations=f"Visual analysis failed: {e}",
                evidence_validation=f"Evidence validation failed: {e}",
                tool_agreement="uncertain",
                final_confidence=0.0,
                alternative_assessment=f"Tool evidence: {self._format_tool_evidence(evidence_list)}",
                requires_human_review=True,
                images_analyzed=[image_path],
                full_reasoning=f"V-CoT failed with error: {e}"
            )
    
    def _create_roi_image(self, image_path: str, bbox: Tuple[int, int, int, int], 
                         highlight_color: Tuple[int, int, int] = (0, 255, 0)) -> str:
        """Create cropped ROI image with context (fallback when no existing viz)"""
        img = cv2.imread(image_path)
        if img is None:
            raise ValueError(f"Could not load image: {image_path}")
            
        x1, y1, x2, y2 = bbox
        
        # Add padding for context
        pad = 50
        x1_pad = max(0, x1 - pad)
        y1_pad = max(0, y1 - pad)
        x2_pad = min(img.shape[1], x2 + pad)
        y2_pad = min(img.shape[0], y2 + pad)
        
        # Crop and highlight
        roi = img[y1_pad:y2_pad, x1_pad:x2_pad].copy()
        
        # Draw rectangle around ROI
        cv2.rectangle(roi, (x1-x1_pad, y1-y1_pad), (x2-x1_pad, y2-y1_pad), highlight_color, 3)
        
        # Save ROI
        roi_path = self.temp_dir / f"roi_{uuid.uuid4().hex[:8]}.png"
        cv2.imwrite(str(roi_path), roi)
        return str(roi_path)
    
    def _perform_visual_analysis(self, images: List[str], task: str, target: str) -> str:
        """Perform Step 1: Pure visual analysis"""
        try:
            # Prepare multimodal input
            messages = []
            for img_path in images:
                with open(img_path, "rb") as img_file:
                    img_b64 = base64.b64encode(img_file.read()).decode()
                    messages.append({
                        "type": "image_url",
                        "image_url": {"url": f"data:image/png;base64,{img_b64}"}
                    })
            
            # Generate visual analysis
            response = self.visual_analysis_prompt.invoke({
                "task": task,
                "target": target
            })
            
            # Add multimodal content
            full_messages = [
                {"role": "system", "content": response.messages[0].content},
                {"role": "user", "content": messages + [{"type": "text", "text": response.messages[1].content}]}
            ]
            
            result = self.llm.invoke(full_messages)
            return result.content if hasattr(result, 'content') else str(result)
            
        except Exception as e:
            return f"Visual analysis failed: {e}"
    
    def _validate_tool_evidence(self, visual_observations: str, evidence_list: List[VisualEvidence]) -> str:
        """Perform Step 2: Tool evidence validation"""
        try:
            tool_evidence = self._format_tool_evidence(evidence_list)
            
            response = self.evidence_validation_prompt.invoke({
                "visual_observations": visual_observations,
                "tool_evidence": tool_evidence
            })
            
            result = self.llm.invoke([
                {"role": "system", "content": response.messages[0].content},
                {"role": "user", "content": response.messages[1].content}
            ])
            
            return result.content if hasattr(result, 'content') else str(result)
            
        except Exception as e:
            return f"Evidence validation failed: {e}"
    
    def _parse_validation_result(self, validation_text: str) -> Dict[str, Any]:
        """Parse structured validation result"""
        result = {
            "tool_agreement": "uncertain",
            "final_confidence": 0.5,
            "alternative_assessment": None,
            "requires_review": False
        }
        
        lines = validation_text.split('\n')
        for line in lines:
            line = line.strip()
            if line.startswith('TOOL_AGREEMENT:'):
                agreement = line.split(':', 1)[1].strip().lower()
                if agreement in ['agrees', 'disagrees', 'uncertain']:
                    result["tool_agreement"] = agreement
            elif line.startswith('FINAL_CONFIDENCE:'):
                try:
                    confidence = float(line.split(':', 1)[1].strip())
                    result["final_confidence"] = max(0.0, min(1.0, confidence))
                except ValueError:
                    pass
            elif line.startswith('DIRECT_VISUAL_ASSESSMENT:'):
                result["direct_visual_assessment"] = line.split(':', 1)[1].strip()
            elif line.startswith('ALTERNATIVE_ASSESSMENT:'):
                alt = line.split(':', 1)[1].strip()
                if alt and alt.lower() != 'none':
                    result["alternative_assessment"] = alt
            elif line.startswith('REQUIRES_REVIEW:'):
                review = line.split(':', 1)[1].strip().lower()
                result["requires_review"] = review in ['yes', 'true']
        
        return result
    
    def _create_full_reasoning(self, visual_obs: str, evidence_val: str, 
                                measurements: Optional[Dict[str, float]]) -> str:
        """Create complete reasoning for transparency"""
        measurement_str = f"Measurements: {measurements}" if measurements else "No measurements available"
        
        return f"""
BIAS-RESISTANT V-CoT ANALYSIS:

Step 1 - Visual Observations:
{visual_obs}

Step 2 - Evidence Validation:
{evidence_val}

{measurement_str}

Note: This assessment prioritizes visual evidence over tool confidence scores and prior knowledge.
"""


# Backward compatibility - keep the original class name as an alias
class VisualCoT(BiasResistantVisualCoT):
    """Backward compatibility alias"""
    
    def generate(
        self,
        image_path: str,
        roi_bbox: Optional[Tuple[int, int, int, int]] = None,
        measurements: Optional[Dict[str, float]] = None,
        task: str = "Visual analysis",
        confidence: float = 0.5
    ) -> str:
        """
        Backward compatibility method that adapts to new interface.
        
        This method maintains the old interface while using the new bias-resistant approach.
        """
        
        # Extract target from task if possible
        target = "anatomical structure or finding"
        if "trachea" in task.lower():
            target = "trachea"
        elif "heart" in task.lower() or "cardiac" in task.lower():
            target = "heart and cardiac silhouette"
        elif "lung" in task.lower() or "breathing" in task.lower():
            target = "lung fields"
        elif "diaphragm" in task.lower():
            target = "diaphragm"
        
        # Create evidence list from old parameters
        evidence_list = []
        if roi_bbox:
            evidence_list.append(VisualEvidence(
                evidence_type=VisualEvidenceType.BOUNDING_BOX,
                tool_name="grounding_tool",
                confidence=confidence,
                data=roi_bbox,
                description=f"Tool located {target} at bounding box coordinates"
            ))
        
        if measurements:
            evidence_list.append(VisualEvidence(
                evidence_type=VisualEvidenceType.MEASUREMENT,
                tool_name="measurement_tool",
                confidence=confidence,
                data=measurements,
                description=f"Tool measured: {measurements}"
            ))
        
        # Use new method
        result = super().generate(
            image_path=image_path,
            task=task,
            target=target,
            evidence_list=evidence_list,
            roi_bbox=roi_bbox,
            measurements=measurements
        )
        
        # Return final assessment for backward compatibility
        return result.full_reasoning 