# Occupancy functions

from agentdriver.utils.geometry import location_to_pixel_coordinate, pixel_coordinate_to_location, CAR_LENGTH, CAR_WIDTH, MAP_METER, GRID_SIZE
import numpy as np
from skimage.draw import polygon

OCC_TH = 0.1

get_occupancy_at_locations_for_timestep_info = {
    "name": "get_occupancy_at_locations_for_timestep",
    "description": "Get the probability whether a list of locations [(x_1, y_1), ..., (x_n, y_n)] is occupied at the timestep t. If the location is out of the occupancy prediction scope, return None",
    "parameters": {
        "type": "object",
        "properties": {
            "locations": {
                "type": "array",
                "items": {
                    "type": "array",
                    "items": {
                        "type": "number"
                    },
                    "minItems": 2,
                    "maxItems": 2
                },
                "description": "occupancy at the locations [(x_1, y_1), ..., (x_n, y_n)]",
            },
            "timestep": {
                "type": "integer",
                "minimum": 0,
                "maximum": 4,
                "multipleOf" : 1,
                "description": "time step t in the occupancy flow, must be one of [0, 1, 2, 3, 4], which denotes the future occupancy at [0s, 0.5s, 1s, 1.5s, 2s].",
            },
        },
        "required": ["locations", "timestep"],
    },
}

def get_occupancy_at_locations_for_timestep(locations, timestep, data_dict):
    occ_map = data_dict["occupancy"].transpose(0, 2, 1)
    occ_list = []
    prompts = "Occupancy information:\n"

    for location in locations:
        x, y = location 
        X, Y, valid = location_to_pixel_coordinate(x, y)
        T = timestep
        # deal with exceptions
        if not valid or T < 0 or T >= 5:
            prompts = None
            return prompts, False

        occ = occ_map[T, X, Y]
        if occ > OCC_TH:
            prompts += f"Location ({x:.2f}, {y:.2f}) is occupied at timestep {timestep}\n"
        else:
            prompts += f"Location ({x:.2f}, {y:.2f}) is not occupied at timestep {timestep}\n"
        occ_list.append(occ)
    return prompts, occ_list


check_occupancy_for_planned_trajectory_info = {
    "name": "check_occupancy_for_planned_trajectory",
    "description": "Evaluate whether the planned trajectory [(x_1, y_1), ..., (x_n, y_n)] collides with other objects.",
    "parameters": {
        "type": "object",
        "properties": {
            "trajectory": {
                "type": "array",
                "items": {
                    "type": "array",
                    "items": {
                        "type": "number"
                    },
                    "minItems": 2,
                    "maxItems": 2
                },
                "minItems": 6,
                "maxItems": 7,
                "description": "the planned trajectory [(x_1, y_1), ..., (x_n, y_n)]",
            },
        },
        "required": ["trajectory"],
    },
}

def check_occupancy_for_planned_trajectory(trajectory, data_dict):
    occ_map = data_dict["occupancy"].transpose(0, 2, 1)
    prompts = "Check collision of the planned trajectory:\n"

    collision = False
    for timestep, location in enumerate(trajectory):
        x, y = location 
        X, Y, valid = location_to_pixel_coordinate(x, y)
        if not valid: # trajectory out of range
            continue
        T = timestep + 1 # We assume the time step starting from 1
        if T >= 5:
            continue

        occ = occ_map[T, X, Y]
        if occ > OCC_TH:
            prompts += f"Waypoint ({x:.2f}, {y:.2f}) collides at timestep {T}\n"
            collision = True
        else:
            continue
    if not collision:
        prompts += f"The planned trajectory does not collide with any other objects.\n"
    return prompts, collision

