#!/usr/bin/env python3
"""
Detect foreign keys and potential join columns in the database.
This tool is available to subagents for database analysis.
"""

import sqlite3
import json
import sys
from pathlib import Path
from collections import defaultdict

def find_relationships(db_path: str) -> dict:
    """Find foreign keys and potential relationships between tables."""
    conn = sqlite3.connect(f'file:{db_path}?mode=ro', uri=True)
    cursor = conn.cursor()
    
    relationships = {
        'explicit_foreign_keys': [],
        'potential_joins': [],
        'common_column_names': defaultdict(list)
    }
    
    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
    tables = [row[0] for row in cursor.fetchall()]
    
    # Track column names across tables
    table_columns = {}
    
    for table in tables:
        # Get explicit foreign keys
        cursor.execute(f'PRAGMA foreign_key_list("{table}")')
        for fk in cursor.fetchall():
            relationships['explicit_foreign_keys'].append({
                'from_table': table,
                'from_column': fk[3],
                'to_table': fk[2],
                'to_column': fk[4]
            })
        
        # Get all columns
        cursor.execute(f'PRAGMA table_info("{table}")')
        columns = []
        for col in cursor.fetchall():
            col_name = col[1]
            columns.append(col_name)
            relationships['common_column_names'][col_name].append(table)
        table_columns[table] = columns
    
    # Find potential joins based on column names
    # Look for columns with same names or ID patterns
    for col_name, tables_with_col in relationships['common_column_names'].items():
        if len(tables_with_col) > 1:
            # Same column name in multiple tables
            for i, table1 in enumerate(tables_with_col):
                for table2 in tables_with_col[i+1:]:
                    relationships['potential_joins'].append({
                        'table1': table1,
                        'table2': table2,
                        'column': col_name,
                        'confidence': 'high' if 'id' in col_name.lower() else 'medium'
                    })
    
    # Look for table_id patterns
    for table in tables:
        potential_fk_name = f"{table}_id"
        potential_fk_name_alt = f"{table}id"
        
        for other_table in tables:
            if other_table != table:
                cursor.execute(f'PRAGMA table_info("{other_table}")')
                for col in cursor.fetchall():
                    col_name = col[1]
                    if col_name.lower() in [potential_fk_name.lower(), potential_fk_name_alt.lower()]:
                        # Check if not already in explicit foreign keys
                        already_exists = any(
                            fk['from_table'] == other_table and 
                            fk['from_column'] == col_name and
                            fk['to_table'] == table
                            for fk in relationships['explicit_foreign_keys']
                        )
                        if not already_exists:
                            relationships['potential_joins'].append({
                                'table1': other_table,
                                'table2': table,
                                'column1': col_name,
                                'column2': 'id',  # Assuming primary key
                                'confidence': 'high',
                                'pattern': 'table_id'
                            })
    
    conn.close()
    return relationships

def main():
    if len(sys.argv) != 2:
        print("Usage: python find_relationships.py <database.sqlite>")
        sys.exit(1)
    
    db_path = sys.argv[1]
    if not Path(db_path).exists():
        print(f"Error: Database not found: {db_path}")
        sys.exit(1)
    
    relationships = find_relationships(db_path)
    
    # Convert defaultdict to regular dict for JSON serialization
    relationships['common_column_names'] = dict(relationships['common_column_names'])
    
    print(json.dumps(relationships, indent=2))

if __name__ == "__main__":
    main()