#!/usr/bin/env python3
"""
Graph converter for amaze3 problem.
Created using subagent_prompt.md version: v_02

This problem is about connecting pairs of endpoints on a grid with non-intersecting paths.
Key challenges: path routing conflicts, endpoint placement, grid connectivity constraints.
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the amaze3 problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create bipartite graph modeling path routing conflicts
    - Grid cells as variable nodes (type 0) - decisions about which path uses each cell
    - Path constraints as constraint nodes (type 1) - endpoint connectivity requirements
    - Cell competition constraints (type 1) - exclusivity constraints for shared cells
    - Edge weights based on path length, routing difficulty, and cell centrality
    """
    X = json_data.get('X', 0)  # Grid width
    Y = json_data.get('Y', 0)  # Grid height  
    N = json_data.get('N', 0)  # Number of pairs
    
    start_x = json_data.get('end_points_start_x', [])
    start_y = json_data.get('end_points_start_y', [])
    end_x = json_data.get('end_points_end_x', [])
    end_y = json_data.get('end_points_end_y', [])
    
    G = nx.Graph()
    
    # Variable nodes: Grid cells (type 0)
    # Weight by centrality and potential for conflicts
    center_x, center_y = X / 2, Y / 2
    for x in range(1, X + 1):
        for y in range(1, Y + 1):
            # Centrality weight - more central cells are more valuable/contested
            dist_from_center = math.sqrt((x - center_x)**2 + (y - center_y)**2)
            max_dist = math.sqrt(center_x**2 + center_y**2)
            centrality = 1.0 - (dist_from_center / max_dist) if max_dist > 0 else 0.5
            
            # Boost weight for cells that might be on multiple shortest paths
            path_potential = 0
            for i in range(N):
                if i < len(start_x) and i < len(start_y) and i < len(end_x) and i < len(end_y):
                    # Manhattan distance from this cell to path endpoints
                    dist_to_start = abs(x - start_x[i]) + abs(y - start_y[i])
                    dist_to_end = abs(x - end_x[i]) + abs(y - end_y[i])
                    path_length = abs(end_x[i] - start_x[i]) + abs(end_y[i] - start_y[i])
                    
                    # If cell is close to being on shortest path, increase potential
                    if path_length > 0 and dist_to_start + dist_to_end <= path_length + 2:
                        path_potential += 1.0 / max(path_length, 1)
            
            # Combine centrality and path potential with non-linear weighting
            weight = min(0.3 * centrality + 0.7 * math.exp(-path_potential), 1.0)
            G.add_node(f'cell_{x}_{y}', type=0, weight=max(weight, 0.1))
    
    # Constraint nodes: Path connectivity constraints (type 1)
    for i in range(N):
        if i < len(start_x) and i < len(start_y) and i < len(end_x) and i < len(end_y):
            # Path length affects difficulty - longer paths are harder to route
            path_length = abs(end_x[i] - start_x[i]) + abs(end_y[i] - start_y[i])
            max_possible_length = X + Y
            
            # Non-linear difficulty scaling
            difficulty = math.log(path_length + 1) / math.log(max_possible_length + 1) if max_possible_length > 0 else 0.5
            G.add_node(f'path_{i}', type=1, weight=difficulty)
    
    # Cell exclusivity constraints (type 1) - each cell can only be used by one path
    total_cells = X * Y
    exclusivity_weight = math.sqrt(N) / math.sqrt(total_cells) if total_cells > 0 else 0.5
    for x in range(1, X + 1):
        for y in range(1, Y + 1):
            G.add_node(f'excl_{x}_{y}', type=1, weight=exclusivity_weight)
    
    # Edges: Variable participation in constraints
    
    # Path constraint edges - connect cells to path constraints they could serve
    for i in range(N):
        if i < len(start_x) and i < len(start_y) and i < len(end_x) and i < len(end_y):
            sx, sy = start_x[i], start_y[i]
            ex, ey = end_x[i], end_y[i]
            path_length = abs(ex - sx) + abs(ey - sy)
            
            # Connect all cells that could potentially be part of this path
            for x in range(1, X + 1):
                for y in range(1, Y + 1):
                    # Distance from cell to path endpoints
                    dist_to_start = abs(x - sx) + abs(y - sy)
                    dist_to_end = abs(x - ex) + abs(y - ey)
                    
                    # Only connect if cell could reasonably be on path
                    if dist_to_start + dist_to_end <= path_length + 4:  # Allow some detour
                        # Weight by how "on-path" the cell is
                        deviation = (dist_to_start + dist_to_end) - path_length
                        if deviation <= 0:
                            weight = 1.0  # On optimal path
                        else:
                            weight = math.exp(-0.5 * deviation)  # Exponential decay for detours
                        
                        G.add_edge(f'cell_{x}_{y}', f'path_{i}', weight=min(weight, 1.0))
    
    # Exclusivity constraint edges - connect each cell to its exclusivity constraint
    for x in range(1, X + 1):
        for y in range(1, Y + 1):
            G.add_edge(f'cell_{x}_{y}', f'excl_{x}_{y}', weight=1.0)
    
    # Add conflict edges between endpoint cells that are close (direct conflicts)
    for i in range(N):
        for j in range(i + 1, N):
            if (i < len(start_x) and i < len(start_y) and i < len(end_x) and i < len(end_y) and
                j < len(start_x) and j < len(start_y) and j < len(end_x) and j < len(end_y)):
                
                # Check if paths might interfere (endpoints close to each other)
                conflicts = []
                endpoints_i = [(start_x[i], start_y[i]), (end_x[i], end_y[i])]
                endpoints_j = [(start_x[j], start_y[j]), (end_x[j], end_y[j])]
                
                for (x1, y1) in endpoints_i:
                    for (x2, y2) in endpoints_j:
                        dist = abs(x1 - x2) + abs(y1 - y2)
                        if dist <= 3:  # Close endpoints create routing conflicts
                            conflicts.append(dist)
                
                if conflicts:
                    # Add conflict edge between path constraints
                    min_dist = min(conflicts)
                    conflict_weight = math.exp(-0.3 * min_dist)
                    G.add_edge(f'path_{i}', f'path_{j}', weight=conflict_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()