#!/usr/bin/env python3
"""
Discover and analyze relationships between tables in the database.
"""

import sqlite3
import json
import sys
from pathlib import Path
from collections import defaultdict

def find_relationships(db_path="database.sqlite"):
    """Find and analyze table relationships."""
    
    relationships = {
        "foreign_keys": [],
        "inferred_relationships": [],
        "relationship_graph": {},
        "table_dependencies": {}
    }
    
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        
        # Get all tables
        cursor.execute("""
            SELECT name FROM sqlite_master 
            WHERE type='table' 
            AND name NOT LIKE 'sqlite_%'
            ORDER BY name
        """)
        tables = [t[0] for t in cursor.fetchall()]
        
        # Initialize relationship graph
        for table in tables:
            relationships["relationship_graph"][table] = {
                "references": [],  # Tables this table references
                "referenced_by": []  # Tables that reference this table
            }
        
        # Find explicit foreign keys
        for table in tables:
            cursor.execute(f"PRAGMA foreign_key_list('{table}')")
            fks = cursor.fetchall()
            
            for fk in fks:
                fk_info = {
                    "from": f"{table}.{fk[3]}",
                    "to": f"{fk[2]}.{fk[4]}",
                    "from_table": table,
                    "from_column": fk[3],
                    "to_table": fk[2],
                    "to_column": fk[4],
                    "type": "foreign_key"
                }
                relationships["foreign_keys"].append(fk_info)
                
                # Update graph
                relationships["relationship_graph"][table]["references"].append({
                    "table": fk[2],
                    "via": fk[3],
                    "to": fk[4]
                })
                relationships["relationship_graph"][fk[2]]["referenced_by"].append({
                    "table": table,
                    "via": fk[4],
                    "from": fk[3]
                })
        
        # Infer relationships from column names
        inferred = infer_relationships(cursor, tables)
        relationships["inferred_relationships"] = inferred
        
        # Build dependency order
        relationships["table_dependencies"] = build_dependency_order(
            relationships["relationship_graph"]
        )
        
        # Analyze join patterns
        relationships["join_patterns"] = analyze_join_patterns(
            cursor, relationships["foreign_keys"], relationships["inferred_relationships"]
        )
        
        conn.close()
        
    except sqlite3.Error as e:
        return {"error": f"Database error: {str(e)}"}
    except Exception as e:
        return {"error": f"Unexpected error: {str(e)}"}
    
    return relationships

def infer_relationships(cursor, tables):
    """Infer relationships based on column naming patterns."""
    
    inferred = []
    table_columns = {}
    
    # Get all columns for each table
    for table in tables:
        cursor.execute(f"PRAGMA table_info('{table}')")
        columns = cursor.fetchall()
        table_columns[table] = [(col[1], col[2]) for col in columns]
    
    # Look for common patterns
    for table1 in tables:
        for col1_name, col1_type in table_columns[table1]:
            # Pattern 1: table_id references table.id
            if col1_name.endswith('_id'):
                potential_table = col1_name[:-3]  # Remove '_id'
                if potential_table in tables:
                    # Check if target table has 'id' column
                    target_cols = [c[0] for c in table_columns[potential_table]]
                    if 'id' in target_cols or 'ID' in target_cols:
                        inferred.append({
                            "from": f"{table1}.{col1_name}",
                            "to": f"{potential_table}.id",
                            "from_table": table1,
                            "from_column": col1_name,
                            "to_table": potential_table,
                            "to_column": "id",
                            "type": "inferred_naming",
                            "confidence": "high"
                        })
            
            # Pattern 2: Matching column names in different tables
            for table2 in tables:
                if table1 != table2:
                    for col2_name, col2_type in table_columns[table2]:
                        if col1_name == col2_name and col1_type == col2_type:
                            # Could be a relationship
                            if 'id' in col1_name.lower() or 'code' in col1_name.lower():
                                inferred.append({
                                    "from": f"{table1}.{col1_name}",
                                    "to": f"{table2}.{col2_name}",
                                    "from_table": table1,
                                    "from_column": col1_name,
                                    "to_table": table2,
                                    "to_column": col2_name,
                                    "type": "inferred_matching",
                                    "confidence": "medium"
                                })
    
    # Remove duplicates
    seen = set()
    unique_inferred = []
    for rel in inferred:
        key = (rel["from"], rel["to"])
        if key not in seen:
            seen.add(key)
            unique_inferred.append(rel)
    
    return unique_inferred

