from simworld.utils.vector import Vector
from simworld.map.map import Road, Edge
from simworld.map.map import Node as SimworldNode
import sys
import json
from PyQt5.QtWidgets import QApplication, QWidget
from PyQt5.QtGui import QPainter, QPen, QColor, QImage
from PyQt5.QtCore import Qt
import random
import heapq

class Node(SimworldNode):
    def __init__(self, position: Vector, type: str = 'normal'):
        super().__init__(position, type)
        self.obstacle = False

class Map:
    def __init__(self, config, traffic_signals: list = None):
        self.nodes = set()
        self.edges = set()
        self.roads = []
        self.adjacency_list = {}
        self.config = config
        self.traffic_signals = traffic_signals

    def add_node(self, node: Node):
        self.nodes.add(node)
        self.adjacency_list[node] = []

    def add_edge(self, edge: Edge):
        self.edges.add(edge)
        self.adjacency_list[edge.node1].append(edge.node2)
        self.adjacency_list[edge.node2].append(edge.node1)

    def has_edge(self, edge: Edge):
        return edge in self.edges
    
    def get_closest_node(self, position: Vector):
        min_distance = float('inf')
        closest_node = None
        for node in self.nodes:
            distance = position.distance(node.position)
            if distance < min_distance:
                min_distance = distance
                closest_node = node
        return closest_node

    def get_closest_non_obstacle_node(self, position: Vector):
        """
        Find the closest node to the given position that is not an obstacle.
        """
        min_distance = float('inf')
        closest_node = None
        for node in self.nodes:
            if getattr(node, "obstacle", False):
                continue
            distance = position.distance(node.position)
            if distance < min_distance:
                min_distance = distance
                closest_node = node
        return closest_node

    def get_random_node(self):
        return random.choice(list(self.nodes))

    def get_route(self, start: Node, end: Node, avoid_obstacle: bool = False):
        """
        Get the shortest path between two nodes using A* algorithm. 
        Include the start node and end node in the path.
        If avoid_obstacle is True, path will avoid nodes with obstacle=True.
        """
        open_heap = []
        heapq.heappush(open_heap, (start.position.distance(end.position), 0, start))
        came_from = {}
        g_score = {start: 0}
        closed_set = set()

        while open_heap:
            _, current_g, current = heapq.heappop(open_heap)
            if current == end:
                return self._reconstruct_path(start, end, came_from)
            if current in closed_set:
                continue
            closed_set.add(current)

            for neighbor in self.adjacency_list.get(current, []):
                if avoid_obstacle and getattr(neighbor, "obstacle", False):
                    continue
                tentative_g = g_score[current] + current.position.distance(neighbor.position)
                if neighbor in closed_set and tentative_g >= g_score.get(neighbor, float('inf')):
                    continue
                if tentative_g < g_score.get(neighbor, float('inf')):
                    came_from[neighbor] = current
                    g_score[neighbor] = tentative_g
                    f_score = tentative_g + neighbor.position.distance(end.position)
                    heapq.heappush(open_heap, (f_score, tentative_g, neighbor))
        return None

    def _reconstruct_path(self, start, end, came_from):
        """
        Reconstruct the path from the end node to the start node using the came_from dictionary.
        """
        current = end
        path = [current]
        while current != start:
            # Use the came_from dictionary instead of g_score
            if current not in came_from:
                # If no path is found, return None
                return None
            current = came_from[current]
            path.append(current)
        # Reverse the path, from start to end
        return path[::-1]

    def initialize_map_from_file(self, file_path: str, fine_grained: bool = False):
        with open(file_path, 'r') as f:
            data = json.load(f)

        # Create a dictionary to store nodes by their positions to avoid duplicates

        road_items = data.get('roads', [])
        road_objects = []
        for road in road_items:
            start = Vector(road['start']['x'] * 100, road['start']['y'] * 100)
            end = Vector(road['end']['x'] * 100, road['end']['y'] * 100)
            road_objects.append(Road(start, end))
            self.roads.append(road_objects[-1])

        for road in road_objects:
            normal = Vector(road.direction.y, -road.direction.x)
            offset = self.config['pysbench.sidewalk_offset']
            p1 = road.start - normal * offset + road.direction * offset
            p2 = road.end - normal * offset - road.direction * offset
            p3 = road.end + normal * offset - road.direction * offset
            p4 = road.start + normal * offset + road.direction * offset

            nodes = [Node(point, 'intersection') for point in (p1, p2, p3, p4)]
            for node in nodes:
                self.add_node(node)


            # Add normal node at midpoint of each sidewalk for coarse-grained map
            if not fine_grained:
                # Calculate midpoints for both sidewalks
                sidewalk1_mid = (p1 + p2) * 0.5
                sidewalk2_mid = (p3 + p4) * 0.5
                
                # Create normal nodes at midpoints
                mid_node1 = Node(sidewalk1_mid, 'normal')
                mid_node2 = Node(sidewalk2_mid, 'normal')
                
                # Add nodes to map
                self.add_node(mid_node1)
                self.add_node(mid_node2)
                
                # Connect mid nodes to their respective sidewalk endpoints
                self.add_edge(Edge(nodes[0], mid_node1))
                self.add_edge(Edge(mid_node1, nodes[1]))
                self.add_edge(Edge(nodes[2], mid_node2))
                self.add_edge(Edge(mid_node2, nodes[3]))
                self.add_edge(Edge(nodes[0], nodes[3]))
                self.add_edge(Edge(nodes[1], nodes[2]))
            else:
                # Add edges between intersection nodes
                self.add_edge(Edge(nodes[0], nodes[1]))
                self.add_edge(Edge(nodes[2], nodes[3]))
                self.add_edge(Edge(nodes[0], nodes[3]))
                self.add_edge(Edge(nodes[1], nodes[2]))

        self._connect_adjacent_roads()

        if fine_grained:
            self._interpolate_nodes()

    def _connect_adjacent_roads(self) -> None:
        """Link nodes from nearby roads within a threshold."""
        nodes = [node for node in self.nodes if getattr(node, 'type', None) == 'intersection']
        threshold = self.config['pysbench.sidewalk_offset'] * 2 + 100
        for i in range(len(nodes)):
            for j in range(i + 1, len(nodes)):
                n1, n2 = nodes[i], nodes[j]
                if (n1.position.distance(n2.position) < threshold and
                        not self.has_edge(Edge(n1, n2))):
                    self.add_edge(Edge(n1, n2))

    def _interpolate_nodes(self):
        """
        Interpolate normal nodes between existing nodes along each edge.
        For each edge, insert a node every waypoints_distance.
        Each interpolation point is classified as crosswalk or sidewalk based on the length of the edge.
        Sidewalk points are further classified as near_road, middle, or far_road.
        """
        waypoints_distance = self.config['pysbench.waypoints_distance']
        waypoints_normal_distance = self.config['pysbench.waypoints_normal_distance']
        sidewalk_offset = self.config['pysbench.sidewalk_offset']

        # Copy the current edges to avoid modifying the set during iteration
        original_edges = list(self.edges)

        for edge in original_edges:
            start = edge.node1.position
            end = edge.node2.position
            direction = (end - start).normalize()
            length = start.distance(end)
            num_points = int(length // waypoints_distance)

            is_crosswalk = abs(length - 2 * sidewalk_offset) < 1e-3  
            node_type = 'crosswalk' if is_crosswalk else 'sidewalk'

            # The first and last nodes are intersections (only one node)
            intersection_start = edge.node1
            intersection_end = edge.node2

            # Store layers of nodes for connection
            layers = []

            # Add the first intersection node as the first layer (single node)
            layers.append([intersection_start])

            # Insert nodes along the edge
            if num_points < 2:
                pos = start + direction * (length / 2)
                normal = Vector(-direction.y, direction.x)
                offsets = [
                    4 * waypoints_normal_distance,
                    waypoints_normal_distance,
                    -waypoints_normal_distance,
                    -3 * waypoints_normal_distance
                ]
                nodes = [Node(pos + normal * offset, node_type) for offset in offsets]
                if node_type == 'sidewalk':
                    dists = [
                        (self._nearest_road_distance(node), node) for node in nodes
                    ]
                    dists.sort(key=lambda x: x[0])
                    dists[0][1].type = 'sidewalk_near_road'
                    dists[1][1].type = 'normal'
                    dists[2][1].type = 'sidewalk_middle'
                    dists[3][1].type = 'sidewalk_far_road'
                for node in nodes:
                    self.add_node(node)
                layers.append(nodes)
            else:
                for i in range(1, num_points):
                    pos = start + direction * (i * waypoints_distance)
                    normal = Vector(-direction.y, direction.x)
                    if node_type == 'sidewalk':
                        offsets = [
                            4 * waypoints_normal_distance,
                            waypoints_normal_distance,
                            -waypoints_normal_distance,
                            -3 * waypoints_normal_distance
                        ]
                        nodes = [Node(pos + normal * offset, node_type) for offset in offsets]
                        dists = [
                            (self._nearest_road_distance(node), node) for node in nodes
                        ]
                        dists.sort(key=lambda x: x[0])
                        dists[0][1].type = 'sidewalk_near_road'
                        dists[1][1].type = 'normal'
                        dists[2][1].type = 'sidewalk_middle'
                        dists[3][1].type = 'sidewalk_far_road'
                    else:
                        offsets = [
                            3 * waypoints_normal_distance,
                            waypoints_normal_distance,
                            -waypoints_normal_distance,
                            -3 * waypoints_normal_distance
                        ]
                        nodes = [Node(pos + normal * offset, node_type) for offset in offsets]
                    for node in nodes:
                        self.add_node(node)
                    layers.append(nodes)

            # Add the last intersection node as the last layer (single node)
            layers.append([intersection_end])

            # Connect nodes between consecutive layers
            for i in range(len(layers) - 1):
                current_layer = layers[i]
                next_layer = layers[i + 1]
                for node_a in current_layer:
                    for node_b in next_layer:
                        # Avoid duplicate edges
                        if not self.has_edge(Edge(node_a, node_b)):
                            self.add_edge(Edge(node_a, node_b))

            # Remove the original long edge
            if edge in self.edges:
                self.edges.remove(edge)
                # Optionally, also remove from adjacency_list if needed
                if edge.node2 in self.adjacency_list.get(edge.node1, []):
                    self.adjacency_list[edge.node1].remove(edge.node2)
                if edge.node1 in self.adjacency_list.get(edge.node2, []):
                    self.adjacency_list[edge.node2].remove(edge.node1)

    def get_sidewalk_near_road_nodes(self):
        return [node for node in self.nodes if getattr(node, 'type', None) == 'sidewalk_near_road']
    
    def get_sidewalk_middle_nodes(self):
        return [node for node in self.nodes if getattr(node, 'type', None) == 'sidewalk_middle']
    
    def get_sidewalk_far_road_nodes(self):
        return [node for node in self.nodes if getattr(node, 'type', None) == 'sidewalk_far_road']
    
    def get_normal_nodes(self):
        return [node for node in self.nodes if getattr(node, 'type', None) == 'normal']
    
    def get_crosswalk_nodes(self):
        return [node for node in self.nodes if getattr(node, 'type', None) == 'crosswalk']

    def get_intersection_nodes(self):
        return [node for node in self.nodes if getattr(node, 'type', None) == 'intersection']


    def _point_to_segment_distance(self, p, a, b):
        ab = b - a
        ap = p - a
        ab_len2 = ab.x ** 2 + ab.y ** 2
        if ab_len2 == 0:
            return ap.length()
        t = max(0, min(1, (ap.x * ab.x + ap.y * ab.y) / ab_len2))
        proj = a + ab * t
        return (p - proj).length()

    def _nearest_road_distance(self, node):
        return min(self._point_to_segment_distance(node.position, road.start, road.end) for road in self.roads)

    def visualize_by_type(self):
        """
        Visualize the map by node type.
        """
        class TypeViewer(QWidget):
            def __init__(self, nodes, edges):
                super().__init__()
                self.nodes = nodes
                self.edges = edges
                self.setMinimumSize(800, 800)
                self.setWindowTitle('Map Visualization by Type')
                self._set_bounds()
                
                self.scale = 1.0
                self.offset_x = 0
                self.offset_y = 0
                self.last_mouse_pos = None

            def _set_bounds(self):
                self.min_x = min(node.position.x for node in self.nodes)
                self.max_x = max(node.position.x for node in self.nodes)
                self.min_y = min(node.position.y for node in self.nodes)
                self.max_y = max(node.position.y for node in self.nodes)

            def paintEvent(self, event):
                painter = QPainter(self)
                painter.setRenderHint(QPainter.Antialiasing)
                width, height, margin = self.width(), self.height(), 50
                scale_x = (width - 2 * margin) / (self.max_x - self.min_x) if self.max_x > self.min_x else 1
                scale_y = (height - 2 * margin) / (self.max_y - self.min_y) if self.max_y > self.min_y else 1
                base_scale = min(scale_x, scale_y) * self.scale

                painter.translate(self.offset_x, self.offset_y)

                # Draw edges (gray)
                painter.setPen(QPen(QColor(200, 200, 200), 1))
                for edge in self.edges:
                    x1 = margin + (edge.node1.position.x - self.min_x) * base_scale
                    y1 = margin + (edge.node1.position.y - self.min_y) * base_scale
                    x2 = margin + (edge.node2.position.x - self.min_x) * base_scale
                    y2 = margin + (edge.node2.position.y - self.min_y) * base_scale
                    painter.drawLine(int(x1), int(y1), int(x2), int(y2))

                # Draw nodes by type
                for node in self.nodes:
                    node_type = getattr(node, 'type', None)
                    color = Qt.gray
                    if node_type == 'sidewalk_near_road':
                        color = Qt.green
                    elif node_type == 'sidewalk_middle':
                        color = Qt.yellow
                    elif node_type == 'sidewalk_far_road':
                        color = Qt.blue
                    elif node_type == 'crosswalk':
                        color = Qt.magenta
                    elif node_type == 'intersection':
                        color = Qt.red
                    painter.setPen(QPen(color, 6))
                    x = margin + (node.position.x - self.min_x) * base_scale
                    y = margin + (node.position.y - self.min_y) * base_scale
                    painter.drawPoint(int(x), int(y))

            def wheelEvent(self, event):
                angle = event.angleDelta().y()
                factor = 1.15 if angle > 0 else 0.85
                old_scale = self.scale
                self.scale *= factor
                mouse_pos = event.pos()
                dx = mouse_pos.x() - self.offset_x
                dy = mouse_pos.y() - self.offset_y
                self.offset_x -= dx * (self.scale - old_scale) / old_scale
                self.offset_y -= dy * (self.scale - old_scale) / old_scale
                self.update()

            def mousePressEvent(self, event):
                if event.button() == Qt.LeftButton:
                    self.last_mouse_pos = event.pos()

            def mouseMoveEvent(self, event):
                if self.last_mouse_pos is not None:
                    delta = event.pos() - self.last_mouse_pos
                    self.offset_x += delta.x()
                    self.offset_y += delta.y()
                    self.last_mouse_pos = event.pos()
                    self.update()

            def mouseReleaseEvent(self, event):
                if event.button() == Qt.LeftButton:
                    self.last_mouse_pos = None

        app = QApplication.instance() or QApplication(sys.argv)
        viewer = TypeViewer(list(self.nodes), list(self.edges))
        viewer.show()
        app.exec_()

    def visualize_route(self, route):
        """
        Visualize the route.
        """
        class RouteViewer(QWidget):
            def __init__(self, nodes, edges, route):
                super().__init__()
                self.nodes = nodes
                self.edges = edges
                self.route = route or []
                self.setMinimumSize(800, 800)
                self.setWindowTitle('Map Visualization by Route')
                self._set_bounds()
                self.scale = 1.0
                self.offset_x = 0
                self.offset_y = 0
                self.last_mouse_pos = None

            def _set_bounds(self):
                self.min_x = min(node.position.x for node in self.nodes)
                self.max_x = max(node.position.x for node in self.nodes)
                self.min_y = min(node.position.y for node in self.nodes)
                self.max_y = max(node.position.y for node in self.nodes)

            def paintEvent(self, event):
                painter = QPainter(self)
                painter.setRenderHint(QPainter.Antialiasing)
                width, height, margin = self.width(), self.height(), 50
                scale_x = (width - 2 * margin) / (self.max_x - self.min_x) if self.max_x > self.min_x else 1
                scale_y = (height - 2 * margin) / (self.max_y - self.min_y) if self.max_y > self.min_y else 1
                base_scale = min(scale_x, scale_y) * self.scale

                painter.translate(self.offset_x, self.offset_y)

                # Draw all edges (gray)
                painter.setPen(QPen(QColor(200, 200, 200), 1))
                for edge in self.edges:
                    x1 = margin + (edge.node1.position.x - self.min_x) * base_scale
                    y1 = margin + (edge.node1.position.y - self.min_y) * base_scale
                    x2 = margin + (edge.node2.position.x - self.min_x) * base_scale
                    y2 = margin + (edge.node2.position.y - self.min_y) * base_scale
                    painter.drawLine(int(x1), int(y1), int(x2), int(y2))

                # Draw route edges (highlighted color, e.g. orange-red)
                if self.route and len(self.route) > 1:
                    painter.setPen(QPen(QColor(255, 69, 0), 4))  # OrangeRed
                    for i in range(len(self.route) - 1):
                        n1, n2 = self.route[i], self.route[i+1]
                        x1 = margin + (n1.position.x - self.min_x) * base_scale
                        y1 = margin + (n1.position.y - self.min_y) * base_scale
                        x2 = margin + (n2.position.x - self.min_x) * base_scale
                        y2 = margin + (n2.position.y - self.min_y) * base_scale
                        painter.drawLine(int(x1), int(y1), int(x2), int(y2))

                # 1. Draw obstacles (orange), but skip route nodes
                route_set = set(self.route) if self.route else set()
                for node in self.nodes:
                    if getattr(node, "obstacle", False) and node not in route_set:
                        painter.setPen(QPen(QColor(255, 140, 0), 10))  # Orange
                        x = margin + (node.position.x - self.min_x) * base_scale
                        y = margin + (node.position.y - self.min_y) * base_scale
                        painter.drawPoint(int(x), int(y))

                # 2. Draw all non-route, non-obstacle nodes (gray)
                for node in self.nodes:
                    if node not in route_set and not getattr(node, "obstacle", False):
                        painter.setPen(QPen(QColor(180, 180, 180), 4))
                        x = margin + (node.position.x - self.min_x) * base_scale
                        y = margin + (node.position.y - self.min_y) * base_scale
                        painter.drawPoint(int(x), int(y))

                # 3. Draw route nodes: start (red), end (blue), middle (green)
                if self.route:
                    # 中间节点绿色
                    painter.setPen(QPen(QColor(0, 200, 0), 8))
                    for node in self.route[1:-1]:
                        x = margin + (node.position.x - self.min_x) * base_scale
                        y = margin + (node.position.y - self.min_y) * base_scale
                        painter.drawPoint(int(x), int(y))
                    # Start红色点
                    painter.setPen(QPen(QColor(255, 0, 0), 10))
                    x = margin + (self.route[0].position.x - self.min_x) * base_scale
                    y = margin + (self.route[0].position.y - self.min_y) * base_scale
                    painter.drawPoint(int(x), int(y))
                    # End蓝色点
                    painter.setPen(QPen(QColor(0, 0, 255), 10))
                    x = margin + (self.route[-1].position.x - self.min_x) * base_scale
                    y = margin + (self.route[-1].position.y - self.min_y) * base_scale
                    painter.drawPoint(int(x), int(y))

            def wheelEvent(self, event):
                angle = event.angleDelta().y()
                factor = 1.15 if angle > 0 else 0.85
                old_scale = self.scale
                self.scale *= factor
                mouse_pos = event.pos()
                dx = mouse_pos.x() - self.offset_x
                dy = mouse_pos.y() - self.offset_y
                self.offset_x -= dx * (self.scale - old_scale) / old_scale
                self.offset_y -= dy * (self.scale - old_scale) / old_scale
                self.update()

            def mousePressEvent(self, event):
                if event.button() == Qt.LeftButton:
                    self.last_mouse_pos = event.pos()

            def mouseMoveEvent(self, event):
                if self.last_mouse_pos is not None:
                    delta = event.pos() - self.last_mouse_pos
                    self.offset_x += delta.x()
                    self.offset_y += delta.y()
                    self.last_mouse_pos = event.pos()
                    self.update()

            def mouseReleaseEvent(self, event):
                if event.button() == Qt.LeftButton:
                    self.last_mouse_pos = None

        app = QApplication.instance() or QApplication(sys.argv)
        viewer = RouteViewer(list(self.nodes), list(self.edges), route)
        viewer.show()
        app.exec_()

    def find_node_pair_with_edge_distance(self, edge_count: int = 1, max_trials: int = 1000):
        """
        Return a pair of normal nodes,
        so that the shortest path between them has exactly edge_count edges.

        Args:
            edge_count: The number of edges in the shortest path.
            max_trials: The maximum number of trials to avoid infinite loop.

        Returns:
            (normal_node, target_node)
        """
        normal_nodes = [node for node in self.nodes if getattr(node, 'type', None) == 'normal']
        if not normal_nodes or len(self.nodes) < 2:
            return None, None

        for _ in range(max_trials):
            node_a = random.choice(normal_nodes)
            # Find a node_b that is distance edge_count away from node_a
            if edge_count == 2:
                node_list = self.get_normal_nodes()
            else:
                node_list = list(self.nodes)
            for node_b in node_list:
                if node_b is node_a:
                    continue
                route = self.get_route(node_a, node_b, avoid_obstacle=False)
                if route and len(route) == edge_count + 1:
                    if edge_count == 4:
                        found = False
                        for i in range(1, len(route) - 1):
                            n1, n2 = route[i], route[i+1]
                            if (getattr(n1, 'type', None) == 'intersection' and
                                getattr(n2, 'type', None) == 'intersection' and
                                len(self.adjacency_list.get(n1, [])) > 2 and
                                len(self.adjacency_list.get(n2, [])) > 2):
                                found = True
                                break
                        if found:
                            return node_a, node_b
                    else:
                        return node_a, node_b
        return None, None  # Not found

    def export_map_by_type(self, output_path: str, width: int = 800, height: int = 800):
        """
        Export the map visualization by node type to an image file.
        
        Args:
            output_path: Path to save the output image
            width: Width of the output image
            height: Height of the output image
        """
        # Create image and painter
        image = QImage(width, height, QImage.Format_RGB32)
        image.fill(Qt.white)
        painter = QPainter(image)
        painter.setRenderHint(QPainter.Antialiasing)
        
        # Calculate bounds
        min_x = min(node.position.x for node in self.nodes)
        max_x = max(node.position.x for node in self.nodes)
        min_y = min(node.position.y for node in self.nodes)
        max_y = max(node.position.y for node in self.nodes)
        
        # Calculate scale
        margin = 50
        scale_x = (width - 2 * margin) / (max_x - min_x) if max_x > min_x else 1
        scale_y = (height - 2 * margin) / (max_y - min_y) if max_y > min_y else 1
        base_scale = min(scale_x, scale_y)
        
        # Draw edges (gray)
        painter.setPen(QPen(QColor(200, 200, 200), 1))
        for edge in self.edges:
            x1 = margin + (edge.node1.position.x - min_x) * base_scale
            y1 = margin + (edge.node1.position.y - min_y) * base_scale
            x2 = margin + (edge.node2.position.x - min_x) * base_scale
            y2 = margin + (edge.node2.position.y - min_y) * base_scale
            painter.drawLine(int(x1), int(y1), int(x2), int(y2))
        
        # Draw nodes by type
        for node in self.nodes:
            node_type = getattr(node, 'type', None)
            color = Qt.gray
            if node_type == 'sidewalk_near_road':
                color = Qt.green
            elif node_type == 'sidewalk_middle':
                color = Qt.yellow
            elif node_type == 'sidewalk_far_road':
                color = Qt.blue
            elif node_type == 'crosswalk':
                color = Qt.magenta
            elif node_type == 'intersection':
                color = Qt.red
            painter.setPen(QPen(color, 6))
            x = margin + (node.position.x - min_x) * base_scale
            y = margin + (node.position.y - min_y) * base_scale
            painter.drawPoint(int(x), int(y))
        
        painter.end()
        image.save(output_path)

    def export_route(self, route, output_path: str, width: int = 800, height: int = 800):
        """
        Export the route visualization to an image file.
        
        Args:
            route: List of nodes representing the route
            output_path: Path to save the output image
            width: Width of the output image
            height: Height of the output image
        """
        # Create image and painter
        image = QImage(width, height, QImage.Format_RGB32)
        image.fill(Qt.white)
        painter = QPainter(image)
        painter.setRenderHint(QPainter.Antialiasing)
        
        # Calculate bounds
        min_x = min(node.position.x for node in self.nodes)
        max_x = max(node.position.x for node in self.nodes)
        min_y = min(node.position.y for node in self.nodes)
        max_y = max(node.position.y for node in self.nodes)
        
        # Calculate scale
        margin = 50
        scale_x = (width - 2 * margin) / (max_x - min_x) if max_x > min_x else 1
        scale_y = (height - 2 * margin) / (max_y - min_y) if max_y > min_y else 1
        base_scale = min(scale_x, scale_y)
        
        # Draw all edges (gray)
        painter.setPen(QPen(QColor(200, 200, 200), 1))
        for edge in self.edges:
            x1 = margin + (edge.node1.position.x - min_x) * base_scale
            y1 = margin + (edge.node1.position.y - min_y) * base_scale
            x2 = margin + (edge.node2.position.x - min_x) * base_scale
            y2 = margin + (edge.node2.position.y - min_y) * base_scale
            painter.drawLine(int(x1), int(y1), int(x2), int(y2))
        
        # Draw route edges (highlighted color)
        if route and len(route) > 1:
            painter.setPen(QPen(QColor(255, 69, 0), 4))  # OrangeRed
            for i in range(len(route) - 1):
                n1, n2 = route[i], route[i+1]
                x1 = margin + (n1.position.x - min_x) * base_scale
                y1 = margin + (n1.position.y - min_y) * base_scale
                x2 = margin + (n2.position.x - min_x) * base_scale
                y2 = margin + (n2.position.y - min_y) * base_scale
                painter.drawLine(int(x1), int(y1), int(x2), int(y2))
        
        # Draw obstacles (orange)
        route_set = set(route) if route else set()
        for node in self.nodes:
            if getattr(node, "obstacle", False) and node not in route_set:
                painter.setPen(QPen(QColor(255, 140, 0), 10))  # Orange
                x = margin + (node.position.x - min_x) * base_scale
                y = margin + (node.position.y - min_y) * base_scale
                painter.drawPoint(int(x), int(y))
        
        # Draw non-route, non-obstacle nodes (gray)
        for node in self.nodes:
            if node not in route_set and not getattr(node, "obstacle", False):
                painter.setPen(QPen(QColor(180, 180, 180), 4))
                x = margin + (node.position.x - min_x) * base_scale
                y = margin + (node.position.y - min_y) * base_scale
                painter.drawPoint(int(x), int(y))
        
        # Draw route nodes
        if route:
            # Middle nodes (green)
            painter.setPen(QPen(QColor(0, 200, 0), 8))
            for node in route[1:-1]:
                x = margin + (node.position.x - min_x) * base_scale
                y = margin + (node.position.y - min_y) * base_scale
                painter.drawPoint(int(x), int(y))
            
            # Start node (red)
            painter.setPen(QPen(QColor(255, 0, 0), 10))
            x = margin + (route[0].position.x - min_x) * base_scale
            y = margin + (route[0].position.y - min_y) * base_scale
            painter.drawPoint(int(x), int(y))
            
            # End node (blue)
            painter.setPen(QPen(QColor(0, 0, 255), 10))
            x = margin + (route[-1].position.x - min_x) * base_scale
            y = margin + (route[-1].position.y - min_y) * base_scale
            painter.drawPoint(int(x), int(y))
        
        painter.end()
        image.save(output_path)
