#!/usr/bin/env python3
"""
Relationship Mapper - Enhanced join path analysis with cardinality detection
Focuses on finding and validating all possible join paths
"""

import sqlite3
import json
import os
from typing import Dict, List, Set, Tuple
from collections import defaultdict, deque

def connect_db(db_path: str) -> sqlite3.Connection:
    """Connect to the database"""
    return sqlite3.connect(db_path)

def build_relationship_graph(conn: sqlite3.Connection) -> Dict[str, List[Dict]]:
    """Build comprehensive relationship graph"""
    cursor = conn.cursor()

    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [row[0] for row in cursor.fetchall()]

    relationships = defaultdict(list)
    all_relationships = []

    # First pass: Get explicit foreign keys
    for table in tables:
        cursor.execute(f"PRAGMA foreign_key_list({table})")
        for fk in cursor.fetchall():
            relationship = {
                'from_table': table,
                'from_column': fk[3],
                'to_table': fk[2],
                'to_column': fk[4],
                'type': 'foreign_key',
                'confidence': 'EXPLICIT'
            }

            # Analyze cardinality
            cardinality = analyze_relationship_cardinality(
                conn, table, fk[3], fk[2], fk[4]
            )
            relationship.update(cardinality)

            relationships[table].append(relationship)
            all_relationships.append(relationship)

            # Add reverse relationship for bidirectional traversal
            reverse_relationship = {
                'from_table': fk[2],
                'from_column': fk[4],
                'to_table': table,
                'to_column': fk[3],
                'type': 'foreign_key_reverse',
                'confidence': 'EXPLICIT',
                'cardinality': reverse_cardinality(cardinality['cardinality'])
            }
            relationships[fk[2]].append(reverse_relationship)

    # Second pass: Find implicit relationships (same column names)
    column_map = {}
    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        for row in cursor.fetchall():
            col_name = row[1]
            if col_name not in column_map:
                column_map[col_name] = []
            column_map[col_name].append(table)

    # Check for implicit relationships
    for col_name, table_list in column_map.items():
        if len(table_list) > 1 and (col_name.endswith('_id') or col_name.endswith('ID') or col_name == 'id'):
            for i, t1 in enumerate(table_list):
                for t2 in table_list[i+1:]:
                    # Check if relationship already exists
                    existing = any(
                        r['from_table'] == t1 and r['to_table'] == t2 and col_name in r['from_column']
                        for r in all_relationships
                    )

                    if not existing:
                        # Verify this is a valid relationship by checking data
                        if is_valid_implicit_relationship(conn, t1, t2, col_name):
                            relationship = {
                                'from_table': t1,
                                'from_column': col_name,
                                'to_table': t2,
                                'to_column': col_name,
                                'type': 'implicit',
                                'confidence': 'INFERRED'
                            }

                            cardinality = analyze_relationship_cardinality(
                                conn, t1, col_name, t2, col_name
                            )
                            relationship.update(cardinality)

                            relationships[t1].append(relationship)
                            relationships[t2].append({
                                'from_table': t2,
                                'from_column': col_name,
                                'to_table': t1,
                                'to_column': col_name,
                                'type': 'implicit_reverse',
                                'confidence': 'INFERRED',
                                'cardinality': reverse_cardinality(cardinality['cardinality'])
                            })

    return dict(relationships), all_relationships

def reverse_cardinality(cardinality: str) -> str:
    """Get reverse cardinality"""
    mapping = {
        'one-to-one': 'one-to-one',
        'one-to-many': 'many-to-one',
        'many-to-one': 'one-to-many',
        'many-to-many': 'many-to-many'
    }
    return mapping.get(cardinality, 'unknown')