def build_dependency_order(graph):
    """Build table dependency order for joins."""
    
    dependencies = {}
    
    # Calculate dependency depth for each table
    for table in graph:
        dependencies[table] = {
            "depth": calculate_depth(table, graph),
            "direct_dependencies": len(graph[table]["references"]),
            "direct_dependents": len(graph[table]["referenced_by"])
        }
    
    # Sort tables by dependency depth
    sorted_tables = sorted(
        dependencies.items(),
        key=lambda x: (x[1]["depth"], -x[1]["direct_dependents"])
    )
    
    return {
        "dependency_order": [t[0] for t in sorted_tables],
        "table_metrics": dependencies
    }

def calculate_depth(table, graph, visited=None):
    """Calculate dependency depth for a table."""
    
    if visited is None:
        visited = set()
    
    if table in visited:
        return 0  # Circular dependency
    
    visited.add(table)
    
    if not graph[table]["references"]:
        return 0  # Base table
    
    max_depth = 0
    for ref in graph[table]["references"]:
        ref_table = ref["table"]
        if ref_table in graph:
            depth = calculate_depth(ref_table, graph, visited.copy())
            max_depth = max(max_depth, depth + 1)
    
    return max_depth

def analyze_join_patterns(cursor, foreign_keys, inferred_rels):
    """Analyze common join patterns."""
    
    patterns = {
        "one_to_many": [],
        "many_to_many": [],
        "self_joins": [],
        "join_paths": {}
    }
    
    # Identify one-to-many relationships
    for fk in foreign_keys:
        patterns["one_to_many"].append({
            "parent": fk["to_table"],
            "child": fk["from_table"],
            "via": f"{fk['from_column']} = {fk['to_column']}"
        })
    
    # Identify self-joins
    for fk in foreign_keys:
        if fk["from_table"] == fk["to_table"]:
            patterns["self_joins"].append({
                "table": fk["from_table"],
                "join_condition": f"{fk['from_column']} = {fk['to_column']}"
            })
    
    # Look for potential many-to-many (junction tables)
    # Tables with multiple foreign keys might be junction tables
    table_fk_count = defaultdict(list)
    for fk in foreign_keys:
        table_fk_count[fk["from_table"]].append(fk)
    
    for table, fks in table_fk_count.items():
        if len(fks) >= 2:
            # Might be a junction table
            referenced_tables = list(set(fk["to_table"] for fk in fks))
            if len(referenced_tables) >= 2:
                patterns["many_to_many"].append({
                    "junction_table": table,
                    "connects": referenced_tables,
                    "via_columns": [fk["from_column"] for fk in fks]
                })
    
    # Build join paths between commonly joined tables
    all_relationships = foreign_keys + inferred_rels
    for rel in all_relationships[:10]:  # Limit to prevent explosion
        path_key = f"{rel['from_table']}_to_{rel['to_table']}"
        patterns["join_paths"][path_key] = {
            "direct": f"{rel['from_table']}.{rel['from_column']} = {rel['to_table']}.{rel['to_column']}",
            "type": rel["type"]
        }
    
    return patterns

def main():
    """Run relationship analysis and output JSON."""
    relationships = find_relationships()
    
    # Output as JSON
    print(json.dumps(relationships, indent=2))
    
    # Also save to file
    output_dir = Path("tool_output")
    output_dir.mkdir(exist_ok=True)
    
    with open(output_dir / "relationships.json", 'w') as f:
        json.dump(relationships, f, indent=2)

if __name__ == "__main__":
    main()