#!/usr/bin/env python3
"""
Schema Analyzer Tool - Inspired by CHESS Schema Selector
Analyzes and scores column relevance for SQL queries
"""

import sqlite3
import json
import os
from collections import Counter, defaultdict

def calculate_column_importance(cursor, table, column):
    """Calculate importance score for a column"""
    score = 0
    details = {}
    
    try:
        # Check if it's a primary key
        cursor.execute(f'PRAGMA table_info("{table}")')
        for row in cursor.fetchall():
            if row[1] == column and row[5] == 1:  # pk flag
                score += 30
                details['is_primary_key'] = True
                break
        
        # Check if it's a foreign key
        cursor.execute(f'PRAGMA foreign_key_list("{table}")')
        for row in cursor.fetchall():
            if row[3] == column:  # from column
                score += 25
                details['is_foreign_key'] = True
                details['references'] = f"{row[2]}.{row[4]}"
                break
        
        # Check uniqueness
        cursor.execute(f'SELECT COUNT(DISTINCT "{column}"), COUNT(*) FROM "{table}"')
        distinct_count, total_count = cursor.fetchone()
        if total_count > 0:
            uniqueness = distinct_count / total_count
            details['uniqueness'] = uniqueness
            if uniqueness > 0.95:
                score += 20
                details['highly_unique'] = True
            elif uniqueness < 0.1:
                score += 15  # Low cardinality can be good for filtering
                details['low_cardinality'] = True
        
        # Check for NULL prevalence
        cursor.execute(f'SELECT COUNT(*) FROM "{table}" WHERE "{column}" IS NULL')
        null_count = cursor.fetchone()[0]
        if total_count > 0:
            null_ratio = null_count / total_count
            details['null_ratio'] = null_ratio
            if null_ratio < 0.1:
                score += 10  # Mostly populated
                details['mostly_populated'] = True
        
        # Check if column name suggests importance
        important_keywords = ['id', 'name', 'date', 'time', 'amount', 'total', 
                            'count', 'status', 'type', 'code', 'number']
        col_lower = column.lower()
        for keyword in important_keywords:
            if keyword in col_lower:
                score += 5
                details['has_important_keyword'] = keyword
                break
        
    except Exception as e:
        details['error'] = str(e)
    
    return score, details

def analyze_relationships(cursor, tables):
    """Analyze relationships between tables"""
    relationships = []
    
    for table in tables:
        # Check foreign keys
        cursor.execute(f'PRAGMA foreign_key_list("{table}")')
        for row in cursor.fetchall():
            relationship = {
                'from_table': table,
                'from_column': row[3],
                'to_table': row[2],
                'to_column': row[4],
                'type': 'foreign_key'
            }
            relationships.append(relationship)
        
        # Check for potential relationships based on column names
        cursor.execute(f'PRAGMA table_info("{table}")')
        columns = [row[1] for row in cursor.fetchall()]
        
        for column in columns:
            # Look for ID references
            if column.endswith('_id') or column.endswith('_ID'):
                potential_table = column[:-3]
                # Check if this table exists
                for other_table in tables:
                    if other_table.lower() == potential_table.lower():
                        relationship = {
                            'from_table': table,
                            'from_column': column,
                            'to_table': other_table,
                            'to_column': 'id',  # Assumed
                            'type': 'potential'
                        }
                        relationships.append(relationship)
                        break
    
    return relationships

def identify_key_patterns(cursor, table, columns):
    """Identify key patterns in the table"""
    patterns = {}
    
    # Check for common patterns
    try:
        # Date/time columns
        date_columns = [c for c in columns if any(d in c.lower() for d in ['date', 'time', 'created', 'updated'])]
        patterns['date_columns'] = date_columns
        
        # Numeric columns (amounts, counts, etc.)
        numeric_columns = []
        for column in columns[:20]:  # Limit for performance
            cursor.execute(f'SELECT typeof("{column}") FROM "{table}" LIMIT 1')
            col_type = cursor.fetchone()
            if col_type and col_type[0] in ['integer', 'real']:
                numeric_columns.append(column)
        patterns['numeric_columns'] = numeric_columns
        
        # Text identifier columns
        id_columns = [c for c in columns if any(i in c.lower() for i in ['id', 'code', 'number', 'key'])]
        patterns['identifier_columns'] = id_columns
        
        # Status/category columns
        category_columns = []
        for column in columns[:20]:  # Limit for performance
            cursor.execute(f'SELECT COUNT(DISTINCT "{column}"), COUNT(*) FROM "{table}"')
            distinct, total = cursor.fetchone()
            if total > 0 and distinct < 20 and distinct > 1:  # Low cardinality
                category_columns.append(column)
        patterns['category_columns'] = category_columns
        
    except Exception as e:
        patterns['error'] = str(e)
    
    return patterns

def main():
    # Connect to database
    conn = sqlite3.connect('./database.sqlite')
    cursor = conn.cursor()
    
    # Create output directory
    os.makedirs('./tool_output', exist_ok=True)
    
    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [t[0] for t in cursor.fetchall()]
    
    results = {
        'schema_analysis': {},
        'relationships': [],
        'column_importance': {},
        'query_patterns': {}
    }
    
    print("Analyzing schema structure...")
    
    # Analyze each table
    for table in tables:
        cursor.execute(f'PRAGMA table_info("{table}")')
        columns = [row[1] for row in cursor.fetchall()]
        
        # Calculate column importance
        table_scores = {}
        for column in columns:
            score, details = calculate_column_importance(cursor, table, column)
            table_scores[column] = {
                'score': score,
                'details': details
            }
        
        results['column_importance'][table] = table_scores
        
        # Identify patterns
        patterns = identify_key_patterns(cursor, table, columns)
        results['query_patterns'][table] = patterns
        
        # Get table statistics
        cursor.execute(f'SELECT COUNT(*) FROM "{table}"')
        row_count = cursor.fetchone()[0]
        
        results['schema_analysis'][table] = {
            'columns': columns,
            'row_count': row_count,
            'column_count': len(columns)
        }
    
    # Analyze relationships
    print("Analyzing table relationships...")
    results['relationships'] = analyze_relationships(cursor, tables)
    
    # Identify most important columns across database
    print("Identifying key columns...")
    all_scores = []
    for table, columns in results['column_importance'].items():
        for column, info in columns.items():
            all_scores.append({
                'table': table,
                'column': column,
                'score': info['score'],
                'full_name': f"{table}.{column}"
            })
    
    all_scores.sort(key=lambda x: x['score'], reverse=True)
    results['top_important_columns'] = all_scores[:20]
    
    # Save results
    with open('./tool_output/schema_analysis.json', 'w') as f:
        json.dump(results, f, indent=2, default=str)
    
    # Print summary
    print("\n=== SCHEMA ANALYSIS COMPLETE ===")
    print(f"Analyzed {len(tables)} tables")
    print(f"Found {len(results['relationships'])} relationships")
    print("\nTop 10 most important columns:")
    for col in results['top_important_columns'][:10]:
        print(f"  {col['full_name']}: score={col['score']}")
    
    print("\nRelationship summary:")
    fk_count = len([r for r in results['relationships'] if r['type'] == 'foreign_key'])
    pot_count = len([r for r in results['relationships'] if r['type'] == 'potential'])
    print(f"  Foreign keys: {fk_count}")
    print(f"  Potential relationships: {pot_count}")
    
    print("\nResults saved to ./tool_output/schema_analysis.json")
    
    conn.close()

if __name__ == "__main__":
    main()