def analyze_relationship_cardinality(conn: sqlite3.Connection, t1: str, c1: str, t2: str, c2: str) -> Dict:
    """Analyze the cardinality of a relationship"""
    cursor = conn.cursor()

    try:
        # Get counts for analysis
        cursor.execute(f"SELECT COUNT(DISTINCT {c1}), COUNT(*) FROM {t1}")
        t1_distinct, t1_total = cursor.fetchone()

        cursor.execute(f"SELECT COUNT(DISTINCT {c2}), COUNT(*) FROM {t2}")
        t2_distinct, t2_total = cursor.fetchone()

        # Check for uniqueness
        t1_unique = t1_distinct == t1_total
        t2_unique = t2_distinct == t2_total

        # Determine cardinality
        if t1_unique and t2_unique:
            cardinality = "one-to-one"
        elif t1_unique and not t2_unique:
            cardinality = "one-to-many"
        elif not t1_unique and t2_unique:
            cardinality = "many-to-one"
        else:
            # Check if it's truly many-to-many or just denormalized
            cursor.execute(f"""
                SELECT COUNT(*) FROM (
                    SELECT {c1}, COUNT(DISTINCT {c2}) as cnt
                    FROM {t1}
                    JOIN {t2} ON {t1}.{c1} = {t2}.{c2}
                    GROUP BY {c1}
                    HAVING cnt > 1
                )
            """)
            multi_mapping = cursor.fetchone()[0] > 0
            cardinality = "many-to-many" if multi_mapping else "one-to-many"

        return {
            'cardinality': cardinality,
            'from_unique': t1_unique,
            'to_unique': t2_unique,
            'from_distinct': t1_distinct,
            'to_distinct': t2_distinct
        }
    except Exception as e:
        return {
            'cardinality': 'unknown',
            'error': str(e)
        }

def is_valid_implicit_relationship(conn: sqlite3.Connection, t1: str, t2: str, col: str) -> bool:
    """Check if an implicit relationship is valid by examining data"""
    cursor = conn.cursor()

    try:
        # Check if values in t1.col exist in t2.col
        cursor.execute(f"""
            SELECT COUNT(*) FROM (
                SELECT DISTINCT {col} FROM {t1}
                WHERE {col} IS NOT NULL
                AND {col} IN (SELECT DISTINCT {col} FROM {t2})
            )
        """)
        matching = cursor.fetchone()[0]

        cursor.execute(f"SELECT COUNT(DISTINCT {col}) FROM {t1} WHERE {col} IS NOT NULL")
        total = cursor.fetchone()[0]

        # If more than 50% of values match, consider it a valid relationship
        return matching > 0 and (matching / max(total, 1)) > 0.5
    except:
        return False

def find_all_join_paths(relationships: Dict, start: str, end: str, max_length: int = 4) -> List[List[Dict]]:
    """Find all valid join paths between two tables"""
    if start == end:
        return [[]]

    paths = []
    queue = deque([(start, [])])
    visited = set()

    while queue:
        current_table, current_path = queue.popleft()

        if len(current_path) >= max_length:
            continue

        path_key = tuple(r['from_table'] + '->' + r['to_table'] for r in current_path)
        if path_key in visited:
            continue
        visited.add(path_key)

        for relationship in relationships.get(current_table, []):
            next_table = relationship['to_table']

            # Avoid cycles
            if next_table in [r['from_table'] for r in current_path]:
                continue

            new_path = current_path + [relationship]

            if next_table == end:
                paths.append(new_path)
            else:
                queue.append((next_table, new_path))

    # Sort paths by length and confidence
    paths.sort(key=lambda p: (len(p), -sum(1 for r in p if r['confidence'] == 'EXPLICIT')))

    return paths

def generate_join_templates(paths: List[List[Dict]]) -> List[Dict]:
    """Generate SQL join templates for each path"""
    templates = []

    for path_idx, path in enumerate(paths):
        if not path:
            templates.append({
                'path_length': 0,
                'template': '-- Tables are the same, no join needed',
                'confidence': 'HIGH'
            })
            continue

        joins = []
        tables_used = set()
        confidence_score = 0

        for i, rel in enumerate(path):
            if i == 0:
                # First join
                if rel['from_table'] not in tables_used:
                    joins.append(f"FROM {rel['from_table']} t1")
                    tables_used.add(rel['from_table'])

                alias = f"t{len(tables_used) + 1}"
                joins.append(f"JOIN {rel['to_table']} {alias} ON t1.{rel['from_column']} = {alias}.{rel['to_column']}")
                tables_used.add(rel['to_table'])
            else:
                # Subsequent joins
                prev_alias = f"t{len(tables_used)}"
                new_alias = f"t{len(tables_used) + 1}"

                if rel['to_table'] not in tables_used:
                    joins.append(f"JOIN {rel['to_table']} {new_alias} ON {prev_alias}.{rel['from_column']} = {new_alias}.{rel['to_column']}")
                    tables_used.add(rel['to_table'])

            # Calculate confidence
            confidence_score += 1 if rel['confidence'] == 'EXPLICIT' else 0.5

        confidence = 'HIGH' if confidence_score == len(path) else 'MEDIUM' if confidence_score >= len(path) * 0.5 else 'LOW'

        templates.append({
            'path_length': len(path),
            'template': '\n'.join(joins),
            'relationships': [f"{r['from_table']}.{r['from_column']} -> {r['to_table']}.{r['to_column']}" for r in path],
            'cardinalities': [r.get('cardinality', 'unknown') for r in path],
            'confidence': confidence,
            'has_many_to_many': any(r.get('cardinality') == 'many-to-many' for r in path)
        })

    return templates

