"""Enhanced orchestrator with intelligent synthesis and confidence calibration"""

from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from enum import Enum
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models import BaseLanguageModel

from ..state import MultiAgentState, Finding, AgentAnalysis


class QueryType(str, Enum):
    """Types of medical queries"""
    DIAGNOSTIC = "diagnostic"  # Yes/No questions
    MEASUREMENT = "measurement"  # Quantitative measurements
    COMPARISON = "comparison"  # Compare with prior
    COMPREHENSIVE = "comprehensive"  # Full assessment
    LOCALIZATION = "localization"  # Where is X?


class QueryAnalysisResult(BaseModel):
    """Enhanced query analysis with priorities and dependencies"""
    primary_agents: List[str] = Field(description="Must-have agents for this query")
    secondary_agents: List[str] = Field(description="Optional agents based on findings")
    agent_priorities: Dict[str, float] = Field(description="Priority scores (0-1) for each agent")
    execution_order: List[str] = Field(description="Optimal execution sequence")
    query_type: QueryType = Field(description="Type of medical query")
    expected_outputs: List[str] = Field(description="Expected findings or measurements")
    confidence_threshold: float = Field(description="Minimum confidence for reliable answer")
    requires_comparison: bool = Field(description="Whether prior image comparison needed")


@dataclass
class ConfidenceFactors:
    """Factors contributing to overall confidence"""
    base: float
    consistency: float
    complexity: float
    completeness: float
    overall: float


@dataclass
class OrchestratorResult:
    """Enhanced result from orchestrator"""
    query: str
    query_analysis: QueryAnalysisResult
    activated_agents: List[str]
    agent_results: Dict[str, Any]
    final_answer: str
    confidence: float
    confidence_factors: ConfidenceFactors
    needs_review: bool
    synthesis_reasoning: str


class EnhancedQueryAnalyzer:
    """Enhanced query analyzer with structured output"""
    
    def __init__(self, llm: BaseLanguageModel):
        self.llm = llm
        self.analyzer_chain = self._create_analyzer_chain()
    
    def _create_analyzer_chain(self):
        """Create LLM chain for structured query analysis"""
        prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a medical imaging query analyzer. Analyze the query and determine:

1. Primary agents needed (must execute to answer the query)
2. Secondary agents (might be needed based on findings)
3. Priority scores for each agent (0-1, higher = more important)
4. Optimal execution order
5. Query type (diagnostic/measurement/comparison/comprehensive/localization)
6. Expected outputs (what findings or measurements are needed)
7. Minimum confidence threshold for reliable answer

Available agents and their capabilities:

AIRWAY Agent:
- Tracheal position (midline vs deviation)
- Carina and bronchi assessment
- Paratracheal/mediastinal masses or lymphadenopathy

BREATHING Agent:
- Lung opacities (alveolar, interstitial, nodular)
- Infiltrates, consolidation patterns
- Pulmonary vessels (redistribution, cephalization)
- Pleural effusion or pneumothorax

CARDIAC Agent:
- Heart size and CTR calculation
- Chamber enlargement or contour asymmetry
- Mediastinal widening or masses
- Cardiac devices (pacemaker, valves, stents)
- Calcifications

DIAPHRAGM Agent:
- Hemidiaphragm height comparison
- Free air/pneumoperitoneum
- Costophrenic angle assessment

EVERYTHING ELSE Agent:
- Bone fractures (ribs, clavicles, scapulae, spine)
- Soft tissue abnormalities
- Lines and tubes (ETT, central lines, chest tubes)
- Foreign bodies

Create a plan that:
1. Addresses the specific query requirements
2. Uses appropriate agents in logical order
3. Is efficient (no redundant agents)
4. Identifies primary vs secondary agents

