from typing import Dict, List, Optional, Any, Union, ClassVar
from pydantic import BaseModel, Field
from enum import Enum
import re
import json
from src.utils.logsetup import logger

class TaskType(str, Enum):
    """Task type enumeration"""
    FACTUAL_QA = "factual_qa"
    MULTI_HOP_QA = "multi_hop_qa"
    COMPARATIVE_QA = "comparative_qa"
    TEMPORAL_QA = "temporal_qa"
    CAUSAL_QA = "causal_qa"
    MATHEMATICAL = "mathematical"
    CODE_GENERATION = "code_generation"
    DATA_ANALYSIS = "data_analysis"
    CREATIVE_WRITING = "creative_writing"
    REASONING = "reasoning"
    COMPLEX_PLANNING = "complex_planning"

class TaskComplexity(str, Enum):
    """Task complexity levels"""
    SIMPLE = "simple"        # 1-2 steps, single tool
    MODERATE = "moderate"    # 3-5 steps, multiple tools
    COMPLEX = "complex"      # 6-10 steps, requires planning
    EXPERT = "expert"        # 10+ steps, domain expertise needed

class TaskDomain(str, Enum):
    """Task domain categories"""
    GENERAL = "general"
    SCIENCE = "science"
    HISTORY = "history"
    GEOGRAPHY = "geography"
    ENTERTAINMENT = "entertainment"
    TECHNOLOGY = "technology"
    BUSINESS = "business"
    EDUCATION = "education"
    HEALTHCARE = "healthcare"
    LEGAL = "legal"

class TaskAnalysisResult(BaseModel):
    """Result of task type analysis"""
    task_type: TaskType = Field(description="Identified task type")
    complexity: TaskComplexity = Field(description="Task complexity level")
    domain: TaskDomain = Field(description="Task domain")
    confidence_score: float = Field(description="Confidence in classification (0-1)")
    
    # Detailed analysis
    entities: List[str] = Field(default_factory=list, description="Key entities identified")
    keywords: List[str] = Field(default_factory=list, description="Important keywords")
    intent_indicators: List[str] = Field(default_factory=list, description="Intent indicators found")
    
    # Requirements analysis
    required_tools: List[str] = Field(default_factory=list, description="Tools likely needed")
    estimated_steps: int = Field(description="Estimated number of execution steps")
    success_criteria: List[str] = Field(default_factory=list, description="Success criteria")
    
    # Context information
    temporal_mentions: List[str] = Field(default_factory=list, description="Time references found")
    geographic_mentions: List[str] = Field(default_factory=list, description="Location references")
    numeric_mentions: List[str] = Field(default_factory=list, description="Numbers/quantities")
    
    reasoning: str = Field(description="Explanation of classification")

