#!/usr/bin/env python3
"""
Precision Entity Resolver - Maps question entities to exact database values
Combines entity resolution from iter8 with precision focus from iter7
"""

import sqlite3
import json
import os
import re
from typing import Dict, List, Set, Tuple, Optional
from collections import defaultdict

class PrecisionEntityResolver:
    def __init__(self, db_path: str):
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()
        self.entity_index = defaultdict(list)
        self.column_metadata = {}
        self.entity_patterns = {}
        
    def build_complete_index(self):
        """Build comprehensive index of all database values for entity resolution."""
        
        tables = self.get_tables()
        
        for table in tables:
            columns = self.get_columns(table)
            
            for col_name, col_type in columns:
                full_column = f"{table}.{col_name}"
                self.column_metadata[full_column] = {
                    "type": col_type,
                    "is_id": self.is_id_column(col_name),
                    "is_name": self.is_name_column(col_name),
                    "is_code": self.is_code_column(col_name)
                }
                
                # Index text columns for entity matching
                if col_type in ['TEXT', 'VARCHAR', 'CHAR']:
                    self.index_column_values(table, col_name, full_column)
    
    def index_column_values(self, table: str, column: str, full_column: str):
        """Index all unique values from a column with exact case preservation."""
        
        try:
            # Get all distinct values
            query = f"SELECT DISTINCT {column} FROM {table} WHERE {column} IS NOT NULL"
            self.cursor.execute(query)
            values = self.cursor.fetchall()
            
            for (value,) in values:
                if value and isinstance(value, str):
                    # Store with exact case
                    entry = {
                        "table": table,
                        "column": column,
                        "full_column": full_column,
                        "exact_value": value,
                        "match_type": "exact"
                    }
                    
                    # Index by lowercase for matching
                    self.entity_index[value.lower()].append(entry)
                    
                    # Also index individual words for partial matching
                    words = re.findall(r'\b\w+\b', value.lower())
                    for word in words:
                        if len(word) > 2:  # Skip very short words
                            partial_entry = entry.copy()
                            partial_entry["match_type"] = "partial"
                            self.entity_index[word].append(partial_entry)
                            
        except sqlite3.Error as e:
            pass  # Skip problematic columns
    
    def detect_entity_type_requirements(self, question: str) -> Dict[str, Any]:
        """Detect what type of entity the question is asking for."""
        
        patterns = {
            "return_state_not_school": [
                r"which school'?s state",
                r"what school'?s state",
                r"school.{0,20}state"
            ],
            "return_country_not_indicator": [
                r"which country.{0,20}indicator",
                r"what country.{0,20}indicator",
                r"country with.{0,20}indicator"
            ],
            "return_name_not_id": [
                r"who is",
                r"what is the name",
                r"which person",
                r"which company"
            ],
            "return_count_only": [
                r"how many",
                r"number of",
                r"count of",
                r"total number"
            ],
            "return_single_value": [
                r"what is the \w+",
                r"which is the \w+",
                r"find the \w+"
            ]
        }
        
        results = {
            "detected_patterns": [],
            "entity_type_hint": None,
            "return_column_hint": None,
            "critical_warnings": []
        }
        
        question_lower = question.lower()
        
        for pattern_type, pattern_list in patterns.items():
            for pattern in pattern_list:
                if re.search(pattern, question_lower):
                    results["detected_patterns"].append(pattern_type)
                    
                    # Set specific hints
                    if pattern_type == "return_state_not_school":
                        results["entity_type_hint"] = "state"
                        results["return_column_hint"] = "Return STATE column, NOT school name"
                        results["critical_warnings"].append("CRITICAL: Question asks for school's state - return STATE not school")
                    elif pattern_type == "return_country_not_indicator":
                        results["entity_type_hint"] = "country"
                        results["return_column_hint"] = "Return COUNTRY column, NOT indicator"
                    elif pattern_type == "return_count_only":
                        results["entity_type_hint"] = "count"
                        results["return_column_hint"] = "Return COUNT(*) only, no other columns"
                    
                    break
        
        return results
    
    def resolve_entities(self, question: str, evidence: str = "") -> Dict[str, Any]:
        """Resolve entities from question to exact database values."""
        
        results = {
            "exact_matches": [],
            "partial_matches": [],
            "recommended_values": {},
            "matching_guidance": []
        }
        
        combined_text = f"{question} {evidence}".lower()
        words = set(re.findall(r'\b\w+\b', combined_text))
        
        # Track which columns have matches
        column_matches = defaultdict(list)
        
        # Find all potential matches
        for word in words:
            if word in self.entity_index:
                for entry in self.entity_index[word]:
                    if entry["match_type"] == "exact":
                        # Exact match found
                        match_info = {
                            "search_term": word,
                            "database_value": entry["exact_value"],
                            "column": entry["full_column"],
                            "confidence": "HIGH"
                        }
                        results["exact_matches"].append(match_info)
                        column_matches[entry["full_column"]].append(entry["exact_value"])
                    else:
                        # Partial match
                        match_info = {
                            "search_term": word,
                            "database_value": entry["exact_value"],
                            "column": entry["full_column"],
                            "confidence": "MEDIUM"
                        }
                        results["partial_matches"].append(match_info)
        
        # Generate recommendations
        for column, values in column_matches.items():
            if values:
                # Use the most frequently matched value
                results["recommended_values"][column] = {
                    "use_exact_match": True,
                    "recommended_value": values[0],  # Use first match
                    "all_matches": list(set(values))
                }
        
        # Generate matching guidance
        if results["exact_matches"]:
            results["matching_guidance"].append({
                "rule": "Use = operator for exact matches",
                "columns": list(set(m["column"] for m in results["exact_matches"]))
            })
        
        if results["partial_matches"] and not results["exact_matches"]:
            results["matching_guidance"].append({
                "rule": "Consider LIKE operator only if no exact matches found",
                "columns": list(set(m["column"] for m in results["partial_matches"]))
            })
        
        return results
    
    def is_id_column(self, col_name: str) -> bool:
        """Check if column is likely an ID column."""
        col_lower = col_name.lower()
        return (col_lower.endswith(('_id', 'id')) or 
                col_lower in ['id', 'identifier', 'key'])
    
    def is_name_column(self, col_name: str) -> bool:
        """Check if column is likely a name column."""
        col_lower = col_name.lower()
        return (col_lower.endswith(('_name', 'name')) or 
                col_lower in ['name', 'title', 'label'])
    
    def is_code_column(self, col_name: str) -> bool:
        """Check if column is likely a code column."""
        col_lower = col_name.lower()
        return (col_lower.endswith(('_code', 'code', '_abbr', 'abbr')) or 
                col_lower in ['code', 'abbreviation'])
    
    def get_tables(self) -> List[str]:
        """Get all tables in the database."""
        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
        return [row[0] for row in self.cursor.fetchall()]
    
    def get_columns(self, table: str) -> List[Tuple[str, str]]:
        """Get columns and their types for a table."""
        self.cursor.execute(f"PRAGMA table_info({table})")
        return [(row[1], row[2]) for row in self.cursor.fetchall()]
    
    def generate_precision_report(self, question: str = "", evidence: str = "") -> Dict[str, Any]:
        """Generate comprehensive entity resolution report."""
        
        # Build the index
        self.build_complete_index()
        
        # Detect entity type requirements
        entity_requirements = self.detect_entity_type_requirements(question) if question else {}
        
        # Resolve entities
        entity_resolution = self.resolve_entities(question, evidence) if question else {}
        
        # Create comprehensive report
        report = {
            "entity_type_analysis": entity_requirements,
            "entity_resolution": entity_resolution,
            "column_metadata": self.column_metadata,
            "precision_rules": self.generate_precision_rules()
        }
        
        return report
    
    def generate_precision_rules(self) -> List[Dict[str, Any]]:
        """Generate precision rules based on column analysis."""
        
        rules = []
        
        # ID/Code columns always need exact match
        id_columns = [col for col, meta in self.column_metadata.items() 
                     if meta.get("is_id") or meta.get("is_code")]
        if id_columns:
            rules.append({
                "rule": "Always use = for ID and code columns",
                "columns": id_columns,
                "priority": "CRITICAL"
            })
        
        # Name columns usually need exact match
        name_columns = [col for col, meta in self.column_metadata.items() 
                       if meta.get("is_name")]
        if name_columns:
            rules.append({
                "rule": "Default to = for name columns unless explicitly searching partial",
                "columns": name_columns,
                "priority": "HIGH"
            })
        
        return rules
    
    def close(self):
        """Close database connection."""
        self.conn.close()

def main():
    """Main execution function."""
    
    db_path = "database.sqlite"
    
    # Ensure output directory exists
    os.makedirs("tool_output", exist_ok=True)
    
    try:
        # Initialize resolver
        resolver = PrecisionEntityResolver(db_path)
        
        # Generate report (without specific question for general analysis)
        report = resolver.generate_precision_report()
        
        # Save to JSON
        output_path = "tool_output/entity_resolution.json"
        with open(output_path, 'w') as f:
            json.dump(report, f, indent=2, default=str)
        
        print(f"✓ Precision entity resolution complete")
        print(f"  - Indexed {len(resolver.entity_index)} unique value patterns")
        print(f"  - Analyzed {len(resolver.column_metadata)} columns")
        print(f"  - Generated {len(report['precision_rules'])} precision rules")
        print(f"  - Results saved to {output_path}")
        
        resolver.close()
        
    except Exception as e:
        print(f"✗ Entity resolution failed: {str(e)}")
        # Create error report
        error_report = {
            "error": str(e),
            "partial_results": {}
        }
        with open("tool_output/entity_resolution.json", 'w') as f:
            json.dump(error_report, f, indent=2)

if __name__ == "__main__":
    main()