# Map functions

from agentdriver.utils.geometry import location_to_pixel_coordinate, pixel_coordinate_to_location, GRID_SIZE
import numpy as np

LANE_CATEGORYS = ['divider', 'ped_crossing', 'boundary']

get_drivable_at_locations_info = {
    "name": "get_drivable_at_locations",
    "description": "Get the drivability at the locations [(x_1, y_1), ..., (x_n, y_n)]. If the location is out of the map scope, return None",
    "parameters": {
        "type": "object",
        "properties":  {
            "locations": {
                "type": "array",
                "items": {
                    "type": "array",
                    "items": {
                        "type": "number"
                    },
                    "minItems": 2,
                    "maxItems": 2
                },
                "description": "the locations [(x_1, y_1), ..., (x_n, y_n)] to be queried"
            },
        },
        "required": ["locations"]
    }
}

def get_drivable_at_locations(locations, data_dict):
    drivable_map = data_dict["map"]["drivable"].T
    prompts = "Drivability of selected locations:\n"
    drivable = []
    for x, y in locations:
        X, Y, valid = location_to_pixel_coordinate(x, y)
        if not valid:
            drivable.append(True)
            continue
        else:
            if drivable_map[X, Y]:
                prompts += f"Location ({x:.2f}, {y:.2f}) is drivable\n"
            else:
                prompts += f"Location ({x:.2f}, {y:.2f}) is not drivable\n"
            drivable.append(drivable_map[X, Y])
    return prompts, drivable

check_drivable_of_planned_trajectory_info = {
    "name": "check_drivable_of_planned_trajectory",
    "description": "Check the drivability at the planned trajectory",
    "parameters": {
        "type": "object",
        "properties":  {
            "trajectory": {
                "type": "array",
                "items": {
                    "type": "array",
                    "items": {
                        "type": "number"
                    },
                    "minItems": 2,
                    "maxItems": 2
                },
                "minItems": 6,
                "maxItems": 7,
                "description": "the planned trajectory [(x_1, y_1), ..., (x_n, y_n)] to be queried",
            },
        },
        "required": ["trajectory"],
    },
}

def check_drivable_of_planned_trajectory(trajectory, data_dict):
    drivable_map = data_dict["map"]["drivable"].T
    prompts = "Drivability of the planned trajectory:\n"
    drivable = []
    all_drivable = True
    for timestep, waypoint in enumerate(trajectory):
        x, y = waypoint
        T = timestep + 1 # assume time step starting from 1
        X, Y, valid = location_to_pixel_coordinate(x, y)
        if not valid:
            drivable.append(True)
        else:
            if not drivable_map[X, Y]:
                prompts += f"Waypoint ({x:.2f}, {y:.2f}) is not drivable at time step {T}\n"
                all_drivable = False
            drivable.append(drivable_map[X, Y])
    if all_drivable:
        prompts += f"All waypoints of the planned trajectory are in drivable regions\n"
    return prompts, drivable

def check_drivable_of_planned_trajectory_and_surrounding(trajectory, data_dict):
    drivable_map = data_dict["map"]["drivable"].T
    prompts = "Drivability of the planned trajectory:\n"
    drivable = []
    all_drivable = True
    for timestep, waypoint in enumerate(trajectory):
        x, y = waypoint
        T = timestep + 1 # assume time step starting from 1
        X, Y, valid = location_to_pixel_coordinate(x, y)
        if not valid:
            drivable.append(True)
        else:
            if not drivable_map[X, Y]:
                prompts += f"Waypoint ({x:.2f}, {y:.2f}) is not drivable at time step {T}\n"

                # check surrounding
                surrounding = drivable_map[X-1:X+2, Y-1:Y+2]
                if True in surrounding:
                    index_x, index_y = np.where(surrounding)
                    index_X, index_Y = index_x + X - 1, index_y + Y - 1
                    prompts += f"- Surrounding drivable regions: {[(pixel_coordinate_to_location(x, y)[:-1]) for x, y in zip(index_X, index_Y)]}\n"
                all_drivable = False
            drivable.append(drivable_map[X, Y])
    if all_drivable:
        prompts += f"All waypoints of the planned trajectory are in drivable regions\n"
    return prompts, drivable

