#!/usr/bin/env python3
"""
Enhanced Database Fingerprinter
Identifies database type, patterns, and characteristics for adaptive query generation.
Enhanced with better date/time detection and value precision mapping.
"""

import sqlite3
import json
import os
from collections import defaultdict, Counter
import re

def analyze_database(db_path):
    """Main analysis function with enhanced pattern detection."""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [row[0] for row in cursor.fetchall()]

    fingerprint = {
        'database_type': 'unknown',
        'complexity_level': 'unknown',
        'key_patterns': [],
        'aggregation_contexts': {},
        'special_characteristics': [],
        'date_time_patterns': {},
        'value_patterns': {},
        'join_hints': {},
        'table_count': len(tables),
        'total_rows': 0,
        'table_analysis': {},
        'optimization_approach': 'standard'
    }

    # Analyze each table
    for table in tables:
        table_info = analyze_table(cursor, table)
        fingerprint['table_analysis'][table] = table_info
        fingerprint['total_rows'] += table_info['row_count']

    # Detect database type and patterns
    detect_database_type(fingerprint, tables)
    detect_complexity_level(fingerprint, tables)
    detect_aggregation_contexts(fingerprint)
    detect_special_patterns(fingerprint)
    detect_date_time_patterns(cursor, fingerprint, tables)
    detect_value_precision_patterns(cursor, fingerprint, tables)
    generate_join_hints(fingerprint)
    determine_optimization_approach(fingerprint)

    conn.close()
    return fingerprint

def analyze_table(cursor, table_name):
    """Analyze individual table with enhanced column detection."""
    # Get column info
    cursor.execute(f"PRAGMA table_info({table_name})")
    columns = cursor.fetchall()

    # Get row count
    cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
    row_count = cursor.fetchone()[0]

    # Get sample data for pattern detection
    cursor.execute(f"SELECT * FROM {table_name} LIMIT 5")
    sample_data = cursor.fetchall()

    table_info = {
        'row_count': row_count,
        'columns': {},
        'primary_key': [],
        'foreign_keys': [],
        'has_stats': False,
        'has_dates': False,
        'has_locations': False,
        'has_categories': False
    }

    for col in columns:
        col_name = col[1]
        col_type = col[2]
        is_primary = col[5] == 1

        table_info['columns'][col_name] = {
            'type': col_type,
            'is_primary': is_primary,
            'is_foreign_key': False,
            'is_stat': False,
            'is_date': False,
            'is_location': False,
            'is_category': False,
            'sample_values': []
        }

        if is_primary:
            table_info['primary_key'].append(col_name)

        # Detect column patterns
        detect_column_patterns(col_name, col_type, table_info['columns'][col_name])

        # Add sample values
        if sample_data:
            for row in sample_data[:3]:
                try:
                    col_index = [c[1] for c in columns].index(col_name)
                    if col_index < len(row) and row[col_index] is not None:
                        table_info['columns'][col_name]['sample_values'].append(str(row[col_index]))
                except:
                    pass

    # Get foreign keys
    cursor.execute(f"PRAGMA foreign_key_list({table_name})")
    foreign_keys = cursor.fetchall()
    for fk in foreign_keys:
        fk_col = fk[3]
        ref_table = fk[2]
        ref_col = fk[4]
        table_info['foreign_keys'].append({
            'column': fk_col,
            'references': f"{ref_table}.{ref_col}"
        })
        if fk_col in table_info['columns']:
            table_info['columns'][fk_col]['is_foreign_key'] = True

    # Update table flags
    table_info['has_stats'] = any(col['is_stat'] for col in table_info['columns'].values())
    table_info['has_dates'] = any(col['is_date'] for col in table_info['columns'].values())
    table_info['has_locations'] = any(col['is_location'] for col in table_info['columns'].values())
    table_info['has_categories'] = any(col['is_category'] for col in table_info['columns'].values())

    return table_info

