"""
Enhanced query analyzer for the multi-agent system.
"""

from typing import Dict, List
from enum import Enum
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models import BaseLanguageModel

class QueryType(str, Enum):
    """Types of medical queries"""
    DIAGNOSTIC = "diagnostic"
    MEASUREMENT = "measurement"
    COMPARISON = "comparison"
    COMPREHENSIVE = "comprehensive"
    LOCALIZATION = "localization"

class QueryAnalysisResult(BaseModel):
    """Enhanced query analysis with priorities and dependencies"""
    primary_agents: List[str] = Field(description="Must-have agents for this query")
    secondary_agents: List[str] = Field(description="Optional agents based on findings", default=[])
    agent_priorities: Dict[str, float] = Field(description="Priority scores (0-1) for each agent", default={})
    execution_order: List[str] = Field(description="Optimal execution sequence", default=[])
    query_type: QueryType = Field(description="Type of medical query")
    expected_outputs: List[str] = Field(description="Expected findings or measurements", default=[])
    confidence_threshold: float = Field(description="Minimum confidence for reliable answer", default=0.7)
    requires_comparison: bool = Field(description="Whether prior image comparison needed", default=False)

class QueryAnalyzer:
    """
    Enhanced query analyzer with structured output.
    This replaces the 'EnhancedQueryAnalyzer' and is designed to be used
    within the LangGraph workflow.
    """
    
    def __init__(self, llm: BaseLanguageModel):
        self.llm = llm
        self.analyzer_chain = self._create_analyzer_chain()
    
    def _create_analyzer_chain(self):
        """Create LLM chain for structured query analysis"""
        prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a medical imaging query analyzer. Your task is to analyze the user's query and create a plan for a team of specialist AI agents.

Analyze the query and determine:
1.  **primary_agents**: A list of specialist agents that MUST be activated to answer the query.
2.  **query_type**: The category of the medical query.
3.  **requires_comparison**: A boolean indicating if the query involves comparing with a prior study.

Available agents and their capabilities:
-   **cardiac**: Analyzes heart size (CTR), cardiac devices, and mediastinal structures.
-   **breathing**: Analyzes lungs for pneumonia, opacities, pleural effusion, and pneumothorax.
-   **airway**: Assesses tracheal position and major airways.
-   **diaphragm**: Checks diaphragm position and for free air.
-   **everything**: Handles bones (fractures), soft tissues, and other lines/tubes.

- If the query is general, like "full analysis" or "comprehensive report", activate all agents.
- Be liberal in your interpretation. If a query mentions a topic, include the relevant agent. For example, "mediastinum" can relate to both "cardiac" and "airway". Include both.
- Base your decision solely on the query provided.
"""),
            ("human", "Query: {query}")
        ])
        
        return prompt | self.llm.with_structured_output(QueryAnalysisResult)
    
    def analyze_query(self, query: str) -> QueryAnalysisResult:
        """Analyze query with structured output, with a fallback for safety."""
        try:
            return self.analyzer_chain.invoke({"query": query})
        except Exception as e:
            print(f"LLM-based query analysis failed: {e}. Using fallback.")
            return self._fallback_analysis(query)
    
    def _fallback_analysis(self, query: str) -> QueryAnalysisResult:
        """Rule-based fallback analysis to ensure system stability."""
        query_lower = query.lower()
        
        primary_agents = set()
        
        if any(word in query_lower for word in ["heart", "cardiac", "cardiomegaly", "ctr", "mediastinum"]):
            primary_agents.add("cardiac")
        if any(word in query_lower for word in ["lung", "pneumonia", "breathing", "effusion", "pneumothorax", "infiltrate"]):
            primary_agents.add("breathing")
        if any(word in query_lower for word in ["trachea", "airway", "mediastinum"]):
            primary_agents.add("airway")
        if any(word in query_lower for word in ["diaphragm", "pneumoperitoneum", "free air"]):
            primary_agents.add("diaphragm")
        if any(word in query_lower for word in ["fracture", "tube", "line", "device", "bone", "rib"]):
            primary_agents.add("everything")
            
        if "full" in query_lower or "complete" in query_lower or "comprehensive" in query_lower:
            primary_agents.update(["airway", "breathing", "cardiac", "diaphragm", "everything"])

        query_type = QueryType.DIAGNOSTIC
        if any(word in query_lower for word in ["measure", "calculate", "ctr", "ratio"]):
            query_type = QueryType.MEASUREMENT
        elif "compare" in query_lower or "prior" in query_lower:
            query_type = QueryType.COMPARISON
        elif "full" in query_lower or "complete" in query_lower:
            query_type = QueryType.COMPREHENSIVE
            
        return QueryAnalysisResult(
            primary_agents=list(primary_agents) or ["cardiac", "breathing"],
            query_type=query_type,
            requires_comparison="compare" in query_lower or "prior" in query_lower
        ) 