get_lane_category_at_locations_info = {
    "name": "get_lane_category_at_locations",
    "description": "Get the lane category at the locations [(x_1, y_1), ..., (x_n, y_n)]. If the location is out of the map scope, return None",
    "parameters": {
        "type": "object",
        "properties":  {
            "locations": {
                "type": "array",
                "items": {
                    "type": "array",
                    "items": {
                        "type": "number"
                    },
                    "minItems": 2,
                    "maxItems": 2
                },
                "description": "the locations [(x_1, y_1), ..., (x_n, y_n)] to be queried",
            },
            "return_score": {
                "type": "boolean",
                "description": "whether to return the probability score of the lane category",
            },
        },
        "required": ["locations", "return_score"],
    },
}

def get_lane_category_at_locations(locations, data_dict, return_score=True):
    lane_map = data_dict["map"]["lane"].transpose(0, 2, 1)
    lane_score_map = data_dict["map"]["lane_probs"].transpose(0, 2, 1)
    prompts = "Lane category of selected locations:\n"
    lane_category = []
    for x, y in locations:
        X, Y, valid = location_to_pixel_coordinate(x, y)
        if not valid:
            lane_category.append(None)
            continue
        else:
            lane_category.append(lane_map[:, X, Y])
            cat_index = np.where(lane_map[:, X, Y])[0]
            if len(cat_index) == 0:
                prompts += f"Location ({x:.2f}, {y:.2f}) has no lane category\n"
            else:
                cat_prompt = ', '.join(LANE_CATEGORYS[i] for i in cat_index)
                score_prompt = ', '.join(f"{lane_score_map[i, X, Y]:.2f}" for i in cat_index)
                if return_score:
                    prompts += f"Location ({x:.2f}, {y:.2f}) has lane category {cat_prompt} with probability score {score_prompt}\n"
                else:
                    prompts += f"Location ({x:.2f}, {y:.2f}) has lane category {cat_prompt}\n"
    return prompts, lane_category

get_distance_to_shoulder_at_locations_info = {
    "name": "get_distance_to_shoulder_at_locations",
    "description": "Get the distance to both sides of road shoulders at the locations [(x_1, y_1), ..., (x_n, y_n)]. If the location is out of the map scope, return None",
    "parameters": {
        "type": "object",
        "properties":  {
            "locations": {
                "type": "array",
                "items": {
                    "type": "array",
                    "items": {
                        "type": "number"
                    },
                    "minItems": 2,
                    "maxItems": 2
                },
                "description": "the locations [(x_1, y_1), ..., (x_n, y_n)] to be queried",
            },
        },
        "required": ["locations"],
    },
}

def get_distance_to_shoulder_at_locations(locations, data_dict):
    boundary_map = data_dict["map"]["lane"][2].T
    prompts = "Distance to both sides of road shoulders of selected locations:\n"
    distance_to_shoulder = []
    for x, y in locations:
        X, Y, valid = location_to_pixel_coordinate(x, y)
        if not valid:
            distance_to_shoulder.append(None)
            continue
        else:
            Y_max = Y+5 if Y+5 < boundary_map.shape[1] else boundary_map.shape[1] - 1
            Y_min = Y-5 if Y-5 >= 0 else 0
            # find left nearest boundary
            ind_x = np.where(boundary_map[:X, Y_min:Y_max])[0]
            if len(ind_x) == 0:
                left_shoulder = None
            else:
                left_shoulder = (X - np.max(ind_x)) * GRID_SIZE
            # find right nearest boundary
            ind_x = np.where(boundary_map[X:, Y_min : Y_max])[0] + X
            if len(ind_x) == 0:
                right_shoulder = None
            else:
                right_shoulder = (np.min(ind_x) - X) * GRID_SIZE
            distance_to_shoulder.append((left_shoulder, right_shoulder))
            if left_shoulder is not None and right_shoulder is not None:
                prompts += f"Location ({x:.2f}, {y:.2f}) distance to left shoulder is {left_shoulder}m and right shoulder is {right_shoulder}m\n"
            elif left_shoulder is None and right_shoulder is None:
                prompts += f"Location ({x:.2f}, {y:.2f}) distance to shoulders are uncertain\n"
            elif left_shoulder is None and right_shoulder is not None:
                prompts += f"Location ({x:.2f}, {y:.2f}) distance to left shoulder is uncertain and distance to right shoulder is {right_shoulder}m\n"
            elif left_shoulder is not None and right_shoulder is None:
                prompts += f"Location ({x:.2f}, {y:.2f}) distance to left shoulder is {left_shoulder}m and distance to right shoulder is uncertain\n"
            else:
                raise Exception("Should not reach here")
    return prompts, distance_to_shoulder


get_current_shoulder_info = {
    "name": "get_current_shoulder",
    "description": "Get the distance to both sides of road shoulders for the current ego-vehicle location.",
    "parameters": {
      "type": "object",
      "properties": {},
      "required": [],
    },
}

