"""Synthesis module for combining agent results into coherent answers"""

from typing import Dict, List, Optional, Tuple, Any
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models import BaseLanguageModel

from .semantic_query_analyzer import QueryIntent, SemanticQueryAnalysis
from ..state import Finding, AgentAnalysis


class SynthesisStrategy(BaseModel):
    """Strategy for synthesizing results based on query type"""
    focus_agents: List[str] = Field(description="Agents to prioritize in synthesis")
    key_metrics: List[str] = Field(description="Key measurements to highlight")
    synthesis_approach: str = Field(description="How to structure the answer")


class SynthesisResult(BaseModel):
    """Result from the synthesis process"""
    answer: str = Field(description="Direct answer to the user's query")
    confidence: float = Field(description="Overall confidence in the answer (0-1)")
    supporting_evidence: List[str] = Field(description="Key evidence supporting the answer")
    caveats: List[str] = Field(description="Important limitations or caveats")
    requires_review: bool = Field(description="Whether clinical review is recommended")


class IntelligentSynthesizer:
    """Synthesizes agent results into coherent answers using LLM"""
    
    def __init__(self, llm: BaseLanguageModel):
        self.llm = llm
        self.strategy_planner = self._create_strategy_planner()
        self.answer_generator = self._create_answer_generator()
    
    def _create_strategy_planner(self):
        """Create LLM chain for planning synthesis strategy"""
        prompt = ChatPromptTemplate.from_messages([
            ("system", """You are planning how to synthesize medical findings into a clear answer.
            
Based on the query intent and available findings, determine:
1. Which agents' findings to prioritize
2. Key metrics to highlight
3. How to structure the answer

Consider the query intent:
- DIAGNOSTIC: Focus on yes/no with confidence
- MEASUREMENT: Highlight specific values
- COMPARISON: Emphasize changes
- LOCALIZATION: Describe locations
- COMPREHENSIVE: Systematic review
- CHARACTERIZATION: Detailed description"""),
            ("human", "Query intent: {query_intent}\nAvailable agents: {available_agents}")
        ])
        
        return prompt | self.llm.with_structured_output(SynthesisStrategy)
    
    def _create_answer_generator(self):
        """Create LLM chain for generating final answers"""
        prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a radiologist providing clear, clinically relevant answers.

Given the query, findings, and synthesis strategy, generate:
1. A direct answer to the query
2. Overall confidence level
3. Supporting evidence
4. Important caveats or limitations
5. Whether clinical review is needed

Guidelines:
- Be concise but complete
- Use medical terminology appropriately
- Include specific measurements when relevant
- Flag uncertain findings
- Consider clinical implications"""),
            ("human", """Query: {query}
Query Intent: {query_intent}
Clinical Context: {clinical_context}

Agent Findings:
{agent_findings}

Synthesis Strategy:
{synthesis_strategy}

