#!/usr/bin/env python3
"""
Graph converter for rel2onto problem.
Created using subagent_prompt.md version: v_02

This problem is about matching relational database attributes to ontology nodes 
while building a minimum-cost tree that connects the matched nodes.
Key challenges: 
- Bipartite matching with costs (attributes to domain nodes)
- Tree connectivity constraints  
- Combined optimization of matching and tree costs
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the rel2onto instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create bipartite graph with explicit constraint nodes
    - Variable nodes: attributes (type 0), tree nodes dnodes/cnodes (type 0)  
    - Constraint nodes: matching constraints (type 1), tree constraints (type 1)
    - Edge weights: reflect matching costs and tree edge costs
    - Model both the assignment problem and tree connectivity
    """
    
    # Extract data
    nbA = json_data.get('nbA', 0)
    nbV = json_data.get('nbV', 0)
    nbE = json_data.get('nbE', 0)
    
    tails = json_data.get('tails', [])
    heads = json_data.get('heads', [])
    ws = json_data.get('ws', [])
    
    # Get node sets
    cnodes = json_data.get('cnodes', [])
    dnodes = json_data.get('dnodes', [])
    anodes = json_data.get('anodes', [])
    
    # Get matching data
    attribute_names = json_data.get('attribute_names', [])
    match_costs = json_data.get('match_costs', [])
    
    # Create graph
    G = nx.Graph()
    
    # Add attribute nodes (type 0) - decision variables for matching
    for a in range(nbA):
        attr_name = attribute_names[a] if a < len(attribute_names) else a + 1
        # Weight by average matching cost (normalized) - higher cost = more important
        if match_costs:
            # Calculate average cost for this attribute across all possible matches
            attr_costs = []
            start_idx = a * nbV
            end_idx = (a + 1) * nbV
            if end_idx <= len(match_costs):
                attr_costs = match_costs[start_idx:end_idx]
                avg_cost = sum(c for c in attr_costs if c > 0) / max(len([c for c in attr_costs if c > 0]), 1)
                max_possible = 10000  # Based on data inspection
                weight = min(avg_cost / max_possible, 1.0)
            else:
                weight = 0.5
        else:
            weight = 0.5
        G.add_node(f'attr_{attr_name}', type=0, weight=weight)
    
    # Add tree nodes (type 0) - variables for tree structure
    all_tree_nodes = set(cnodes + dnodes + anodes)
    max_node_id = max(all_tree_nodes) if all_tree_nodes else 1
    
    for node_id in all_tree_nodes:
        # Weight by connectivity potential and node type
        if node_id in cnodes:
            # Connection nodes - weight by potential edges
            edge_count = sum(1 for i in range(len(tails)) if tails[i] == node_id or heads[i] == node_id)
            weight = min(edge_count / max(nbE/nbV, 1), 1.0)
        elif node_id in dnodes:
            # Decision nodes that can be matched - higher importance
            weight = 0.8
        else:  # anodes
            weight = 0.6
            
        G.add_node(f'tree_node_{node_id}', type=0, weight=weight)
    
    # Add matching constraint nodes (type 1) - one per attribute
    for a in range(nbA):
        attr_name = attribute_names[a] if a < len(attribute_names) else a + 1
        # Weight by constraint tightness (fewer valid matches = tighter)
        # Note: In this problem, attribute_domains defines valid matches per attribute
        # For simplicity, use average constraint weight
        weight = 0.7  # Medium tightness for matching constraints
        G.add_node(f'match_constraint_{attr_name}', type=1, weight=weight)
    
    # Add tree constraint nodes (type 1) - model tree connectivity
    # Create constraint nodes for major tree properties
    
    # 1. Tree connectivity constraint
    if nbV > 0:
        G.add_node('tree_connectivity', type=1, weight=0.9)
    
    # 2. Root selection constraint  
    if nbV > 0:
        G.add_node('tree_root', type=1, weight=0.8)
        
    # 3. Edge capacity constraints for tree edges
    for i in range(len(tails)):
        if i < len(heads) and i < len(ws):
            tail = tails[i]
            head = heads[i]
            edge_cost = ws[i]
            
            # Weight by edge cost (normalized) - higher cost = more constraining
            max_cost = max(ws) if ws else 1
            tightness = min(edge_cost / max_cost, 1.0)
            G.add_node(f'edge_constraint_{tail}_{head}', type=1, weight=tightness)
    
    # Add bipartite edges: attributes to matching constraints
    for a in range(nbA):
        attr_name = attribute_names[a] if a < len(attribute_names) else a + 1
        G.add_edge(f'attr_{attr_name}', f'match_constraint_{attr_name}', weight=0.8)
    
    # Add bipartite edges: tree nodes to tree constraints
    for node_id in all_tree_nodes:
        # Connect to tree connectivity constraint
        G.add_edge(f'tree_node_{node_id}', 'tree_connectivity', weight=0.6)
        
        # Connect root candidates to root constraint (prefer cnodes for root)
        if node_id in cnodes:
            G.add_edge(f'tree_node_{node_id}', 'tree_root', weight=0.8)
        else:
            G.add_edge(f'tree_node_{node_id}', 'tree_root', weight=0.4)
    
    # Add tree edge constraints
    for i in range(len(tails)):
        if i < len(heads) and i < len(ws):
            tail = tails[i]
            head = heads[i]
            edge_cost = ws[i]
            
            # Connect tree nodes to their edge constraints
            constraint_name = f'edge_constraint_{tail}_{head}'
            if constraint_name in [n for n in G.nodes()]:
                G.add_edge(f'tree_node_{tail}', constraint_name, weight=0.7)
                G.add_edge(f'tree_node_{head}', constraint_name, weight=0.7)
    
    # Add matching participation edges: dnodes to matching constraints  
    # Each dnode can potentially be matched to any attribute
    for node_id in dnodes:
        for a in range(nbA):
            attr_name = attribute_names[a] if a < len(attribute_names) else a + 1
            
            # Weight by matching cost (inverted - lower cost = stronger connection)
            if match_costs and a < nbA:
                cost_idx = a * nbV + (node_id - 1)  # Adjust for 0-based indexing
                if cost_idx < len(match_costs):
                    cost = match_costs[cost_idx]
                    if cost > 0:
                        max_cost = max(match_costs) if match_costs else 1
                        # Invert cost: higher cost = lower weight
                        weight = max(1.0 - (cost / max_cost), 0.1)
                    else:
                        weight = 0.9  # Zero cost means perfect match
                else:
                    weight = 0.5
            else:
                weight = 0.5
                
            G.add_edge(f'tree_node_{node_id}', f'match_constraint_{attr_name}', weight=weight)
    
    # Add structural edges between tree nodes based on problem graph
    # Add edges with costs based on ws values
    for i in range(len(tails)):
        if i < len(heads) and i < len(ws):
            tail = tails[i]
            head = heads[i]
            edge_cost = ws[i]
            
            # Add direct edge between tree nodes with cost-based weight
            if edge_cost > 0:
                max_cost = max(ws) if ws else 1
                # Use exponential decay for cost-based weighting
                weight = math.exp(-3.0 * edge_cost / max_cost)
            else:
                weight = 1.0
                
            G.add_edge(f'tree_node_{tail}', f'tree_node_{head}', weight=weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()