#!/usr/bin/env python3
"""
Enhanced Pattern Synthesizer
Analyzes database schema to synthesize query patterns and templates.
Enhanced with better column return patterns and percentage calculations.
"""

import sqlite3
import json
import os
from collections import defaultdict

def synthesize_patterns(db_path):
    """Main pattern synthesis with enhanced precision."""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    patterns = {
        'max_min_patterns': {},
        'aggregation_patterns': {},
        'ranking_patterns': {},
        'percentage_patterns': {},
        'temporal_patterns': {},
        'join_patterns': {},
        'column_return_patterns': {},
        'value_matching_patterns': {}
    }

    # Get all tables and analyze
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [row[0] for row in cursor.fetchall()]

    # Synthesize patterns for each aspect
    synthesize_max_min_patterns(cursor, tables, patterns)
    synthesize_aggregation_patterns(cursor, tables, patterns)
    synthesize_ranking_patterns(cursor, tables, patterns)
    synthesize_percentage_patterns(cursor, tables, patterns)
    synthesize_temporal_patterns(cursor, tables, patterns)
    synthesize_join_patterns(cursor, tables, patterns)
    synthesize_column_return_patterns(cursor, tables, patterns)
    synthesize_value_matching_patterns(cursor, tables, patterns)

    conn.close()
    return patterns

def synthesize_max_min_patterns(cursor, tables, patterns):
    """Synthesize MAX/MIN query patterns with precision."""
    patterns['max_min_patterns'] = {
        'single_item_patterns': [],
        'value_only_patterns': [],
        'comparison_patterns': [],
        'multi_item_patterns': []
    }

    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()

        # Find numeric columns for MAX/MIN
        numeric_cols = [col[1] for col in columns if 'INT' in col[2].upper() or 'REAL' in col[2].upper() or 'NUMERIC' in col[2].upper()]
        text_cols = [col[1] for col in columns if 'TEXT' in col[2].upper() or 'VARCHAR' in col[2].upper()]

        for num_col in numeric_cols:
            # Single item with max/min value
            patterns['max_min_patterns']['single_item_patterns'].append({
                'pattern': f"Item with highest {num_col}",
                'template': f"SELECT * FROM {table} ORDER BY {num_col} DESC LIMIT 1",
                'use_case': f"Get the {table} record with maximum {num_col}"
            })

            patterns['max_min_patterns']['single_item_patterns'].append({
                'pattern': f"Item with lowest {num_col}",
                'template': f"SELECT * FROM {table} ORDER BY {num_col} ASC LIMIT 1",
                'use_case': f"Get the {table} record with minimum {num_col}"
            })

            # Value only patterns
            patterns['max_min_patterns']['value_only_patterns'].append({
                'pattern': f"Maximum {num_col} value",
                'template': f"SELECT MAX({num_col}) FROM {table}",
                'use_case': f"Get only the maximum value of {num_col}"
            })

            # Multi-item patterns (for both highest and lowest)
            patterns['max_min_patterns']['multi_item_patterns'].append({
                'pattern': f"Both highest and lowest {num_col}",
                'template': f"(SELECT * FROM {table} ORDER BY {num_col} DESC LIMIT 1) UNION ALL (SELECT * FROM {table} ORDER BY {num_col} ASC LIMIT 1)",
                'use_case': f"Get records with both maximum and minimum {num_col}"
            })

