#!/usr/bin/env python3
"""
Smart Value Extractor Tool
Extracts sample values with intelligent sampling strategies and case pattern detection.
Handles large tables, NULL-heavy columns, and encoding issues gracefully.
"""

import sqlite3
import json
import os
import sys
from typing import Dict, List, Any, Optional
import re

def safe_query(conn, query: str, params=None) -> Optional[List]:
    """Execute query with timeout and error handling."""
    try:
        cursor = conn.cursor()
        conn.execute("PRAGMA query_timeout = 3000")  # 3 second timeout per query
        if params:
            cursor.execute(query, params)
        else:
            cursor.execute(query)
        return cursor.fetchall()
    except Exception as e:
        print(f"⚠️  Query failed: {str(e)[:100]}")
        return None

def detect_case_pattern(values: List[str]) -> Dict[str, Any]:
    """Detect case patterns in string values."""
    if not values:
        return {"pattern": "unknown", "examples": []}

    patterns = {
        "UPPERCASE": 0,
        "lowercase": 0,
        "Title Case": 0,
        "camelCase": 0,
        "snake_case": 0,
        "CONSTANT_CASE": 0,
        "mixed": 0
    }

    clean_values = []
    for val in values[:20]:  # Analyze first 20 values
        if val is None:
            continue
        val_str = str(val).strip()
        if not val_str:
            continue

        clean_values.append(val_str)

        # Detect patterns
        if val_str.isupper():
            patterns["UPPERCASE"] += 1
        elif val_str.islower():
            patterns["lowercase"] += 1
        elif val_str.istitle():
            patterns["Title Case"] += 1
        elif '_' in val_str and val_str.replace('_', '').isalpha():
            if val_str.isupper():
                patterns["CONSTANT_CASE"] += 1
            else:
                patterns["snake_case"] += 1
        elif val_str[0].islower() and any(c.isupper() for c in val_str[1:]):
            patterns["camelCase"] += 1
        else:
            patterns["mixed"] += 1

    # Determine dominant pattern
    dominant = max(patterns.items(), key=lambda x: x[1])

    return {
        "pattern": dominant[0] if dominant[1] > 0 else "mixed",
        "confidence": dominant[1] / len(clean_values) if clean_values else 0,
        "examples": clean_values[:5],
        "all_patterns": {k: v for k, v in patterns.items() if v > 0}
    }

def analyze_numeric_column(conn, table: str, column: str) -> Dict[str, Any]:
    """Analyze numeric columns for range and patterns."""
    result = {
        "type": "numeric",
        "min": None,
        "max": None,
        "avg": None,
        "has_nulls": False,
        "has_zeros": False,
        "has_negatives": False,
        "looks_like_percentage": False,
        "looks_like_count": False,
        "sample_values": []
    }

    # Get statistics
    stats_query = f"""
        SELECT
            MIN(`{column}`) as min_val,
            MAX(`{column}`) as max_val,
            AVG(`{column}`) as avg_val,
            COUNT(*) as total,
            COUNT(`{column}`) as non_null,
            COUNT(CASE WHEN `{column}` = 0 THEN 1 END) as zeros,
            COUNT(CASE WHEN `{column}` < 0 THEN 1 END) as negatives
        FROM `{table}`
        LIMIT 10000
    """

    stats = safe_query(conn, stats_query)
    if stats and stats[0]:
        min_val, max_val, avg_val, total, non_null, zeros, negatives = stats[0]
        result["min"] = min_val
        result["max"] = max_val
        result["avg"] = round(avg_val, 2) if avg_val else None
        result["has_nulls"] = non_null < total
        result["has_zeros"] = zeros > 0
        result["has_negatives"] = negatives > 0

        # Check if looks like percentage
        if min_val is not None and max_val is not None:
            if 0 <= min_val <= 1 and 0 <= max_val <= 1:
                result["looks_like_percentage"] = True
                result["percentage_format"] = "decimal"
            elif 0 <= min_val <= 100 and 0 <= max_val <= 100:
                result["looks_like_percentage"] = True
                result["percentage_format"] = "percent"

        # Check if looks like count
        column_lower = column.lower()
        if any(x in column_lower for x in ['count', 'total', 'number', 'qty', 'quantity']):
            result["looks_like_count"] = True

    # Get sample values
    samples = safe_query(conn, f"""
        SELECT DISTINCT `{column}`
        FROM `{table}`
        WHERE `{column}` IS NOT NULL
        LIMIT 10
    """)

    if samples:
        result["sample_values"] = [s[0] for s in samples]

    return result