def check_occupancy_for_planned_trajectory_and_surrounding(trajectory, data_dict):
    occ_map = data_dict["occupancy"].transpose(0, 2, 1)
    prompts = "Check collision of the planned trajectory:\n"

    collision = False
    for timestep, location in enumerate(trajectory):
        x, y = location 
        X, Y, valid = location_to_pixel_coordinate(x, y)
        if not valid: # trajectory out of range
            continue
        T = timestep + 1 # We assume the time step starting from 1
        if T >= 5:
            continue

        occ = occ_map[T, X, Y]
        if occ > OCC_TH:
            prompts += f"Waypoint ({x:.2f}, {y:.2f}) collides at timestep {T}\n"

            # check surrounding
            surrounding = occ_map[T, X-1:X+2, Y-1:Y+2]
            if True in surrounding:
                index_x, index_y = np.where(surrounding)
                index_X, index_Y = index_x + X - 1, index_y + Y - 1
                prompts += f"- Surrounding not occupied region: {[(pixel_coordinate_to_location(x, y)[:-1]) for x, y in zip(index_X, index_Y)]}\n"
            collision = True
        else:
            continue
    if not collision:
        prompts += f"The planned trajectory does not collide with any other objects.\n"
    return prompts, collision

def check_collision(car_length, car_width, trajectory, occ_map):
        pts = np.array([
                [-car_length / 2. + 0.5, car_width / 2.],
                [car_length / 2. + 0.5, car_width / 2.],
                [car_length / 2. + 0.5, -car_width / 2.],
                [-car_length / 2. + 0.5, -car_width / 2.],
            ])    

        pts = (pts - (- MAP_METER)) / GRID_SIZE
        pts[:, [0, 1]] = pts[:, [1, 0]]

        rr, cc = polygon(pts[:,1], pts[:,0])
        rc = np.concatenate([rr[:,None], cc[:,None]], axis=-1) # all points inside the box (car)

        n_future = occ_map.shape[0] # trajectory.shape[0]   since we only have 4 future occupancy

        trajectory = trajectory * np.array([-1, 1])
        trajectory = trajectory[:, np.newaxis, :] # (n_future, 1, 2)

        trajectory[:,:,[0,1]] = trajectory[:,:,[1,0]]
        trajectory = trajectory / GRID_SIZE
        trajectory = trajectory + rc # (n_future, 32, 2) # all points during the trajectory

        r = trajectory[:,:,0].astype(np.int32) # (n_future, 32) decompose the points into row
        r = np.clip(r, 0, occ_map.shape[1] - 1)

        c = trajectory[:,:,1].astype(np.int32) # (n_future, 32) decompose the points into column
        c = np.clip(c, 0, occ_map.shape[2] - 1)

        collision = np.full(trajectory.shape[0], False) # we set the length of collision same as the length of trajectory though we only check 4 timesteps
        for t in range(n_future):
            rr = r[t]
            cc = c[t]
            I = np.logical_and(
                np.logical_and(rr >= 0, rr < occ_map.shape[1]),
                np.logical_and(cc >= 0, cc < occ_map.shape[2]),
            )
            collision[t] = np.any(occ_map[t, rr[I], cc[I]] > OCC_TH)

        return collision

def check_occupancy_for_planned_trajectory_correct(trajectory, data_dict, safe_margin=1., token=None):
    '''
    trajs: torch.Tensor (B, n_future, 2)
    segmentation: torch.Tensor (B, n_future, 200, 200)
    '''
    occ_map = data_dict["occupancy"]

    occ_map = np.fliplr(occ_map.transpose(1,2,0)).transpose(2,0,1)
    occ_map = occ_map[1:] # remove the current timestep
    if occ_map.shape[0] == 4: # if we only have 4 future occupancy
        # New shape
        new_shape = (6, 200, 200)

        # Initialize the new array with the new shape
        expanded_array = np.zeros(new_shape)

        # Copy the original data
        expanded_array[:4] = occ_map

        # Assume that the conditions in the last second continue
        expanded_array[4] = occ_map[-1]
        expanded_array[5] = occ_map[-1]
        occ_map = expanded_array
    
    collision_t = check_collision(CAR_LENGTH+safe_margin, CAR_WIDTH+safe_margin, trajectory, occ_map)

    return collision_t