from pcdet.ops.roiaware_pool3d import roiaware_pool3d_utils
from pcdet.models.model_utils import graph_utils

import os 
import cv2 
import numpy as np 
import torch 
import pickle 

COLORS = np.array([
        [0.3, 0.3, 0.3], # 0
        #[1,0,0],
        [1.0,140.0/255.0,0],
        [30.0/255.0,144.0/255.0, 1.0],
        #[1,0,0],
        #[0.6, 0.1, 0.8], # 3
        [50.0/255.0,205.0/255.0,50.0/255.0],
        [1.0,0,1.0], #[0.2, 0.1, 0.9],
        [127/255.0,255/255.0,0],
        [255.0/255.0, 51.0/255.0, 51.0/255.0 ],  #[0,1,0], # 6
        [1.0,1.0,0],  # [0.8,0.8,0.8],
        [0.0, 0.8, 0.8],
        [210/255.0,105/255.0,30/255.0], #[0.05, 0.05, 0.3],
        [0.8, 0.6, 0.2], # 10 
        [148/255.0,0,211/255.0],
        [127/255.0,255/255.0,212/255.0], # 12
        [255/255.0,215/255.0,0], #[0.2, 0.5, 0.8],
        [0.0, 128.0/255.0, 0],
        [154/255.0,205/255.0,50/255.0],
        [230/255.0,230/255.0,250/255.0], # 16
        [240/255.0,230/255.0,140/255.0],
        [176/255.0,224/255.0,230/255.0], #[0.8, 0.2, 0.8],
        [255/255.0,99/255.0,71/255.0], # 19
        [0,1,1], #[0., 1, 0.3],
        [100/255.0,149/255.0,237/255.0],
        [255/255.0,105/255.0,180/255.0]
        ]).astype(np.float32)

def write_ply_color(points, labels,  out_filename, label_color=False):
    """ Color (N,3) points with labels (N) within range 0 ~ num_classes-1 as OBJ file """
    N = points.shape[0]
    fout = open(out_filename, 'w')
    ### Write header here
    fout.write("ply\n")
    fout.write("format ascii 1.0\n")
    fout.write("element vertex %d\n" % N)
    fout.write("property float x\n")
    fout.write("property float y\n")
    fout.write("property float z\n")
    fout.write("property uchar red\n")
    fout.write("property uchar green\n")
    fout.write("property uchar blue\n")
    fout.write("end_header\n")
    for i in range(N):
        #c = pyplot.cm.hsv(labels[i])
        #c = colors[i,:]
        #c = [int(x*255) for x in c]
        if not label_color:
            fout.write('%f %f %f %d %d %d\n' % (points[i,0],points[i,1],points[i,2], \
                   COLORS[labels[i], 0]*255, COLORS[labels[i], 1]*255, COLORS[labels[i], 2]*255))
        else:
            fout.write('%f %f %f %d %d %d\n' % (points[i,0],points[i,1],points[i,2], \
                labels[i,0],  labels[i,1], labels[i,1]))

    fout.close()


def get_box_points(box):
    center = box[0:3]
    dimensions = box[3:6]
    yaw = box[6]
    x, y, z = dimensions
    points = np.array([[-x/2, -y/2, -z/2], [-x/2, -y/2, z/2],
                       [-x/2, y/2, -z/2], [-x/2, y/2, z/2],
                       [x/2, -y/2, -z/2], [x/2, -y/2, z/2],
                       [x/2, y/2, -z/2], [x/2, y/2, z/2],])
    R = np.array([
        [np.cos(-yaw), -np.sin(-yaw), 0],
        [np.sin(-yaw), np.cos(-yaw), 0],
        [0,0,1]])
    points = np.matmul(points, R)
    points = points + center.reshape(-1, 3)
    return points


