#!/usr/bin/env python3
"""
Enhanced Junction Table Analyzer with Foreign Key Focus
Builds on iter8's approach with better foreign key detection and join patterns.
"""

import sqlite3
import os
import json
from collections import defaultdict

def ensure_output_dir():
    """Ensure tool_output directory exists."""
    os.makedirs('tool_output', exist_ok=True)

def analyze_junction_patterns(db_path):
    """Analyze tables to identify junction tables and optimal join patterns."""
    
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    analysis = {
        'junction_tables': {},
        'base_tables': {},
        'table_purposes': {},
        'column_locations': defaultdict(list),
        'foreign_key_joins': [],
        'optimal_join_paths': [],
        'column_confusion_warnings': [],
        'critical_guidance': [],
        'date_time_columns': [],
        'award_patterns': []
    }
    
    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [row[0] for row in cursor.fetchall()]
    
    # Track columns that appear in multiple tables
    column_occurrences = defaultdict(list)
    
    # Analyze each table
    for table in tables:
        table_info = {
            'foreign_keys': [],
            'primary_keys': [],
            'columns': {},
            'row_count': 0,
            'is_junction': False,
            'linked_tables': [],
            'data_columns': [],
            'sample_values': {}
        }
        
        # Get row count
        cursor.execute(f"SELECT COUNT(*) FROM `{table}`")
        table_info['row_count'] = cursor.fetchone()[0]
        
        # Get column info
        cursor.execute(f"PRAGMA table_info(`{table}`)")
        columns = cursor.fetchall()
        
        for col in columns:
            col_name = col[1]
            col_type = col[2]
            is_pk = col[5]
            is_nullable = not col[3]
            
            table_info['columns'][col_name] = {
                'type': col_type,
                'is_primary': bool(is_pk),
                'is_nullable': is_nullable
            }
            
            if is_pk:
                table_info['primary_keys'].append(col_name)
            
            # Track column locations
            analysis['column_locations'][col_name.lower()].append(f"{table}.{col_name}")
            column_occurrences[col_name.lower()].append((table, col_name))
            
            # Identify date/time columns
            if any(dt in col_name.lower() for dt in ['date', 'time', 'year', 'month', 'day']):
                analysis['date_time_columns'].append({
                    'table': table,
                    'column': col_name,
                    'type': col_type
                })
            
            # Check for award/result columns
            if any(aw in col_name.lower() for aw in ['award', 'result', 'winner', 'rank']):
                analysis['award_patterns'].append({
                    'table': table,
                    'column': col_name
                })
        
        # Get foreign keys and create join templates
        cursor.execute(f"PRAGMA foreign_key_list(`{table}`)")
        fks = cursor.fetchall()
        for fk in fks:
            fk_info = {
                'from_column': fk[3],
                'to_table': fk[2],
                'to_column': fk[4],
                'join_template': f"JOIN `{fk[2]}` ON `{table}`.`{fk[3]}` = `{fk[2]}`.`{fk[4]}`"
            }
            table_info['foreign_keys'].append(fk_info)
            table_info['linked_tables'].append(fk[2])
            
            # Add to optimal join paths
            analysis['foreign_key_joins'].append({
                'from': table,
                'to': fk[2],
                'on': f"`{table}`.`{fk[3]}` = `{fk[2]}`.`{fk[4]}`",
                'template': fk_info['join_template']
            })
        
        # Identify data columns (non-key columns)
        fk_columns = [fk['from_column'] for fk in table_info['foreign_keys']]
        pk_columns = table_info['primary_keys']
        for col_name in table_info['columns']:
            if col_name not in fk_columns and col_name not in pk_columns:
                if not any(suffix in col_name.lower() for suffix in ['_id', 'id', '_no', '_code']):
                    table_info['data_columns'].append(col_name)
        
        # Determine if junction table
        if len(table_info['foreign_keys']) >= 2:
            table_info['is_junction'] = True
            analysis['junction_tables'][table] = table_info
            
            if table_info['data_columns']:
                analysis['table_purposes'][table] = f"Junction table WITH DATA: {', '.join(table_info['data_columns'][:3])}"
            else:
                analysis['table_purposes'][table] = f"Pure junction table (links only)"
        else:
            analysis['base_tables'][table] = table_info
            
            if table_info['foreign_keys']:
                analysis['table_purposes'][table] = f"Entity table with foreign key"
            else:
                analysis['table_purposes'][table] = "Base reference table"
    
    # Identify column confusion patterns
    identify_column_confusion(analysis, column_occurrences)
    
    # Generate critical guidance
    generate_critical_guidance(analysis)
    
    conn.close()
    
    # Create reports
    create_junction_report(analysis)
    create_foreign_key_report(analysis)
    
    return analysis