def get_current_shoulder(data_dict):
    boundary_map = data_dict["map"]["lane"][2].T
    prompts = "Distance to both sides of road shoulders of current ego-vehicle location:\n"
    distance_to_shoulder = []
    x, y = 0.0, 0.0
    X, Y, valid = location_to_pixel_coordinate(x, y)
    if not valid:
        distance_to_shoulder.append(None)
        prompts = None
    else:
        Y_max = Y+5 if Y+5 < boundary_map.shape[1] else boundary_map.shape[1] - 1
        Y_min = Y-5 if Y-5 >= 0 else 0
        # find left nearest boundary
        ind_x = np.where(boundary_map[:X, Y_min:Y_max])[0]
        if len(ind_x) == 0:
            left_shoulder = None
        else:
            left_shoulder = (X - np.max(ind_x)) * GRID_SIZE
        # find right nearest boundary
        ind_x = np.where(boundary_map[X:, Y_min : Y_max])[0] + X
        if len(ind_x) == 0:
            right_shoulder = None
        else:
            right_shoulder = (np.min(ind_x) - X) * GRID_SIZE
        distance_to_shoulder.append((left_shoulder, right_shoulder))
        if left_shoulder is not None and right_shoulder is not None:
            prompts += f"Current ego-vehicle's distance to left shoulder is {left_shoulder}m and right shoulder is {right_shoulder}m\n"
        elif left_shoulder is None and right_shoulder is None:
            prompts += f"Current ego-vehicle's distance to shoulders are uncertain\n"
        elif left_shoulder is None and right_shoulder is not None:
            prompts += f"Current ego-vehicle's distance to left shoulder is uncertain and distance to right shoulder is {right_shoulder}m\n"
        elif left_shoulder is not None and right_shoulder is None:
            prompts += f"Current ego-vehicle's distance to left shoulder is {left_shoulder}m and distance to right shoulder is uncertain\n"
        else:
            raise Exception("Should not reach here")
    return prompts, distance_to_shoulder


# get_current_center_line_info = {
#     "name": "get_current_center_line",
#     "description": "Get the current center line that the ego-vehicle is driving on. If there is no such lane, return None",
#     "parameters": {
#     },
# }

get_distance_to_lane_divider_at_locations_info = {
    "name": "get_distance_to_lane_divider_at_locations",
    "description": "Get the distance to both sides of road lane_dividers at the locations [(x_1, y_1), ..., (x_n, y_n)]. If the location is out of the map scope, return None",
    "parameters": {
        "type": "object",
        "properties":  {
            "locations": {
                "type": "array",
                "items": {
                    "type": "array",
                    "items": {
                        "type": "number"
                    },
                    "minItems": 2,
                    "maxItems": 2
                },
                "description": "the locations [(x_1, y_1), ..., (x_n, y_n)] to be queried",
            },
        },
        "required": ["locations"],
    },
}

def get_distance_to_lane_divider_at_locations(locations, data_dict):
    boundary_map = data_dict["map"]["lane"][0].T
    prompts = "Get distance to both sides of road lane_dividers of selected locations:\n"
    distance_to_lane_divider = []
    for x, y in locations:
        X, Y, valid = location_to_pixel_coordinate(x, y)
        if not valid:
            distance_to_lane_divider.append(None)
            continue
        else:
            Y_max = Y+5 if Y+5 < boundary_map.shape[1] else boundary_map.shape[1] - 1
            Y_min = Y-5 if Y-5 >= 0 else 0
            # find left nearest lane divider
            ind_x = np.where(boundary_map[:X, Y_min:Y_max])[0]
            if len(ind_x) == 0:
                left_lane_divider = None
            else:
                left_lane_divider = (X - np.max(ind_x)) * GRID_SIZE
            # find right nearest lane divider
            ind_x = np.where(boundary_map[X:, Y_min : Y_max])[0] + X
            if len(ind_x) == 0:
                right_lane_divider = None
            else:
                right_lane_divider = (np.min(ind_x) - X) * GRID_SIZE
            distance_to_lane_divider.append((left_lane_divider, right_lane_divider))
            if left_lane_divider is not None and right_lane_divider is not None:
                prompts += f"Location ({x:.2f}, {y:.2f}) distance to left lane_divider is {left_lane_divider}m and right lane_divider is {right_lane_divider}m\n"
            elif left_lane_divider is None and right_lane_divider is None:
                prompts += f"Location ({x:.2f}, {y:.2f}) distance to lane_dividers are uncertain\n"
            elif left_lane_divider is None and right_lane_divider is not None:
                prompts += f"Location ({x:.2f}, {y:.2f}) distance to left lane_divider is uncertain and distance to right lane_divider is {right_lane_divider}m\n"
            elif left_lane_divider is not None and right_lane_divider is None:
                prompts += f"Location ({x:.2f}, {y:.2f}) distance to left lane_divider is {left_lane_divider}m and distance to right lane_divider is uncertain\n"
            else:
                raise Exception("Should not reach here")
    return prompts, distance_to_lane_divider

