#!/usr/bin/env python3
"""
Enhanced Relationship Mapper Tool
Maps foreign key relationships and generates join paths.
Enhanced with path ranking and ID-to-name column mapping.
"""

import sqlite3
import json
import os
import re

def map_relationships(db_path="database.sqlite"):
    """Map table relationships and generate optimal join paths."""

    os.makedirs("tool_output", exist_ok=True)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    relationships = {
        "foreign_keys": [],
        "join_paths": {},
        "common_joins": [],
        "id_columns_needing_joins": [],  # ID columns that should be joined to get names
        "ambiguous_columns": {}
    }

    try:
        # Get all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
        tables = [row[0] for row in cursor.fetchall() if not row[0].startswith("sqlite_")]

        # Detect foreign keys
        for table in tables:
            cursor.execute(f"PRAGMA foreign_key_list(`{table}`)")
            fks = cursor.fetchall()

            for fk in fks:
                fk_info = {
                    "from_table": table,
                    "from_column": fk[3],
                    "to_table": fk[2],
                    "to_column": fk[4],
                    "join_sql": f"JOIN {fk[2]} ON {table}.{fk[3]} = {fk[2]}.{fk[4]}"
                }
                relationships["foreign_keys"].append(fk_info)

        # Also detect implicit foreign keys based on naming patterns
        all_columns = {}
        for table in tables:
            cursor.execute(f"PRAGMA table_info(`{table}`)")
            columns = cursor.fetchall()
            all_columns[table] = [col[1] for col in columns]

        # Find potential foreign keys based on naming
        for table in tables:
            for column in all_columns[table]:
                # Pattern: ColumnID or Column_ID or ColumnId
                if re.match(r'.+[_]?[Ii][Dd]$', column) and column.lower() != 'id':
                    # Look for matching table
                    potential_table = column.replace('ID', '').replace('Id', '').replace('_id', '').replace('_ID', '')

                    # Check for exact table match
                    for target_table in tables:
                        if target_table.lower() == potential_table.lower() or \
                           target_table.lower() == potential_table.lower() + 's' or \
                           target_table.lower() + 's' == potential_table.lower():
                            # Check if target table has ID column
                            if any(col.lower() in ['id', f'{target_table.lower()}id'] for col in all_columns[target_table]):
                                # Check if not already in foreign keys
                                exists = any(
                                    fk['from_table'] == table and
                                    fk['from_column'] == column and
                                    fk['to_table'] == target_table
                                    for fk in relationships["foreign_keys"]
                                )
                                if not exists:
                                    target_id = 'ID' if 'ID' in all_columns[target_table] else \
                                               'Id' if 'Id' in all_columns[target_table] else \
                                               'id' if 'id' in all_columns[target_table] else \
                                               next((col for col in all_columns[target_table] if col.lower() == f'{target_table.lower()}id'), None)

                                    if target_id:
                                        fk_info = {
                                            "from_table": table,
                                            "from_column": column,
                                            "to_table": target_table,
                                            "to_column": target_id,
                                            "inferred": True,
                                            "join_sql": f"JOIN {target_table} ON {table}.{column} = {target_table}.{target_id}"
                                        }
                                        relationships["foreign_keys"].append(fk_info)

        # Identify ID columns that need joins to get human-readable values
        for table in tables:
            for column in all_columns[table]:
                # Is this an ID column that references another table?
                if column.lower().endswith('id') or column.lower() == 'capital':  # Special case for Capital
                    # Find if there's a foreign key for this
                    for fk in relationships["foreign_keys"]:
                        if fk["from_table"] == table and fk["from_column"] == column:
                            # Look for name columns in target table
                            target_table = fk["to_table"]
                            if target_table in all_columns:
                                name_cols = [col for col in all_columns[target_table]
                                           if 'name' in col.lower() or col.lower() in ['title', 'label']]
                                if name_cols:
                                    relationships["id_columns_needing_joins"].append({
                                        "id_column": f"{table}.{column}",
                                        "join_to": target_table,
                                        "get_column": name_cols[0],
                                        "example_sql": f"SELECT {target_table}.{name_cols[0]} FROM {table} JOIN {target_table} ON {table}.{column} = {target_table}.{fk['to_column']}"
                                    })

        # Generate common join patterns
        join_patterns = {}
        for fk in relationships["foreign_keys"]:
            key = f"{fk['from_table']}_to_{fk['to_table']}"
            if key not in join_patterns:
                join_patterns[key] = []
            join_patterns[key].append({
                "sql": f"FROM {fk['from_table']} JOIN {fk['to_table']} ON {fk['from_table']}.{fk['from_column']} = {fk['to_table']}.{fk['to_column']}",
                "via_column": fk['from_column'],
                "confidence": "high" if not fk.get('inferred') else "medium"
            })

        relationships["join_paths"] = join_patterns

        # Find ambiguous column names (appear in multiple tables)
        column_occurrences = {}
        for table in all_columns:
            for column in all_columns[table]:
                if column not in column_occurrences:
                    column_occurrences[column] = []
                column_occurrences[column].append(table)

        for column, tables_list in column_occurrences.items():
            if len(tables_list) > 1:
                relationships["ambiguous_columns"][column] = {
                    "appears_in": tables_list,
                    "recommendation": f"Always qualify with table name: {', '.join([f'{t}.{column}' for t in tables_list])}"
                }

        # Generate common multi-table join patterns
        # Look for chains of relationships
        for fk1 in relationships["foreign_keys"]:
            for fk2 in relationships["foreign_keys"]:
                if fk1["to_table"] == fk2["from_table"] and fk1["from_table"] != fk2["to_table"]:
                    pattern = {
                        "pattern": f"{fk1['from_table']} -> {fk1['to_table']} -> {fk2['to_table']}",
                        "sql": f"""FROM {fk1['from_table']}
JOIN {fk1['to_table']} ON {fk1['from_table']}.{fk1['from_column']} = {fk1['to_table']}.{fk1['to_column']}
JOIN {fk2['to_table']} ON {fk2['from_table']}.{fk2['from_column']} = {fk2['to_table']}.{fk2['to_column']}""",
                        "use_case": f"Joining {fk1['from_table']} to {fk2['to_table']} via {fk1['to_table']}"
                    }
                    relationships["common_joins"].append(pattern)

    except Exception as e:
        relationships["error"] = str(e)

    finally:
        conn.close()

    # Write output
    with open("tool_output/relationships.json", "w") as f:
        json.dump(relationships, f, indent=2)

    print(f"Relationship mapping complete: {len(relationships['foreign_keys'])} foreign keys found")
    print(f"Found {len(relationships['ambiguous_columns'])} ambiguous columns")
    print(f"Identified {len(relationships['id_columns_needing_joins'])} ID columns needing name joins")
    print("Results saved to tool_output/relationships.json")

if __name__ == "__main__":
    map_relationships()