#!/usr/bin/env python3
"""
Enhanced Pattern Detector Tool
Detects query patterns and common SQL requirements.
"""

import sqlite3
import json
import os

def detect_patterns(db_path="database.sqlite"):
    """Detect patterns that inform SQL generation."""

    os.makedirs("tool_output", exist_ok=True)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    pattern_info = {
        "aggregation_guidance": {},
        "common_filters": {},
        "column_patterns": {},
        "query_templates": [],
        "special_cases": []
    }

    try:
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
        tables = [t[0] for t in cursor.fetchall() if not t[0].startswith("sqlite_")]

        for table in tables:
            # Get column info
            cursor.execute(f"PRAGMA table_info(`{table}`)")
            columns = cursor.fetchall()

            # Get row count
            cursor.execute(f"SELECT COUNT(*) FROM `{table}`")
            row_count = cursor.fetchone()[0]

            if row_count == 0:
                continue

            pattern_info["aggregation_guidance"][table] = {
                "count_columns": [],  # Columns to COUNT
                "sum_columns": [],    # Columns to SUM
                "avg_columns": [],    # Columns to AVG
                "group_by_columns": []  # Common grouping columns
            }

            pattern_info["common_filters"][table] = []
            pattern_info["column_patterns"][table] = {}

            for col in columns:
                col_name = col[1]
                col_type = col[2].upper()
                is_nullable = not col[3]

                # Analyze column for aggregation patterns
                col_name_lower = col_name.lower()

                # Determine aggregation type
                if any(keyword in col_name_lower for keyword in ['count', 'total', 'number', 'num_', 'quantity', 'amount']):
                    # This is likely a measure to be SUMmed
                    pattern_info["aggregation_guidance"][table]["sum_columns"].append({
                        "column": col_name,
                        "reason": "Column name suggests numeric measure"
                    })
                elif any(keyword in col_name_lower for keyword in ['avg', 'average', 'mean', 'rate', 'ratio', 'percentage']):
                    pattern_info["aggregation_guidance"][table]["avg_columns"].append({
                        "column": col_name,
                        "reason": "Column name suggests average/ratio"
                    })
                elif col_name_lower.endswith('_id') or col_name_lower == 'id':
                    pattern_info["aggregation_guidance"][table]["count_columns"].append({
                        "column": col_name,
                        "reason": "ID column - count for number of entities"
                    })

                # Check for grouping columns
                if col_type in ['TEXT', 'VARCHAR'] and not col_name_lower.endswith('_id'):
                    try:
                        cursor.execute(f"SELECT COUNT(DISTINCT `{col_name}`) FROM `{table}`")
                        distinct_count = cursor.fetchone()[0]

                        if 1 < distinct_count < row_count * 0.5:  # Good grouping candidate
                            pattern_info["aggregation_guidance"][table]["group_by_columns"].append({
                                "column": col_name,
                                "distinct_values": distinct_count,
                                "selectivity": round(distinct_count / row_count, 3)
                            })
                    except:
                        pass

                # Detect common filter patterns
                if col_type in ['INTEGER', 'REAL', 'NUMERIC']:
                    try:
                        cursor.execute(f"""
                            SELECT MIN(`{col_name}`), MAX(`{col_name}`), AVG(`{col_name}`)
                            FROM `{table}`
                            WHERE `{col_name}` IS NOT NULL
                        """)
                        min_val, max_val, avg_val = cursor.fetchone()

                        if min_val is not None:
                            pattern_info["common_filters"][table].append({
                                "column": col_name,
                                "type": "numeric_range",
                                "min": min_val,
                                "max": max_val,
                                "avg": round(avg_val, 2) if avg_val else None
                            })

                            # Check if it looks like a percentage
                            if 0 <= min_val <= max_val <= 100:
                                pattern_info["column_patterns"][table][col_name] = "percentage_0_100"
                            elif 0 <= min_val <= max_val <= 1:
                                pattern_info["column_patterns"][table][col_name] = "percentage_0_1"
                    except:
                        pass

                # Check NULL prevalence
                if is_nullable:
                    try:
                        cursor.execute(f"SELECT COUNT(*) FROM `{table}` WHERE `{col_name}` IS NULL")
                        null_count = cursor.fetchone()[0]
                        null_percentage = (null_count / row_count * 100) if row_count > 0 else 0

                        if null_percentage > 10:
                            pattern_info["special_cases"].append({
                                "table": table,
                                "column": col_name,
                                "issue": "high_null_percentage",
                                "percentage": round(null_percentage, 1),
                                "recommendation": f"Consider using IS NOT NULL when filtering {col_name}"
                            })
                    except:
                        pass

        # Generate query templates based on discovered patterns
        for table, agg_info in pattern_info["aggregation_guidance"].items():
            # COUNT template
            if agg_info["count_columns"]:
                pattern_info["query_templates"].append({
                    "pattern": f"Count {table} entities",
                    "template": f"SELECT COUNT(*) FROM {table}",
                    "variant": f"SELECT COUNT(DISTINCT {agg_info['count_columns'][0]['column']}) FROM {table}"
                })

            # SUM template
            if agg_info["sum_columns"]:
                col = agg_info["sum_columns"][0]["column"]
                pattern_info["query_templates"].append({
                    "pattern": f"Total {col} in {table}",
                    "template": f"SELECT SUM({col}) FROM {table}",
                    "note": f"Use SUM not COUNT for {col} - it's a measure column"
                })

            # GROUP BY template
            if agg_info["group_by_columns"] and agg_info["sum_columns"]:
                group_col = agg_info["group_by_columns"][0]["column"]
                sum_col = agg_info["sum_columns"][0]["column"] if agg_info["sum_columns"] else "*"
                pattern_info["query_templates"].append({
                    "pattern": f"Aggregate by {group_col}",
                    "template": f"SELECT {group_col}, SUM({sum_col}) FROM {table} GROUP BY {group_col}"
                })

        # Add percentage calculation templates
        pattern_info["query_templates"].append({
            "pattern": "Calculate percentage",
            "template": "SELECT (COUNT(CASE WHEN condition THEN 1 END) * 100.0 / COUNT(*)) AS percentage",
            "note": "Always multiply by 100.0 for percentage, not 1.0"
        })

        pattern_info["query_templates"].append({
            "pattern": "Ratio as decimal",
            "template": "SELECT (COUNT(CASE WHEN condition THEN 1 END) * 1.0 / COUNT(*)) AS ratio",
            "note": "Use 1.0 for ratio, 100.0 for percentage"
        })

        # Save to file
        with open("tool_output/patterns.json", "w") as f:
            json.dump(pattern_info, f, indent=2)

        print(f"✓ Pattern detection complete: Analyzed {len(tables)} tables")
        print(f"✓ Generated {len(pattern_info['query_templates'])} query templates")
        print(f"✓ Found {len(pattern_info['special_cases'])} special cases to consider")

    except Exception as e:
        print(f"✗ Pattern detection failed: {str(e)}")
        raise
    finally:
        conn.close()

if __name__ == "__main__":
    detect_patterns()