def identify_column_confusion(analysis, column_occurrences):
    """Identify columns that might cause confusion."""
    
    warnings = []
    
    # Check for columns that appear in multiple tables
    for col_name, occurrences in column_occurrences.items():
        if len(occurrences) > 1:
            # Special attention to commonly confused columns
            if any(key in col_name for key in ['name', 'title', 'type', 'id', 'date', 'issue', 'sub']):
                table_list = [f"{t}.{c}" for t, c in occurrences]
                warnings.append({
                    'column': col_name,
                    'tables': table_list,
                    'guidance': f"Column '{col_name}' in {len(occurrences)} tables - use exact table.column"
                })
    
    # Check for similar column names
    similar_pairs = [
        ('issue', 'sub-issue'),
        ('name', 'firstname'),
        ('name', 'lastname'),
        ('date', 'datetime'),
        ('type', 'category'),
        ('division', 'region')
    ]
    
    for pair1, pair2 in similar_pairs:
        cols1 = [k for k in column_occurrences.keys() if pair1 in k]
        cols2 = [k for k in column_occurrences.keys() if pair2 in k]
        
        if cols1 and cols2:
            warnings.append({
                'column': f"{pair1} vs {pair2}",
                'tables': cols1[:2] + cols2[:2],
                'guidance': f"Be precise: {pair1} and {pair2} are different columns"
            })
    
    analysis['column_confusion_warnings'] = warnings[:10]

def generate_critical_guidance(analysis):
    """Generate critical guidance for table selection and joins."""
    
    guidance = []
    
    # Guidance for junction tables with data
    for table_name, table_info in analysis['junction_tables'].items():
        if table_info['data_columns']:
            cols = ', '.join(table_info['data_columns'][:3])
            guidance.append(
                f"JUNCTION WITH DATA: {table_name} has columns ({cols}) - query directly"
            )
    
    # Foreign key join guidance
    if analysis['foreign_key_joins']:
        guidance.append(
            f"FOREIGN KEY JOINS: {len(analysis['foreign_key_joins'])} available - always use these"
        )
    
    # Date/time column guidance
    if analysis['date_time_columns']:
        date_tables = list(set([d['table'] for d in analysis['date_time_columns']]))
        guidance.append(
            f"DATE COLUMNS: Use STRFTIME for {', '.join(date_tables[:3])}"
        )
    
    # Award pattern guidance  
    if analysis['award_patterns']:
        award_tables = list(set([a['table'] for a in analysis['award_patterns']]))
        guidance.append(
            f"AWARD TABLES: {', '.join(award_tables[:3])} may need result='Winner' filter"
        )
    
    analysis['critical_guidance'] = guidance

def create_junction_report(analysis):
    """Create a formatted report about junction tables."""
    
    report = []
    report.append("# Junction Table Analysis")
    report.append("")
    
    # Critical guidance
    if analysis['critical_guidance']:
        report.append("## 🚨 CRITICAL GUIDANCE")
        report.append("")
        for guide in analysis['critical_guidance']:
            report.append(f"- {guide}")
        report.append("")
    
    # Junction tables
    if analysis['junction_tables']:
        report.append("## Junction Tables (Check for Data!)")
        report.append("")
        for table_name, table_info in analysis['junction_tables'].items():
            report.append(f"### {table_name}")
            report.append(f"- Purpose: {analysis['table_purposes'].get(table_name, 'Unknown')}")
            report.append(f"- Links: {' ↔ '.join(table_info['linked_tables'])}")
            if table_info['data_columns']:
                report.append(f"- ❗ DATA COLUMNS: {', '.join(table_info['data_columns'][:5])}")
                report.append(f"- **Query this table directly for these columns**")
            report.append(f"- Row Count: {table_info['row_count']:,}")
            report.append("")
    
    # Column confusion warnings
    if analysis['column_confusion_warnings']:
        report.append("## ⚠️ Column Confusion Warnings")
        report.append("")
        for warning in analysis['column_confusion_warnings'][:5]:
            report.append(f"- **{warning['column']}**: {warning['guidance']}")
        report.append("")
    
    # Save report
    ensure_output_dir()
    with open('tool_output/junction_analysis.txt', 'w') as f:
        f.write('\n'.join(report))

def create_foreign_key_report(analysis):
    """Create a report of optimal foreign key joins."""
    
    report = []
    report.append("# Foreign Key Join Templates")
    report.append("")
    report.append("## 🔑 Use These Join Patterns (Not Name Matching!)")
    report.append("")
    
    if analysis['foreign_key_joins']:
        for join in analysis['foreign_key_joins'][:20]:
            report.append(f"### {join['from']} → {join['to']}")
            report.append(f"```sql")
            report.append(join['template'])
            report.append(f"```")
            report.append("")
    
    # Save as JSON for programmatic use
    with open('tool_output/foreign_keys.json', 'w') as f:
        json.dump(analysis, f, indent=2, default=str)
    
    # Save text report
    with open('tool_output/foreign_key_joins.txt', 'w') as f:
        f.write('\n'.join(report))
    
    print("Enhanced junction analysis complete - results in tool_output/")

if __name__ == "__main__":
    analyze_junction_patterns("database.sqlite")