#!/usr/bin/env python3
"""
Maps all table relationships with exact join columns.
Identifies junction tables and multi-hop paths.
"""

import sqlite3
import json
import os
from collections import defaultdict

def map_relationships(db_path="database.sqlite"):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    relationships = {
        "direct_relationships": [],
        "junction_tables": [],
        "join_paths": {},
        "relationship_graph": defaultdict(list)
    }

    try:
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = [row[0] for row in cursor.fetchall()]

        # Map foreign keys
        for table in tables:
            cursor.execute(f"PRAGMA foreign_key_list({table})")
            fks = cursor.fetchall()

            for fk in fks:
                from_col = fk[3]
                to_table = fk[2]
                to_col = fk[4]

                relationship = {
                    "from_table": table,
                    "from_column": from_col,
                    "to_table": to_table,
                    "to_column": to_col,
                    "join_syntax": f"{table}.{from_col} = {to_table}.{to_col}"
                }

                relationships["direct_relationships"].append(relationship)
                relationships["relationship_graph"][table].append(to_table)
                relationships["relationship_graph"][to_table].append(table)

        # Identify junction tables
        for table in tables:
            cursor.execute(f"PRAGMA foreign_key_list({table})")
            fks = cursor.fetchall()

            if len(fks) >= 2:
                # This might be a junction table
                cursor.execute(f"PRAGMA table_info({table})")
                columns = cursor.fetchall()

                # Count non-FK columns
                fk_columns = [fk[3] for fk in fks]
                non_fk_columns = [col[1] for col in columns if col[1] not in fk_columns]

                # Remove common timestamp/id columns
                filtered_non_fk = [col for col in non_fk_columns
                                 if not any(x in col.lower() for x in
                                          ['created', 'updated', 'timestamp', '_at', '_id', 'id'])]

                junction_info = {
                    "table": table,
                    "connects": [],
                    "join_columns": {},
                    "additional_data": filtered_non_fk,
                    "is_pure_junction": len(filtered_non_fk) <= 2
                }

                for fk in fks:
                    junction_info["connects"].append(fk[2])
                    junction_info["join_columns"][fk[2]] = {
                        "from": fk[3],
                        "to": fk[4]
                    }

                relationships["junction_tables"].append(junction_info)

        # Map multi-hop paths through junction tables
        for junction in relationships["junction_tables"]:
            if len(junction["connects"]) == 2:
                table1, table2 = junction["connects"]
                path_key = f"{table1}_to_{table2}"

                relationships["join_paths"][path_key] = {
                    "from": table1,
                    "to": table2,
                    "via": junction["table"],
                    "join_sequence": [
                        f"{table1}.{junction['join_columns'][table1]['to']} = {junction['table']}.{junction['join_columns'][table1]['from']}",
                        f"{junction['table']}.{junction['join_columns'][table2]['from']} = {table2}.{junction['join_columns'][table2]['to']}"
                    ],
                    "sql_template": f"""SELECT t1.*, t2.*
FROM {table1} t1
JOIN {junction['table']} j ON t1.{junction['join_columns'][table1]['to']} = j.{junction['join_columns'][table1]['from']}
JOIN {table2} t2 ON j.{junction['join_columns'][table2]['from']} = t2.{junction['join_columns'][table2]['to']}"""
                }

        # Find implicit relationships (matching column names)
        find_implicit_relationships(cursor, tables, relationships)

        # Save relationships
        os.makedirs("tool_output", exist_ok=True)

        with open("tool_output/relationships.json", "w") as f:
            json.dump(relationships, f, indent=2)

        # Generate relationship report
        generate_relationship_report(relationships)

        print("Relationship mapping complete - results in tool_output/")

    except Exception as e:
        print(f"Error mapping relationships: {e}")
        relationships["error"] = str(e)
        with open("tool_output/relationships.json", "w") as f:
            json.dump(relationships, f, indent=2)

    finally:
        conn.close()

