#!/usr/bin/env python3
"""
Graph converter for amaze3_2 problem.
# Converter created with subagent_prompt.md v_02

This problem is about connecting pairs of endpoints on a grid with non-crossing paths.
Each pair must be connected by a path, endpoints have exactly one neighbor, 
interior points have exactly two neighbors. Paths cannot cross.

Key challenges: 
- Path routing and avoiding crossings
- Grid connectivity constraints
- Balancing path lengths and conflicts
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the amaze problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph modeling grid cells, pairs, and constraints
    - Type 0 nodes: Grid cells (decision variables)
    - Type 1 nodes: Constraints (endpoint, connectivity, non-crossing)
    - Type 2 nodes: Pairs (resources that need paths)
    - Edges model participation in constraints and resource usage
    """
    # Extract problem parameters
    X = json_data.get('X', 0)
    Y = json_data.get('Y', 0)
    N = json_data.get('N', 0)
    
    start_x = json_data.get('end_points_start_x', [])
    start_y = json_data.get('end_points_start_y', [])
    end_x = json_data.get('end_points_end_x', [])
    end_y = json_data.get('end_points_end_y', [])
    
    G = nx.Graph()
    
    # Type 0 nodes: Grid cells with position-based weights
    for x in range(1, X + 1):
        for y in range(1, Y + 1):
            # Central cells are more constrained (higher weight)
            centrality = 1.0 - (abs(x - X//2) + abs(y - Y//2)) / (X + Y)
            
            # Check if this is an endpoint
            is_endpoint = False
            for i in range(len(start_x)):
                if (start_x[i] == x and start_y[i] == y) or (end_x[i] == x and end_y[i] == y):
                    is_endpoint = True
                    break
            
            # Endpoints are more critical
            weight = min(centrality + (0.3 if is_endpoint else 0.0), 1.0)
            G.add_node(f'cell_{x}_{y}', type=0, weight=weight)
    
    # Type 2 nodes: Pairs (resources)
    for i in range(N):
        if i < len(start_x) and i < len(end_x):
            # Calculate Manhattan distance between endpoints
            dist = abs(end_x[i] - start_x[i]) + abs(end_y[i] - start_y[i])
            max_dist = X + Y - 2
            
            # Longer paths are more difficult (exponential scaling)
            difficulty = 1.0 - math.exp(-3.0 * dist / max_dist) if max_dist > 0 else 0.5
            G.add_node(f'pair_{i+1}', type=2, weight=difficulty)
    
    # Type 1 nodes: Constraints
    
    # 1. Endpoint constraints (high weight - critical)
    for i in range(N):
        if i < len(start_x):
            G.add_node(f'endpoint_start_{i+1}', type=1, weight=0.9)
        if i < len(end_x):
            G.add_node(f'endpoint_end_{i+1}', type=1, weight=0.9)
    
    # 2. Cell connectivity constraints (medium weight)
    for x in range(1, X + 1):
        for y in range(1, Y + 1):
            # Count neighbors
            neighbor_count = 0
            for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                nx_pos, ny_pos = x + dx, y + dy
                if 1 <= nx_pos <= X and 1 <= ny_pos <= Y:
                    neighbor_count += 1
            
            # Cells with fewer neighbors are more constrained
            constraint_weight = 1.0 - (neighbor_count / 4.0)
            G.add_node(f'connectivity_{x}_{y}', type=1, weight=constraint_weight)
    
    # 3. Non-crossing constraints (variable weight based on crossing potential)
    crossing_id = 0
    for x in range(1, X + 1):
        for y in range(1, Y + 1):
            # Check potential for path crossings at this cell
            crossing_potential = 0
            for i in range(N):
                if i < len(start_x) and i < len(end_x):
                    # Check if this cell is on potential path for pair i
                    if (min(start_x[i], end_x[i]) <= x <= max(start_x[i], end_x[i]) and
                        min(start_y[i], end_y[i]) <= y <= max(start_y[i], end_y[i])):
                        crossing_potential += 1
            
            if crossing_potential > 1:
                weight = min(crossing_potential / N, 1.0)
                G.add_node(f'crossing_{crossing_id}', type=1, weight=weight)
                crossing_id += 1
    
    # Edges: Model constraint participation and resource usage
    
    # 1. Endpoint constraint edges
    for i in range(N):
        if i < len(start_x):
            start_cell = f'cell_{start_x[i]}_{start_y[i]}'
            endpoint_constraint = f'endpoint_start_{i+1}'
            pair_node = f'pair_{i+1}'
            
            if start_cell in G.nodes:
                G.add_edge(start_cell, endpoint_constraint, weight=1.0)
                G.add_edge(pair_node, endpoint_constraint, weight=0.8)
        
        if i < len(end_x):
            end_cell = f'cell_{end_x[i]}_{end_y[i]}'
            endpoint_constraint = f'endpoint_end_{i+1}'
            pair_node = f'pair_{i+1}'
            
            if end_cell in G.nodes:
                G.add_edge(end_cell, endpoint_constraint, weight=1.0)
                G.add_edge(pair_node, endpoint_constraint, weight=0.8)
    
    # 2. Connectivity constraint edges (cells to their connectivity constraints)
    for x in range(1, X + 1):
        for y in range(1, Y + 1):
            cell = f'cell_{x}_{y}'
            connectivity = f'connectivity_{x}_{y}'
            
            # Connect cell to its connectivity constraint
            G.add_edge(cell, connectivity, weight=0.7)
            
            # Connect neighboring cells to this connectivity constraint
            for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                nx_pos, ny_pos = x + dx, y + dy
                if 1 <= nx_pos <= X and 1 <= ny_pos <= Y:
                    neighbor_cell = f'cell_{nx_pos}_{ny_pos}'
                    G.add_edge(neighbor_cell, connectivity, weight=0.4)
    
    # 3. Resource usage edges (pairs to cells they might use)
    for i in range(N):
        if i < len(start_x) and i < len(end_x):
            pair_node = f'pair_{i+1}'
            
            # Connect pair to cells in its bounding box
            min_x, max_x = min(start_x[i], end_x[i]), max(start_x[i], end_x[i])
            min_y, max_y = min(start_y[i], end_y[i]), max(start_y[i], end_y[i])
            
            total_cells = (max_x - min_x + 1) * (max_y - min_y + 1)
            
            for x in range(min_x, max_x + 1):
                for y in range(min_y, max_y + 1):
                    cell = f'cell_{x}_{y}'
                    if cell in G.nodes:
                        # Weight based on distance from endpoints (exponential decay)
                        dist_to_start = abs(x - start_x[i]) + abs(y - start_y[i])
                        dist_to_end = abs(x - end_x[i]) + abs(y - end_y[i])
                        min_dist = min(dist_to_start, dist_to_end)
                        max_dist = abs(end_x[i] - start_x[i]) + abs(end_y[i] - start_y[i])
                        
                        if max_dist > 0:
                            weight = math.exp(-2.0 * min_dist / max_dist)
                        else:
                            weight = 1.0
                        
                        G.add_edge(pair_node, cell, weight=weight)
    
    # 4. Conflict edges between pairs that might compete for the same cells
    for i in range(N):
        for j in range(i + 1, N):
            if (i < len(start_x) and i < len(end_x) and 
                j < len(start_x) and j < len(end_x)):
                
                # Check if bounding boxes overlap
                min_x1, max_x1 = min(start_x[i], end_x[i]), max(start_x[i], end_x[i])
                min_y1, max_y1 = min(start_y[i], end_y[i]), max(start_y[i], end_y[i])
                min_x2, max_x2 = min(start_x[j], end_x[j]), max(start_x[j], end_x[j])
                min_y2, max_y2 = min(start_y[j], end_y[j]), max(start_y[j], end_y[j])
                
                overlap_x = max(0, min(max_x1, max_x2) - max(min_x1, min_x2) + 1)
                overlap_y = max(0, min(max_y1, max_y2) - max(min_y1, min_y2) + 1)
                overlap_area = overlap_x * overlap_y
                
                if overlap_area > 0:
                    total_area = (max_x1 - min_x1 + 1) * (max_y1 - min_y1 + 1) + (max_x2 - min_x2 + 1) * (max_y2 - min_y2 + 1)
                    conflict_weight = min(2.0 * overlap_area / total_area, 1.0) if total_area > 0 else 0.5
                    G.add_edge(f'pair_{i+1}', f'pair_{j+1}', weight=conflict_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()