def extract_table_values(conn, table_name: str, columns: List[str]) -> Dict[str, Any]:
    """Extract values for a single table with smart sampling."""
    table_data = {
        "table_name": table_name,
        "columns": {},
        "row_count": 0,
        "analysis_status": "complete",
        "warnings": []
    }

    # Get approximate row count
    count_result = safe_query(conn, f"SELECT COUNT(*) FROM (SELECT 1 FROM `{table_name}` LIMIT 10000)")
    if count_result:
        table_data["row_count"] = count_result[0][0]
        if table_data["row_count"] == 10000:
            table_data["row_count"] = "10000+"
            table_data["warnings"].append("Large table - sampling limited")

    # Determine sampling strategy
    sample_size = min(100, table_data["row_count"] if isinstance(table_data["row_count"], int) else 100)

    for col_name in columns:
        col_data = {
            "distinct_count": 0,
            "null_count": 0,
            "sample_values": [],
            "case_pattern": None,
            "value_patterns": [],
            "data_type_inferred": "unknown",
            "warnings": []
        }

        try:
            # Get NULL count and distinct count
            stats = safe_query(conn, f"""
                SELECT
                    COUNT(DISTINCT `{col_name}`) as distinct_count,
                    COUNT(*) - COUNT(`{col_name}`) as null_count
                FROM (SELECT `{col_name}` FROM `{table_name}` LIMIT 5000)
            """)

            if stats:
                col_data["distinct_count"] = stats[0][0]
                col_data["null_count"] = stats[0][1]

                # High cardinality warning
                if col_data["distinct_count"] > 1000:
                    col_data["warnings"].append("High cardinality column")

            # Get sample values with better distribution
            if col_data["distinct_count"] < 50:
                # For low cardinality, get all distinct values
                samples = safe_query(conn, f"""
                    SELECT DISTINCT `{col_name}`, COUNT(*) as cnt
                    FROM `{table_name}`
                    WHERE `{col_name}` IS NOT NULL
                    GROUP BY `{col_name}`
                    ORDER BY cnt DESC
                    LIMIT 20
                """)
                if samples:
                    col_data["sample_values"] = [{"value": s[0], "count": s[1]} for s in samples]
            else:
                # For high cardinality, get diverse samples
                samples = safe_query(conn, f"""
                    SELECT DISTINCT `{col_name}`
                    FROM (
                        SELECT `{col_name}` FROM `{table_name}`
                        WHERE `{col_name}` IS NOT NULL
                        ORDER BY RANDOM()
                        LIMIT {sample_size}
                    )
                    LIMIT 20
                """)
                if samples:
                    col_data["sample_values"] = [s[0] for s in samples]

            # Infer data type and patterns from samples
            if col_data["sample_values"]:
                sample_vals = [s["value"] if isinstance(s, dict) else s
                              for s in col_data["sample_values"]]

                # Try to determine data type
                numeric_count = 0
                date_count = 0
                bool_count = 0

                for val in sample_vals[:10]:
                    if val is None:
                        continue

                    val_str = str(val)

                    # Check if numeric
                    try:
                        float(val_str)
                        numeric_count += 1
                    except:
                        pass

                    # Check if date-like
                    if re.match(r'\d{4}-\d{2}-\d{2}', val_str):
                        date_count += 1
                    elif re.match(r'\d{2}/\d{2}/\d{4}', val_str):
                        date_count += 1

                    # Check if boolean-like
                    if val_str.lower() in ['true', 'false', '0', '1', 'yes', 'no', 't', 'f', 'y', 'n']:
                        bool_count += 1

                # Determine type
                sample_count = len(sample_vals[:10])
                if numeric_count == sample_count:
                    col_data["data_type_inferred"] = "numeric"
                    # Detailed numeric analysis
                    numeric_info = analyze_numeric_column(conn, table_name, col_name)
                    col_data.update(numeric_info)
                elif date_count > sample_count * 0.8:
                    col_data["data_type_inferred"] = "date"
                elif bool_count > sample_count * 0.8:
                    col_data["data_type_inferred"] = "boolean"
                else:
                    col_data["data_type_inferred"] = "text"
                    # Case pattern analysis for text
                    col_data["case_pattern"] = detect_case_pattern(sample_vals)

            # Detect special patterns
            col_name_lower = col_name.lower()
            if 'email' in col_name_lower:
                col_data["value_patterns"].append("email")
            elif 'phone' in col_name_lower or 'tel' in col_name_lower:
                col_data["value_patterns"].append("phone")
            elif 'url' in col_name_lower or 'link' in col_name_lower:
                col_data["value_patterns"].append("url")
            elif 'percent' in col_name_lower or 'rate' in col_name_lower:
                col_data["value_patterns"].append("percentage")

        except Exception as e:
            col_data["warnings"].append(f"Analysis failed: {str(e)[:50]}")
            table_data["analysis_status"] = "partial"

        table_data["columns"][col_name] = col_data

    return table_data