def find_implicit_relationships(cursor, tables, relationships):
    """Find relationships based on matching column names."""

    implicit = []

    for i, table1 in enumerate(tables):
        cursor.execute(f"PRAGMA table_info({table1})")
        cols1 = {col[1]: col[2] for col in cursor.fetchall()}

        for table2 in tables[i+1:]:
            cursor.execute(f"PRAGMA table_info({table2})")
            cols2 = {col[1]: col[2] for col in cursor.fetchall()}

            # Find matching columns
            common_cols = set(cols1.keys()) & set(cols2.keys())

            for col in common_cols:
                # Check if it looks like a join column
                if (col.lower().endswith('_id') or
                    col.lower().endswith('id') or
                    col.lower().endswith('_code') or
                    col.lower() in ['id', 'code', 'key']):

                    # Check if types match
                    if cols1[col] == cols2[col]:
                        implicit.append({
                            "table1": table1,
                            "table2": table2,
                            "column": col,
                            "type": cols1[col],
                            "confidence": "high",
                            "join_syntax": f"{table1}.{col} = {table2}.{col}"
                        })

    if implicit:
        relationships["implicit_relationships"] = implicit

def generate_relationship_report(relationships):
    """Generate a human-readable relationship report."""

    report = []
    report.append("# DATABASE RELATIONSHIP MAP\n\n")

    # Direct relationships
    if relationships["direct_relationships"]:
        report.append("## Direct Foreign Key Relationships\n\n")
        for rel in relationships["direct_relationships"]:
            report.append(f"- **{rel['from_table']}.{rel['from_column']}** → ")
            report.append(f"**{rel['to_table']}.{rel['to_column']}**\n")
            report.append(f"  ```sql\n  {rel['join_syntax']}\n  ```\n")
        report.append("\n")

    # Junction tables
    if relationships["junction_tables"]:
        report.append("## Junction Tables (Many-to-Many)\n\n")
        for junction in relationships["junction_tables"]:
            report.append(f"### {junction['table']}\n")
            report.append(f"Connects: {' ↔ '.join(junction['connects'])}\n\n")

            if junction["additional_data"]:
                report.append(f"Additional columns: {', '.join(junction['additional_data'])}\n\n")

            report.append("Join columns:\n")
            for table, cols in junction["join_columns"].items():
                report.append(f"- To {table}: {junction['table']}.{cols['from']} → {table}.{cols['to']}\n")
            report.append("\n")

    # Multi-hop paths
    if relationships["join_paths"]:
        report.append("## Multi-Table Join Paths\n\n")
        for path_name, path in relationships["join_paths"].items():
            report.append(f"### {path['from']} ↔ {path['to']} (via {path['via']})\n\n")
            report.append("Join sequence:\n")
            for i, join in enumerate(path["join_sequence"], 1):
                report.append(f"{i}. {join}\n")
            report.append("\nSQL Template:\n")
            report.append(f"```sql\n{path['sql_template']}\n```\n\n")

    # Implicit relationships
    if relationships.get("implicit_relationships"):
        report.append("## Potential Implicit Relationships\n\n")
        report.append("Based on matching column names:\n\n")
        for impl in relationships["implicit_relationships"][:10]:
            report.append(f"- {impl['table1']}.{impl['column']} ↔ ")
            report.append(f"{impl['table2']}.{impl['column']} ")
            report.append(f"({impl['type']}, {impl['confidence']} confidence)\n")
        report.append("\n")

    # Relationship graph summary
    if relationships["relationship_graph"]:
        report.append("## Table Connectivity\n\n")
        for table, connected in relationships["relationship_graph"].items():
            if connected:
                unique_connected = list(set(connected))
                report.append(f"- **{table}** connects to: {', '.join(unique_connected)}\n")

    with open("tool_output/relationship_map.txt", "w") as f:
        f.writelines(report)

if __name__ == "__main__":
    map_relationships()