import re
import ast
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np


def parse_dag(input_text):
    nodes = {}
    current_node = None

    for line in input_text.split('\n'):
        line = line.strip()
        if not line:
            continue

        if match := re.match(r'^node_(\d+)\s*:$', line):
            current_node = f"node_{match.group(1)}"
            nodes[current_node] = {'edges': []}
            continue

        if current_node and (match := re.match(r'^\s*(\w+):\s*(.*)', line)):
            key, value = match.groups()

            if key == 'edge':
                try:
                    edges = ast.literal_eval(value)
                    nodes[current_node]['edges'] = [f'node_{e}' for e in edges]
                except:
                    nodes[current_node]['edges'] = []
            else:
                nodes[current_node][key] = value

    return nodes


def visualize_dag(nodes):
    G = nx.DiGraph()

    COLOR_MAP = {
        'pick': '#90C9F9',       
        'place': '#B0E57C',      
        'press': '#A2D2FF',      
        'slide_open': '#FF9A9E', 
        'slide_close': '#FF9A9E',
        'flap_open': '#FFDA9E',  
        'flap_close': '#FFB347', 
        'push_to': '#D291BC',    
        'lift_from': '#AED9E0',  
        'insert_into': '#B19CD9',
        'open_cap': '#FFB6C1',   
        'close_cap': '#FA8072',  
        'wipe': '#C1E1C1',       
        'stick_on': '#FFC0CB',   
        'pour_into': '#DDA0DD',  
        'cut': '#FF6961',        
        'stir': '#FFD1DC',       
        'press_open': '#F4A460', 
        'press_close': '#DEB887',
        'write': '#C0C0C0',      
        'inspect': '#7FFFD4',    
        'bind': '#DA70D6',       
        'weld': '#CD5C5C',       
        'scan': '#40E0D0',       
        'tighten': '#E6E6FA',    
        'align': '#FFE4B5',      
        'assemble': '#98FB98',   
        'drill': '#9370DB',      
        'mark': '#FFA07A',       
        'task_completion': '#FFD700'
    }
    
    for node_id, attrs in nodes.items():
        node_type = attrs.get('type', 'default')
        label = f"{node_id}: {attrs.get('name', '')}\n({attrs.get('arm_num', '?')} arm)"
        G.add_node(node_id,
                   label=label,
                   color=COLOR_MAP.get(node_type, '#FFFFFF'),
                   node_type=node_type)
    
    for node_id, attrs in nodes.items():
        for edge in attrs['edges']:
            if edge in nodes:
                G.add_edge(edge, node_id)

    pos = {}
    layer_spacing = {'x': 4.0, 'y': 2.5}

    try:
        layers = list(nx.topological_generations(G))
    except nx.NetworkXUnfeasible:
        print("Error: Dependencies form a cycle.")
        return

    for layer_idx, layer in enumerate(layers):
        x_coords = np.linspace(
            start=-(len(layer) - 1) * layer_spacing['x'] / 2,
            stop=(len(layer) - 1) * layer_spacing['x'] / 2,
            num=len(layer)
        )

        for node, x in zip(layer, x_coords):
            pos[node] = (x, -layer_idx * layer_spacing['y'])

    plt.figure(figsize=(25, 16))

    nx.draw_networkx_nodes(
        G, pos,
        node_size=1000,
        node_color=[G.nodes[n]['color'] for n in G.nodes],
        edgecolors='#333333',
        linewidths=1.5,
        alpha=0.95
    )

    nx.draw_networkx_edges(
        G, pos,
        arrows=True,
        arrowsize=25,
        arrowstyle='->,head_width=0.7,head_length=0.7',
        edge_color='#606060',
        width=2.0,
        connectionstyle='arc3,rad=0.15'
    )

    label_options = {
        'labels': {n: G.nodes[n]['label'] for n in G.nodes},
        'font_size': 11,
        'font_family': 'Arial',
        'verticalalignment': 'center'
    }
    nx.draw_networkx_labels(G, pos,  ** label_options)

    plt.axis('off')
    plt.savefig('fixed_dag.png', dpi=300, bbox_inches='tight')
    plt.show()


input_text = """
YOUR_DAG_NODES_HERE
"""

dag_data = parse_dag(input_text)
visualize_dag(dag_data)