#!/usr/bin/env python3
"""
Relationship Mapper Tool
Identifies foreign key relationships and common join patterns.
Critical for generating correct JOINs in SQL queries.
"""

import sqlite3
import json
import os
import re

def map_relationships(db_path="database.sqlite"):
    """Map foreign key relationships and identify join patterns."""

    os.makedirs("tool_output", exist_ok=True)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    relationships = {
        "foreign_keys": [],
        "inferred_relationships": [],
        "join_paths": {},
        "relationship_graph": {}
    }

    try:
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
        tables = [row[0] for row in cursor.fetchall() if not row[0].startswith("sqlite_")]

        # Get explicit foreign keys
        for table in tables:
            cursor.execute(f"PRAGMA foreign_key_list(`{table}`)")
            fks = cursor.fetchall()

            for fk in fks:
                fk_info = {
                    "from_table": table,
                    "from_column": fk[3],
                    "to_table": fk[2],
                    "to_column": fk[4],
                    "type": "explicit"
                }
                relationships["foreign_keys"].append(fk_info)

        # Infer relationships based on column names
        # First, get all columns for each table
        table_columns = {}
        for table in tables:
            cursor.execute(f"PRAGMA table_info(`{table}`)")
            columns = cursor.fetchall()
            table_columns[table] = [col[1] for col in columns]

        # Look for common patterns
        id_pattern = re.compile(r'(.+?)_?[iI][dD]$')

        for table1 in tables:
            for col1 in table_columns[table1]:
                # Check if column name suggests a foreign key
                match = id_pattern.match(col1)
                if match:
                    potential_table = match.group(1).lower()

                    # Look for matching table
                    for table2 in tables:
                        if table2.lower() == potential_table or \
                           table2.lower() == potential_table + 's' or \
                           table2.lower() == potential_table.rstrip('s'):
                            # Check if table2 has an 'id' column
                            if 'id' in [col.lower() for col in table_columns[table2]]:
                                # Found potential relationship
                                rel = {
                                    "from_table": table1,
                                    "from_column": col1,
                                    "to_table": table2,
                                    "to_column": "id",
                                    "type": "inferred",
                                    "confidence": "high"
                                }

                                # Verify with data sampling
                                try:
                                    cursor.execute(f"""
                                        SELECT COUNT(DISTINCT t1.`{col1}`)
                                        FROM `{table1}` t1
                                        WHERE t1.`{col1}` IN (
                                            SELECT `id` FROM `{table2}`
                                        )
                                    """)
                                    matching_count = cursor.fetchone()[0]

                                    cursor.execute(f"SELECT COUNT(DISTINCT `{col1}`) FROM `{table1}`")
                                    total_count = cursor.fetchone()[0]

                                    if total_count > 0:
                                        match_ratio = matching_count / total_count
                                        if match_ratio > 0.8:
                                            rel["confidence"] = "verified"
                                            relationships["inferred_relationships"].append(rel)
                                except:
                                    pass

                # Also check for matching column names across tables
                for table2 in tables:
                    if table1 != table2 and col1 in table_columns[table2]:
                        # Same column name in different tables - potential join column
                        if col1.lower().endswith('id') or col1.lower() == 'id':
                            rel = {
                                "from_table": table1,
                                "from_column": col1,
                                "to_table": table2,
                                "to_column": col1,
                                "type": "inferred",
                                "confidence": "medium"
                            }

                            # Check if not already recorded
                            if not any(r["from_table"] == table1 and
                                      r["from_column"] == col1 and
                                      r["to_table"] == table2
                                      for r in relationships["inferred_relationships"]):
                                relationships["inferred_relationships"].append(rel)

        # Build relationship graph for finding join paths
        all_relationships = relationships["foreign_keys"] + relationships["inferred_relationships"]

        for rel in all_relationships:
            from_table = rel["from_table"]
            to_table = rel["to_table"]

            if from_table not in relationships["relationship_graph"]:
                relationships["relationship_graph"][from_table] = []

            relationships["relationship_graph"][from_table].append({
                "to_table": to_table,
                "join_condition": f"`{from_table}`.`{rel['from_column']}` = `{to_table}`.`{rel['to_column']}`",
                "type": rel["type"],
                "confidence": rel.get("confidence", "explicit")
            })

        # Find common join paths (2-hop relationships)
        for table1 in tables:
            relationships["join_paths"][table1] = {}

            # Direct relationships
            if table1 in relationships["relationship_graph"]:
                for rel in relationships["relationship_graph"][table1]:
                    table2 = rel["to_table"]
                    if table2 not in relationships["join_paths"][table1]:
                        relationships["join_paths"][table1][table2] = {
                            "direct": True,
                            "path": [table1, table2],
                            "conditions": [rel["join_condition"]]
                        }

            # 2-hop relationships
            if table1 in relationships["relationship_graph"]:
                for rel1 in relationships["relationship_graph"][table1]:
                    intermediate = rel1["to_table"]
                    if intermediate in relationships["relationship_graph"]:
                        for rel2 in relationships["relationship_graph"][intermediate]:
                            table3 = rel2["to_table"]
                            if table3 != table1 and table3 not in relationships["join_paths"][table1]:
                                relationships["join_paths"][table1][table3] = {
                                    "direct": False,
                                    "path": [table1, intermediate, table3],
                                    "conditions": [rel1["join_condition"], rel2["join_condition"]]
                                }

    except Exception as e:
        relationships["error"] = str(e)
    finally:
        conn.close()

    # Write to output file
    output_path = "tool_output/relationships.json"
    with open(output_path, 'w') as f:
        json.dump(relationships, f, indent=2)

    print(f"Relationship mapping complete - results in {output_path}")
    print(f"Found {len(relationships['foreign_keys'])} explicit foreign keys")
    print(f"Inferred {len(relationships['inferred_relationships'])} additional relationships")

    if relationships["join_paths"]:
        total_paths = sum(len(paths) for paths in relationships["join_paths"].values())
        print(f"Mapped {total_paths} join paths between tables")

if __name__ == "__main__":
    map_relationships()