#!/usr/bin/env python3
"""
Extracts value patterns from columns to determine appropriate SQL operators.
Identifies nullable columns, string patterns, and special encodings.
"""

import sqlite3
import json
import os
import re
from collections import Counter, defaultdict

def analyze_value_patterns(db_path="database.sqlite"):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    patterns = {
        "tables": {},
        "operator_hints": {},
        "nullable_columns": {},
        "string_patterns": {},
        "numeric_ranges": {},
        "special_values": {}
    }

    try:
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = [row[0] for row in cursor.fetchall()]

        for table in tables:
            patterns["tables"][table] = {}
            patterns["nullable_columns"][table] = []
            patterns["string_patterns"][table] = {}
            patterns["numeric_ranges"][table] = {}

            # Get table info
            cursor.execute(f"PRAGMA table_info({table})")
            columns = cursor.fetchall()

            # Get row count for sampling
            cursor.execute(f"SELECT COUNT(*) FROM {table}")
            row_count = cursor.fetchone()[0]

            if row_count == 0:
                continue

            # Sample size (max 1000 rows for performance)
            sample_size = min(1000, row_count)

            for col_info in columns:
                col_name = col_info[1]
                col_type = col_info[2].upper()
                can_be_null = col_info[3] == 0  # notnull flag

                # Count NULLs
                cursor.execute(f"SELECT COUNT(*) FROM {table} WHERE {col_name} IS NULL")
                null_count = cursor.fetchone()[0]

                if null_count > 0:
                    patterns["nullable_columns"][table].append({
                        "column": col_name,
                        "null_percentage": round((null_count / row_count) * 100, 2)
                    })

                # Analyze based on column type
                if 'TEXT' in col_type or 'VARCHAR' in col_type or 'CHAR' in col_type:
                    # Sample text values
                    cursor.execute(f"""
                        SELECT DISTINCT {col_name}
                        FROM {table}
                        WHERE {col_name} IS NOT NULL
                        LIMIT {sample_size}
                    """)
                    values = [row[0] for row in cursor.fetchall()]

                    if values:
                        # Detect patterns
                        pattern_info = analyze_text_patterns(values, col_name)
                        if pattern_info:
                            patterns["string_patterns"][table][col_name] = pattern_info

                        # Check if values contain wildcards or special chars
                        special_chars = any(any(c in str(v) for c in ['%', '_', '*']) for v in values)
                        if special_chars:
                            patterns["special_values"][f"{table}.{col_name}"] = "Contains wildcard characters"

                elif 'INT' in col_type or 'REAL' in col_type or 'NUMERIC' in col_type:
                    # Get numeric range
                    cursor.execute(f"""
                        SELECT MIN({col_name}), MAX({col_name}), AVG({col_name})
                        FROM {table}
                        WHERE {col_name} IS NOT NULL
                    """)
                    min_val, max_val, avg_val = cursor.fetchone()

                    if min_val is not None:
                        patterns["numeric_ranges"][table][col_name] = {
                            "min": min_val,
                            "max": max_val,
                            "avg": round(avg_val, 2) if avg_val else None
                        }

                        # Check for special numeric patterns
                        cursor.execute(f"""
                            SELECT DISTINCT {col_name}
                            FROM {table}
                            WHERE {col_name} IS NOT NULL
                            ORDER BY {col_name}
                            LIMIT 20
                        """)
                        sample_values = [row[0] for row in cursor.fetchall()]

                        # Check if values are categorical (limited set)
                        if len(sample_values) <= 10:
                            patterns["special_values"][f"{table}.{col_name}"] = {
                                "type": "categorical_numeric",
                                "values": sample_values
                            }

        # Generate operator hints based on patterns
        for table, string_cols in patterns["string_patterns"].items():
            for col, info in string_cols.items():
                hint_key = f"{table}.{col}"

                if info.get("has_spaces"):
                    patterns["operator_hints"][hint_key] = {
                        "recommended": "LIKE for partial matches, = for exact",
                        "reason": "Values contain spaces"
                    }

                if info.get("has_special_chars"):
                    patterns["operator_hints"][hint_key] = {
                        "recommended": "= for exact matches (escape special chars if using LIKE)",
                        "reason": "Values contain special characters"
                    }

                if info.get("common_prefixes"):
                    patterns["operator_hints"][hint_key] = {
                        "recommended": f"LIKE '{info['common_prefixes'][0]}%' for prefix matching",
                        "reason": "Common prefix pattern detected"
                    }

        # Output results
        os.makedirs("tool_output", exist_ok=True)

        with open("tool_output/value_patterns.json", "w") as f:
            json.dump(patterns, f, indent=2)

        # Generate summary
        summary = []
        summary.append("# Value Pattern Analysis\n\n")

        if patterns["nullable_columns"]:
            summary.append("## Nullable Columns (affects COUNT):\n")
            for table, cols in patterns["nullable_columns"].items():
                if cols:
                    summary.append(f"\n{table}:\n")
                    for col_info in cols:
                        summary.append(f"  - {col_info['column']}: {col_info['null_percentage']}% NULL\n")

        if patterns["operator_hints"]:
            summary.append("\n## Operator Recommendations:\n")
            for col, hint in patterns["operator_hints"].items():
                summary.append(f"\n{col}:\n")
                summary.append(f"  Use: {hint['recommended']}\n")
                summary.append(f"  Reason: {hint['reason']}\n")

        with open("tool_output/value_pattern_summary.txt", "w") as f:
            f.writelines(summary)

        print("Value pattern analysis complete - results in tool_output/")

    except Exception as e:
        print(f"Error analyzing value patterns: {e}")
        patterns["error"] = str(e)
        with open("tool_output/value_patterns.json", "w") as f:
            json.dump(patterns, f, indent=2)

    finally:
        conn.close()


def analyze_text_patterns(values, col_name):
    """Analyze patterns in text values."""
    if not values:
        return None

    info = {}

    # Check for spaces
    info["has_spaces"] = any(' ' in str(v) for v in values)

    # Check for special characters
    special_pattern = re.compile(r'[^a-zA-Z0-9\s]')
    info["has_special_chars"] = any(special_pattern.search(str(v)) for v in values)

    # Check for common prefixes
    if len(values) > 5:
        # Get first 3 chars of each value
        prefixes = [str(v)[:3] for v in values if len(str(v)) >= 3]
        if prefixes:
            prefix_counts = Counter(prefixes)
            common = [p for p, c in prefix_counts.most_common(3) if c > len(values) * 0.2]
            if common:
                info["common_prefixes"] = common

    # Check if values look like codes/IDs
    if all(len(str(v)) == len(str(values[0])) for v in values[:10]):
        info["fixed_length"] = len(str(values[0]))

    # Check case patterns
    sample = values[:10]
    if all(str(v).isupper() for v in sample):
        info["case"] = "UPPER"
    elif all(str(v).islower() for v in sample):
        info["case"] = "lower"
    elif all(str(v).istitle() for v in sample):
        info["case"] = "Title"

    return info if len(info) > 0 else None


if __name__ == "__main__":
    analyze_value_patterns()