from pcdet.ops.roiaware_pool3d import roiaware_pool3d_utils
from pcdet.models.model_utils import graph_utils

import os 
import numpy as np 
import torch 
import pickle 

COLORS = np.array([
        [0.3, 0.3, 0.3], # 0
        #[1,0,0],
        [1.0,140.0/255.0,0],
        [30.0/255.0,144.0/255.0, 1.0],
        #[1,0,0],
        #[0.6, 0.1, 0.8], # 3
        [50.0/255.0,205.0/255.0,50.0/255.0],
        [1.0,0,1.0], #[0.2, 0.1, 0.9],
        [127/255.0,255/255.0,0],
        [255.0/255.0, 51.0/255.0, 51.0/255.0 ],  #[0,1,0], # 6
        [1.0,1.0,0],  # [0.8,0.8,0.8],
        [0.0, 0.8, 0.8],
        [210/255.0,105/255.0,30/255.0], #[0.05, 0.05, 0.3],
        [0.8, 0.6, 0.2], # 10 
        [148/255.0,0,211/255.0],
        [127/255.0,255/255.0,212/255.0], # 12
        [255/255.0,215/255.0,0], #[0.2, 0.5, 0.8],
        [0.0, 128.0/255.0, 0],
        [154/255.0,205/255.0,50/255.0],
        [230/255.0,230/255.0,250/255.0], # 16
        [240/255.0,230/255.0,140/255.0],
        [176/255.0,224/255.0,230/255.0], #[0.8, 0.2, 0.8],
        [255/255.0,99/255.0,71/255.0], # 19
        [0,1,1], #[0., 1, 0.3],
        [100/255.0,149/255.0,237/255.0],
        [255/255.0,105/255.0,180/255.0]
        ]).astype(np.float32)

def write_ply_color(points, labels,  out_filename, label_color=False):
    """ Color (N,3) points with labels (N) within range 0 ~ num_classes-1 as OBJ file """
    N = points.shape[0]
    fout = open(out_filename, 'w')
    ### Write header here
    fout.write("ply\n")
    fout.write("format ascii 1.0\n")
    fout.write("element vertex %d\n" % N)
    fout.write("property float x\n")
    fout.write("property float y\n")
    fout.write("property float z\n")
    fout.write("property uchar red\n")
    fout.write("property uchar green\n")
    fout.write("property uchar blue\n")
    fout.write("end_header\n")
    for i in range(N):
        #c = pyplot.cm.hsv(labels[i])
        #c = colors[i,:]
        #c = [int(x*255) for x in c]
        if not label_color:
            fout.write('%f %f %f %d %d %d\n' % (points[i,0],points[i,1],points[i,2], \
                   COLORS[labels[i], 0]*255, COLORS[labels[i], 1]*255, COLORS[labels[i], 2]*255))
        else:
            fout.write('%f %f %f %d %d %d\n' % (points[i,0],points[i,1],points[i,2], \
                labels[i,0],  labels[i,1], labels[i,1]))

    fout.close()


def write_output(pred_dicts, batch_dict, result_dir, rare_classes):
    graph = graph_utils.KNNGraph({}, dict(NUM_NEIGHBORS=1))  
    for ii in range(len(pred_dicts)):
        gt_labels = pred_dicts[ii]['point_wise']['gt_segmentation_label'].cpu().numpy()
        pred_labels = pred_dicts[ii]['point_wise']['pred_segmentation_label'].cpu().numpy()
        points = pred_dicts[ii]['point_wise']['point_xyz'].cpu().numpy()
        ins_labels = None 
        if 'instance_label_back' in pred_dicts[ii]['point_wise'].keys():
            ins_labels = pred_dicts[ii]['point_wise']['instance_label_back'].cpu().numpy()

        gt_classes = np.unique(gt_labels)

        rare = False
        cls_str = "_cls"
        classes =[]
        for cls in rare_classes:
            if cls in gt_classes:
                rare = True 
                cls_str += "_%d"%(cls)
                classes.append(cls) 
        if 1:
            out_prefix = 'rare_classes'
            if rare:
                #frame_id = batch_dict['frame_id'][ii][-1]
                frame_id = batch_dict['frame_id'][ii]
                out_path= os.path.join(result_dir, out_prefix, frame_id+cls_str)
                
                if not os.path.exists(out_path):
                    os.makedirs(out_path)

                print(frame_id, out_path)
                frame_id_or = frame_id 
                frame_id = frame_id[-3:] 
                
                sample_ids = np.random.permutation(range(points.shape[0]))[:100000]
                write_ply_color(points[sample_ids,:], gt_labels[sample_ids], os.path.join(out_path, '%s_seg_gt.ply'%(frame_id)))
                # write_ply_color(points, gt_labels, os.path.join(out_path, '%s_gt.ply'%(frame_id)))
                print((pred_labels[sample_ids]!=gt_labels[sample_ids]).sum())
                #write_ply_color(points[sample_ids,:], pred_labels[sample_ids], os.path.join(out_path, '%s_seg_pred.ply'%(frame_id)))
                write_ply_color(points, pred_labels, os.path.join(out_path, '%s_seg_pred.ply'%(frame_id)))
                if ins_labels is not None:
                    write_ply_color(points[sample_ids,:], ins_labels[sample_ids]%22, os.path.join(out_path, '%s_seg_pred.ply'%(frame_id)))
                
                valid_mask = (gt_labels!=0).astype(np.int32)
                
                false_labels = (pred_labels*valid_mask != gt_labels).astype(np.int32)
                write_ply_color(points[sample_ids,:], false_labels[sample_ids], os.path.join(out_path, '%s_seg_false.ply'%(frame_id)))
                
                for cls in classes:
                    cls_ids = (gt_labels == cls)
                    write_ply_color(points[cls_ids, :], gt_labels[cls_ids], os.path.join(out_path, '%s_cls_%d.ply'%(frame_id, cls)))

                out_dict = {
                    'points': points, 
                    'gt_seg_label': gt_labels,
                    'pred_seg_label': pred_labels, 
                }
                if ins_labels is not None:
                    out_dict['gt_instance_label'] = ins_labels

                with open(os.path.join(out_path, frame_id_or+'.pkl'), 'wb') as f:
                    pickle.dump(out_dict, f)
                print('saved', frame_id)

                
        
