#!/usr/bin/env python3
"""
Query Validator Tool - Inspired by CHESS Unit Tester
Generates validation checks and test cases for SQL queries
"""

import sqlite3
import json
import os
from collections import defaultdict

def generate_validation_tests(cursor, table, columns):
    """Generate validation test cases for a table"""
    tests = []
    
    try:
        # Test 1: Check for NULL handling
        for column in columns[:10]:  # Limit for performance
            cursor.execute(f'SELECT COUNT(*) FROM "{table}" WHERE "{column}" IS NULL')
            null_count = cursor.fetchone()[0]
            cursor.execute(f'SELECT COUNT(*) FROM "{table}"')
            total_count = cursor.fetchone()[0]
            
            if null_count > 0 and total_count > 0:
                tests.append({
                    'type': 'null_handling',
                    'table': table,
                    'column': column,
                    'test': f'When querying {column}, consider that {null_count}/{total_count} rows have NULL values',
                    'sql_check': f'WHERE "{column}" IS NOT NULL'
                })
        
        # Test 2: Check for common aggregation patterns
        for column in columns[:10]:
            cursor.execute(f'SELECT typeof("{column}") FROM "{table}" LIMIT 1')
            col_type = cursor.fetchone()
            if col_type and col_type[0] in ['integer', 'real']:
                tests.append({
                    'type': 'aggregation',
                    'table': table,
                    'column': column,
                    'test': f'For {column}, SUM/AVG/COUNT operations are valid',
                    'sql_check': f'SUM("{column}"), AVG("{column}"), COUNT("{column}")'
                })
        
        # Test 3: Check for grouping columns
        for column in columns[:10]:
            cursor.execute(f'SELECT COUNT(DISTINCT "{column}"), COUNT(*) FROM "{table}"')
            distinct, total = cursor.fetchone()
            if total > 0:
                ratio = distinct / total
                if ratio < 0.1:  # Low cardinality - good for grouping
                    tests.append({
                        'type': 'grouping',
                        'table': table,
                        'column': column,
                        'test': f'{column} has {distinct} unique values - good for GROUP BY',
                        'sql_check': f'GROUP BY "{column}"'
                    })
        
        # Test 4: Check for filtering patterns
        for column in columns[:5]:
            # Get sample values
            cursor.execute(f'SELECT DISTINCT "{column}" FROM "{table}" LIMIT 5')
            samples = [row[0] for row in cursor.fetchall() if row[0] is not None]
            if samples:
                tests.append({
                    'type': 'filtering',
                    'table': table,
                    'column': column,
                    'test': f'Common filter values for {column}: {samples[:3]}',
                    'sql_check': f'WHERE "{column}" = ?'
                })
        
    except Exception as e:
        tests.append({
            'type': 'error',
            'table': table,
            'error': str(e)
        })
    
    return tests

def identify_common_mistakes(cursor, tables_info):
    """Identify common SQL mistakes based on schema"""
    mistakes = []
    
    for table, info in tables_info.items():
        columns = info.get('columns', [])
        
        # Mistake 1: Confusing similar column names
        for i, col1 in enumerate(columns):
            for col2 in columns[i+1:]:
                if col1.lower() != col2.lower() and similar_names(col1, col2):
                    mistakes.append({
                        'type': 'column_confusion',
                        'table': table,
                        'warning': f'Columns "{col1}" and "{col2}" are similar - verify you\'re using the right one',
                        'columns': [col1, col2]
                    })
        
        # Mistake 2: Missing quotes for special names
        if ' ' in table or '-' in table:
            mistakes.append({
                'type': 'quoting_required',
                'table': table,
                'warning': f'Table name "{table}" requires quotes due to special characters',
                'correct_usage': f'"{table}"'
            })
        
        for column in columns:
            if ' ' in column or '-' in column or column.upper() in ['ORDER', 'GROUP', 'SELECT']:
                mistakes.append({
                    'type': 'quoting_required',
                    'column': column,
                    'warning': f'Column "{column}" requires quotes',
                    'correct_usage': f'"{column}"'
                })
        
        # Mistake 3: Type mismatches
        for column in columns[:20]:
            try:
                cursor.execute(f'SELECT typeof("{column}") FROM "{table}" LIMIT 1')
                col_type = cursor.fetchone()
                if col_type:
                    if 'date' in column.lower() and col_type[0] == 'text':
                        mistakes.append({
                            'type': 'type_mismatch',
                            'table': table,
                            'column': column,
                            'warning': f'{column} looks like a date but is stored as TEXT',
                            'recommendation': 'Use string comparison or date functions'
                        })
                    elif 'id' in column.lower() and col_type[0] == 'text':
                        mistakes.append({
                            'type': 'type_mismatch',
                            'table': table,
                            'column': column,
                            'warning': f'{column} is an ID stored as TEXT - use string comparison',
                            'recommendation': 'Compare with quotes: = "value"'
                        })
            except:
                pass
    
    return mistakes