def synthesize_aggregation_patterns(cursor, tables, patterns):
    """Enhanced aggregation pattern synthesis."""
    patterns['aggregation_patterns'] = {
        'count_patterns': [],
        'sum_patterns': [],
        'group_by_patterns': [],
        'conditional_patterns': [],
        'distinct_patterns': []
    }

    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()

        numeric_cols = [col[1] for col in columns if 'INT' in col[2].upper() or 'REAL' in col[2].upper()]
        text_cols = [col[1] for col in columns if 'TEXT' in col[2].upper() or 'VARCHAR' in col[2].upper()]
        likely_categories = [col for col in text_cols if any(word in col.lower() for word in ['type', 'category', 'status', 'state', 'city', 'name'])]

        # Count patterns
        patterns['aggregation_patterns']['count_patterns'].append({
            'pattern': f"Count all {table}",
            'template': f"SELECT COUNT(*) FROM {table}",
            'use_case': f"Total number of records"
        })

        # Distinct count patterns
        for col in text_cols[:3]:  # Limit to avoid too many patterns
            patterns['aggregation_patterns']['distinct_patterns'].append({
                'pattern': f"Count distinct {col}",
                'template': f"SELECT COUNT(DISTINCT {col}) FROM {table}",
                'use_case': f"Number of unique {col} values"
            })

        # Conditional count patterns
        for cat_col in likely_categories[:2]:
            patterns['aggregation_patterns']['conditional_patterns'].append({
                'pattern': f"Count by {cat_col} condition",
                'template': f"SELECT SUM(CASE WHEN {cat_col} = 'value' THEN 1 ELSE 0 END) FROM {table}",
                'alternative': f"SELECT COUNT(*) FROM {table} WHERE {cat_col} = 'value'",
                'use_case': f"Count records matching specific {cat_col}"
            })

        # Group by patterns
        for group_col in likely_categories[:2]:
            for agg_col in numeric_cols[:2]:
                patterns['aggregation_patterns']['group_by_patterns'].append({
                    'pattern': f"Sum {agg_col} by {group_col}",
                    'template': f"SELECT {group_col}, SUM({agg_col}) FROM {table} GROUP BY {group_col}",
                    'use_case': f"Aggregate {agg_col} grouped by {group_col}"
                })

        # Combined column patterns
        related_cols = find_related_columns(numeric_cols)
        for col_group in related_cols:
            if len(col_group) > 1:
                sum_expr = ' + '.join(col_group)
                patterns['aggregation_patterns']['sum_patterns'].append({
                    'pattern': f"Combined total of {' and '.join(col_group)}",
                    'template': f"SELECT SUM({sum_expr}) FROM {table}",
                    'use_case': f"Sum of related columns together"
                })

def synthesize_ranking_patterns(cursor, tables, patterns):
    """Synthesize ranking and top-N patterns."""
    patterns['ranking_patterns'] = {
        'top_n': [],
        'rank_by': [],
        'percentile': []
    }

    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()

        numeric_cols = [col[1] for col in columns if 'INT' in col[2].upper() or 'REAL' in col[2].upper()]

        for num_col in numeric_cols[:3]:  # Limit patterns
            # Top N patterns
            patterns['ranking_patterns']['top_n'].append({
                'pattern': f"Top N {table} by {num_col}",
                'template': f"SELECT * FROM {table} ORDER BY {num_col} DESC LIMIT N",
                'use_case': f"Get top N records ordered by {num_col}"
            })

            # Rank patterns
            patterns['ranking_patterns']['rank_by'].append({
                'pattern': f"Rank {table} by {num_col}",
                'template': f"SELECT *, RANK() OVER (ORDER BY {num_col} DESC) as rank FROM {table}",
                'use_case': f"Add ranking based on {num_col}"
            })

def synthesize_percentage_patterns(cursor, tables, patterns):
    """Enhanced percentage calculation patterns."""
    patterns['percentage_patterns'] = {
        'simple_percentage': [],
        'group_percentage': [],
        'conditional_percentage': []
    }

    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()

        text_cols = [col[1] for col in columns if 'TEXT' in col[2].upper() or 'VARCHAR' in col[2].upper()]
        likely_categories = [col for col in text_cols if any(word in col.lower() for word in ['type', 'status', 'state', 'category'])]

        for cat_col in likely_categories[:2]:
            # Simple percentage
            patterns['percentage_patterns']['simple_percentage'].append({
                'pattern': f"Percentage of specific {cat_col}",
                'template': f"SELECT CAST(COUNT(CASE WHEN {cat_col} = 'value' THEN 1 END) AS REAL) * 100 / COUNT(*) FROM {table}",
                'rounded': f"SELECT ROUND(CAST(COUNT(CASE WHEN {cat_col} = 'value' THEN 1 END) AS REAL) * 100 / COUNT(*), 2) FROM {table}",
                'use_case': f"Calculate percentage of records with specific {cat_col}"
            })

            # Group percentage
            patterns['percentage_patterns']['group_percentage'].append({
                'pattern': f"Percentage breakdown by {cat_col}",
                'template': f"SELECT {cat_col}, COUNT(*) * 100.0 / (SELECT COUNT(*) FROM {table}) as percentage FROM {table} GROUP BY {cat_col}",
                'use_case': f"Show percentage distribution of {cat_col}"
            })

