#!/usr/bin/env python3
"""
Analyzes database relationships, foreign keys, and junction tables.
Identifies all paths between tables for accurate JOIN generation.
"""

import sqlite3
import json
import os
from collections import defaultdict

def analyze_relationships(db_path="database.sqlite"):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    relationships = {
        "foreign_keys": {},
        "junction_tables": [],
        "relationship_paths": {},
        "table_connections": defaultdict(list)
    }

    try:
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = [row[0] for row in cursor.fetchall()]

        # Analyze foreign keys for each table
        for table in tables:
            cursor.execute(f"PRAGMA foreign_key_list({table})")
            fks = cursor.fetchall()

            if fks:
                relationships["foreign_keys"][table] = []
                for fk in fks:
                    fk_info = {
                        "from_column": fk[3],
                        "to_table": fk[2],
                        "to_column": fk[4],
                        "constraint_name": fk[0]
                    }
                    relationships["foreign_keys"][table].append(fk_info)

                    # Track connections
                    relationships["table_connections"][table].append(fk[2])
                    relationships["table_connections"][fk[2]].append(table)

        # Identify junction tables (tables with 2+ foreign keys)
        for table, fks in relationships["foreign_keys"].items():
            if len(fks) >= 2:
                # Check if this might be a junction table
                cursor.execute(f"PRAGMA table_info({table})")
                columns = cursor.fetchall()

                # Junction tables typically have few columns beyond FKs
                non_fk_columns = []
                fk_columns = [fk["from_column"] for fk in fks]

                for col in columns:
                    col_name = col[1]
                    if col_name not in fk_columns and not col_name.endswith('_id'):
                        non_fk_columns.append(col_name)

                junction_info = {
                    "table": table,
                    "connects": [fk["to_table"] for fk in fks],
                    "foreign_keys": fks,
                    "additional_columns": non_fk_columns,
                    "is_pure_junction": len(non_fk_columns) <= 2  # Allow for timestamps
                }
                relationships["junction_tables"].append(junction_info)

        # Analyze implicit relationships (matching column names)
        for table1 in tables:
            cursor.execute(f"PRAGMA table_info({table1})")
            columns1 = {col[1]: col[2] for col in cursor.fetchall()}

            for table2 in tables:
                if table1 >= table2:
                    continue

                cursor.execute(f"PRAGMA table_info({table2})")
                columns2 = {col[1]: col[2] for col in cursor.fetchall()}

                # Find matching column names that could indicate relationships
                common_columns = set(columns1.keys()) & set(columns2.keys())
                potential_joins = []

                for col in common_columns:
                    # Check if column names suggest relationship
                    if (col.endswith('_id') or col.endswith('Id') or
                        col in ['id', 'ID', 'code', 'Code', 'key']):
                        potential_joins.append({
                            "column": col,
                            "type_match": columns1[col] == columns2[col]
                        })

                if potential_joins:
                    path_key = f"{table1}-{table2}"
                    relationships["relationship_paths"][path_key] = {
                        "tables": [table1, table2],
                        "potential_join_columns": potential_joins,
                        "confidence": "high" if any(j["type_match"] for j in potential_joins) else "medium"
                    }

        # Find multi-hop paths through junction tables
        for junction in relationships["junction_tables"]:
            if len(junction["connects"]) == 2:
                table1, table2 = junction["connects"]
                path_key = f"{table1}-{table2}"

                if path_key not in relationships["relationship_paths"]:
                    relationships["relationship_paths"][path_key] = {
                        "tables": [table1, table2],
                        "via_junction": junction["table"],
                        "junction_additional_data": junction["additional_columns"],
                        "confidence": "high"
                    }

        # Output results
        os.makedirs("tool_output", exist_ok=True)

        with open("tool_output/relationships.json", "w") as f:
            json.dump(relationships, f, indent=2)

        # Generate summary
        summary = []
        summary.append("# Database Relationship Analysis\n")
        summary.append(f"Total tables: {len(tables)}\n")
        summary.append(f"Tables with foreign keys: {len(relationships['foreign_keys'])}\n")
        summary.append(f"Junction tables found: {len(relationships['junction_tables'])}\n")
        summary.append(f"Relationship paths: {len(relationships['relationship_paths'])}\n")

        if relationships["junction_tables"]:
            summary.append("\n## Junction Tables (Many-to-Many):\n")
            for jt in relationships["junction_tables"]:
                summary.append(f"- {jt['table']}: connects {' <-> '.join(jt['connects'])}")
                if jt['additional_columns']:
                    summary.append(f"  (has extra data: {', '.join(jt['additional_columns'])})\n")
                else:
                    summary.append("\n")

        if relationships["foreign_keys"]:
            summary.append("\n## Direct Foreign Keys:\n")
            for table, fks in relationships["foreign_keys"].items():
                for fk in fks:
                    summary.append(f"- {table}.{fk['from_column']} -> {fk['to_table']}.{fk['to_column']}\n")

        with open("tool_output/relationship_summary.txt", "w") as f:
            f.writelines(summary)

        print("Relationship analysis complete - results in tool_output/")

    except Exception as e:
        print(f"Error analyzing relationships: {e}")
        relationships["error"] = str(e)
        with open("tool_output/relationships.json", "w") as f:
            json.dump(relationships, f, indent=2)

    finally:
        conn.close()

if __name__ == "__main__":
    analyze_relationships()