#!/usr/bin/env python3
"""
Extract exact values from database columns to ensure precise matching.
This tool addresses the common failure pattern where values don't match due to
case sensitivity, trailing spaces, or special character issues.
"""

import sqlite3
import json
import os
import sys
from typing import Dict, List, Set, Any
import re

class ExactValueExtractor:
    def __init__(self, db_path: str):
        self.db_path = db_path
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()
        self.exact_values = {}
        
    def extract(self):
        """Extract exact values from important columns."""
        try:
            self._identify_important_columns()
            self._extract_exact_values()
            self._detect_value_variations()
            self._create_value_lookup_map()
            return self.exact_values
        except Exception as e:
            return {"error": f"Value extraction failed: {str(e)}"}
        finally:
            self.conn.close()
            
    def _identify_important_columns(self):
        """Identify columns likely to be used in WHERE clauses."""
        self.important_columns = {}
        
        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = self.cursor.fetchall()
        
        for table in tables:
            table_name = table[0]
            self.important_columns[table_name] = []
            
            # Get column info
            self.cursor.execute(f"PRAGMA table_info({table_name})")
            columns = self.cursor.fetchall()
            
            # Get row count for sampling decisions
            self.cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
            row_count = self.cursor.fetchone()[0]
            
            for col in columns:
                col_name = col[1]
                col_type = col[2]
                
                # Priority columns for exact value extraction
                is_important = False
                
                # String columns with limited distinct values
                if 'TEXT' in col_type.upper() or 'VARCHAR' in col_type.upper() or 'CHAR' in col_type.upper():
                    # Check distinct count
                    if row_count > 0:
                        self.cursor.execute(f"SELECT COUNT(DISTINCT {col_name}) FROM {table_name}")
                        distinct_count = self.cursor.fetchone()[0]
                        
                        # Extract if reasonable number of distinct values
                        if distinct_count <= 1000 or distinct_count / row_count < 0.1:
                            is_important = True
                            
                # Columns with specific patterns in name
                name_patterns = ['name', 'type', 'status', 'category', 'region', 'country', 
                               'state', 'city', 'code', 'indicator', 'title', 'path']
                if any(pattern in col_name.lower() for pattern in name_patterns):
                    is_important = True
                    
                # Foreign key columns
                if col_name.endswith('_id') or col_name.endswith('Id') or col_name.endswith('ID'):
                    is_important = True
                    
                if is_important:
                    self.important_columns[table_name].append(col_name)
                    
    def _extract_exact_values(self):
        """Extract exact values from important columns."""
        for table_name, columns in self.important_columns.items():
            if not columns:
                continue
                
            self.exact_values[table_name] = {}
            
            for col_name in columns:
                try:
                    # Get distinct values with their exact format
                    query = f"""
                        SELECT DISTINCT {col_name}, COUNT(*) as cnt
                        FROM {table_name}
                        WHERE {col_name} IS NOT NULL
                        GROUP BY {col_name}
                        ORDER BY cnt DESC
                        LIMIT 500
                    """
                    self.cursor.execute(query)
                    results = self.cursor.fetchall()
                    
                    values_info = {
                        "exact_values": [],
                        "value_count": len(results),
                        "special_characteristics": {}
                    }
                    
                    for value, count in results:
                        value_entry = {
                            "value": value,
                            "count": count,
                            "characteristics": []
                        }
                        
                        # Detect special characteristics
                        if isinstance(value, str):
                            # Check for trailing/leading spaces
                            if value != value.strip():
                                value_entry["characteristics"].append("has_spaces")
                                if value.startswith(' '):
                                    value_entry["characteristics"].append("leading_space")
                                if value.endswith(' '):
                                    value_entry["characteristics"].append("trailing_space")
                                    
                            # Check for case variations
                            if value != value.lower() and value != value.upper():
                                value_entry["characteristics"].append("mixed_case")
                                
                            # Check for special characters
                            if re.search(r'[^\w\s\-_.]', value):
                                value_entry["characteristics"].append("special_chars")
                                
                            # Check for path separators
                            if '\\' in value:
                                value_entry["characteristics"].append("backslash")
                            if '/' in value and '\\' not in value:
                                value_entry["characteristics"].append("forward_slash")
                                
                        values_info["exact_values"].append(value_entry)
                        
                    self.exact_values[table_name][col_name] = values_info
                    
                except Exception as e:
                    self.exact_values[table_name][col_name] = {"error": str(e)}
                    
    def _detect_value_variations(self):
        """Detect columns where similar values exist with different cases or spacing."""
        for table_name, columns_data in self.exact_values.items():
            for col_name, values_info in columns_data.items():
                if "error" in values_info:
                    continue
                    
                exact_vals = values_info.get("exact_values", [])
                if not exact_vals:
                    continue
                    
                # Check for case variations
                normalized_map = {}
                variations = []
                
                for val_entry in exact_vals:
                    value = val_entry["value"]
                    if not isinstance(value, str):
                        continue
                        
                    # Normalize for comparison
                    normalized = value.lower().strip()
                    
                    if normalized in normalized_map:
                        # Found a variation
                        variations.append({
                            "original": normalized_map[normalized],
                            "variation": value,
                            "type": self._identify_variation_type(normalized_map[normalized], value)
                        })
                    else:
                        normalized_map[normalized] = value
                        
                if variations:
                    values_info["special_characteristics"]["variations"] = variations[:10]  # Limit
                    
    def _identify_variation_type(self, val1: str, val2: str) -> str:
        """Identify the type of variation between two values."""
        if val1.lower() == val2.lower():
            if val1.strip() == val2.strip():
                return "case_only"
            else:
                return "case_and_spacing"
        elif val1.strip().lower() == val2.strip().lower():
            return "spacing_only"
        else:
            return "other"
            
    def _create_value_lookup_map(self):
        """Create a lookup map for common value patterns."""
        lookup_map = {}
        
        for table_name, columns_data in self.exact_values.items():
            for col_name, values_info in columns_data.items():
                if "error" in values_info:
                    continue
                    
                exact_vals = values_info.get("exact_values", [])
                
                # Create patterns for common lookups
                for val_entry in exact_vals[:100]:  # Limit to top 100
                    value = val_entry["value"]
                    if not isinstance(value, str):
                        continue
                        
                    # Create keys for different lookup patterns
                    keys = [
                        value,  # Exact
                        value.lower(),  # Lowercase
                        value.upper(),  # Uppercase
                        value.strip(),  # No spaces
                        value.lower().strip()  # Normalized
                    ]
                    
                    for key in keys:
                        if key not in lookup_map:
                            lookup_map[key] = []
                        lookup_map[key].append({
                            "table": table_name,
                            "column": col_name,
                            "exact_value": value,
                            "count": val_entry["count"]
                        })
                        
        # Add to results
        self.exact_values["_lookup_map"] = lookup_map