def detect_column_patterns(col_name, col_type, col_info):
    """Enhanced column pattern detection."""
    col_lower = col_name.lower()

    # Statistical columns
    stat_patterns = ['score', 'points', 'goals', 'assists', 'rebounds', 'yards',
                     'wins', 'losses', 'rating', 'rank', 'total', 'sum', 'count',
                     'avg', 'average', 'min', 'max', 'stats', 'metric']
    if any(p in col_lower for p in stat_patterns):
        col_info['is_stat'] = True

    # Date/time columns - ENHANCED
    date_patterns = ['date', 'time', 'year', 'month', 'day', 'timestamp',
                     'created', 'updated', 'modified', 'start', 'end', 'dob',
                     'birth', 'expired', 'valid']
    if any(p in col_lower for p in date_patterns) or 'DATETIME' in col_type.upper() or 'DATE' in col_type.upper():
        col_info['is_date'] = True

    # Location columns
    location_patterns = ['city', 'state', 'country', 'location', 'address',
                        'zip', 'postal', 'region', 'area', 'lat', 'long',
                        'latitude', 'longitude', 'place']
    if any(p in col_lower for p in location_patterns):
        col_info['is_location'] = True

    # Category columns
    category_patterns = ['type', 'category', 'class', 'group', 'status',
                        'level', 'tier', 'division', 'department', 'role',
                        'position', 'style', 'genre', 'kind']
    if any(p in col_lower for p in category_patterns):
        col_info['is_category'] = True

def detect_database_type(fingerprint, tables):
    """Enhanced database type detection."""
    table_names_lower = [t.lower() for t in tables]
    all_columns = []
    for table_info in fingerprint['table_analysis'].values():
        all_columns.extend(table_info['columns'].keys())
    columns_lower = [c.lower() for c in all_columns]

    # Sports database detection
    sports_indicators = ['player', 'team', 'game', 'match', 'season', 'league',
                        'score', 'goals', 'assists', 'rebounds', 'yards', 'stats']
    sports_score = sum(1 for ind in sports_indicators if any(ind in t for t in table_names_lower + columns_lower))

    # Geographic database detection
    geo_indicators = ['city', 'country', 'state', 'region', 'location', 'continent',
                     'population', 'area', 'capital', 'border']
    geo_score = sum(1 for ind in geo_indicators if any(ind in t for t in table_names_lower + columns_lower))

    # Business database detection
    business_indicators = ['customer', 'order', 'product', 'invoice', 'payment',
                          'employee', 'department', 'salary', 'sales', 'revenue']
    business_score = sum(1 for ind in business_indicators if any(ind in t for t in table_names_lower + columns_lower))

    # Transportation database detection - NEW
    transport_indicators = ['trip', 'station', 'driver', 'vehicle', 'route',
                           'shipment', 'delivery', 'bike', 'ride', 'journey']
    transport_score = sum(1 for ind in transport_indicators if any(ind in t for t in table_names_lower + columns_lower))

    # Recipe/Food database detection - NEW
    food_indicators = ['recipe', 'ingredient', 'nutrition', 'cooking', 'food',
                      'meal', 'dish', 'cuisine', 'flavor', 'taste']
    food_score = sum(1 for ind in food_indicators if any(ind in t for t in table_names_lower + columns_lower))

    # Determine type
    scores = {
        'sports': sports_score,
        'geographic': geo_score,
        'business': business_score,
        'transportation': transport_score,
        'recipe_food': food_score
    }

    max_score = max(scores.values())
    if max_score >= 3:
        fingerprint['database_type'] = max(scores, key=scores.get)
    else:
        fingerprint['database_type'] = 'general'

    # Add specific patterns
    if fingerprint['database_type'] == 'sports':
        fingerprint['key_patterns'].extend(['player_stats', 'team_aggregations', 'season_tracking'])
    elif fingerprint['database_type'] == 'geographic':
        fingerprint['key_patterns'].extend(['hierarchical_locations', 'population_analysis', 'area_calculations'])
    elif fingerprint['database_type'] == 'business':
        fingerprint['key_patterns'].extend(['transaction_tracking', 'customer_analysis', 'time_series'])
    elif fingerprint['database_type'] == 'transportation':
        fingerprint['key_patterns'].extend(['trip_analysis', 'station_connections', 'duration_calculations'])
    elif fingerprint['database_type'] == 'recipe_food':
        fingerprint['key_patterns'].extend(['ingredient_combinations', 'nutrition_analysis', 'recipe_matching'])

