from Base.Types import Vector
from Config.config import Config

import matplotlib.pyplot as plt
from Base.Map import Map
from Base.Types import Node
from typing import Optional


def estimated_delivery_time(store_position: Vector, customer_position: Vector):
    distance = (store_position - customer_position).length()
    return (distance / Config.DELIVERY_MAN_MIN_SPEED) * 3

def estimated_delivery_step(map_obj: Map, store_position: Vector, customer_position: Vector):
    manhattan_distance = map_obj.get_edge_distance_between_two_points(Node(store_position), Node(customer_position))
    return manhattan_distance

def visualize_map(map_obj: Map, save_path: Optional[str] = None):
    """
    visualize the map object
    
    Args:
        map_obj: Map object
        save_path: Optional, the path to save the image. If None, the image will be displayed
    """
    # create a new figure
    plt.figure(figsize=(12, 8))
    
    # draw edges
    for edge in map_obj.edges:
        x_coords = [edge.node1.position.x, edge.node2.position.x]
        y_coords = [edge.node1.position.y, edge.node2.position.y]
        plt.plot(x_coords, y_coords, 'gray', alpha=0.5, linewidth=1)
    
    # collect different types of nodes
    normal_nodes = []
    intersection_nodes = []
    supply_nodes = []
    
    for node in map_obj.nodes:
        if node.type == "normal":
            normal_nodes.append(node)
        elif node.type == "intersection":
            intersection_nodes.append(node)
        elif node.type == "supply":
            supply_nodes.append(node)
    
    # draw different types of nodes
    if normal_nodes:
        x_coords = [node.position.x for node in normal_nodes]
        y_coords = [node.position.y for node in normal_nodes]
        plt.scatter(x_coords, y_coords, c='blue', s=30, label='Normal Nodes')
    
    if intersection_nodes:
        x_coords = [node.position.x for node in intersection_nodes]
        y_coords = [node.position.y for node in intersection_nodes]
        plt.scatter(x_coords, y_coords, c='red', s=50, label='Intersections')
    
    if supply_nodes:
        x_coords = [node.position.x for node in supply_nodes]
        y_coords = [node.position.y for node in supply_nodes]
        plt.scatter(x_coords, y_coords, c='green', s=80, marker='*', label='Supply Points')
    
    # set the properties of the figure
    plt.title('Map Visualization')
    plt.xlabel('X Coordinate')
    plt.ylabel('Y Coordinate')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # keep the aspect ratio
    plt.axis('equal')
    
    # save or display the figure
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()