#!/usr/bin/env python3
"""
Relationship Detector - New tool to handle complex relationships
Addresses weaknesses in databases like hockey with many-to-many relationships
"""

import sqlite3
import os

def detect_relationships():
    """Detect and analyze database relationships including implicit ones."""
    
    try:
        conn = sqlite3.connect("database.sqlite")
        cursor = conn.cursor()
        
        output = []
        output.append("# RELATIONSHIP DETECTION")
        output.append("(Advanced relationship analysis for complex joins)")
        output.append("")
        
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
        tables = [row[0] for row in cursor.fetchall()]
        
        # Detect junction tables (many-to-many)
        output.append("## JUNCTION TABLES")
        output.append("(Tables that connect many-to-many relationships)")
        output.append("")
        
        junction_tables = []
        for table in tables:
            table_escaped = f"`{table}`" if any(c in table for c in [' ', '-', '.']) else table
            
            try:
                cursor.execute(f"PRAGMA table_info({table_escaped})")
                columns = cursor.fetchall()
                
                # Count ID-like columns
                id_columns = []
                for col in columns:
                    col_name = col[1].lower()
                    if col_name.endswith('_id') or col_name.endswith('id') or 'fk' in col_name:
                        id_columns.append(col[1])
                
                # Junction tables typically have 2+ foreign keys and few other columns
                if len(id_columns) >= 2 and len(columns) <= len(id_columns) + 2:
                    junction_tables.append((table, id_columns))
                    output.append(f"**{table}**:")
                    output.append(f"  Connects: {' ↔ '.join(id_columns)}")
                    
                    # Sample the cardinality
                    for id_col in id_columns[:2]:
                        id_escaped = f"`{id_col}`" if any(c in id_col for c in [' ', '-', '.']) else id_col
                        cursor.execute(f"""
                            SELECT COUNT(DISTINCT {id_escaped}), COUNT(*)
                            FROM {table_escaped}
                        """)
                        distinct, total = cursor.fetchone()
                        if distinct and total:
                            ratio = total / distinct if distinct > 0 else 0
                            output.append(f"  {id_col}: {distinct} unique values, avg {ratio:.1f} per value")
                    output.append("")
            except:
                pass
        
        if not junction_tables:
            output.append("No obvious junction tables detected.")
        output.append("")
        
        # Detect implicit relationships
        output.append("## IMPLICIT RELATIONSHIPS")
        output.append("(Based on column name patterns)")
        output.append("")
        
        implicit_rels = []
        for table in tables:
            table_escaped = f"`{table}`" if any(c in table for c in [' ', '-', '.']) else table
            
            try:
                cursor.execute(f"PRAGMA table_info({table_escaped})")
                columns = cursor.fetchall()
                
                for col in columns:
                    col_name = col[1]
                    col_lower = col_name.lower()
                    
                    # Look for ID references
                    if col_lower.endswith('_id') or col_lower.endswith('id'):
                        # Extract potential table name
                        if col_lower.endswith('_id'):
                            potential_table = col_name[:-3]
                        else:
                            potential_table = col_name[:-2]
                        
                        # Check if matching table exists
                        for other_table in tables:
                            if other_table.lower() == potential_table.lower() or \
                               other_table.lower() == potential_table.lower() + 's' or \
                               other_table.lower() == potential_table.lower() + 'es' or \
                               other_table.lower().replace('_', '') == potential_table.lower().replace('_', ''):
                                implicit_rels.append((table, col_name, other_table))
                                break
            except:
                pass
        
        if implicit_rels:
            # Group by source table
            by_table = {}
            for src, col, tgt in implicit_rels:
                if src not in by_table:
                    by_table[src] = []
                by_table[src].append((col, tgt))
            
            for table, rels in by_table.items():
                output.append(f"**{table}**:")
                for col, target in rels:
                    output.append(f"  {col} → {target} (likely)")
                output.append("")
        else:
            output.append("No implicit relationships detected.")
        output.append("")
        
        # Generate JOIN paths
        output.append("## RECOMMENDED JOIN PATHS")
        output.append("")
        
        # Find explicit foreign keys
        explicit_fks = []
        for table in tables:
            table_escaped = f"`{table}`" if any(c in table for c in [' ', '-', '.']) else table
            try:
                cursor.execute(f"PRAGMA foreign_key_list({table_escaped})")
                fks = cursor.fetchall()
                for fk in fks:
                    explicit_fks.append((table, fk[3], fk[2], fk[4]))
            except:
                pass
        
        if explicit_fks:
            output.append("**Based on Foreign Keys:**")
            for from_table, from_col, to_table, to_col in explicit_fks[:10]:
                # Escape table names if needed
                from_escaped = f"`{from_table}`" if any(c in from_table for c in [' ', '-', '.']) else from_table
                to_escaped = f"`{to_table}`" if any(c in to_table for c in [' ', '-', '.']) else to_table
                
                output.append(f"```sql")
                output.append(f"FROM {from_escaped} t1")
                output.append(f"JOIN {to_escaped} t2 ON t1.{from_col} = t2.{to_col}")
                output.append(f"```")
                output.append("")
        
        if junction_tables:
            output.append("**Many-to-Many via Junction Tables:**")
            for junction, id_cols in junction_tables[:5]:
                junction_escaped = f"`{junction}`" if any(c in junction for c in [' ', '-', '.']) else junction
                output.append(f"```sql")
                output.append(f"-- Connect tables via {junction}")
                output.append(f"FROM table1 t1")
                output.append(f"JOIN {junction_escaped} j ON t1.id = j.{id_cols[0] if id_cols else 'table1_id'}")
                output.append(f"JOIN table2 t2 ON j.{id_cols[1] if len(id_cols) > 1 else 'table2_id'} = t2.id")
                output.append(f"```")
                output.append("")
        
        # Cardinality warnings
        output.append("## CARDINALITY WARNINGS")
        output.append("")
        output.append("- Junction tables may produce multiple rows per entity")
        output.append("- Use DISTINCT or GROUP BY to avoid duplicates")
        output.append("- Check if COUNT(*) or COUNT(DISTINCT column) is needed")
        output.append("")
        
        conn.close()
        
        # Write output
        os.makedirs("tool_output", exist_ok=True)
        with open("tool_output/relationship_detector_output.txt", "w", encoding='utf-8') as f:
            f.write("\n".join(output))
        
        print(f"Relationship detection complete - {len(junction_tables)} junction tables, {len(implicit_rels)} implicit relationships")
        
    except Exception as e:
        # This is non-critical, so just log the error
        print(f"Relationship detection failed: {e}")
        
        # Write minimal output
        os.makedirs("tool_output", exist_ok=True)
        with open("tool_output/relationship_detector_output.txt", "w") as f:
            f.write("# RELATIONSHIP DETECTION\n")
            f.write("Could not detect relationships.\n")
            f.write("Use column name patterns to infer joins.\n")

if __name__ == "__main__":
    detect_relationships()