# Utils for collision optimization
import numpy as np
from agentdriver.utils.geometry import BEV_H, BEV_W, rotate_bbox, location_to_pixel_coordinate
import cv2

debug = True
def det2occ(data_dict):
    # n_future= 6
    # obj_maps = [[] for _ in range(n_future)]   
    # for obj in objects:
    #     x, y, z, dx, dy, dz, rotation_z, rotation_y, rotation_x = obj["bbox"]
    #     cx, cy = x, y

    #     rotated_corners = rotate_bbox(0, 0, dx, dy, rotation_z)
        

    #     for ts, pt in enumerate(obj["traj"][:n_future]):
    #         cx, cy = pt[0], pt[1]
    #         agent_final_corners = [(cx + x_prime, cy + y_prime) for x_prime, y_prime in rotated_corners]
    #         agent_final_corners = np.array(agent_final_corners) * 4.

    #         pseudo_corners = agent_final_corners - agent_final_corners.min(axis=0).round()

    #         rr, cc = polygon(pseudo_corners[:,1], pseudo_corners[:,0])
    #         rc = np.concatenate([rr[:,None], cc[:,None]], axis=-1)
    #         rc = rc + agent_final_corners.min(axis=0)[::-1].round()
    #         obj_maps[ts].append(rc / 4.)
    objects = data_dict["objects"]
    

    n_future= 6
    segmentation = np.zeros((n_future, BEV_H, BEV_W))
    poly_region_per_ts = [[] for _ in range(n_future)]
    for obj in objects:
        x, y, z, dx, dy, dz, rotation_z, rotation_y, rotation_x = obj["bbox"]
        cx, cy = x, y

        rotated_corners = rotate_bbox(0, 0, dx, dy, rotation_z)
        
        for ts, pt in enumerate(obj["traj"][:n_future]):
            cx, cy = pt[0], pt[1]
            agent_final_corners = [(cx + x_prime, cy + y_prime) for x_prime, y_prime in rotated_corners]
            agent_final_corners = np.array(agent_final_corners)
            for i in range(len(agent_final_corners)):
                agent_final_corners[i, 0], agent_final_corners[i, 1], _ = location_to_pixel_coordinate(agent_final_corners[i, 0], agent_final_corners[i, 1])
            poly_region_per_ts[ts].append(agent_final_corners)

    for ts in range(n_future):
        for poly in poly_region_per_ts[ts]:
            cv2.fillPoly(segmentation[ts], [poly.astype(np.int32).reshape((-1, 1, 2))], 1.0)

        if debug:
            occ = data_dict["occupancy"]
            gt_occ = data_dict["gt_occ"]
            import matplotlib.pyplot as plt
            rgb_image = np.zeros((*segmentation[ts].shape, 3), dtype=np.uint8)

            rgb_image[occ[min(ts+1, occ.shape[0]-1)]>0.1, 0] = 255
            rgb_image[segmentation[ts]==True, 1] = 255
            # rgb_image[gt_occ[0][min(ts+1, gt_occ.shape[0])]==True, 2] = 255 
            plt.figure()
            plt.imshow(rgb_image)
            # plt.imshow(occ_map[ts], cmap='gray')
            plt.savefig(f'vis_debug/occ_compare_occ&det{ts}.png')
            plt.close()

            rgb_image = np.zeros((*segmentation[ts].shape, 3), dtype=np.uint8)
            rgb_image[segmentation[ts]==True, 1] = 255
            rgb_image[gt_occ[0][min(ts+1, gt_occ.shape[0]-1)]==True, 2] = 255 
            plt.figure()
            plt.imshow(rgb_image)
            # plt.imshow(occ_map[ts], cmap='gray')
            plt.savefig(f'vis_debug/occ_compare_gt&det{ts}.png')
            plt.close()            

    segmentation = np.concatenate((segmentation[0:1], segmentation), axis=0)

    return segmentation