#!/usr/bin/env python3
"""
Graph converter for radiation therapy optimization problem.
Created using subagent_prompt.md version: v_02

This problem is about radiation therapy treatment planning where we need to
minimize beam-on time and number of shape matrices while satisfying intensity
requirements for each tissue cell.

Key challenges: 
- Balancing intensity requirements across a 2D grid
- Minimizing both beam time and number of shape matrices
- Shape matrix constraints that create complex dependencies
- Non-uniform intensity distributions create varying difficulty
"""

import sys
import json
import math
import networkx as nx
from pathlib import Path


def build_graph(mzn_file, json_data):
    """
    Build graph representation of the radiation therapy problem instance.
    
    Args:
        mzn_file: Path to .mzn file (for reference)
        json_data: Dict containing parsed DZN data
    
    Strategy: Create a bipartite graph with:
    - Cell nodes (type 0): Individual tissue cells requiring radiation
    - Constraint nodes (type 1): Intensity constraints and shape matrix constraints
    - Edge weights reflect intensity requirements and spatial relationships
    - Node weights reflect cell criticality and constraint complexity
    """
    # Access data from json_data dict
    m = json_data.get('m', 0)  # rows
    n = json_data.get('n', 0)  # columns
    intensity_flat = json_data.get('Intensity', [])
    
    if len(intensity_flat) != m * n:
        raise ValueError(f"Intensity array size {len(intensity_flat)} doesn't match grid size {m}x{n}")
    
    # Convert flat intensity array to 2D matrix
    intensity = []
    for i in range(m):
        row = []
        for j in range(n):
            idx = i * n + j
            row.append(intensity_flat[idx] if idx < len(intensity_flat) else 0)
        intensity.append(row)
    
    G = nx.Graph()
    
    # Calculate problem-wide statistics for normalization
    max_intensity = max(intensity_flat) if intensity_flat else 1
    total_intensity = sum(intensity_flat)
    avg_intensity = total_intensity / len(intensity_flat) if intensity_flat else 0
    
    # Add cell nodes (type 0) - each tissue cell that needs radiation
    for i in range(m):
        for j in range(n):
            cell_intensity = intensity[i][j]
            
            # Cell weight based on relative intensity and spatial centrality
            # Higher intensity cells are more critical
            intensity_weight = cell_intensity / max_intensity if max_intensity > 0 else 0
            
            # Central cells are typically more constrained in radiation therapy
            center_i, center_j = (m-1)/2, (n-1)/2
            dist_from_center = math.sqrt((i - center_i)**2 + (j - center_j)**2)
            max_dist = math.sqrt(center_i**2 + center_j**2)
            centrality_bonus = 1.0 - (dist_from_center / max_dist) if max_dist > 0 else 0.5
            
            # Combine intensity importance with spatial constraints
            # Use non-linear scaling to emphasize high-intensity cells
            cell_weight = 0.3 * centrality_bonus + 0.7 * (intensity_weight ** 0.7)
            
            G.add_node(f'cell_{i}_{j}', type=0, weight=min(cell_weight, 1.0))
    
    # Add intensity constraint nodes (type 1) - one per cell's intensity requirement
    for i in range(m):
        for j in range(n):
            cell_intensity = intensity[i][j]
            
            if cell_intensity > 0:  # Only create constraints for cells needing radiation
                # Constraint complexity based on intensity requirement
                # Higher intensities create more complex constraints
                complexity = math.log(1 + cell_intensity) / math.log(1 + max_intensity) if max_intensity > 0 else 0.5
                
                G.add_node(f'intensity_constraint_{i}_{j}', type=1, weight=complexity)
                
                # Connect cell to its intensity constraint
                # Edge weight reflects how demanding this constraint is
                demand_ratio = cell_intensity / max_intensity if max_intensity > 0 else 0.5
                G.add_edge(f'cell_{i}_{j}', f'intensity_constraint_{i}_{j}', weight=demand_ratio)
    
    # Add row-based shape matrix constraints (type 1)
    # These model the constraint that shape matrices affect entire rows
    for i in range(m):
        row_intensities = [intensity[i][j] for j in range(n)]
        row_total = sum(row_intensities)
        row_variation = max(row_intensities) - min(row_intensities) if row_intensities else 0
        
        # Row constraint complexity based on total intensity and variation
        if row_total > 0:
            complexity = 0.5 * (row_total / total_intensity) + 0.5 * (row_variation / max_intensity)
            G.add_node(f'row_constraint_{i}', type=1, weight=min(complexity, 1.0))
            
            # Connect all cells in this row to the row constraint
            for j in range(n):
                if intensity[i][j] > 0:
                    # Edge weight based on cell's contribution to row complexity
                    contribution = intensity[i][j] / row_total if row_total > 0 else 0
                    G.add_edge(f'cell_{i}_{j}', f'row_constraint_{i}', weight=contribution)
    
    # Add beam-time optimization constraint (global constraint)
    # This models the overall optimization challenge
    total_beam_time_weight = math.tanh(total_intensity / (m * n * 10))  # Normalize by problem size
    G.add_node('beam_time_constraint', type=1, weight=total_beam_time_weight)
    
    # Connect high-intensity cells to beam-time constraint with higher weights
    for i in range(m):
        for j in range(n):
            if intensity[i][j] > avg_intensity:  # Only connect above-average cells
                impact = (intensity[i][j] - avg_intensity) / (max_intensity - avg_intensity) if max_intensity > avg_intensity else 0.5
                G.add_edge(f'cell_{i}_{j}', 'beam_time_constraint', weight=impact)
    
    # Add spatial adjacency constraints for neighboring high-intensity cells
    # High-intensity adjacent cells create additional complexity
    for i in range(m):
        for j in range(n):
            if intensity[i][j] > avg_intensity:
                # Check adjacent cells (4-connected)
                for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
                    ni, nj = i + di, j + dj
                    if 0 <= ni < m and 0 <= nj < n and intensity[ni][nj] > avg_intensity:
                        # Create constraint for adjacent high-intensity cells
                        constraint_id = f'adjacency_{min(i,ni)}_{min(j,nj)}_{max(i,ni)}_{max(j,nj)}'
                        if not G.has_node(constraint_id):
                            # Constraint weight based on combined intensity
                            combined_intensity = (intensity[i][j] + intensity[ni][nj]) / (2 * max_intensity)
                            G.add_node(constraint_id, type=1, weight=combined_intensity)
                        
                        # Connect both cells to this adjacency constraint
                        weight1 = intensity[i][j] / (intensity[i][j] + intensity[ni][nj])
                        weight2 = intensity[ni][nj] / (intensity[i][j] + intensity[ni][nj])
                        
                        if not G.has_edge(f'cell_{i}_{j}', constraint_id):
                            G.add_edge(f'cell_{i}_{j}', constraint_id, weight=weight1)
                        if not G.has_edge(f'cell_{ni}_{nj}', constraint_id):
                            G.add_edge(f'cell_{ni}_{nj}', constraint_id, weight=weight2)
    
    return G


def main():
    if len(sys.argv) != 4:
        print("Usage: python converter.py <mzn_file> <dzn_file> <json_file>")
        sys.exit(1)
    
    mzn_file = sys.argv[1]
    dzn_file = sys.argv[2]
    json_file = sys.argv[3]
    
    # Load JSON data
    with open(json_file, 'r') as f:
        json_data = json.load(f)
    
    # Build graph
    G = build_graph(mzn_file, json_data)
    
    # Graph is returned by build_graph for direct feature extraction
    print(f"Graph built: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


if __name__ == "__main__":
    main()