#!/usr/bin/env python3
"""
Enhanced Relationship Mapper Tool
Maps all join paths and identifies junction tables.
"""

import sqlite3
import json
import os

def map_relationships(db_path="database.sqlite"):
    """Map comprehensive relationships between tables."""

    os.makedirs("tool_output", exist_ok=True)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    relationship_info = {
        "foreign_keys": [],
        "implied_relationships": [],
        "junction_tables": [],
        "join_paths": {},
        "table_connections": {}
    }

    try:
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
        tables = [t[0] for t in cursor.fetchall() if not t[0].startswith("sqlite_")]

        # Extract foreign keys from SQL definitions
        for table in tables:
            cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name=?", (table,))
            create_sql = cursor.fetchone()
            if create_sql and create_sql[0]:
                sql_text = create_sql[0]

                # Parse foreign keys using multiple patterns
                import re

                # Pattern 1: Standard FOREIGN KEY syntax
                fk_pattern = r'FOREIGN\s+KEY\s*\(([^)]+)\)\s+REFERENCES\s+(\w+)\s*\(([^)]+)\)'
                for match in re.finditer(fk_pattern, sql_text, re.IGNORECASE):
                    relationship_info["foreign_keys"].append({
                        "from_table": table,
                        "from_column": match.group(1).strip().strip('"').strip('`'),
                        "to_table": match.group(2).strip().strip('"').strip('`'),
                        "to_column": match.group(3).strip().strip('"').strip('`'),
                        "type": "explicit"
                    })

                # Pattern 2: Column-level REFERENCES
                ref_pattern = r'(\w+)\s+\w+[^,]*\s+REFERENCES\s+(\w+)\s*\(([^)]+)\)'
                for match in re.finditer(ref_pattern, sql_text, re.IGNORECASE):
                    relationship_info["foreign_keys"].append({
                        "from_table": table,
                        "from_column": match.group(1).strip().strip('"').strip('`'),
                        "to_table": match.group(2).strip().strip('"').strip('`'),
                        "to_column": match.group(3).strip().strip('"').strip('`'),
                        "type": "explicit"
                    })

        # Detect implied relationships based on column naming patterns
        table_columns = {}
        for table in tables:
            cursor.execute(f"PRAGMA table_info(`{table}`)")
            columns = cursor.fetchall()
            table_columns[table] = [col[1] for col in columns]

        # Look for ID-based relationships
        for table1 in tables:
            for col1 in table_columns[table1]:
                col1_lower = col1.lower()

                # Pattern: table_id or tableId
                if col1_lower.endswith('_id') or col1_lower.endswith('id'):
                    potential_table = col1_lower.replace('_id', '').replace('id', '')

                    for table2 in tables:
                        if table1 != table2 and potential_table in table2.lower():
                            # Check if table2 has an 'id' column
                            if 'id' in table_columns[table2] or any(c.lower() == col1_lower for c in table_columns[table2]):
                                # Don't duplicate explicit foreign keys
                                already_exists = any(
                                    fk['from_table'] == table1 and
                                    fk['from_column'] == col1 and
                                    fk['to_table'] == table2
                                    for fk in relationship_info["foreign_keys"]
                                )

                                if not already_exists:
                                    relationship_info["implied_relationships"].append({
                                        "from_table": table1,
                                        "from_column": col1,
                                        "to_table": table2,
                                        "to_column": 'id' if 'id' in table_columns[table2] else col1,
                                        "confidence": "high" if potential_table == table2.lower() else "medium"
                                    })

        # Identify junction/bridge tables
        for table in tables:
            # Count foreign key references
            fk_count = sum(1 for fk in relationship_info["foreign_keys"] if fk["from_table"] == table)

            # Get column count
            cursor.execute(f"PRAGMA table_info(`{table}`)")
            columns = cursor.fetchall()
            col_count = len(columns)

            # Junction tables typically have 2-3 foreign keys and few other columns
            if fk_count >= 2 and col_count <= fk_count + 2:
                relationship_info["junction_tables"].append({
                    "table": table,
                    "connects": [fk["to_table"] for fk in relationship_info["foreign_keys"] if fk["from_table"] == table],
                    "foreign_key_count": fk_count,
                    "total_columns": col_count
                })

        # Build connection graph
        for table in tables:
            relationship_info["table_connections"][table] = {
                "directly_connected_to": [],
                "connected_via_junction": []
            }

        # Direct connections
        for fk in relationship_info["foreign_keys"]:
            if fk["to_table"] not in relationship_info["table_connections"][fk["from_table"]]["directly_connected_to"]:
                relationship_info["table_connections"][fk["from_table"]]["directly_connected_to"].append({
                    "table": fk["to_table"],
                    "via_column": f"{fk['from_column']} -> {fk['to_column']}"
                })

        # Junction connections
        for junction in relationship_info["junction_tables"]:
            connected_tables = junction["connects"]
            if len(connected_tables) >= 2:
                for i, table1 in enumerate(connected_tables):
                    for table2 in connected_tables[i+1:]:
                        if table1 in relationship_info["table_connections"] and table2 in relationship_info["table_connections"]:
                            relationship_info["table_connections"][table1]["connected_via_junction"].append({
                                "table": table2,
                                "junction": junction["table"]
                            })
                            relationship_info["table_connections"][table2]["connected_via_junction"].append({
                                "table": table1,
                                "junction": junction["table"]
                            })

        # Generate common join paths
        relationship_info["join_paths"]["examples"] = []

        for fk in relationship_info["foreign_keys"][:10]:  # First 10 examples
            relationship_info["join_paths"]["examples"].append({
                "description": f"Join {fk['from_table']} with {fk['to_table']}",
                "sql": f"{fk['from_table']} JOIN {fk['to_table']} ON {fk['from_table']}.{fk['from_column']} = {fk['to_table']}.{fk['to_column']}"
            })

        # Save to file
        with open("tool_output/relationships.json", "w") as f:
            json.dump(relationship_info, f, indent=2)

        print(f"✓ Relationship mapping complete: {len(relationship_info['foreign_keys'])} foreign keys found")
        print(f"✓ Identified {len(relationship_info['junction_tables'])} junction tables")
        print(f"✓ Found {len(relationship_info['implied_relationships'])} implied relationships")

    except Exception as e:
        print(f"✗ Relationship mapping failed: {str(e)}")
        raise
    finally:
        conn.close()

if __name__ == "__main__":
    map_relationships()