def extract_values(db_path="database.sqlite"):
    """Main value extraction with comprehensive error handling."""

    os.makedirs("tool_output", exist_ok=True)

    extraction_result = {
        "tables": {},
        "summary": {
            "tables_analyzed": 0,
            "columns_analyzed": 0,
            "extraction_quality": "complete"
        },
        "global_patterns": {
            "common_values": {},
            "case_patterns": {},
            "data_types": {}
        },
        "errors": []
    }

    conn = None
    try:
        # Read schema analysis to get table and column info
        schema_path = "tool_output/schema_analysis.json"
        if not os.path.exists(schema_path):
            extraction_result["errors"].append("Schema analysis not found - running blind extraction")
            extraction_result["summary"]["extraction_quality"] = "degraded"

            # Try basic extraction without schema
            conn = sqlite3.connect(db_path, timeout=30.0)
            cursor = conn.cursor()

            # Get tables directly
            cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
            tables = [t[0] for t in cursor.fetchall() if not t[0].startswith('sqlite_')]

            schema_info = {"tables": {}}
            for table in tables:
                cursor.execute(f"PRAGMA table_info(`{table}`)")
                cols = cursor.fetchall()
                schema_info["tables"][table] = {
                    "columns": {col[1]: {} for col in cols}
                }
        else:
            with open(schema_path, 'r') as f:
                schema_info = json.load(f)

        if not conn:
            conn = sqlite3.connect(db_path, timeout=30.0)

        # Extract values for each table
        for table_name, table_info in schema_info.get("tables", {}).items():
            print(f"  Extracting values from: {table_name}")

            if table_info.get("analysis_status") == "failed":
                extraction_result["errors"].append(f"Skipping {table_name} - schema analysis failed")
                continue

            try:
                columns = list(table_info.get("columns", {}).keys())
                if not columns:
                    extraction_result["errors"].append(f"No columns found for {table_name}")
                    continue

                table_data = extract_table_values(conn, table_name, columns)
                extraction_result["tables"][table_name] = table_data

                extraction_result["summary"]["tables_analyzed"] += 1
                extraction_result["summary"]["columns_analyzed"] += len(columns)

                # Track global patterns
                for col_name, col_data in table_data["columns"].items():
                    # Track common values across tables
                    if col_data.get("sample_values") and col_data["distinct_count"] < 10:
                        for val_info in col_data["sample_values"][:5]:
                            val = val_info["value"] if isinstance(val_info, dict) else val_info
                            if val is not None:
                                val_str = str(val)
                                if val_str not in extraction_result["global_patterns"]["common_values"]:
                                    extraction_result["global_patterns"]["common_values"][val_str] = []
                                extraction_result["global_patterns"]["common_values"][val_str].append(
                                    f"{table_name}.{col_name}"
                                )

                    # Track case patterns
                    if col_data.get("case_pattern"):
                        pattern = col_data["case_pattern"]["pattern"]
                        if pattern not in extraction_result["global_patterns"]["case_patterns"]:
                            extraction_result["global_patterns"]["case_patterns"][pattern] = []
                        extraction_result["global_patterns"]["case_patterns"][pattern].append(
                            f"{table_name}.{col_name}"
                        )

                    # Track data types
                    dtype = col_data.get("data_type_inferred", "unknown")
                    if dtype not in extraction_result["global_patterns"]["data_types"]:
                        extraction_result["global_patterns"]["data_types"][dtype] = []
                    extraction_result["global_patterns"]["data_types"][dtype].append(
                        f"{table_name}.{col_name}"
                    )

                if table_data.get("warnings"):
                    extraction_result["errors"].extend(
                        [f"{table_name}: {w}" for w in table_data["warnings"]]
                    )
                    if extraction_result["summary"]["extraction_quality"] == "complete":
                        extraction_result["summary"]["extraction_quality"] = "partial"

            except Exception as e:
                print(f"  ⚠️  Failed to extract from {table_name}: {str(e)}")
                extraction_result["errors"].append(f"{table_name}: {str(e)}")
                if extraction_result["summary"]["extraction_quality"] == "complete":
                    extraction_result["summary"]["extraction_quality"] = "partial"

    except Exception as e:
        extraction_result["errors"].append(f"Critical error: {str(e)}")
        extraction_result["summary"]["extraction_quality"] = "failed"
        print(f"❌ Critical error: {str(e)}")

    finally:
        if conn:
            conn.close()

    # Save results
    output_path = "tool_output/value_samples.json"
    try:
        with open(output_path, 'w') as f:
            json.dump(extraction_result, f, indent=2, default=str)
        print(f"\n✓ Value extraction saved to {output_path}")
        print(f"  Quality: {extraction_result['summary']['extraction_quality']}")
        print(f"  Tables: {extraction_result['summary']['tables_analyzed']}")
        print(f"  Columns: {extraction_result['summary']['columns_analyzed']}")

        if extraction_result["errors"]:
            print(f"  ⚠️  {len(extraction_result['errors'])} warnings/errors")

    except Exception as e:
        print(f"❌ Failed to save results: {str(e)}")
        try:
            with open("tool_output/value_samples.error", 'w') as f:
                f.write(f"Extraction failed: {str(e)}\n")
        except:
            pass

    return extraction_result

if __name__ == "__main__":
    result = extract_values()

    if result["summary"]["extraction_quality"] == "failed":
        sys.exit(1)
    else:
        sys.exit(0)