#!/usr/bin/env python3
"""
Relationship analyzer - discovers and documents all table relationships and join paths.
Enhanced with patterns from all three agents for comprehensive relationship mapping.
"""

import sqlite3
import json
from pathlib import Path
from collections import defaultdict

def analyze_relationships(conn):
    """Analyze all relationships and valid join paths."""
    cursor = conn.cursor()
    
    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
    tables = [row[0] for row in cursor.fetchall()]
    
    relationships = {
        "direct_relationships": [],
        "junction_tables": [],
        "join_paths": {},
        "relationship_types": {}
    }
    
    # Collect all foreign key relationships
    fk_map = defaultdict(list)
    reverse_fk_map = defaultdict(list)
    
    for table in tables:
        cursor.execute(f"PRAGMA foreign_key_list({table})")
        foreign_keys = cursor.fetchall()
        
        for fk in foreign_keys:
            from_table = table
            from_col = fk[3]
            to_table = fk[2]
            to_col = fk[4]
            
            rel = {
                "from_table": from_table,
                "from_column": from_col,
                "to_table": to_table,
                "to_column": to_col
            }
            
            relationships["direct_relationships"].append(rel)
            fk_map[from_table].append((to_table, from_col, to_col))
            reverse_fk_map[to_table].append((from_table, to_col, from_col))
    
    # Identify junction tables
    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()
        
        # Check if this looks like a junction table
        fk_columns = [col for col in columns if col[1].endswith('_id') or col[1].endswith('ID')]
        non_fk_columns = [col for col in columns if not (col[1].endswith('_id') or col[1].endswith('ID'))]
        
        if len(fk_columns) >= 2 and len(non_fk_columns) <= 2:
            # Likely a junction table
            cursor.execute(f"SELECT COUNT(*) FROM {table}")
            row_count = cursor.fetchone()[0]
            
            junction_info = {
                "table": table,
                "links": [],
                "row_count": row_count
            }
            
            # Find what it links
            for fk_col in fk_columns:
                col_name = fk_col[1]
                # Try to infer the referenced table
                potential_table = col_name.replace('_id', '').replace('ID', '')
                if potential_table in tables:
                    junction_info["links"].append(potential_table)
                elif potential_table + 's' in tables:
                    junction_info["links"].append(potential_table + 's')
                elif potential_table.rstrip('s') in tables:
                    junction_info["links"].append(potential_table.rstrip('s'))
            
            if len(junction_info["links"]) >= 2:
                relationships["junction_tables"].append(junction_info)
    
    # Determine relationship cardinality
    for rel in relationships["direct_relationships"]:
        from_table = rel["from_table"]
        to_table = rel["to_table"]
        from_col = rel["from_column"]
        to_col = rel["to_column"]
        
        # Check uniqueness to determine cardinality
        cursor.execute(f"SELECT COUNT(DISTINCT {from_col}) FROM {from_table}")
        distinct_from = cursor.fetchone()[0]
        
        cursor.execute(f"SELECT COUNT(*) FROM {from_table}")
        total_from = cursor.fetchone()[0]
        
        cursor.execute(f"SELECT COUNT(DISTINCT {to_col}) FROM {to_table}")
        distinct_to = cursor.fetchone()[0]
        
        cursor.execute(f"SELECT COUNT(*) FROM {to_table}")
        total_to = cursor.fetchone()[0]
        
        # Determine cardinality
        if distinct_from == total_from and distinct_to == total_to:
            cardinality = "1:1"
        elif distinct_from < total_from and distinct_to == total_to:
            cardinality = "many:1"
        elif distinct_from == total_from and distinct_to < total_to:
            cardinality = "1:many"
        else:
            cardinality = "many:many"
        
        rel["cardinality"] = cardinality
        
        # Store in relationship types
        rel_key = f"{from_table} → {to_table}"
        relationships["relationship_types"][rel_key] = cardinality
    
    # Find multi-step join paths
    def find_paths(start_table, end_table, visited=None, path=None):
        if visited is None:
            visited = set()
        if path is None:
            path = []
        
        if start_table == end_table:
            return [path]
        
        if start_table in visited:
            return []
        
        visited.add(start_table)
        paths = []
        
        # Check direct connections
        for next_table, from_col, to_col in fk_map.get(start_table, []):
            if next_table not in visited:
                new_path = path + [f"{start_table}.{from_col} → {next_table}.{to_col}"]
                found_paths = find_paths(next_table, end_table, visited.copy(), new_path)
                paths.extend(found_paths)
        
        # Check reverse connections
        for prev_table, to_col, from_col in reverse_fk_map.get(start_table, []):
            if prev_table not in visited:
                new_path = path + [f"{start_table}.{to_col} ← {prev_table}.{from_col}"]
                found_paths = find_paths(prev_table, end_table, visited.copy(), new_path)
                paths.extend(found_paths)
        
        return paths
    
    # Find common join paths between major tables
    major_tables = [t for t in tables if t not in [jt["table"] for jt in relationships["junction_tables"]]]
    
    for i, table1 in enumerate(major_tables):
        for table2 in major_tables[i+1:]:
            paths = find_paths(table1, table2)
            if paths:
                # Keep only shortest paths
                min_len = min(len(p) for p in paths)
                shortest_paths = [p for p in paths if len(p) == min_len]
                
                if shortest_paths:
                    key = f"{table1} to {table2}"
                    relationships["join_paths"][key] = shortest_paths[:3]  # Keep top 3 paths
    
    return {"join_paths": relationships}

def main():
    """Main entry point."""
    db_path = Path("./database.sqlite")
    
    if not db_path.exists():
        print(json.dumps({"error": "Database file not found"}))
        return
    
    try:
        conn = sqlite3.connect(str(db_path))
        results = analyze_relationships(conn)
        conn.close()
        
        # Output as JSON
        print(json.dumps(results, indent=2))
        
    except Exception as e:
        print(json.dumps({"error": str(e)}))

if __name__ == "__main__":
    main()