class TaskTypeClassifier(BaseModel):
    """Advanced task type classification tool for intelligent task routing and planning"""
    
    name: str = "task_type_classifier"
    description: str = """Analyze and classify user tasks to determine task type, complexity, domain, 
    and execution requirements. Provides detailed analysis for optimal tool selection and planning strategy."""
    
    parameters: dict = {
        "type": "object",
        "properties": {
            "task_query": {
                "type": "string",
                "description": "(required) The user's task or query to be analyzed and classified"
            },
            "context": {
                "type": "string",
                "description": "(optional) Additional context about the task or user's goals",
                "default": ""
            },
            "domain_hint": {
                "type": "string",
                "description": "(optional) Hint about the expected domain if known",
                "default": ""
            },
            "complexity_hint": {
                "type": "string", 
                "enum": ["simple", "moderate", "complex", "expert"],
                "description": "(optional) Hint about expected complexity level",
                "default": ""
            }
        },
        "required": ["task_query"]
    }
    
    # Classification patterns and rules
    TASK_TYPE_PATTERNS: ClassVar[Dict] = {
        TaskType.FACTUAL_QA: [
            r"\b(who|what|when|where|which)\s+is\b",
            r"\bwhat\s+(is|are|was|were)\b",
            r"\bwho\s+(is|was|are|were)\b",
            r"\btell me about\b",
            r"\bdefine\b",
            r"\bexplain\b"
        ],
        TaskType.MULTI_HOP_QA: [
            r"\bof the\s+\w+\s+who\b",
            r"\bthat\s+(also|additionally|furthermore)\b",
            r"\bwhose\s+\w+\s+(is|was|are|were)\b",
            r"\b(along with|together with|in addition to)\b", 
            r"\bthe\s+\w+\s+of\s+the\s+\w+\s+of\b"
        ],
        TaskType.COMPARATIVE_QA: [
            r"\b(first|earlier|older|younger|before|after)\b",
            r"\b(better|worse|more|less|higher|lower)\b",
            r"\b(compare|versus|vs\.?|against)\b",
            r"\bwhich\s+(is|was|are|were)\s+(more|less|better|worse)\b",
            r"\bdifference\s+between\b"
        ],
        TaskType.TEMPORAL_QA: [
            r"\b(when|what year|what date|what time)\b",
            r"\b(19|20)\d{2}\b",
            r"\b(january|february|march|april|may|june|july|august|september|october|november|december)\b",
            r"\b(birthday|birth|death|died|born)\b",
            r"\btimeline\b"
        ],
        TaskType.CAUSAL_QA: [
            r"\b(why|how|because|due to|caused by|reason for)\b",
            r"\bwhat (caused|led to|resulted in)\b",
            r"\b(consequence|result|effect|impact)\b",
            r"\bwhat happens (if|when)\b"
        ],
        TaskType.MATHEMATICAL: [
            r"\b(calculate|compute|solve|find|determine)\b.*\b(sum|average|percentage|ratio|equation)\b",
            r"\b\d+\s*[+\-*/=]\s*\d+\b",
            r"\bmathematical\s+(problem|equation|formula)\b",
            r"\b(derivative|integral|matrix|probability)\b"
        ],
        TaskType.CODE_GENERATION: [
            r"\b(write|create|generate|implement)\s+(code|program|function|script)\b",
            r"\b(python|java|javascript|c\+\+|html|css)\b",
            r"\balgorithm\s+(to|for)\b",
            r"\bfunction\s+that\b"
        ]
    }
    
    COMPLEXITY_INDICATORS: ClassVar[Dict] = {
        TaskComplexity.SIMPLE: {
            "max_entities": 2,
            "max_concepts": 1,
            "max_words": 15,
            "patterns": [r"\bwhat is\b", r"\bwho is\b", r"\bwhen is\b"]
        },
        TaskComplexity.MODERATE: {
            "max_entities": 4,
            "max_concepts": 3,
            "max_words": 25,
            "patterns": [r"\bcompare\b", r"\bfind.*and\b", r"\blist\b"]
        },
        TaskComplexity.COMPLEX: {
            "max_entities": 8,
            "max_concepts": 5,
            "max_words": 40,
            "patterns": [r"\bmulti-step\b", r"\banalyze\b", r"\bsynthesis\b"]
        },
        TaskComplexity.EXPERT: {
            "max_entities": float('inf'),
            "max_concepts": float('inf'),
            "max_words": float('inf'),
            "patterns": [r"\bcomplex\s+analysis\b", r"\badvanced\b", r"\bresearch\s+project\b"]
        }
    }
    
    DOMAIN_KEYWORDS: ClassVar[Dict] = {
        TaskDomain.SCIENCE: ["research", "study", "experiment", "hypothesis", "theory", "data", "analysis"],
        TaskDomain.HISTORY: ["historical", "century", "war", "ancient", "medieval", "revolution", "empire"],
        TaskDomain.GEOGRAPHY: ["country", "city", "continent", "mountain", "river", "climate", "population"],
        TaskDomain.ENTERTAINMENT: ["movie", "film", "actor", "director", "music", "song", "artist", "album"],
        TaskDomain.TECHNOLOGY: ["software", "hardware", "computer", "internet", "programming", "digital"],
        TaskDomain.BUSINESS: ["company", "market", "finance", "economy", "profit", "investment", "strategy"],
        TaskDomain.HEALTHCARE: ["medical", "health", "disease", "treatment", "doctor", "hospital", "medicine"],
        TaskDomain.LEGAL: ["law", "legal", "court", "judge", "attorney", "regulation", "statute"]
    }
    
    async def execute(
        self, 
        task_query: str, 
        context: str = "", 
        domain_hint: str = "", 
        complexity_hint: str = "",
        **kwargs
    ) -> TaskAnalysisResult:
        """
        Analyze and classify a task query
        
        Args:
            task_query: The user's task or query to analyze
            context: Additional context about the task
            domain_hint: Optional domain hint
            complexity_hint: Optional complexity hint
            
        Returns:
            TaskAnalysisResult with detailed classification and analysis
        """
        
        logger.info(f"🔍 Analyzing task: {task_query[:100]}...")
        
        # Extract basic components
        entities = self._extract_entities(task_query)
        keywords = self._extract_keywords(task_query)
        
        # Classify task type
        task_type, type_confidence = self._classify_task_type(task_query, keywords)
        
        # Determine complexity
        complexity = self._assess_complexity(task_query, entities, keywords, complexity_hint)
        
        # Identify domain
        domain = self._identify_domain(task_query, keywords, domain_hint)
        
        # Analyze requirements
        required_tools = self._identify_required_tools(task_type, complexity, domain)
        estimated_steps = self._estimate_steps(task_type, complexity)
        success_criteria = self._define_success_criteria(task_type, task_query)
        
        # Extract contextual information
        temporal_mentions = self._extract_temporal_info(task_query)
        geographic_mentions = self._extract_geographic_info(task_query)
        numeric_mentions = self._extract_numeric_info(task_query)
        
        # Generate intent indicators
        intent_indicators = self._identify_intent_indicators(task_query, task_type)
        
        # Calculate overall confidence
        confidence_score = self._calculate_confidence(
            type_confidence, len(entities), len(keywords), task_type, complexity
        )
        
        # Generate reasoning
        reasoning = self._generate_reasoning(
            task_query, task_type, complexity, domain, entities, keywords
        )
        
        result = TaskAnalysisResult(
            task_type=task_type,
            complexity=complexity,
            domain=domain,
            confidence_score=confidence_score,
            entities=entities,
            keywords=keywords,
            intent_indicators=intent_indicators,
            required_tools=required_tools,
            estimated_steps=estimated_steps,
            success_criteria=success_criteria,
            temporal_mentions=temporal_mentions,
            geographic_mentions=geographic_mentions,
            numeric_mentions=numeric_mentions,
            reasoning=reasoning
        )
        
        logger.info(f"📊 Task classified as: {task_type.value} ({complexity.value}) in {domain.value} domain")
        return result
    
    def _extract_entities(self, text: str) -> List[str]:
        """Extract named entities from text"""
        # Simple named entity extraction using capitalization patterns
        entity_pattern = r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b'
        entities = re.findall(entity_pattern, text)
        
        # Filter out common non-entities
        stop_entities = {"The", "This", "That", "When", "Where", "Who", "What", "How", "Why", "Which"}
        entities = [e for e in entities if e not in stop_entities and len(e) > 2]
        
        return list(set(entities))
    
    def _extract_keywords(self, text: str) -> List[str]:
        """Extract important keywords from text"""
        # Remove stop words and extract meaningful terms
        stop_words = {
            "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by",
            "is", "was", "are", "were", "be", "been", "being", "have", "has", "had", "do", "does", "did"
        }
        
        words = re.findall(r'\b[a-z]+\b', text.lower())
        keywords = [w for w in words if w not in stop_words and len(w) > 3]
        
        # Count frequency and return top keywords
        from collections import Counter
        counter = Counter(keywords)
        return [word for word, count in counter.most_common(10)]
    
    def _classify_task_type(self, text: str, keywords: List[str]) -> tuple[TaskType, float]:
        """Classify the task type based on patterns and keywords"""
        
        text_lower = text.lower()
        scores = {}
        
        # Check patterns for each task type
        for task_type, patterns in self.TASK_TYPE_PATTERNS.items():
            score = 0
            for pattern in patterns:
                matches = len(re.findall(pattern, text_lower, re.IGNORECASE))
                score += matches
            scores[task_type] = score
        
        # Boost scores based on keywords
        keyword_boosts = {
            TaskType.FACTUAL_QA: ["fact", "information", "about", "explain"],
            TaskType.MULTI_HOP_QA: ["complex", "detailed", "comprehensive"],
            TaskType.COMPARATIVE_QA: ["compare", "difference", "better", "worse"],
            TaskType.MATHEMATICAL: ["calculate", "compute", "number", "amount"],
            TaskType.CODE_GENERATION: ["code", "program", "function", "algorithm"]
        }
        
        for task_type, boost_words in keyword_boosts.items():
            boost = sum(1 for word in keywords if word in boost_words)
            scores[task_type] = scores.get(task_type, 0) + boost
        
        # Find best match
        if not scores or max(scores.values()) == 0:
            return TaskType.FACTUAL_QA, 0.3  # Default with low confidence
        
        best_type = max(scores, key=scores.get)
        max_score = scores[best_type]
        total_score = sum(scores.values())
        confidence = max_score / total_score if total_score > 0 else 0.5
        
        return best_type, min(confidence, 1.0)
    
    def _assess_complexity(self, text: str, entities: List[str], keywords: List[str], hint: str) -> TaskComplexity:
        """Assess task complexity"""
        
        # Use hint if provided
        if hint and hint in [c.value for c in TaskComplexity]:
            return TaskComplexity(hint)
        
        # Calculate complexity indicators
        word_count = len(text.split())
        entity_count = len(entities)
        keyword_count = len(keywords)
        
        # Check for complexity patterns
        complexity_indicators = {
            "multi_step": len(re.findall(r'\b(first|then|next|finally|after that)\b', text.lower())),
            "conjunctions": len(re.findall(r'\b(and|or|but|however|furthermore|moreover)\b', text.lower())),
            "questions": len(re.findall(r'\?', text)),
            "subclauses": len(re.findall(r'\b(that|which|who|whose|where|when)\b', text.lower()))
        }
        
        # Calculate complexity score
        complexity_score = (
            min(word_count / 10, 5) +  # Word count factor (max 5)
            min(entity_count, 3) +     # Entity count factor (max 3)
            min(keyword_count / 2, 3) + # Keyword density factor (max 3)
            sum(complexity_indicators.values()) * 0.5  # Pattern indicators
        )
        
        # Map to complexity levels
        if complexity_score <= 3:
            return TaskComplexity.SIMPLE
        elif complexity_score <= 7:
            return TaskComplexity.MODERATE
        elif complexity_score <= 12:
            return TaskComplexity.COMPLEX
        else:
            return TaskComplexity.EXPERT
    
    def _identify_domain(self, text: str, keywords: List[str], hint: str) -> TaskDomain:
        """Identify the task domain"""
        
        # Use hint if provided
        if hint and hint in [d.value for d in TaskDomain]:
            return TaskDomain(hint)
        
        text_lower = text.lower()
        domain_scores = {}
        
        # Score based on domain keywords
        for domain, domain_keywords in self.DOMAIN_KEYWORDS.items():
            score = 0
            for keyword in domain_keywords:
                if keyword in text_lower:
                    score += 2
                if keyword in keywords:
                    score += 1
            domain_scores[domain] = score
        
        # Find best match
        if not domain_scores or max(domain_scores.values()) == 0:
            return TaskDomain.GENERAL
        
        return max(domain_scores, key=domain_scores.get)
    
    def _identify_required_tools(self, task_type: TaskType, complexity: TaskComplexity, domain: TaskDomain) -> List[str]:
        """Identify tools likely needed for this task"""
        
        tools = []
        
        # Base tools by task type
        tool_mapping = {
            TaskType.FACTUAL_QA: ["web_search", "answer_summarizer"],
            TaskType.MULTI_HOP_QA: ["enhanced_web_search", "multi_hop_reasoner", "answer_summarizer"],
            TaskType.COMPARATIVE_QA: ["web_search", "multi_hop_reasoner", "answer_summarizer"],
            TaskType.TEMPORAL_QA: ["web_search", "answer_summarizer"],
            TaskType.CAUSAL_QA: ["web_search", "multi_hop_reasoner", "answer_summarizer"],
            TaskType.MATHEMATICAL: ["python_execute", "answer_summarizer"],
            TaskType.CODE_GENERATION: ["python_execute", "answer_summarizer"],
            TaskType.DATA_ANALYSIS: ["python_execute", "web_search", "answer_summarizer"]
        }
        
        tools.extend(tool_mapping.get(task_type, ["web_search", "answer_summarizer"]))
        
        # Add complexity-based tools
        if complexity in [TaskComplexity.COMPLEX, TaskComplexity.EXPERT]:
            tools.extend(["task_planner", "task_decomposer"])
        
        # Add domain-specific tools
        if domain == TaskDomain.TECHNOLOGY:
            tools.append("code_analyzer")
        elif domain == TaskDomain.SCIENCE:
            tools.append("data_analyzer")
        
        return list(set(tools))  # Remove duplicates
    
    def _estimate_steps(self, task_type: TaskType, complexity: TaskComplexity) -> int:
        """Estimate number of execution steps needed"""
        
        base_steps = {
            TaskType.FACTUAL_QA: 2,
            TaskType.MULTI_HOP_QA: 4,
            TaskType.COMPARATIVE_QA: 3,
            TaskType.TEMPORAL_QA: 2,
            TaskType.CAUSAL_QA: 3,
            TaskType.MATHEMATICAL: 2,
            TaskType.CODE_GENERATION: 3,
            TaskType.DATA_ANALYSIS: 4
        }
        
        complexity_multipliers = {
            TaskComplexity.SIMPLE: 1.0,
            TaskComplexity.MODERATE: 1.5,
            TaskComplexity.COMPLEX: 2.0,
            TaskComplexity.EXPERT: 3.0
        }
        
        base = base_steps.get(task_type, 3)
        multiplier = complexity_multipliers.get(complexity, 1.5)
        
        return max(1, int(base * multiplier))
    
    def _define_success_criteria(self, task_type: TaskType, query: str) -> List[str]:
        """Define success criteria for the task"""
        
        criteria = []
        
        # General criteria
        criteria.append("Provide accurate and relevant information")
        criteria.append("Address all parts of the query")
        
        # Task-specific criteria
        if task_type == TaskType.FACTUAL_QA:
            criteria.append("Provide factual, verifiable information")
            criteria.append("Include credible sources when possible")
        elif task_type == TaskType.MULTI_HOP_QA:
            criteria.append("Connect information from multiple sources")
            criteria.append("Show clear reasoning chain")
        elif task_type == TaskType.COMPARATIVE_QA:
            criteria.append("Provide clear comparison points")
            criteria.append("Support comparison with evidence")
        elif task_type == TaskType.MATHEMATICAL:
            criteria.append("Show calculation steps clearly")
            criteria.append("Verify numerical accuracy")
        elif task_type == TaskType.CODE_GENERATION:
            criteria.append("Code should be functional and tested")
            criteria.append("Include comments and explanations")
        
        # Query-specific criteria
        if "explain" in query.lower():
            criteria.append("Provide clear explanations and reasoning")
        if "example" in query.lower():
            criteria.append("Include relevant examples")
        
        return criteria
    
    def _extract_temporal_info(self, text: str) -> List[str]:
        """Extract temporal information from text"""
        
        temporal_patterns = [
            r'\b(19|20)\d{2}\b',  # Years
            r'\b(January|February|March|April|May|June|July|August|September|October|November|December)\b',
            r'\b\d{1,2}(st|nd|rd|th)\b',  # Days
            r'\b(yesterday|today|tomorrow|recently|currently|now)\b',
            r'\b(before|after|during|while|since|until)\b'
        ]
        
        temporal_mentions = []
        for pattern in temporal_patterns:
            matches = re.findall(pattern, text, re.IGNORECASE)
            temporal_mentions.extend(matches)
        
        return list(set(temporal_mentions))
    
    def _extract_geographic_info(self, text: str) -> List[str]:
        """Extract geographic information from text"""
        
        # Simple geographic extraction (could be enhanced with NER)
        geographic_patterns = [
            r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*(?:\s+(?:City|State|Country|Province|County|Region))\b',
            r'\b(?:United States|USA|UK|United Kingdom|Canada|Australia|France|Germany|Japan|China|India|Brazil)\b'
        ]
        
        geographic_mentions = []
        for pattern in geographic_patterns:
            matches = re.findall(pattern, text, re.IGNORECASE)
            geographic_mentions.extend(matches)
        
        return list(set(geographic_mentions))
    
    def _extract_numeric_info(self, text: str) -> List[str]:
        """Extract numeric information from text"""
        
        numeric_patterns = [
            r'\b\d+(?:\.\d+)?\s*(?:%|percent|percentage)\b',  # Percentages
            r'\b\d+(?:,\d{3})*(?:\.\d+)?\b',  # Numbers with commas
            r'\$\d+(?:,\d{3})*(?:\.\d{2})?\b',  # Currency
            r'\b\d+(?:\.\d+)?\s*(?:kg|g|lb|oz|km|m|ft|in|l|ml)\b'  # Units
        ]
        
        numeric_mentions = []
        for pattern in numeric_patterns:
            matches = re.findall(pattern, text, re.IGNORECASE)
            numeric_mentions.extend(matches)
        
        return list(set(numeric_mentions))
    
    def _identify_intent_indicators(self, text: str, task_type: TaskType) -> List[str]:
        """Identify indicators of user intent"""
        
        intent_patterns = {
            "information_seeking": [r'\b(tell me|explain|describe|what is|who is)\b'],
            "comparison": [r'\b(compare|versus|difference|better|worse)\b'],
            "instruction": [r'\b(how to|step by step|guide|tutorial)\b'],
            "analysis": [r'\b(analyze|examine|evaluate|assess)\b'],
            "creation": [r'\b(create|generate|make|build|write)\b'],
            "problem_solving": [r'\b(solve|fix|resolve|find solution)\b']
        }
        
        indicators = []
        text_lower = text.lower()
        
        for intent, patterns in intent_patterns.items():
            for pattern in patterns:
                if re.search(pattern, text_lower):
                    indicators.append(intent)
                    break
        
        return indicators
    
    def _calculate_confidence(
        self, type_confidence: float, entity_count: int, keyword_count: int, 
        task_type: TaskType, complexity: TaskComplexity
    ) -> float:
        """Calculate overall classification confidence"""
        
        # Base confidence from type classification
        confidence = type_confidence
        
        # Boost confidence based on available information
        info_boost = min(0.2, (entity_count + keyword_count) * 0.02)
        confidence += info_boost
        
        # Penalty for very complex tasks (harder to classify accurately)
        if complexity == TaskComplexity.EXPERT:
            confidence *= 0.9
        elif complexity == TaskComplexity.COMPLEX:
            confidence *= 0.95
        
        return min(1.0, max(0.1, confidence))
    
    def _generate_reasoning(
        self, query: str, task_type: TaskType, complexity: TaskComplexity, 
        domain: TaskDomain, entities: List[str], keywords: List[str]
    ) -> str:
        """Generate explanation of the classification"""
        
        reasoning = f"Task classified as {task_type.value} based on "
        
        # Add reasoning factors
        factors = []
        
        if entities:
            factors.append(f"entities identified: {', '.join(entities[:3])}")
        
        if keywords:
            factors.append(f"key terms: {', '.join(keywords[:3])}")
        
        factors.append(f"complexity assessed as {complexity.value}")
        factors.append(f"domain identified as {domain.value}")
        
        reasoning += "; ".join(factors)
        
        # Add specific reasoning for task type
        if task_type == TaskType.MULTI_HOP_QA:
            reasoning += ". Multi-hop reasoning required due to complex entity relationships."
        elif task_type == TaskType.COMPARATIVE_QA:
            reasoning += ". Comparison task identified from comparative language patterns."
        elif task_type == TaskType.TEMPORAL_QA:
            reasoning += ". Temporal reasoning required due to time-related queries."
        
        return reasoning
    
    def to_param(self) -> Dict:
        """Convert tool to function call format"""
        return {
            "type": "function",
            "function": {
                "name": self.name,
                "description": self.description,
                "parameters": self.parameters,
            },
        }