def synthesize_temporal_patterns(cursor, tables, patterns):
    """Synthesize time-based query patterns."""
    patterns['temporal_patterns'] = {
        'date_filtering': [],
        'date_grouping': [],
        'date_extraction': []
    }

    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()

        date_cols = [col[1] for col in columns if any(word in col[1].lower() for word in ['date', 'time', 'year', 'month'])]

        for date_col in date_cols[:2]:
            # Date filtering patterns
            patterns['temporal_patterns']['date_filtering'].append({
                'pattern': f"Filter by {date_col}",
                'exact_date': f"SELECT * FROM {table} WHERE DATE({date_col}) = 'YYYY-MM-DD'",
                'date_range': f"SELECT * FROM {table} WHERE {date_col} BETWEEN 'start' AND 'end'",
                'year_filter': f"SELECT * FROM {table} WHERE {date_col} LIKE '%YYYY%'",
                'use_case': f"Filter records by {date_col}"
            })

            # Date grouping patterns
            patterns['temporal_patterns']['date_grouping'].append({
                'pattern': f"Group by {date_col} period",
                'by_date': f"SELECT DATE({date_col}), COUNT(*) FROM {table} GROUP BY DATE({date_col})",
                'by_year': f"SELECT strftime('%Y', {date_col}), COUNT(*) FROM {table} GROUP BY strftime('%Y', {date_col})",
                'by_month': f"SELECT strftime('%Y-%m', {date_col}), COUNT(*) FROM {table} GROUP BY strftime('%Y-%m', {date_col})",
                'use_case': f"Aggregate by time periods"
            })

def synthesize_join_patterns(cursor, tables, patterns):
    """Enhanced join pattern synthesis."""
    patterns['join_patterns'] = {
        'foreign_key_joins': [],
        'name_based_joins': [],
        'multi_table_joins': [],
        'self_joins': []
    }

    # Detect foreign key relationships
    for table in tables:
        cursor.execute(f"PRAGMA foreign_key_list({table})")
        foreign_keys = cursor.fetchall()

        for fk in foreign_keys:
            from_col = fk[3]
            to_table = fk[2]
            to_col = fk[4]

            patterns['join_patterns']['foreign_key_joins'].append({
                'pattern': f"Join {table} with {to_table}",
                'template': f"SELECT * FROM {table} t1 JOIN {to_table} t2 ON t1.{from_col} = t2.{to_col}",
                'use_case': f"Join using foreign key relationship"
            })

    # Detect potential name-based joins
    for t1 in tables:
        cursor.execute(f"PRAGMA table_info({t1})")
        t1_cols = [col[1] for col in cursor.fetchall()]

        for t2 in tables:
            if t1 >= t2:
                continue

            cursor.execute(f"PRAGMA table_info({t2})")
            t2_cols = [col[1] for col in cursor.fetchall()]

            # Look for name columns that might join
            t1_name_cols = [c for c in t1_cols if 'name' in c.lower()]
            t2_name_cols = [c for c in t2_cols if 'name' in c.lower()]

            if t1_name_cols and t2_name_cols:
                for c1 in t1_name_cols[:1]:
                    for c2 in t2_name_cols[:1]:
                        if c1.replace(t1.lower(), '').strip('_') == c2.replace(t2.lower(), '').strip('_'):
                            patterns['join_patterns']['name_based_joins'].append({
                                'pattern': f"Name-based join {t1}-{t2}",
                                'template': f"SELECT * FROM {t1} t1 JOIN {t2} t2 ON t1.{c1} = t2.{c2}",
                                'use_case': f"Join tables on name columns"
                            })

def synthesize_column_return_patterns(cursor, tables, patterns):
    """Synthesize patterns for precise column returns."""
    patterns['column_return_patterns'] = {
        'single_column': [],
        'multi_column': [],
        'calculated_column': [],
        'conditional_return': []
    }

    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()

        col_names = [col[1] for col in columns]
        primary_cols = [col[1] for col in columns if col[5] == 1]

        # Single column patterns
        for col in col_names[:3]:
            patterns['column_return_patterns']['single_column'].append({
                'pattern': f"Return only {col}",
                'template': f"SELECT {col} FROM {table}",
                'use_case': f"When asked specifically for {col}"
            })

        # Multi-column patterns
        if len(col_names) > 1:
            patterns['column_return_patterns']['multi_column'].append({
                'pattern': f"Return specific columns from {table}",
                'template': f"SELECT col1, col2, col3 FROM {table}",
                'use_case': "Return exactly the columns requested"
            })

        # Calculated column patterns
        numeric_cols = [col[1] for col in columns if 'INT' in col[2].upper() or 'REAL' in col[2].upper()]
        if len(numeric_cols) >= 2:
            patterns['column_return_patterns']['calculated_column'].append({
                'pattern': f"Calculate and return",
                'template': f"SELECT {numeric_cols[0]} - {numeric_cols[1]} as difference FROM {table}",
                'use_case': "Return calculated values"
            })

