#!/usr/bin/env python3
"""
Graph converter for Multi-Agent Path Finding (MAPF) problem.
Created using subagent_prompt.md version: v_02

This problem is about finding collision-free paths for multiple agents on a grid map.
Key challenges: agent conflicts, path coordination, temporal constraints, grid topology
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the MAPF problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Model agents, their start/goal positions, and temporal constraints
    - Agent nodes (type 0): Each agent with weight based on path complexity
    - Position nodes (type 2): Grid positions weighted by congestion/centrality  
    - Temporal constraint nodes (type 1): Time-based collision avoidance constraints
    - Goal constraint nodes (type 1): Agent-specific goal achievement constraints
    """
    # Extract data
    k = json_data.get('k', 0)  # number of agents
    as_data = json_data.get('as', [])  # start/goal positions (flattened array)
    makespan = json_data.get('makespan', 1)
    rel_len = json_data.get('rel_len', 0)  # number of grid positions
    
    # Parse agent start/goal pairs from flattened array
    agent_pairs = []
    for i in range(k):
        if 2*i+1 < len(as_data):
            start = as_data[2*i]
            goal = as_data[2*i+1]
            agent_pairs.append((start, goal))
        else:
            # Fallback if data incomplete
            agent_pairs.append((1, rel_len))
    
    G = nx.Graph()
    
    # Agent nodes (type 0) - weighted by path complexity
    max_distance = rel_len  # approximation of maximum possible distance
    for a in range(k):
        start, goal = agent_pairs[a] if a < len(agent_pairs) else (1, rel_len)
        
        # Estimate path difficulty (higher weight = more complex)
        # Use distance approximation and congestion factors
        distance_factor = abs(goal - start) / max_distance if max_distance > 0 else 0.5
        
        # Agents with similar start/goals create more conflicts
        conflict_factor = 0.0
        for b in range(k):
            if b != a and b < len(agent_pairs):
                other_start, other_goal = agent_pairs[b]
                # Conflict if start/goals are close
                start_conflict = 1.0 / (1.0 + abs(start - other_start))
                goal_conflict = 1.0 / (1.0 + abs(goal - other_goal))
                conflict_factor += (start_conflict + goal_conflict) / 2
        
        # Normalize conflict factor
        conflict_factor = min(conflict_factor / k, 1.0) if k > 1 else 0.0
        
        # Combine factors with non-linear scaling
        agent_weight = 0.3 + 0.4 * distance_factor + 0.3 * conflict_factor
        agent_weight = min(agent_weight, 1.0)
        
        G.add_node(f'agent_{a}', type=0, weight=agent_weight)
    
    # Position nodes (type 2) - weighted by expected congestion
    position_congestion = {}
    for pos in range(1, rel_len + 1):
        congestion = 0.0
        
        # Count how many agents start or end here
        for start, goal in agent_pairs:
            if start == pos or goal == pos:
                congestion += 1.0
        
        # Count how many agents might pass through (crude approximation)
        for start, goal in agent_pairs:
            # If position is between start and goal (very rough approximation)
            if min(start, goal) <= pos <= max(start, goal) and start != goal:
                congestion += 0.3
        
        position_congestion[pos] = congestion
    
    # Normalize position weights
    max_congestion = max(position_congestion.values()) if position_congestion else 1.0
    for pos in range(1, rel_len + 1):
        congestion = position_congestion.get(pos, 0.0)
        # Use exponential scaling to emphasize high-congestion positions
        pos_weight = 0.1 + 0.9 * (1.0 - math.exp(-2.0 * congestion / max_congestion))
        G.add_node(f'pos_{pos}', type=2, weight=pos_weight)
    
    # Goal constraint nodes (type 1) - one per agent
    for a in range(k):
        start, goal = agent_pairs[a] if a < len(agent_pairs) else (1, rel_len)
        
        # Goal constraint difficulty based on position congestion and distance
        goal_congestion = position_congestion.get(goal, 0.0)
        distance_factor = abs(goal - start) / max_distance if max_distance > 0 else 0.5
        
        # Harder goals are those with high congestion and long distances
        goal_weight = 0.4 + 0.3 * (goal_congestion / max_congestion) + 0.3 * distance_factor
        goal_weight = min(goal_weight, 1.0)
        
        G.add_node(f'goal_constraint_{a}', type=1, weight=goal_weight)
    
    # Temporal collision constraints (type 1) - for high-congestion time-position pairs
    # Create constraints for positions with high expected traffic
    high_congestion_positions = [pos for pos, cong in position_congestion.items() 
                               if cong >= 2.0]  # Positions with multiple agents
    
    for pos in high_congestion_positions:
        # Weight by congestion level and temporal complexity
        congestion = position_congestion[pos]
        temporal_weight = 0.5 + 0.5 * min(congestion / k, 1.0)
        G.add_node(f'collision_constraint_{pos}', type=1, weight=temporal_weight)
    
    # Overall coordination constraint - captures global complexity
    coordination_weight = 0.3 + 0.7 * min(k / 20.0, 1.0)  # More agents = harder coordination
    G.add_node('coordination_constraint', type=1, weight=coordination_weight)
    
    # Add edges for relationships
    
    # Agent to goal constraint edges
    for a in range(k):
        start, goal = agent_pairs[a] if a < len(agent_pairs) else (1, rel_len)
        G.add_edge(f'agent_{a}', f'goal_constraint_{a}', weight=1.0)
    
    # Agent to position edges (start and goal positions)
    for a in range(k):
        start, goal = agent_pairs[a] if a < len(agent_pairs) else (1, rel_len)
        
        # Strong connection to start position
        G.add_edge(f'agent_{a}', f'pos_{start}', weight=0.9)
        # Strong connection to goal position  
        G.add_edge(f'agent_{a}', f'pos_{goal}', weight=1.0)
        
        # Weaker connections to intermediate positions (approximation)
        for pos in range(min(start, goal) + 1, max(start, goal)):
            if pos in position_congestion:
                # Weight decreases with distance from path
                path_weight = 0.3 * math.exp(-0.5 * abs(pos - (start + goal) / 2))
                if path_weight > 0.1:
                    G.add_edge(f'agent_{a}', f'pos_{pos}', weight=path_weight)
    
    # Position to collision constraint edges
    for pos in high_congestion_positions:
        congestion = position_congestion[pos]
        constraint_weight = min(congestion / k, 1.0)
        G.add_edge(f'pos_{pos}', f'collision_constraint_{pos}', weight=constraint_weight)
    
    # Agent to coordination constraint edges
    for a in range(k):
        start, goal = agent_pairs[a] if a < len(agent_pairs) else (1, rel_len)
        
        # Weight by agent's contribution to overall complexity
        agent_node = G.nodes[f'agent_{a}']
        agent_complexity = agent_node['weight']
        G.add_edge(f'agent_{a}', 'coordination_constraint', weight=agent_complexity)
    
    # Agent conflict edges (direct conflicts between agents with overlapping paths)
    for a in range(k):
        for b in range(a + 1, k):
            start_a, goal_a = agent_pairs[a] if a < len(agent_pairs) else (1, rel_len)
            start_b, goal_b = agent_pairs[b] if b < len(agent_pairs) else (1, rel_len)
            
            # Check for path conflicts (crude approximation)
            conflict_score = 0.0
            
            # Start/goal proximity conflicts
            if abs(start_a - start_b) <= 3:
                conflict_score += 0.3
            if abs(goal_a - goal_b) <= 3:
                conflict_score += 0.3
            if abs(start_a - goal_b) <= 2 or abs(goal_a - start_b) <= 2:
                conflict_score += 0.4  # Crossing paths
            
            # Path overlap (very rough approximation)
            path_a_range = range(min(start_a, goal_a), max(start_a, goal_a) + 1)
            path_b_range = range(min(start_b, goal_b), max(start_b, goal_b) + 1)
            overlap = len(set(path_a_range) & set(path_b_range))
            if overlap > 1:
                conflict_score += 0.4 * min(overlap / 5.0, 1.0)
            
            if conflict_score > 0.2:  # Only add edge if significant conflict
                conflict_weight = min(conflict_score, 1.0)
                G.add_edge(f'agent_{a}', f'agent_{b}', weight=conflict_weight)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()