from matplotlib import pyplot as plt
import numpy as np

import opencood.visualization.simple_plot3d.canvas_3d as canvas_3d
import opencood.visualization.simple_plot3d.canvas_bev as canvas_bev

def visualize(pred_box_tensor, gt_tensor, pcd, pc_range, save_path, method='3d', vis_gt_box=True, vis_pred_box=True, 
              left_hand=False, uncertainty=None, pcd_score=None, cluster_pcd=None):
    """
    Visualize the prediction, ground truth with point cloud together.
    They may be flipped in y axis. Since carla is left hand coordinate, while kitti is right hand.

    Parameters
    ----------
    pred_box_tensor : torch.Tensor
        (N, 8, 3) prediction.

    gt_tensor : torch.Tensor
        (N, 8, 3) groundtruth bbx

    pcd : torch.Tensor
        PointCloud, (N, 4).

    pc_range : list
        [xmin, ymin, zmin, xmax, ymax, zmax]

    save_path : str
        Save the visualization results to given path.

    dataset : BaseDataset
        opencood dataset object.

    method: str, 'bev' or '3d'

    """

    pc_range = [int(i) for i in pc_range]
    if isinstance(pcd, list):
        pcd_np = [x.cpu().numpy() for x in pcd]
        if (len(pcd) > 2):
            pcd_colors = [(255, 0, 0), (255, 255, 255), (0, 0, 255),
                          (136, 72, 152), (191, 121, 78), (235, 216, 66),
                          (0, 149, 217), (108, 53, 36), (118, 145, 100)]    # 红 紫 蓝是车的点
            # pcd_colors = [(63, 149, 211), (255, 255, 255), 
            #               (136, 72, 152), (191, 121, 78), (235, 216, 66),(255, 0, 0),
            #               (0, 149, 217), (108, 53, 36), (118, 145, 100)]    # 红 紫 蓝是车的点
        else:
            pcd_colors = [(255, 255, 255),  (191, 121, 78)]
    else:
        pcd_np = pcd.cpu().numpy()
        pcd_colors = None
        
    if cluster_pcd is not None:
        if isinstance(cluster_pcd, list):
            cluster_pcd_np = [x.cpu().numpy() for x in cluster_pcd]
            cluster_pcd_colors = [(231, 230, 33),  (229, 171, 190), (0, 164, 151)]
        else:
            cluster_pcd_np = cluster_pcd.cpu().numpy()
            cluster_pcd_colors = None

    if vis_pred_box:
        pred_box_np = pred_box_tensor.cpu().numpy()
        # pred_name = ['pred'] * pred_box_np.shape[0]
        pred_name = [''] * pred_box_np.shape[0]
        if uncertainty is not None:
            uncertainty_np = uncertainty.cpu().numpy()
            uncertainty_np = np.exp(uncertainty_np)
            d_a_square = 1.6**2 + 3.9**2
            
            if uncertainty_np.shape[1] == 3:
                uncertainty_np[:,:2] *= d_a_square
                uncertainty_np = np.sqrt(uncertainty_np) 
                # yaw angle is in radian, it's the same in g2o SE2's setting.

                pred_name = [f'x_u:{uncertainty_np[i,0]:.3f} y_u:{uncertainty_np[i,1]:.3f} a_u:{uncertainty_np[i,2]:.3f}' \
                                for i in range(uncertainty_np.shape[0])]

            elif uncertainty_np.shape[1] == 2:
                uncertainty_np[:,:2] *= d_a_square
                uncertainty_np = np.sqrt(uncertainty_np) # yaw angle is in radian

                pred_name = [f'x_u:{uncertainty_np[i,0]:.3f} y_u:{uncertainty_np[i,1]:3f}' \
                                for i in range(uncertainty_np.shape[0])]

            elif uncertainty_np.shape[1] == 7:
                uncertainty_np[:,:2] *= d_a_square
                uncertainty_np = np.sqrt(uncertainty_np) # yaw angle is in radian

                pred_name = [f'x_u:{uncertainty_np[i,0]:.3f} y_u:{uncertainty_np[i,1]:3f} a_u:{uncertainty_np[i,6]:3f}' \
                                for i in range(uncertainty_np.shape[0])]                    

    if vis_gt_box:
        gt_box_np = gt_tensor.cpu().numpy()
        # gt_name = ['gt'] * gt_box_np.shape[0]
        gt_name = [''] * gt_box_np.shape[0]

    if method == 'bev':
        canvas = canvas_bev.Canvas_BEV_heading_right(canvas_shape=((pc_range[4]-pc_range[1])*10, (pc_range[3]-pc_range[0])*10),
                                        canvas_x_range=(pc_range[0], pc_range[3]), 
                                        canvas_y_range=(pc_range[1], pc_range[4]),
                                        left_hand=left_hand
                                        ) 
        if isinstance(pcd_np, list):
            for i, pcd_np_t in enumerate(pcd_np):
                canvas_xy, valid_mask = canvas.get_canvas_coords(pcd_np_t)
                canvas.draw_canvas_points(canvas_xy[valid_mask], colors=pcd_colors[i])
        else:
            canvas_xy, valid_mask = canvas.get_canvas_coords(pcd_np) # Get Canvas Coords
            canvas.draw_canvas_points(canvas_xy[valid_mask], colors=pcd_colors)
            
        if cluster_pcd is not None:
            if isinstance(cluster_pcd, list):
                for i, cluster_pcd_np_t in enumerate(cluster_pcd_np):
                    canvas_xy, valid_mask = canvas.get_canvas_coords(cluster_pcd_np_t)
                    # paint cluster with larger points in canvas
                    canvas.draw_canvas_points(canvas_xy[valid_mask], colors=cluster_pcd_colors[i], radius=3)
            else:
                canvas_xy, valid_mask = canvas.get_canvas_coords(cluster_pcd_np)
                # paint cluster with larger points in canvas
                canvas.draw_canvas_points(canvas_xy[valid_mask], colors=cluster_pcd_colors, radius=3)
            
        box_line_thickness = 2
        if vis_gt_box:
            # canvas.draw_boxes(gt_box_np,colors=(0,255,0), texts=gt_name)
            canvas.draw_boxes(gt_box_np,colors=(0,255,0), texts=gt_name, box_line_thickness=box_line_thickness)
        
        if vis_pred_box:
            canvas.draw_boxes(pred_box_np, colors=(255,0,0), texts=pred_name, box_line_thickness=box_line_thickness)

    elif method == '3d':
        canvas = canvas_3d.Canvas_3D(left_hand=left_hand)
                                       
        if isinstance(pcd_np, list):
            for i, pcd_np_t in enumerate(pcd_np):
                canvas_xy, valid_mask = canvas.get_canvas_coords(pcd_np_t)
                canvas.draw_canvas_points(canvas_xy[valid_mask], colors=pcd_colors[i])
        else:
            canvas_xy, valid_mask = canvas.get_canvas_coords(pcd_np) # Get Canvas Coords
            canvas.draw_canvas_points(canvas_xy[valid_mask], colors=pcd_colors)
        
        if vis_gt_box:
            canvas.draw_boxes(gt_box_np,colors=(0,255,0), texts=gt_name)
        if vis_pred_box:
            canvas.draw_boxes(pred_box_np, colors=(255,0,0), texts=pred_name)
    else:
        raise(f"Not Completed for f{method} visualization.")

    plt.axis("off")

    plt.imshow(canvas.canvas)

    plt.tight_layout()
    plt.savefig(save_path, transparent=True, dpi=800, pad_inches=0.0)
    plt.clf()
    # print(save_path)