#!/usr/bin/env python3
"""
Relationship Mapper Tool
Identifies foreign key relationships and common join patterns.
"""

import sqlite3
import json
import os
import re

def map_relationships(db_path="database.sqlite"):
    """Map foreign key relationships and identify join patterns."""

    os.makedirs("tool_output", exist_ok=True)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    relationships = {
        "foreign_keys": [],
        "inferred_relationships": [],
        "join_paths": {},
        "relationship_graph": {},
        "join_templates": []  # Ready-to-use JOIN templates
    }

    try:
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
        tables = [row[0] for row in cursor.fetchall() if not row[0].startswith("sqlite_")]

        # Get explicit foreign keys
        for table in tables:
            cursor.execute(f"PRAGMA foreign_key_list(`{table}`)")
            fks = cursor.fetchall()

            for fk in fks:
                fk_info = {
                    "from_table": table,
                    "from_column": fk[3],
                    "to_table": fk[2],
                    "to_column": fk[4],
                    "type": "explicit"
                }
                relationships["foreign_keys"].append(fk_info)

        # Infer relationships based on column names
        table_columns = {}
        for table in tables:
            cursor.execute(f"PRAGMA table_info(`{table}`)")
            columns = cursor.fetchall()
            table_columns[table] = [col[1] for col in columns]

        # Look for common patterns
        id_pattern = re.compile(r'(.+?)_?[iI][dD]$')

        for table1 in tables:
            for col1 in table_columns[table1]:
                # Check if column name suggests a foreign key
                match = id_pattern.match(col1)
                if match:
                    potential_table = match.group(1).lower()

                    for table2 in tables:
                        if table2.lower() == potential_table or \
                           table2.lower() == potential_table + 's' or \
                           table2.lower() == potential_table.rstrip('s'):
                            # Check if table2 has matching ID columns
                            for col2 in table_columns[table2]:
                                if col2.lower() in ['id', col1.lower(), potential_table + 'id']:
                                    # Verify with data sampling
                                    try:
                                        cursor.execute(f"""
                                            SELECT COUNT(DISTINCT t1.`{col1}`)
                                            FROM `{table1}` t1
                                            WHERE t1.`{col1}` IN (
                                                SELECT `{col2}` FROM `{table2}`
                                            )
                                        """)
                                        matching_count = cursor.fetchone()[0]

                                        cursor.execute(f"SELECT COUNT(DISTINCT `{col1}`) FROM `{table1}`")
                                        total_count = cursor.fetchone()[0]

                                        if total_count > 0 and matching_count / total_count > 0.8:
                                            rel = {
                                                "from_table": table1,
                                                "from_column": col1,
                                                "to_table": table2,
                                                "to_column": col2,
                                                "type": "inferred",
                                                "confidence": "verified"
                                            }
                                            relationships["inferred_relationships"].append(rel)
                                            break
                                    except:
                                        pass

                # Check for matching column names across tables
                for table2 in tables:
                    if table1 != table2 and col1 in table_columns[table2]:
                        if col1.lower().endswith('id') or col1.lower() == 'id':
                            rel = {
                                "from_table": table1,
                                "from_column": col1,
                                "to_table": table2,
                                "to_column": col1,
                                "type": "inferred",
                                "confidence": "medium"
                            }
                            # Avoid duplicates
                            if not any(r["from_table"] == table1 and
                                      r["from_column"] == col1 and
                                      r["to_table"] == table2
                                      for r in relationships["inferred_relationships"]):
                                relationships["inferred_relationships"].append(rel)

        # Build relationship graph and create JOIN templates
        all_relationships = relationships["foreign_keys"] + relationships["inferred_relationships"]

        for rel in all_relationships:
            from_table = rel["from_table"]
            to_table = rel["to_table"]

            # Build graph
            if from_table not in relationships["relationship_graph"]:
                relationships["relationship_graph"][from_table] = []

            relationships["relationship_graph"][from_table].append({
                "to_table": to_table,
                "join_condition": f"`{from_table}`.`{rel['from_column']}` = `{to_table}`.`{rel['to_column']}`",
                "type": rel["type"],
                "confidence": rel.get("confidence", "explicit")
            })

            # Create JOIN template
            template = {
                "description": f"Join {from_table} with {to_table}",
                "sql": f"FROM `{from_table}` T1\nJOIN `{to_table}` T2 ON T1.`{rel['from_column']}` = T2.`{rel['to_column']}`",
                "tables": [from_table, to_table],
                "relationship_type": rel["type"]
            }
            relationships["join_templates"].append(template)

        # Find common join paths (2-hop relationships)
        for table1 in tables:
            relationships["join_paths"][table1] = {}

            # Direct relationships
            if table1 in relationships["relationship_graph"]:
                for rel in relationships["relationship_graph"][table1]:
                    table2 = rel["to_table"]
                    if table2 not in relationships["join_paths"][table1]:
                        relationships["join_paths"][table1][table2] = {
                            "direct": True,
                            "path": [table1, table2],
                            "conditions": [rel["join_condition"]]
                        }

            # 2-hop relationships
            if table1 in relationships["relationship_graph"]:
                for rel1 in relationships["relationship_graph"][table1]:
                    intermediate = rel1["to_table"]
                    if intermediate in relationships["relationship_graph"]:
                        for rel2 in relationships["relationship_graph"][intermediate]:
                            table3 = rel2["to_table"]
                            if table3 != table1 and table3 not in relationships["join_paths"][table1]:
                                relationships["join_paths"][table1][table3] = {
                                    "direct": False,
                                    "path": [table1, intermediate, table3],
                                    "conditions": [rel1["join_condition"], rel2["join_condition"]]
                                }

    except Exception as e:
        relationships["error"] = str(e)
    finally:
        conn.close()

    # Write to output file
    output_path = "tool_output/relationships.json"
    with open(output_path, 'w') as f:
        json.dump(relationships, f, indent=2)

    print(f"Relationship mapping complete - results in {output_path}")
    print(f"Found {len(relationships['foreign_keys'])} explicit foreign keys")
    print(f"Inferred {len(relationships['inferred_relationships'])} additional relationships")
    print(f"Generated {len(relationships['join_templates'])} JOIN templates")

if __name__ == "__main__":
    map_relationships()