def synthesize_value_matching_patterns(cursor, tables, patterns):
    """Synthesize patterns for value matching and filtering."""
    patterns['value_matching_patterns'] = {
        'exact_match': [],
        'partial_match': [],
        'case_insensitive': [],
        'null_handling': []
    }

    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()

        text_cols = [col[1] for col in columns if 'TEXT' in col[2].upper() or 'VARCHAR' in col[2].upper()]

        for col in text_cols[:2]:
            # Exact match
            patterns['value_matching_patterns']['exact_match'].append({
                'pattern': f"Exact match on {col}",
                'template': f"SELECT * FROM {table} WHERE {col} = 'exact_value'",
                'use_case': "When exact value is known"
            })

            # Partial match
            patterns['value_matching_patterns']['partial_match'].append({
                'pattern': f"Partial match on {col}",
                'template': f"SELECT * FROM {table} WHERE {col} LIKE '%partial%'",
                'use_case': "For substring searching"
            })

            # Case insensitive
            patterns['value_matching_patterns']['case_insensitive'].append({
                'pattern': f"Case insensitive match on {col}",
                'template': f"SELECT * FROM {table} WHERE LOWER({col}) = LOWER('value')",
                'use_case': "When case doesn't matter"
            })

            # NULL handling
            patterns['value_matching_patterns']['null_handling'].append({
                'pattern': f"Handle NULL in {col}",
                'check_null': f"SELECT * FROM {table} WHERE {col} IS NOT NULL",
                'coalesce': f"SELECT COALESCE({col}, 'default') FROM {table}",
                'use_case': "Proper NULL value handling"
            })

def find_related_columns(columns):
    """Find columns that might be related and should be summed together."""
    related_groups = []

    # Look for columns with similar prefixes/suffixes
    prefixes = defaultdict(list)
    suffixes = defaultdict(list)

    for col in columns:
        # Check for underscore patterns
        if '_' in col:
            parts = col.split('_')
            if len(parts) >= 2:
                prefix = parts[0]
                suffix = '_'.join(parts[1:])
                prefixes[prefix].append(col)
                suffixes[suffix].append(col)

    # Group related columns
    for prefix, cols in prefixes.items():
        if len(cols) > 1 and any(word in prefix.lower() for word in ['total', 'sum', 'count']):
            related_groups.append(cols)

    # Special patterns
    rebound_cols = [c for c in columns if 'rebound' in c.lower()]
    if len(rebound_cols) > 1:
        related_groups.append(rebound_cols)

    yards_cols = [c for c in columns if 'yards' in c.lower() or 'yard' in c.lower()]
    if len(yards_cols) > 1:
        related_groups.append(yards_cols)

    return related_groups

def main():
    """Main execution function."""
    db_path = "./database.sqlite"

    if not os.path.exists(db_path):
        print(f"Error: Database not found at {db_path}")
        return

    # Create output directory
    os.makedirs("tool_output", exist_ok=True)

    # Synthesize patterns
    patterns = synthesize_patterns(db_path)

    # Save results
    output_path = "tool_output/synthesized_patterns.json"
    with open(output_path, 'w') as f:
        json.dump(patterns, f, indent=2)

    print(f"✅ Pattern synthesis complete")
    print(f"📊 MAX/MIN patterns: {len(patterns['max_min_patterns']['single_item_patterns'])}")
    print(f"📈 Aggregation patterns: {len(patterns['aggregation_patterns']['group_by_patterns'])}")
    print(f"🏆 Ranking patterns: {len(patterns['ranking_patterns']['top_n'])}")
    print(f"💯 Percentage patterns: {len(patterns['percentage_patterns']['simple_percentage'])}")
    print(f"💾 Results saved to: {output_path}")

if __name__ == "__main__":
    main()