Generate a comprehensive answer.""")
        ])
        
        return prompt | self.llm.with_structured_output(SynthesisResult)
    
    def synthesize(
        self, 
        query: str,
        query_analysis: SemanticQueryAnalysis,
        agent_results: Dict[str, AgentAnalysis]
    ) -> SynthesisResult:
        """Synthesize agent results into a final answer"""
        
        # Extract available agents and format findings
        available_agents = list(agent_results.keys())
        agent_findings = self._format_agent_findings(agent_results)
        
        # Plan synthesis strategy
        try:
            strategy = self.strategy_planner.invoke({
                "query_intent": query_analysis.query_intent.value,
                "available_agents": ", ".join(available_agents)
            })
        except Exception as e:
            print(f"Error planning synthesis strategy: {e}")
            strategy = SynthesisStrategy(
                focus_agents=available_agents,
                key_metrics=["all findings"],
                synthesis_approach="comprehensive review"
            )
        
        # Generate final answer
        try:
            result = self.answer_generator.invoke({
                "query": query,
                "query_intent": query_analysis.query_intent.value,
                "clinical_context": query_analysis.clinical_context,
                "agent_findings": agent_findings,
                "synthesis_strategy": strategy.dict()
            })
            
            return result
            
        except Exception as e:
            print(f"Error generating answer: {e}")
            # Fallback to simple synthesis
            return self._fallback_synthesis(query, query_analysis, agent_results)
    
    def _format_agent_findings(self, agent_results: Dict[str, AgentAnalysis]) -> str:
        """Format agent findings for LLM consumption"""
        formatted = []
        
        for agent_name, analysis in agent_results.items():
            if not analysis:
                continue
                
            agent_section = f"\n{agent_name.upper()} AGENT:"
            
            # Add findings
            if hasattr(analysis, 'findings') and analysis.findings:
                for finding in analysis.findings:
                    if isinstance(finding, dict):
                        pathology = finding.get('pathology', 'unknown')
                        confidence = finding.get('confidence', 0)
                        evidence = finding.get('evidence', '')
                        
                        finding_text = f"- {pathology}: {confidence:.0%} confidence"
                        
                        # Add measurements if available
                        if finding.get('measurements'):
                            measurements = ", ".join([
                                f"{k}={v:.3f}" 
                                for k, v in finding['measurements'].items()
                            ])
                            finding_text += f" ({measurements})"
                        
                        if evidence:
                            finding_text += f"\n  Evidence: {evidence}"
                        
                        formatted.append(finding_text)
            else:
                formatted.append("- No significant findings")
            
            agent_section += "\n" + "\n".join(formatted)
            formatted = [agent_section]
        
        return "\n".join(formatted) if formatted else "No findings available"
    
    def _fallback_synthesis(
        self,
        query: str,
        query_analysis: SemanticQueryAnalysis,
        agent_results: Dict[str, AgentAnalysis]
    ) -> SynthesisResult:
        """Simple fallback synthesis if LLM fails"""
        
        # Collect all findings
        all_findings = []
        for analysis in agent_results.values():
            if hasattr(analysis, 'findings'):
                all_findings.extend(analysis.findings)
        
        # Generate simple answer based on query intent
        if query_analysis.query_intent == QueryIntent.DIAGNOSTIC:
            if all_findings:
                answer = f"Yes, abnormalities detected: {len(all_findings)} findings."
            else:
                answer = "No significant abnormalities detected."
                
        elif query_analysis.query_intent == QueryIntent.MEASUREMENT:
            measurements = []
            for f in all_findings:
                if isinstance(f, dict) and f.get('measurements'):
                    measurements.extend(f['measurements'].items())
            
            if measurements:
                answer = "Measurements: " + ", ".join([f"{k}={v:.3f}" for k, v in measurements])
            else:
                answer = "No measurements available."
                
        else:
            # Default comprehensive
            if all_findings:
                answer = f"Analysis complete. {len(all_findings)} findings detected."
            else:
                answer = "Analysis complete. No significant abnormalities detected."
        
        return SynthesisResult(
            answer=answer,
            confidence=0.7,
            supporting_evidence=[],
            caveats=["Fallback synthesis used due to error"],
            requires_review=True
        )


class ConfidenceCalculator:
    """Calculate calibrated confidence scores"""
    
    @staticmethod
    def calculate_overall_confidence(
        agent_results: Dict[str, AgentAnalysis],
        query_analysis: SemanticQueryAnalysis
    ) -> float:
        """Calculate overall confidence from agent results"""
        
        if not agent_results:
            return 0.5
        
        # Collect all confidence scores
        confidences = []
        for analysis in agent_results.values():
            if hasattr(analysis, 'findings'):
                for finding in analysis.findings:
                    if isinstance(finding, dict):
                        confidences.append(finding.get('confidence', 0.5))
        
        if not confidences:
            # No findings with high confidence indicates absence
            return 0.9 if query_analysis.query_intent == QueryIntent.DIAGNOSTIC else 0.7
        
        # Weighted average with penalty for low confidence findings
        avg_confidence = sum(confidences) / len(confidences)
        min_confidence = min(confidences)
        
        # Penalize if any finding has very low confidence
        if min_confidence < 0.3:
            avg_confidence *= 0.8
        
        return min(avg_confidence, 0.95)  # Cap at 95%
    
    @staticmethod
    def needs_clinical_review(
        confidence: float,
        findings: List[Finding],
        query_intent: QueryIntent
    ) -> bool:
        """Determine if clinical review is needed"""
        
        # Always review if confidence is low
        if confidence < 0.6:
            return True
        
        # Check for borderline findings
        for finding in findings:
            if isinstance(finding, dict):
                f_conf = finding.get('confidence', 0)
                if 0.3 < f_conf < 0.7:
                    return True
        
        # Certain query types need review
        if query_intent in [QueryIntent.COMPARISON, QueryIntent.COMPREHENSIVE]:
            return confidence < 0.8
        
        return False 