def write_output(pred_dicts, batch_dict, result_dir, rare_classes):
    graph = graph_utils.KNNGraph({}, dict(NUM_NEIGHBORS=1))  
    for ii in range(len(pred_dicts)):
        gt_labels = pred_dicts[ii]['point_wise']['gt_segmentation_label'].cpu().numpy()
        pred_labels = pred_dicts[ii]['point_wise']['pred_segmentation_label'].cpu().numpy()
        points = pred_dicts[ii]['point_wise']['point_xyz'].cpu().numpy()
        
        gt_boxes = batch_dict['gt_boxes'][ii].cpu().numpy()

        pred_labels_refine = pred_dicts[ii]['point_wise']['pred_segmentation_label_refine'].cpu().numpy()
        gt_labels_refine = pred_dicts[ii]['point_wise']['gt_segmentation_label_refine'].cpu().numpy()
        points_xyz = pred_dicts[ii]['point_wise']['point_xyz_refine'].cpu().numpy()
        
        ref_bxyz = torch.cat((torch.zeros(pred_dicts[ii]['point_wise']['point_xyz'].shape[0], 1).cuda(),\
                pred_dicts[ii]['point_wise']['point_xyz']), -1)
        query_bxyz = torch.cat((torch.zeros(pred_dicts[ii]['point_wise']['point_xyz_refine'].shape[0], 1).cuda(),\
                pred_dicts[ii]['point_wise']['point_xyz_refine']), -1)

        e_ref, e_query = graph(ref_bxyz, query_bxyz)
        e_ref = e_ref.cpu().numpy()
        #pred_point_seg_cls_logits = pred_seg_cls_logits[e_ref]
        pred_labels = pred_labels[e_ref]
        gt_labels = gt_labels[e_ref]
        points = points_xyz 


        gt_classes = np.unique(gt_labels)
        pred_bev_map = pred_dicts[ii]['voxel_wise']['pred_bev_segmentation_label'].cpu().numpy()
        gt_bev_map = pred_dicts[ii]['voxel_wise']['gt_bev_segmentation_label'].cpu().numpy()

        pred_boxes = pred_dicts[ii]['object_wise']['pred_box_attr'].cpu().numpy()
        pred_box_scores = pred_dicts[ii]['object_wise']['pred_box_scores'].cpu().numpy()
        pred_box_cls_label = pred_dicts[ii]['object_wise']['pred_box_cls_label'].cpu().numpy()
        gt_boxes = batch_dict['gt_boxes'][ii].cpu().numpy()
        gt_box_cls_label = batch_dict['gt_box_cls_label'][ii].cpu().numpy()

        rare = False
        cls_str = "_cls"
        classes =[]
        for cls in rare_classes:
            if cls in gt_classes:
                rare = True 
                cls_str += "_%d"%(cls)
                classes.append(cls) 
        if 1:
            out_prefix = 'rare_classes_cycle_train'
            if rare:
                frame_id = batch_dict['frame_id'][ii][-1]
                out_path= os.path.join(result_dir, out_prefix, frame_id+cls_str)
                frame_num = int(frame_id[-3:])
                check_flag=False
                check_interval = 1 
                #check_interval = 200 
                #if 5 in gt_classes:
                #    check_interval = 50 
                for check_i in range(check_interval):
                    check_path =  os.path.join(result_dir, out_prefix, frame_id[:-3]+str(frame_num-check_i).zfill(3)+cls_str, "%s_seg_gt.ply"%(str(frame_num-check_i).zfill(3)))
                    if os.path.exists(check_path):
                    #if os.path.exists(os.path.join(out_path, "%s%s_gt.ply"%(frame_id[:-3], str(frame_num-check_i).zfill(3)))):
                        check_flag=True
                        break
                #if check_flag:
                #    continue
                
                if not os.path.exists(out_path):
                    os.makedirs(out_path)

                print(frame_id, out_path)
                frame_id_or = frame_id 
                frame_id = frame_id[-3:] 

                sample_ids = np.random.permutation(range(points.shape[0]))[:100000]
                write_ply_color(points[sample_ids,:], gt_labels[sample_ids], os.path.join(out_path, '%s_seg_gt.ply'%(frame_id)))
                # write_ply_color(points, gt_labels, os.path.join(out_path, '%s_gt.ply'%(frame_id)))
                write_ply_color(points[sample_ids,:], pred_labels[sample_ids], os.path.join(out_path, '%s_seg_pred.ply'%(frame_id)))
                
                write_ply_color(points, pred_labels_refine, os.path.join(out_path, '%s_seg_pred_refined.ply'%(frame_id)))

                valid_mask = (gt_labels!=0).astype(np.int32)
                false_labels = (pred_labels*valid_mask != gt_labels).astype(np.int32)
                write_ply_color(points[sample_ids,:], false_labels[sample_ids], os.path.join(out_path, '%s_seg_false.ply'%(frame_id)))
                
                valid_mask = (gt_labels_refine!=0).astype(np.int32)
                false_labels = (pred_labels_refine*valid_mask != gt_labels_refine).astype(np.int32)
                write_ply_color(points, false_labels, os.path.join(out_path, '%s_seg_false_refine.ply'%(frame_id)))


                # write_ply_color(points[sweep_ids,:], false_labels, os.path.join(out_path, '%s_false.ply'%(frame_id)))
                for cls in classes:
                    cls_ids = (gt_labels == cls)
                    write_ply_color(points[cls_ids, :], gt_labels[cls_ids], os.path.join(out_path, '%s_cls_%d.ply'%(frame_id, cls)))
                
                print(batch_dict['gt_boxes'].shape)
                # gt box and cls, pred box and cls, pred box and confidence, box corners  
                gt_point_box_ids = roiaware_pool3d_utils.points_in_boxes_gpu(pred_dicts[ii]['point_wise']['point_xyz_refine'].unsqueeze(0), \
                        batch_dict['gt_boxes'][:,:,0:7])[0]
                pred_point_box_ids = roiaware_pool3d_utils.points_in_boxes_gpu(pred_dicts[ii]['point_wise']['point_xyz_refine'].unsqueeze(0), \
                    pred_dicts[ii]['object_wise']['pred_box_attr'].unsqueeze(0))[0]
                gt_point_box_ids = gt_point_box_ids.cpu().numpy()
                pred_point_box_ids = pred_point_box_ids.cpu().numpy() 
                # gt_point_box_ids += 1 
                # pred_point_box_ids += 1 

                gt_point_box_cls = gt_box_cls_label[gt_point_box_ids]
                gt_point_box_cls[gt_point_box_ids==-1] = 0 

                pred_point_box_cls = pred_box_cls_label[pred_point_box_ids]
                pred_point_box_cls[pred_point_box_ids==-1] = 0 

                write_ply_color(points, gt_point_box_cls, os.path.join(out_path, '%s_box_gt.ply'%(frame_id)))
                write_ply_color(points, pred_point_box_cls, os.path.join(out_path, '%s_box_pred.ply'%(frame_id)))

                corner_points_gt = np.zeros((gt_boxes.shape[0], 8, 3))
                corner_points_gt_cls = np.zeros((gt_boxes.shape[0], 8, ), np.int32)
                for j in range(gt_boxes.shape[0]):
                    corner_points_gt[j] = get_box_points(gt_boxes[j])
                    corner_points_gt_cls[j, :] = gt_box_cls_label[j]

                corner_points_pred = np.zeros((pred_boxes.shape[0], 8, 3))
                corner_points_pred_cls = np.zeros((pred_boxes.shape[0], 8, ), np.int32)
                for j in range(gt_boxes.shape[0]):
                    corner_points_pred[j] = get_box_points(pred_boxes[j])
                    corner_points_pred_cls[j, :] = pred_box_cls_label[j]

                write_ply_color(corner_points_gt.reshape((-1,3)), corner_points_gt_cls.reshape((-1)), \
                        os.path.join(out_path, '%s_box_corners_gt.ply'%(frame_id)))
                write_ply_color(corner_points_pred.reshape((-1,3)), corner_points_pred_cls.reshape((-1)), \
                        os.path.join(out_path, '%s_box_corners_pred.ply'%(frame_id)))
                
                #point_bev_labels = batch_dict['bev_seg_points_label'][0].cpu().numpy()
                #write_ply_color(point_bev_labels[:, 0:3], point_bev_labels[:,3].astype(np.int32), os.path.join(out_path, '%s_bev_gt.ply'%(frame_id)))
                #write_ply_color(point_bev_labels[:, 0:3], pred_bev_map.reshape(-1), os.path.join(out_path, '%s_bev_pred.ply'%(frame_id)))

                cv2.imwrite(os.path.join(out_path, '%s_bev_gt.png'%(frame_id)),  gt_bev_map*255.0/3.0)
                cv2.imwrite(os.path.join(out_path, '%s_bev_pred.png'%(frame_id)),  pred_bev_map*255.0/3.0)
                
                # box point mask 
                num_pt = batch_dict['box_points_xyz'].shape[1]
                num_input = batch_dict['box_points_xyz'].shape[0]
                num_box = batch_dict['num_box']

                box_points_xyz_or = batch_dict['box_points_xyz']
                box_points_xyz_or[:num_box,] = box_points_xyz_or[:num_box] + batch_dict['pred_boxes_refine'][:,0:3].unsqueeze(1)
                box_points_xyz = box_points_xyz_or.reshape(-1, 3).cpu().numpy()
                
                gt_seg_label = batch_dict['gt_seg_label'].reshape(-1).cpu().numpy()
                # gt_point_mask = batch_dict['gt_point_mask'].reshape(-1).cpu().numpy() # [256,512,]
                gt_point_mask = pred_dicts[ii]['object_wise']['gt_mask'].reshape(-1).cpu().numpy()
                gt_pred_box_cls = pred_dicts[ii]['object_wise']['gt_box_cls'].unsqueeze(1).repeat(1, num_pt).reshape(-1).cpu().numpy()
                
                # gt_pred_box_cls = batch_dict['gt_pred_box_cls'].unsqueeze(1).repeat(1, num_pt).reshape(-1).cpu().numpy() # [256,]
                box_points_first_stage_logits = batch_dict['box_points_first_stage_logits']
                _, box_points_seg = torch.sigmoid(box_points_first_stage_logits).max(-1) # [256,512]
                box_points_seg = box_points_seg.reshape(-1).cpu().numpy()
                

                write_ply_color(box_points_xyz, gt_point_mask, os.path.join(out_path, '%s_refine_mask_gt.ply'%(frame_id)))
                write_ply_color(box_points_xyz, gt_pred_box_cls, os.path.join(out_path, '%s_refine_box_cls_gt.ply'%(frame_id)))
                write_ply_color(box_points_xyz, box_points_seg, os.path.join(out_path, '%s_refine_seg_first.ply'%(frame_id)))
                write_ply_color(box_points_xyz, gt_seg_label, os.path.join(out_path, '%s_refine_seg_gt.ply'%(frame_id)))

                second_stage_score = batch_dict['second_stage_score']
                final_score = batch_dict['final_score']
                pred_mask_logits = batch_dict['pred_point_mask_logits']
                pred_box_cls_logits = batch_dict['pred_box_cls_logits']

                _, box_points_seg_final = final_score.max(-1)
                box_points_seg_final =box_points_seg_final.reshape(-1).cpu().numpy()
                _, box_points_seg_second = second_stage_score.max(-1)
                box_points_seg_second = box_points_seg_second.reshape(-1).cpu().numpy()
                # _, pred_mask = torch.sigmoid(pred_mask_logits).max(-1)
                # pred_mask = pred_mask.reshape(-1).cpu().numpy()
                # _, pred_box_cls = torch.sigmoid(pred_box_cls_logits).max(-1)
                # pred_box_cls = pred_box_cls.unsqueeze(1).repeat(1, num_pt).reshape(-1).cpu().numpy()
                pred_mask = pred_dicts[ii]['object_wise']['pred_mask'].reshape(-1).cpu().numpy()
                pred_box_cls = pred_dicts[ii]['object_wise']['pred_box_cls'].unsqueeze(1).repeat(1, num_pt).reshape(-1).cpu().numpy()

                write_ply_color(box_points_xyz, box_points_seg_second, os.path.join(out_path, '%s_refine_seg_second.ply'%(frame_id)))
                write_ply_color(box_points_xyz, box_points_seg_final, os.path.join(out_path, '%s_refine_seg_final.ply'%(frame_id)))
                write_ply_color(box_points_xyz, pred_mask, os.path.join(out_path, '%s_refine_mask_pred.ply'%(frame_id)))
                write_ply_color(box_points_xyz, pred_box_cls, os.path.join(out_path, '%s_refine_box_cls_pred.ply'%(frame_id)))

                out_dict = {
                    'points': points, 
                    'gt_seg_label': gt_labels,
                    'gt_instance_label': batch_dict['instance_label'].cpu().numpy(),
                    'gt_boxes': gt_boxes, 
                    'gt_corner_points': corner_points_gt, 
                    'gt_box_cls_label': gt_box_cls_label, 
                    'pred_boxes': pred_boxes, 
                    'pred_corner_points': corner_points_pred, 
                    'pred_box_scores': pred_box_scores, 
                    'pred_box_cls_label': pred_box_cls_label,
                    'pred_seg_label': pred_labels, 
                    'pred_seg_logits': batch_dict['pred_seg_cls_logits'].cpu().numpy()[e_ref],
                    'pred_seg_label_refine': pred_labels_refine,
                    'pred_seg_logits_refine': batch_dict['pred_seg_cls_logits_refine'].cpu().numpy(),
                    'prerefine_gt_seg_label': batch_dict['gt_seg_label'].cpu().numpy(), 
                    'prerefine_seg_label': box_points_seg.reshape((num_input, num_pt)),
                    'prerefine_seg_prob': box_points_first_stage_logits.cpu().numpy(),
                    'prerefine_gt_point_mask':  pred_dicts[ii]['object_wise']['gt_mask'].cpu().numpy(), 
                    'prerefine_gt_box_cls':  pred_dicts[ii]['object_wise']['gt_box_cls'].cpu().numpy(),
                    'prerefine_input_box': batch_dict['pred_boxes_refine'].cpu().numpy(), 
                    'refined_points': box_points_xyz_or.cpu().numpy(), 
                    'refined_pred_point_mask': pred_dicts[ii]['object_wise']['pred_mask'].cpu().numpy(),
                    'refined_pred_box_cls': pred_dicts[ii]['object_wise']['pred_box_cls'].cpu().numpy(),
                    'refined_pred_point_mask_logits': pred_mask_logits.cpu().numpy(),
                    'refined_pred_box_cls_logits': pred_box_cls_logits.cpu().numpy(), 
                    # 'refined_point_box_id': None, 
                    # 'refined_box_cls_label': None, 
                    'refined_seg_prob': batch_dict['final_score'][0].cpu().numpy(),
                    'refined_seg_label':  box_points_seg_final.reshape((num_input, num_pt)), 
                }
                #for key in out_dict.keys():
                #    print(key, out_dict[key].shape)

                with open(os.path.join(out_path, frame_id_or+'.pkl'), 'wb') as f:
                    pickle.dump(out_dict, f)
                
                print('saved', frame_id)

                
        