def identify_junction_tables(conn: sqlite3.Connection, relationships: List[Dict]) -> List[Dict]:
    """Identify junction tables that handle many-to-many relationships"""
    cursor = conn.cursor()

    # Get all tables
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [row[0] for row in cursor.fetchall()]

    junction_tables = []

    for table in tables:
        # Get foreign keys for this table
        table_fks = [r for r in relationships if r['from_table'] == table and r['type'] == 'foreign_key']

        if len(table_fks) >= 2:
            # Check if this is likely a junction table
            cursor.execute(f"PRAGMA table_info({table})")
            columns = cursor.fetchall()

            # Junction tables typically have 2-3 foreign keys and not many other columns
            if len(columns) <= len(table_fks) + 2:
                connected_tables = list(set(fk['to_table'] for fk in table_fks))
                junction_tables.append({
                    'table': table,
                    'connects': connected_tables,
                    'foreign_keys': [f"{fk['from_column']} -> {fk['to_table']}.{fk['to_column']}" for fk in table_fks]
                })

    return junction_tables

def main():
    db_path = "./database.sqlite"
    output_dir = "./tool_output"
    os.makedirs(output_dir, exist_ok=True)

    conn = connect_db(db_path)

    # Build relationship graph
    relationships, all_relationships = build_relationship_graph(conn)

    # Identify junction tables
    junction_tables = identify_junction_tables(conn, all_relationships)

    # Get all tables for comprehensive path finding
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [row[0] for row in cursor.fetchall()]

    # Find important join paths
    important_paths = {}
    for t1 in tables:
        for t2 in tables:
            if t1 < t2:  # Avoid duplicates
                paths = find_all_join_paths(relationships, t1, t2, max_length=3)
                if paths and paths[0]:  # If there's at least one non-empty path
                    key = f"{t1}_to_{t2}"
                    templates = generate_join_templates(paths[:3])  # Top 3 paths
                    if templates:
                        important_paths[key] = {
                            'from': t1,
                            'to': t2,
                            'best_path': templates[0],
                            'alternative_paths': templates[1:] if len(templates) > 1 else []
                        }

    # Create comprehensive report
    relationship_report = {
        'total_relationships': len(all_relationships),
        'junction_tables': junction_tables,
        'explicit_relationships': [r for r in all_relationships if r['confidence'] == 'EXPLICIT'],
        'inferred_relationships': [r for r in all_relationships if r['confidence'] == 'INFERRED'],
        'join_paths': important_paths,
        'cardinality_summary': {
            'one_to_one': len([r for r in all_relationships if r.get('cardinality') == 'one-to-one']),
            'one_to_many': len([r for r in all_relationships if r.get('cardinality') == 'one-to-many']),
            'many_to_one': len([r for r in all_relationships if r.get('cardinality') == 'many-to-one']),
            'many_to_many': len([r for r in all_relationships if r.get('cardinality') == 'many-to-many'])
        }
    }

    # Write output
    with open(f"{output_dir}/relationship_map.json", 'w') as f:
        json.dump(relationship_report, f, indent=2)

    print("Relationship mapping complete")
    print(f"Found {len(all_relationships)} relationships")
    print(f"Identified {len(junction_tables)} junction tables")
    print(f"Generated {len(important_paths)} join path templates")

    conn.close()

if __name__ == "__main__":
    main()