Be liberal in interpretation - if unsure, include the agent."""),
            ("human", "Query: {query}")
        ])
        
        return prompt | self.llm.with_structured_output(QueryAnalysisResult)
    
    def analyze_query(self, query: str) -> QueryAnalysisResult:
        """Analyze query with structured output"""
        try:
            return self.analyzer_chain.invoke({"query": query})
        except Exception as e:
            # Fallback to rule-based analysis
            return self._fallback_analysis(query)
    
    def _fallback_analysis(self, query: str) -> QueryAnalysisResult:
        """Rule-based fallback analysis"""
        query_lower = query.lower()
        
        # Determine query type
        if any(word in query_lower for word in ["is there", "do you see", "any"]):
            query_type = QueryType.DIAGNOSTIC
        elif any(word in query_lower for word in ["measure", "calculate", "ctr", "ratio"]):
            query_type = QueryType.MEASUREMENT
        elif "compare" in query_lower or "prior" in query_lower:
            query_type = QueryType.COMPARISON
        elif any(word in query_lower for word in ["full", "complete", "comprehensive"]):
            query_type = QueryType.COMPREHENSIVE
        else:
            query_type = QueryType.DIAGNOSTIC
        
        # Determine primary agents
        primary_agents = []
        if any(word in query_lower for word in ["heart", "cardiac", "cardiomegaly", "ctr"]):
            primary_agents.append("cardiac")
        if any(word in query_lower for word in ["lung", "pneumonia", "breathing"]):
            primary_agents.append("breathing")
        if any(word in query_lower for word in ["trachea", "airway", "mediastinum"]):
            primary_agents.append("airway")
        if any(word in query_lower for word in ["diaphragm", "pneumoperitoneum"]):
            primary_agents.append("diaphragm")
        if any(word in query_lower for word in ["fracture", "tube", "line", "device"]):
            primary_agents.append("everything")
            
        # Full analysis
        if "full" in query_lower or "complete" in query_lower:
            primary_agents = ["airway", "breathing", "cardiac", "diaphragm", "everything"]
        
        return QueryAnalysisResult(
            primary_agents=primary_agents or ["cardiac", "breathing"],
            secondary_agents=[],
            agent_priorities={agent: 1.0 for agent in primary_agents},
            execution_order=primary_agents,
            query_type=query_type,
            expected_outputs=self._extract_expected_outputs(query_lower),
            confidence_threshold=0.7,
            requires_comparison="compare" in query_lower
        )
    
    def _extract_expected_outputs(self, query: str) -> List[str]:
        """Extract expected outputs from query"""
        outputs = []
        if "cardiomegaly" in query:
            outputs.append("cardiomegaly")
        if "ctr" in query:
            outputs.append("ctr_measurement")
        if "pneumonia" in query:
            outputs.append("pneumonia")
        if "pneumothorax" in query:
            outputs.append("pneumothorax")
        if "fracture" in query:
            outputs.append("fracture")
        return outputs


class IntelligentSynthesizer:
    """Context-aware result synthesizer"""
    
    def __init__(self, llm: BaseLanguageModel):
        self.llm = llm
        self.synthesis_strategies = {
            QueryType.DIAGNOSTIC: self._diagnostic_synthesis,
            QueryType.MEASUREMENT: self._measurement_synthesis,
            QueryType.COMPARISON: self._comparison_synthesis,
            QueryType.COMPREHENSIVE: self._comprehensive_synthesis,
            QueryType.LOCALIZATION: self._localization_synthesis
        }
    
    def synthesize(self, query: str, agent_results: Dict, 
                  query_analysis: QueryAnalysisResult) -> Tuple[str, str]:
        """Synthesize final answer based on query type"""
        
        # Select synthesis strategy
        strategy = self.synthesis_strategies.get(
            query_analysis.query_type,
            self._default_synthesis
        )
        
        # Apply strategy
        answer = strategy(query, agent_results, query_analysis)
        
        # Generate reasoning
        reasoning = self._generate_synthesis_reasoning(
            query_analysis.query_type, 
            agent_results
        )
        
        return answer, reasoning
    
    def _diagnostic_synthesis(self, query: str, results: Dict, 
                            analysis: QueryAnalysisResult) -> str:
        """Synthesis for Yes/No diagnostic queries"""
        findings = self._extract_relevant_findings(results, analysis.expected_outputs)
        
        if findings:
            # Positive finding
            primary_finding = findings[0]
            answer = f"Yes, {primary_finding['pathology']} is present "
            answer += f"with {primary_finding['confidence']:.0%} confidence."
            
            # Add supporting evidence
            if primary_finding.get('measurements'):
                for key, value in primary_finding['measurements'].items():
                    answer += f" {key.upper()}: {value:.3f}."
            
            if primary_finding.get('evidence'):
                answer += f" {primary_finding['evidence']}"
                    
        else:
            # Negative finding
            expected = " or ".join(analysis.expected_outputs) if analysis.expected_outputs else "requested pathology"
            answer = f"No evidence of {expected} detected in the image."
        
        return answer
    
    def _measurement_synthesis(self, query: str, results: Dict, 
                             analysis: QueryAnalysisResult) -> str:
        """Synthesis for measurement queries"""
        measurements = self._extract_measurements(results)
        
        if not measurements:
            return "Unable to calculate the requested measurements from the image."
        
        answer_parts = []
        for measure_type, value in measurements.items():
            if measure_type == "ctr":
                interpretation = "enlarged" if value > 0.5 else "normal"
                answer_parts.append(
                    f"The cardiothoracic ratio (CTR) measures {value:.3f} ({interpretation})"
                )
            else:
                answer_parts.append(f"{measure_type}: {value:.3f}")
        
        return ". ".join(answer_parts) + "."
    
    def _comparison_synthesis(self, query: str, results: Dict, 
                            analysis: QueryAnalysisResult) -> str:
        """Synthesis for comparison queries"""
        # Note: Full comparison would need prior image results
        if "comparison_results" in results and results["comparison_results"]:
            return f"Comparison findings: {results['comparison_results']}"
        
        return "Comparison with prior imaging requires access to previous studies. " \
               "Current findings: " + self._comprehensive_synthesis(query, results, analysis)
    
    def _comprehensive_synthesis(self, query: str, results: Dict, 
                               analysis: QueryAnalysisResult) -> str:
        """Synthesis for comprehensive assessment"""
        all_findings = []
        
        # Process each agent's results
        agent_order = ["airway", "breathing", "cardiac", "diaphragm", "everything"]
        
        for agent_name in agent_order:
            analysis_key = f"{agent_name}_analysis"
            if analysis_key not in results or not results[analysis_key]:
                continue
                
            agent_result = results[analysis_key]
            if isinstance(agent_result, dict) and 'findings' in agent_result:
                findings = agent_result['findings']
            elif hasattr(agent_result, 'findings'):
                findings = agent_result.findings
            else:
                continue
                
            for finding in findings:
                if isinstance(finding, dict):
                    conf = finding.get('confidence', 0)
                    if conf > 0.3:  # Only include findings above threshold
                        all_findings.append(
                            f"{finding['pathology'].replace('_', ' ')} ({conf:.0%})"
                        )
        
        if all_findings:
            # Group findings by system
            return f"Chest X-ray findings: {', '.join(all_findings)}. " \
                   f"Total {len(all_findings)} abnormalities detected."
        else:
            return "No significant abnormalities detected in the chest X-ray."
    
    def _localization_synthesis(self, query: str, results: Dict, 
                              analysis: QueryAnalysisResult) -> str:
        """Synthesis for localization queries"""
        return self._diagnostic_synthesis(query, results, analysis)
    
    def _default_synthesis(self, query: str, results: Dict, 
                         analysis: QueryAnalysisResult) -> str:
        """Default synthesis strategy"""
        return self._comprehensive_synthesis(query, results, analysis)
    
    def _extract_relevant_findings(self, results: Dict, 
                                 expected_outputs: List[str]) -> List[Dict]:
        """Extract findings matching expected outputs"""
        relevant = []
        
        for key, agent_result in results.items():
            if not key.endswith('_analysis'):
                continue
                
            if isinstance(agent_result, dict) and 'findings' in agent_result:
                findings = agent_result['findings']
            elif hasattr(agent_result, 'findings'):
                findings = agent_result.findings
            else:
                continue
            
            for finding in findings:
                if isinstance(finding, dict):
                    pathology = finding.get('pathology', '')
                    if not expected_outputs or pathology in expected_outputs:
                        relevant.append(finding)
        
        return sorted(relevant, key=lambda x: x.get('confidence', 0), reverse=True)
    
    def _extract_measurements(self, results: Dict) -> Dict[str, float]:
        """Extract all measurements from results"""
        measurements = {}
        
        for key, agent_result in results.items():
            if not key.endswith('_analysis'):
                continue
                
            if isinstance(agent_result, dict) and 'findings' in agent_result:
                findings = agent_result['findings']
            elif hasattr(agent_result, 'findings'):
                findings = agent_result.findings
            else:
                continue
            
            for finding in findings:
                if isinstance(finding, dict) and finding.get('measurements'):
                    measurements.update(finding['measurements'])
        
        return measurements
    
    def _generate_synthesis_reasoning(self, query_type: QueryType, 
                                    results: Dict) -> str:
        """Generate reasoning for synthesis approach"""
        agent_count = sum(1 for k in results.keys() if k.endswith('_analysis') and results[k])
        finding_count = 0
        
        for key, result in results.items():
            if not key.endswith('_analysis'):
                continue
            if isinstance(result, dict) and 'findings' in result:
                finding_count += len(result['findings'])
        
        return (f"Used {query_type.value} synthesis strategy. "
                f"Analyzed {agent_count} agent(s) with {finding_count} total findings.")


class ConfidenceCalibrator:
    """Calibrated confidence calculation"""
    
    def calculate_calibrated_confidence(self, findings: List[Finding], 
                                      query_type: QueryType) -> ConfidenceFactors:
        """Calculate calibrated confidence with multiple factors"""
        
        if not findings:
            # High confidence in negative findings
            return ConfidenceFactors(
                base=1.0,
                consistency=1.0,
                complexity=1.0,
                completeness=1.0,
                overall=1.0
            )
        
        # Base confidence from findings
        base_confidence = self._weighted_confidence(findings)
        
        # Check consistency between findings
        consistency_factor = self._check_finding_consistency(findings)
        
        # Adjust for query complexity
        complexity_factor = self._query_complexity_factor(query_type)
        
        # Check completeness
        completeness_factor = self._check_completeness(findings)
        
        # Calculate overall
        overall = (base_confidence * consistency_factor * 
                  complexity_factor * completeness_factor)
        
        return ConfidenceFactors(
            base=base_confidence,
            consistency=consistency_factor,
            complexity=complexity_factor,
            completeness=completeness_factor,
            overall=overall
        )
    
    def _weighted_confidence(self, findings: List[Finding]) -> float:
        """Calculate weighted average confidence"""
        if not findings:
            return 1.0
            
        confidences = []
        for finding in findings:
            if isinstance(finding, dict):
                confidences.append(finding.get('confidence', 0))
                
        return sum(confidences) / len(confidences) if confidences else 0
    
    def _check_finding_consistency(self, findings: List[Finding]) -> float:
        """Check if findings are consistent with each other"""
        if len(findings) <= 1:
            return 1.0
            
        # Simple consistency check - can be enhanced
        return 0.95
    
    def _query_complexity_factor(self, query_type: QueryType) -> float:
        """Adjust for query complexity"""
        complexity_factors = {
            QueryType.DIAGNOSTIC: 1.0,
            QueryType.MEASUREMENT: 0.95,
            QueryType.COMPARISON: 0.85,
            QueryType.COMPREHENSIVE: 0.9,
            QueryType.LOCALIZATION: 0.95
        }
        return complexity_factors.get(query_type, 0.9)
    
    def _check_completeness(self, findings: List[Finding]) -> float:
        """Check if all expected findings are present"""
        # Simple completeness check
        return 0.95 if findings else 0.8


class EnhancedOrchestrator:
    """Enhanced orchestrator that coordinates agents and provides final answers"""
    
    def __init__(self, llm: BaseLanguageModel, agents: Dict[str, Any]):
        self.llm = llm
        self.agents = agents
        self.query_analyzer = EnhancedQueryAnalyzer(llm)
        self.synthesizer = IntelligentSynthesizer(llm)
        self.confidence_calibrator = ConfidenceCalibrator()
    
    def execute(self, query: str, image_path: str, 
                prior_image_path: Optional[str] = None) -> OrchestratorResult:
        """Execute the query and return enhanced result with final answer"""
        
        # Step 1: Analyze query
        print(f"\n{'='*60}")
        print(f"ORCHESTRATOR: Analyzing query...")
        print(f"Query: {query}")
        
        query_analysis = self.query_analyzer.analyze_query(query)
        print(f"\nQuery Analysis:")
        print(f"  Type: {query_analysis.query_type}")
        print(f"  Primary agents: {query_analysis.primary_agents}")
        print(f"  Expected outputs: {query_analysis.expected_outputs}")
        
        # Step 2: Create initial state
        state = MultiAgentState(
            image_path=image_path,
            query=query,
            prior_image_path=prior_image_path,
            messages=[],
            current_step="analysis",
            completed_agents=[],
            active_agents=query_analysis.primary_agents,
            need_comparison=query_analysis.requires_comparison,
            execution_mode="sequential"  # Default to sequential for safety
        )
        
        # Step 3: Execute agents
        agent_results = {}
        all_findings = []
        
        for agent_name in query_analysis.primary_agents:
            if agent_name in self.agents:
                print(f"\n{'='*60}")
                print(f"ORCHESTRATOR: Executing {agent_name} agent...")
                
                try:
                    agent = self.agents[agent_name]
                    updated_state = agent.analyze(state.copy())
                    
                    # Get agent's analysis
                    analysis_key = f"{agent_name}_analysis"
                    if analysis_key in updated_state:
                        agent_results[analysis_key] = updated_state[analysis_key]
                        
                        # Extract findings
                        if isinstance(updated_state[analysis_key], dict):
                            findings = updated_state[analysis_key].get('findings', [])
                            all_findings.extend(findings)
                except Exception as e:
                    print(f"Error in {agent_name} agent: {e}")
                    agent_results[analysis_key] = None
        
        # Step 4: Synthesize final answer
        print(f"\n{'='*60}")
        print(f"ORCHESTRATOR: Synthesizing final answer...")
        
        final_answer, reasoning = self.synthesizer.synthesize(
            query, agent_results, query_analysis
        )
        
        # Step 5: Calculate confidence
        confidence_factors = self.confidence_calibrator.calculate_calibrated_confidence(
            all_findings, query_analysis.query_type
        )
        
        # Step 6: Determine review needs
        needs_review = self._needs_clinical_review(
            confidence_factors.overall,
            query_analysis.confidence_threshold,
            all_findings
        )
        
        return OrchestratorResult(
            query=query,
            query_analysis=query_analysis,
            activated_agents=query_analysis.primary_agents,
            agent_results=agent_results,
            final_answer=final_answer,
            confidence=confidence_factors.overall,
            confidence_factors=confidence_factors,
            needs_review=needs_review,
            synthesis_reasoning=reasoning
        )
    
    def _needs_clinical_review(self, confidence: float, threshold: float, 
                             findings: List[Finding]) -> bool:
        """Determine if clinical review needed"""
        # Review needed if:
        # 1. Overall confidence below threshold
        # 2. Any finding in borderline range
        # 3. Conflicting findings
        
        if confidence < threshold:
            return True
            
        for finding in findings:
            if isinstance(finding, dict):
                f_conf = finding.get('confidence', 0)
                if 0.3 < f_conf < 0.7:
                    return True
                    
        return False