def detect_aggregation_contexts(fingerprint):
    """Detect natural aggregation levels."""
    contexts = {
        'individual_level': [],
        'group_level': [],
        'overall_level': [],
        'temporal_level': [],
        'combined_stats': []
    }

    for table_name, table_info in fingerprint['table_analysis'].items():
        # Individual level detection
        if 'player' in table_name.lower() or 'person' in table_name.lower() or 'user' in table_name.lower():
            contexts['individual_level'].append(table_name)

        # Group level detection
        if 'team' in table_name.lower() or 'department' in table_name.lower() or 'group' in table_name.lower():
            contexts['group_level'].append(table_name)

        # Temporal aggregation detection
        if table_info['has_dates']:
            contexts['temporal_level'].append(table_name)

        # Combined stats detection
        stat_columns = [col for col, info in table_info['columns'].items() if info['is_stat']]
        if len(stat_columns) > 1:
            # Look for related stats that should be combined
            if any('rebound' in col.lower() for col in stat_columns):
                contexts['combined_stats'].append('rebounds (offensive + defensive)')
            if any('yards' in col.lower() for col in stat_columns):
                contexts['combined_stats'].append('total_yards (passing + rushing)')

    fingerprint['aggregation_contexts'] = contexts

def detect_special_patterns(fingerprint):
    """Detect special patterns and characteristics."""
    # Check for hierarchical structures
    for table_info in fingerprint['table_analysis'].values():
        if any('parent' in col.lower() or 'child' in col.lower() for col in table_info['columns']):
            fingerprint['special_characteristics'].append('hierarchical_structure')
            break

    # Check for time-series data
    date_tables = sum(1 for t in fingerprint['table_analysis'].values() if t['has_dates'])
    if date_tables > len(fingerprint['table_analysis']) * 0.5:
        fingerprint['special_characteristics'].append('time_series_heavy')

    # Check for statistical database
    stat_tables = sum(1 for t in fingerprint['table_analysis'].values() if t['has_stats'])
    if stat_tables > len(fingerprint['table_analysis']) * 0.5:
        fingerprint['special_characteristics'].append('statistics_focused')

def detect_date_time_patterns(cursor, fingerprint, tables):
    """Enhanced date/time pattern detection with format analysis."""
    date_patterns = {}

    for table in tables:
        table_info = fingerprint['table_analysis'][table]
        for col_name, col_info in table_info['columns'].items():
            if col_info['is_date'] and col_info['sample_values']:
                # Analyze date format from samples
                formats_detected = []
                for val in col_info['sample_values'][:3]:
                    if val:
                        # Check for common patterns
                        if re.match(r'\d{1,2}/\d{1,2}/\d{4}', val):
                            formats_detected.append('M/D/YYYY')
                        elif re.match(r'\d{4}-\d{2}-\d{2}', val):
                            formats_detected.append('YYYY-MM-DD')
                        elif re.match(r'\d{1,2}/\d{1,2}/\d{4} \d{1,2}:\d{2}', val):
                            formats_detected.append('M/D/YYYY H:MM')
                        elif re.match(r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}', val):
                            formats_detected.append('YYYY-MM-DD HH:MM:SS')

                if formats_detected:
                    most_common = Counter(formats_detected).most_common(1)[0][0]
                    date_patterns[f"{table}.{col_name}"] = {
                        'format': most_common,
                        'sample': col_info['sample_values'][0] if col_info['sample_values'] else None
                    }

    fingerprint['date_time_patterns'] = date_patterns

def detect_value_precision_patterns(cursor, fingerprint, tables):
    """Extract exact categorical values for precision matching."""
    value_patterns = {}

    for table in tables:
        table_info = fingerprint['table_analysis'][table]
        for col_name, col_info in table_info['columns'].items():
            if col_info['is_category'] or (col_info['type'] in ['TEXT', 'VARCHAR'] and not col_info['is_date']):
                # Get distinct values for categorical columns
                try:
                    cursor.execute(f"SELECT DISTINCT {col_name} FROM {table} WHERE {col_name} IS NOT NULL LIMIT 50")
                    values = [row[0] for row in cursor.fetchall()]
                    if values and len(values) < 50:  # Only store if reasonable number of categories
                        value_patterns[f"{table}.{col_name}"] = {
                            'distinct_values': values[:20],  # Limit to 20 for output
                            'total_distinct': len(values),
                            'case_sensitive': any(v != v.upper() and v != v.lower() for v in values if isinstance(v, str))
                        }
                except:
                    pass

    fingerprint['value_patterns'] = value_patterns