def similar_names(s1, s2):
    """Check if two names are confusingly similar"""
    s1, s2 = s1.lower(), s2.lower()
    # Check for subset
    if s1 in s2 or s2 in s1:
        return True
    # Check for edit distance
    if abs(len(s1) - len(s2)) <= 2:
        differences = sum(c1 != c2 for c1, c2 in zip(s1, s2))
        if differences <= 2:
            return True
    return False

def generate_query_templates(cursor, tables_info):
    """Generate common query templates based on schema"""
    templates = []
    
    for table, info in tables_info.items():
        columns = info.get('columns', [])
        
        # Template 1: Simple COUNT
        templates.append({
            'pattern': 'count_all',
            'description': f'Count all rows in {table}',
            'template': f'SELECT COUNT(*) FROM "{table}"'
        })
        
        # Template 2: Conditional COUNT
        for column in columns[:5]:
            templates.append({
                'pattern': 'conditional_count',
                'description': f'Count rows where {column} meets condition',
                'template': f'SELECT COUNT(*) FROM "{table}" WHERE "{column}" = ?'
            })
        
        # Template 3: Percentage calculation
        templates.append({
            'pattern': 'percentage',
            'description': f'Calculate percentage in {table}',
            'template': f'SELECT CAST(COUNT(CASE WHEN condition THEN 1 END) AS REAL) * 100.0 / COUNT(*) FROM "{table}"'
        })
        
        # Template 4: Top N with ties
        for column in columns[:3]:
            templates.append({
                'pattern': 'top_n',
                'description': f'Get top N by {column}',
                'template': f'SELECT * FROM "{table}" ORDER BY "{column}" DESC LIMIT ?'
            })
    
    return templates

def main():
    # Connect to database
    conn = sqlite3.connect('./database.sqlite')
    cursor = conn.cursor()
    
    # Create output directory
    os.makedirs('./tool_output', exist_ok=True)
    
    # Get all tables and their info
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [t[0] for t in cursor.fetchall()]
    
    tables_info = {}
    for table in tables:
        cursor.execute(f'PRAGMA table_info("{table}")')
        columns = [row[1] for row in cursor.fetchall()]
        cursor.execute(f'SELECT COUNT(*) FROM "{table}"')
        row_count = cursor.fetchone()[0]
        tables_info[table] = {
            'columns': columns,
            'row_count': row_count
        }
    
    results = {
        'validation_tests': {},
        'common_mistakes': [],
        'query_templates': [],
        'validation_rules': []
    }
    
    print("Generating validation tests...")
    
    # Generate validation tests for each table
    for table in tables[:10]:  # Limit for performance
        columns = tables_info[table]['columns']
        tests = generate_validation_tests(cursor, table, columns)
        results['validation_tests'][table] = tests
    
    # Identify common mistakes
    print("Identifying common SQL mistakes...")
    results['common_mistakes'] = identify_common_mistakes(cursor, tables_info)
    
    # Generate query templates
    print("Generating query templates...")
    results['query_templates'] = generate_query_templates(cursor, tables_info)[:50]  # Limit output
    
    # Generate validation rules
    print("Creating validation rules...")
    results['validation_rules'] = [
        {
            'rule': 'Always check for NULL values when aggregating',
            'sql_pattern': 'WHERE column IS NOT NULL',
            'applies_to': 'aggregation queries'
        },
        {
            'rule': 'Use CAST(... AS REAL) for percentage calculations',
            'sql_pattern': 'CAST(COUNT(CASE WHEN ... THEN 1 END) AS REAL) * 100.0 / COUNT(*)',
            'applies_to': 'percentage queries'
        },
        {
            'rule': 'Quote table/column names with spaces or special characters',
            'sql_pattern': '"table name" or "column-name"',
            'applies_to': 'all queries with special names'
        },
        {
            'rule': 'Consider using DISTINCT when counting unique items',
            'sql_pattern': 'COUNT(DISTINCT column)',
            'applies_to': 'counting queries'
        },
        {
            'rule': 'Use proper JOIN conditions for relationships',
            'sql_pattern': 'JOIN table2 ON table1.fk = table2.pk',
            'applies_to': 'multi-table queries'
        }
    ]
    
    # Save results
    with open('./tool_output/query_validation.json', 'w') as f:
        json.dump(results, f, indent=2, default=str)
    
    # Print summary
    print("\n=== QUERY VALIDATION ANALYSIS COMPLETE ===")
    print(f"Generated validation tests for {len(results['validation_tests'])} tables")
    print(f"Identified {len(results['common_mistakes'])} potential mistakes")
    print(f"Created {len(results['query_templates'])} query templates")
    print(f"Defined {len(results['validation_rules'])} validation rules")
    
    print("\nTop warnings:")
    for mistake in results['common_mistakes'][:5]:
        print(f"  ⚠️ {mistake['warning']}")
    
    print("\nKey validation rules:")
    for rule in results['validation_rules']:
        print(f"  ✓ {rule['rule']}")
    
    print("\nResults saved to ./tool_output/query_validation.json")
    
    conn.close()

if __name__ == "__main__":
    main()