#!/usr/bin/env python3
"""
Join Pattern Detector - Identifies valid join paths between tables
"""
import sqlite3
import json
import os
from collections import defaultdict, deque

def detect_join_patterns(db_path):
    """Identify all valid join paths between tables"""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    # Build relationship graph
    relationships = defaultdict(list)
    reverse_relationships = defaultdict(list)
    
    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [row[0] for row in cursor.fetchall()]
    
    # Get all foreign key relationships
    for table in tables:
        cursor.execute(f"PRAGMA foreign_key_list('{table}')")
        foreign_keys = cursor.fetchall()
        
        for fk in foreign_keys:
            from_table = table
            from_column = fk[3]
            to_table = fk[2]
            to_column = fk[4]
            
            relationships[from_table].append({
                "to_table": to_table,
                "from_column": from_column,
                "to_column": to_column
            })
            
            reverse_relationships[to_table].append({
                "from_table": from_table,
                "from_column": from_column,
                "to_column": to_column
            })
    
    # Find junction tables (tables with 2+ foreign keys)
    junction_tables = []
    for table in tables:
        cursor.execute(f"PRAGMA foreign_key_list('{table}')")
        fks = cursor.fetchall()
        if len(fks) >= 2:
            junction_tables.append({
                "table": table,
                "connects": [fk[2] for fk in fks],
                "foreign_keys": [{"column": fk[3], "references": fk[2]} for fk in fks]
            })
    
    # Find all paths between tables (up to 3 hops)
    def find_paths(start_table, end_table, max_depth=3):
        """BFS to find all paths between two tables"""
        paths = []
        queue = deque([(start_table, [start_table], [])])
        visited = set()
        
        while queue:
            current_table, path, joins = queue.popleft()
            
            if len(path) > max_depth:
                continue
                
            if current_table == end_table and len(path) > 1:
                paths.append({"tables": path, "joins": joins})
                continue
            
            state = (current_table, tuple(path))
            if state in visited:
                continue
            visited.add(state)
            
            # Forward relationships
            for rel in relationships.get(current_table, []):
                if rel["to_table"] not in path:
                    new_path = path + [rel["to_table"]]
                    new_joins = joins + [{
                        "from": f"{current_table}.{rel['from_column']}",
                        "to": f"{rel['to_table']}.{rel['to_column']}"
                    }]
                    queue.append((rel["to_table"], new_path, new_joins))
            
            # Reverse relationships
            for rel in reverse_relationships.get(current_table, []):
                if rel["from_table"] not in path:
                    new_path = path + [rel["from_table"]]
                    new_joins = joins + [{
                        "from": f"{current_table}.{rel['to_column']}",
                        "to": f"{rel['from_table']}.{rel['from_column']}"
                    }]
                    queue.append((rel["from_table"], new_path, new_joins))
        
        return paths
    
    # Find common join patterns
    common_patterns = {}
    important_pairs = []
    
    # Identify important table pairs (based on foreign keys)
    for table in tables:
        for rel in relationships.get(table, []):
            important_pairs.append((table, rel["to_table"]))
    
    # Find paths for important pairs
    for start, end in important_pairs[:20]:  # Limit to avoid explosion
        paths = find_paths(start, end)
        if paths:
            key = f"{start}_to_{end}"
            common_patterns[key] = {
                "from": start,
                "to": end,
                "path_count": len(paths),
                "shortest_path": min(paths, key=lambda p: len(p["tables"])) if paths else None
            }
    
    # Analyze cardinality patterns
    cardinality = {}
    for table in tables:
        cursor.execute(f"SELECT COUNT(*) FROM '{table}'")
        row_count = cursor.fetchone()[0]
        
        cardinality[table] = {
            "row_count": row_count,
            "relationships": []
        }
        
        # Check relationship cardinalities
        for rel in relationships.get(table, []):
            # Sample to determine if it's 1:1, 1:many, etc
            try:
                cursor.execute(f"""
                    SELECT COUNT(DISTINCT "{rel['from_column']}") as distinct_from,
                           COUNT(*) as total
                    FROM "{table}"
                    WHERE "{rel['from_column']}" IS NOT NULL
                """)
                result = cursor.fetchone()
                if result and result[1] > 0:
                    ratio = result[1] / max(result[0], 1)
                    if ratio > 1.5:
                        relationship_type = "1:many"
                    else:
                        relationship_type = "1:1"
                else:
                    relationship_type = "unknown"
            except:
                relationship_type = "unknown"
            
            cardinality[table]["relationships"].append({
                "to_table": rel["to_table"],
                "type": relationship_type,
                "from_column": rel["from_column"],
                "to_column": rel["to_column"]
            })
    
    conn.close()
    
    # Prepare output
    output = {
        "junction_tables": junction_tables,
        "direct_relationships": {
            table: rels for table, rels in relationships.items()
        },
        "reverse_relationships": {
            table: rels for table, rels in reverse_relationships.items()
        },
        "common_join_patterns": common_patterns,
        "cardinality": cardinality
    }
    
    # Create output directory
    os.makedirs('tool_output', exist_ok=True)
    
    # Save results
    with open('tool_output/join_patterns.json', 'w') as f:
        json.dump(output, f, indent=2)
    
    # Print summary
    print("=" * 60)
    print("JOIN PATTERN ANALYSIS COMPLETE")
    print("=" * 60)
    print(f"Junction tables found: {len(junction_tables)}")
    for jt in junction_tables:
        print(f"  - {jt['table']} connects: {', '.join(jt['connects'])}")
    
    print(f"\nDirect relationships: {sum(len(r) for r in relationships.values())}")
    print(f"Common join patterns identified: {len(common_patterns)}")
    
    print("\nRelationship Summary:")
    for table, rels in relationships.items():
        if rels:
            print(f"  {table} ->")
            for rel in rels:
                card_info = next(
                    (c for c in cardinality[table]["relationships"] 
                     if c["to_table"] == rel["to_table"]), 
                    {"type": "unknown"}
                )
                print(f"    {rel['to_table']} ({card_info['type']}) via {rel['from_column']}")
    
    print("\nResults saved to tool_output/join_patterns.json")

if __name__ == "__main__":
    detect_join_patterns("database.sqlite")