def generate_join_hints(fingerprint):
    """Generate hints for optimal join strategies."""
    join_hints = {}

    # Look for station-based joins (transportation pattern)
    if fingerprint['database_type'] == 'transportation':
        join_hints['station_joins'] = 'Prefer name-based joins over ID-based when both are available'
        join_hints['trip_patterns'] = 'Check for start/end station patterns'

    # Look for weather-related joins
    for table_name in fingerprint['table_analysis']:
        if 'weather' in table_name.lower():
            join_hints['weather_joins'] = 'Join on date AND location (zip_code/city)'
            break

    # Look for hierarchical joins
    if 'hierarchical_structure' in fingerprint['special_characteristics']:
        join_hints['hierarchy'] = 'Use recursive CTEs or multiple joins for hierarchy traversal'

    fingerprint['join_hints'] = join_hints

def detect_complexity_level(fingerprint, tables):
    """Detect database complexity level for selective optimization."""
    complexity_score = 0

    # Factor 1: Number of tables
    if fingerprint['table_count'] < 5:
        complexity_score += 0
    elif fingerprint['table_count'] < 15:
        complexity_score += 1
    else:
        complexity_score += 2

    # Factor 2: Foreign key relationships
    total_foreign_keys = sum(
        len(t.get('foreign_keys', []))
        for t in fingerprint['table_analysis'].values()
    )
    if total_foreign_keys < 5:
        complexity_score += 0
    elif total_foreign_keys < 20:
        complexity_score += 1
    else:
        complexity_score += 2

    # Factor 3: Date/time complexity
    if fingerprint['date_time_patterns']:
        complexity_score += 1

    # Factor 4: Special characteristics
    if 'hierarchical_structure' in fingerprint['special_characteristics']:
        complexity_score += 1
    if 'time_series_heavy' in fingerprint['special_characteristics']:
        complexity_score += 1

    # Factor 5: Database type complexity
    complex_types = ['transportation', 'business', 'geographic']
    if fingerprint['database_type'] in complex_types:
        complexity_score += 1

    # Determine level
    if complexity_score <= 2:
        fingerprint['complexity_level'] = 'simple'
    elif complexity_score <= 4:
        fingerprint['complexity_level'] = 'moderate'
    else:
        fingerprint['complexity_level'] = 'complex'

    return fingerprint

def determine_optimization_approach(fingerprint):
    """Determine the optimization approach based on complexity."""
    if fingerprint['complexity_level'] == 'simple':
        fingerprint['optimization_approach'] = 'minimal'
        fingerprint['optimization_notes'] = [
            'Use simple, direct queries',
            'Avoid unnecessary joins',
            'Skip complex pattern matching',
            'Focus on straightforward column selection'
        ]
    elif fingerprint['complexity_level'] == 'moderate':
        fingerprint['optimization_approach'] = 'balanced'
        fingerprint['optimization_notes'] = [
            'Use pattern matching where helpful',
            'Apply targeted optimizations',
            'Balance simplicity with accuracy',
            'Use advanced features selectively'
        ]
    else:  # complex
        fingerprint['optimization_approach'] = 'comprehensive'
        fingerprint['optimization_notes'] = [
            'Apply all available optimizations',
            'Use sophisticated pattern matching',
            'Leverage all analysis tools',
            'Focus on handling edge cases'
        ]

    # Add specific recommendations
    if fingerprint['database_type'] == 'transportation':
        fingerprint['optimization_notes'].append('Pay special attention to station name joins')
        fingerprint['optimization_notes'].append('Handle date/time formats carefully')
    elif fingerprint['database_type'] == 'business':
        fingerprint['optimization_notes'].append('Distinguish between Sales and Purchase tables')
        fingerprint['optimization_notes'].append('Navigate through Customer for territories')

def main():
    """Main execution function."""
    db_path = "./database.sqlite"

    if not os.path.exists(db_path):
        print(f"Error: Database not found at {db_path}")
        return

    # Create output directory
    os.makedirs("tool_output", exist_ok=True)

    # Analyze database
    fingerprint = analyze_database(db_path)

    # Save results
    output_path = "tool_output/database_fingerprint.json"
    with open(output_path, 'w') as f:
        json.dump(fingerprint, f, indent=2)

    print(f"✅ Database fingerprinting complete")
    print(f"📊 Database type: {fingerprint['database_type']}")
    print(f"🎯 Complexity level: {fingerprint['complexity_level']}")
    print(f"⚙️ Optimization approach: {fingerprint['optimization_approach']}")
    print(f"🔑 Key patterns: {', '.join(fingerprint['key_patterns'])}")
    print(f"📅 Date patterns found: {len(fingerprint['date_time_patterns'])}")
    print(f"💾 Results saved to: {output_path}")

if __name__ == "__main__":
    main()