def format_exact_values_output(exact_values: Dict) -> str:
    """Format exact values into a reference guide."""
    output = []
    
    output.append("# Exact Value Reference Guide\n")
    output.append("## Purpose")
    output.append("This guide contains EXACT values from the database to ensure precise matching.")
    output.append("Pay special attention to case sensitivity, spaces, and special characters.\n")
    
    # Remove lookup map from iteration
    tables_to_process = {k: v for k, v in exact_values.items() if k != "_lookup_map"}
    
    for table_name, columns_data in tables_to_process.items():
        if "error" in columns_data:
            continue
            
        has_important_values = False
        table_output = []
        
        table_output.append(f"\n## Table: {table_name}")
        
        for col_name, values_info in columns_data.items():
            if "error" in values_info or not values_info.get("exact_values"):
                continue
                
            # Only show columns with special characteristics or limited values
            exact_vals = values_info["exact_values"]
            has_special = any(v["characteristics"] for v in exact_vals)
            
            if has_special or len(exact_vals) <= 20:
                has_important_values = True
                table_output.append(f"\n### {col_name}")
                
                # Show warning for special cases
                special_chars = values_info.get("special_characteristics", {})
                if special_chars.get("variations"):
                    table_output.append("⚠️  **CASE/SPACING VARIATIONS DETECTED**")
                    for var in special_chars["variations"][:3]:
                        table_output.append(f"  - '{var['original']}' vs '{var['variation']}' ({var['type']})")
                        
                # List exact values
                table_output.append(f"**Exact Values** ({len(exact_vals)} distinct):")
                
                for val_entry in exact_vals[:15]:  # Limit display
                    value = val_entry["value"]
                    chars = val_entry["characteristics"]
                    
                    # Format value with indicators
                    if chars:
                        char_str = ", ".join(chars)
                        table_output.append(f"  - `{repr(value)}` [{char_str}] (count: {val_entry['count']})")
                    else:
                        table_output.append(f"  - `{repr(value)}` (count: {val_entry['count']})")
                        
                if len(exact_vals) > 15:
                    table_output.append(f"  ... and {len(exact_vals) - 15} more values")
                    
        if has_important_values:
            output.extend(table_output)
            
    # Add summary of critical issues
    output.append("\n## Critical Value Matching Rules")
    output.append("1. **Always use exact case** - 'FROM' ≠ 'from'")
    output.append("2. **Preserve spaces** - 'West ' ≠ 'West'") 
    output.append("3. **Match separators** - Use backslashes in paths exactly as stored")
    output.append("4. **Quote strings properly** - Use single quotes for SQL string literals")
    
    return "\n".join(output)

def main():
    """Main execution function."""
    db_path = "./database.sqlite"
    
    if not os.path.exists(db_path):
        print(f"Error: Database not found at {db_path}", file=sys.stderr)
        sys.exit(1)
        
    # Create output directory
    os.makedirs("tool_output", exist_ok=True)
    
    print("Extracting exact values from database...")
    
    try:
        extractor = ExactValueExtractor(db_path)
        exact_values = extractor.extract()
        
        # Save raw JSON
        with open("tool_output/exact_values.json", "w") as f:
            # Remove lookup map for cleaner JSON
            json_output = {k: v for k, v in exact_values.items() if k != "_lookup_map"}
            json.dump(json_output, f, indent=2, default=str)
            
        # Save formatted reference guide
        formatted = format_exact_values_output(exact_values)
        with open("tool_output/exact_values_guide.txt", "w") as f:
            f.write(formatted)
            
        print("Extraction complete!")
        print(f"- Raw values: tool_output/exact_values.json")
        print(f"- Reference guide: tool_output/exact_values_guide.txt")
        
        # Print summary
        if "error" not in exact_values:
            total_columns = sum(len(cols) for table, cols in exact_values.items() 
                              if table != "_lookup_map" and "error" not in cols)
            print(f"\nExtracted values from {total_columns} important columns")
            
            # Count special cases
            special_count = 0
            for table, cols in exact_values.items():
                if table == "_lookup_map" or "error" in cols:
                    continue
                for col, info in cols.items():
                    if "error" not in info:
                        for val in info.get("exact_values", []):
                            if val["characteristics"]:
                                special_count += 1
                                break
                                
            if special_count > 0:
                print(f"⚠️  Found {special_count} columns with special value characteristics")
                
    except Exception as e:
        print(f"Fatal error: {str(e)}", file=sys.stderr)
        sys.exit(1)

if __name__ == "__main__":
    main()