get_current_lane_divider_info = {
    "name": "get_current_lane_divider",
    "description": "Get the distance to both sides of road lane_dividers for the current ego-vehicle location",
    "parameters": {
      "type": "object",
      "properties": {},
      "required": [],
    },
}

def get_current_lane_divider(data_dict):
    boundary_map = data_dict["map"]["lane"][0].T
    prompts = "Get distance to both sides of road lane_dividers of current ego-vehicle location:\n"
    distance_to_lane_divider = []
    x, y = 0.0, 0.0
    X, Y, valid = location_to_pixel_coordinate(x, y)
    if not valid:
        distance_to_lane_divider.append(None)
        prompts = None
    else:
        Y_max = Y+5 if Y+5 < boundary_map.shape[1] else boundary_map.shape[1] - 1
        Y_min = Y-5 if Y-5 >= 0 else 0
        # find left nearest lane divider
        ind_x = np.where(boundary_map[:X, Y_min:Y_max])[0]
        if len(ind_x) == 0:
            left_lane_divider = None
        else:
            left_lane_divider = (X - np.max(ind_x)) * GRID_SIZE
        # find right nearest lane divider
        ind_x = np.where(boundary_map[X:, Y_min : Y_max])[0] + X
        if len(ind_x) == 0:
            right_lane_divider = None
        else:
            right_lane_divider = (np.min(ind_x) - X) * GRID_SIZE
        distance_to_lane_divider.append((left_lane_divider, right_lane_divider))
        if left_lane_divider is not None and right_lane_divider is not None:
            prompts += f"Current ego-vehicle's distance to left lane_divider is {left_lane_divider}m and distance to right lane_divider is {right_lane_divider}m\n"
        elif left_lane_divider is None and right_lane_divider is None:
            prompts += f"Current ego-vehicle's distance to both lane_dividers are uncertain\n"
        elif left_lane_divider is None and right_lane_divider is not None:
            prompts += f"Current ego-vehicle's distance to left lane_divider is uncertain and distance to right lane_divider is {right_lane_divider}m\n"
        elif left_lane_divider is not None and right_lane_divider is None:
            prompts += f"Current ego-vehicle's distance to left lane_divider is {left_lane_divider}m and distance to right lane_divider is uncertain\n"
        else:
            raise Exception("Should not reach here")
    return prompts, distance_to_lane_divider

get_nearest_pedestrian_crossing_info = {
    "name": "get_nearest_pedestrian_crossing",
    "description": "Get the location of the nearest pedestrian crossing to the ego-vehicle. If there is no such pedestrian crossing, return None",
    "parameters": {
      "type": "object",
      "properties": {},
      "required": [],
    },
}

def get_nearest_pedestrian_crossing(data_dict):
    boundary_map = data_dict["map"]["lane"][1].T
    prompts = "Get the nearest pedestrian crossing location:\n"
    distance_to_nearest_pedestrian_crossing = []
    X, Y, valid = location_to_pixel_coordinate(0.0, 0.0)
    if not valid:
        prompts = None
        return prompts, distance_to_nearest_pedestrian_crossing
    else:
        ind_X, ind_Y = np.where(boundary_map[:, Y:]) # Plz double check this
        ind_Y += Y
        if len(ind_X) == 0:
            prompts = None
            return prompts, distance_to_nearest_pedestrian_crossing
        else:
            dist = np.abs(ind_X - X) ** 2 + np.abs(ind_Y - Y) ** 2
            ind = np.argmin(dist)
            min_ped_crossing_X, min_ped_crossing_Y = ind_X[ind], ind_Y[ind]
            min_ped_crossing_x, min_ped_crossing_y, _ = pixel_coordinate_to_location(min_ped_crossing_X, min_ped_crossing_Y)
            prompts += f"The nearest pedestrian crossing is at ({min_ped_crossing_x:.2f}, {min_ped_crossing_y:.2f})\n"
            distance_to_nearest_pedestrian_crossing.append((min_ped_crossing_x, min_ped_crossing_y))
    return prompts, distance_to_nearest_pedestrian_crossing