import torch
from torch import nn
import torch.nn.functional as F

from ...utils import box_utils, loss_utils
from .point_head_template import PointHeadTemplate
from ...ops.torch_hash import RadiusGraph
from pcdet.models.model_utils import graph_utils
from pcdet.models.backbones_3d import ShapeGFEncoder, ShapeGFDecoder
from torch_scatter import scatter
import torch_cluster 
from chamferdist import ChamferDistance  

import numpy as np
from scipy.sparse.csgraph import connected_components
from scipy.sparse import csr_matrix
from torch_cluster import radius_graph
from sklearn.neighbors import NearestNeighbors as NN


class PointSegHead_InsClsGen(PointHeadTemplate):
    """
    A simple point-based segmentation head, which are used for PV-RCNN keypoint segmentaion.
    Reference Paper: https://arxiv.org/abs/1912.13192
    PV-RCNN: Point-Voxel Feature Set Abstraction for 3D Object Detection
    """
    def __init__(self, runtime_cfg, model_cfg, **kwargs):
        num_class = runtime_cfg['num_seg_classes']
        input_channels = runtime_cfg['input_channels']
        self.cls_weights = model_cfg.get("CLS_WEIGHTS", None)

        super().__init__(model_cfg=model_cfg,
                         num_class=num_class)
        self.scale = runtime_cfg.get('scale', 1.0)
        self.assign_to_point = model_cfg.get("ASSIGN_TO_POINT", False)
        self.embedding_channel = model_cfg.get("EMBEDDING_CHANNEL", 64)
        self.proj_channel = model_cfg.get("PROJ_CHANNEL", 64)
        self.n_views = model_cfg.get("N_VIEWS", 50)
        self.use_prev = True
        self.ins_cls = model_cfg.get("INS_CLS", False)
        self.ins_cls_weight = model_cfg.get("INS_CLS_LOSS_WEIGHT", False)
        self.global_fea_op = model_cfg.get("GLOBAL_FEA_OP", "max")
        self.n_ins_cls = model_cfg.get("N_INS_CLS", 7)
        self.more_channel = model_cfg.get("MORE_CHANNEL", False)
        self.multi_scale = model_cfg.get("MULTI_SCALE", False)
        self.append_points = model_cfg.get("APPEND_POINTS", False)
        self.append_rawpoint_features = model_cfg.get("APPEND_RAWPOINT_FEATURES", False)
        self.raw_point_feature_channel = model_cfg.get("RAW_POINT_FEATURE_CHANNEL", 64)
        self.ref_all = model_cfg.get("REF_ALL", False)
        self.ins_cls_chosen = model_cfg.get("INS_CLS_CHOSEN", None )
        self.cls_weights = model_cfg.get("CLS_WEIGHTS", None)

        self.add_generator = model_cfg.get("ADD_GENERATOR", False)
        self.generator_weight = model_cfg.get("GENERATOR_WEIGHT", False)
        
        self.chamferDist = ChamferDistance()

        self.cls_layers = self.make_fc_layers(
            fc_cfg=[int(c*self.scale) for c in self.model_cfg.CLS_FC],
            input_channels=input_channels,
            output_channels=num_class,
            dropout=self.dropout
        )
        
        if self.n_ins_cls==12:
            self.sem_cls = [1,2,3,4,5,6,7,8,14,16,18,19]
            self.ins_cls_mapping = torch.zeros(20).cuda()
            cnt = 1
            for s_i in self.sem_cls:
                self.ins_cls_mapping[s_i] = cnt
                cnt += 1
        else:
            self.ins_cls_mapping = None

        embed_input_channel = input_channels
        embed_fc = self.model_cfg.EMBED_FC

        if self.append_points:
            embed_input_channel += 32*3 
            embed_fc = [128, 64,]
            if self.more_channel:
                embed_fc = [128, 128, 64,]

        if self.multi_scale: 
            #embed_input_channel += 64 + 128 + 256 + 256 
            embed_input_channel += 32 + 64 + 128
            embed_fc = [256, 128,]
            if self.more_channel:
                embed_fc = [256, 128, 64]
      
        if self.ins_cls:
            self.embed_layers = self.make_fc_layers(
                fc_cfg=[int(c*self.scale) for c in embed_fc],
                input_channels=embed_input_channel,
                output_channels=self.embedding_channel,
                dropout=False,
            )
        
        if self.append_rawpoint_features:
            self.raw_point_layers = self.make_fc_layers(
                fc_cfg=[int(c*self.scale) for c in self.model_cfg.RAW_POINT_FC],
                input_channels=6,
                output_channels=self.raw_point_feature_channel,
                dropout=False,
            )
            ins_cls_input_channel = self.embedding_channel+ self.raw_point_feature_channel
            ins_cls_fc = [128, 64, 64]

        else:
            ins_cls_input_channel = self.embedding_channel
            ins_cls_fc = self.model_cfg.INS_CLS_FC
            if self.more_channel:
                ins_cls_fc = [128, 64, 64, 32]
        if self.ins_cls:
            self.ins_cls_layers = self.make_fc_layers(
                fc_cfg=[int(c*self.scale) for c in ins_cls_fc],
                input_channels=ins_cls_input_channel,
                output_channels=self.n_ins_cls,
                dropout=False,
            )

        self.gen_encoder = ShapeGFEncoder(model_cfg.get("GEN_ENCODER", None), runtime_cfg)
        self.gen_decoder = ShapeGFDecoder(model_cfg.get("GEN_DECODER", None), runtime_cfg)
        self.n_points_gen = model_cfg.get("N_POINTS_GEN", None)
        self.n_channel_gen_input = 32 + 3 

        self.build_losses(self.model_cfg.LOSS_CONFIG)
        self.target_assigner_cfg = self.model_cfg.get("TARGET_ASSIGNER", None)
        if self.target_assigner_cfg is not None:
            max_num_points = self.target_assigner_cfg.get("MAX_NUM_POINTS", None)
            self.radius_graph = RadiusGraph(max_num_points=max_num_points, ndim=3)
        
        self.graph = graph_utils.KNNGraph({}, dict(NUM_NEIGHBORS=1))
        if self.append_points:
            self.ins_graph = graph_utils.KNNGraph({}, dict(NUM_NEIGHBORS=32))

    def build_losses(self, losses_cfg):
        if not isinstance(losses_cfg['LOSS'], list):
            losses_cfg['LOSS'] = [losses_cfg['LOSS']]
        if not isinstance(losses_cfg['WEIGHT'], list):
            losses_cfg['WEIGHT'] = [losses_cfg['WEIGHT']]
        self.loss_names = losses_cfg['LOSS']
        self.losses = nn.ModuleList()
        self.loss_weight = []
        i=0
        for loss, weight in zip(losses_cfg['LOSS'], losses_cfg['WEIGHT']):
            if i==1 and self.cls_weights is not None:
                self.losses.append(
                    loss_utils.LOSSES[loss](weight=torch.Tensor(self.cls_weights).cuda(), loss_cfg=losses_cfg)
                )
            else:
                self.losses.append(
                    loss_utils.LOSSES[loss](loss_cfg=losses_cfg)
                )

            self.loss_weight.append(weight)
            i+=1 

    def get_cls_loss(self, preds, labels, tb_dict=None, index=0,prefix=None):

        if tb_dict is None:
            tb_dict = {}
        point_loss = 0.0
        loss_module = self.losses[index]
        loss_name = self.loss_names[index]
        loss_this = loss_module(preds, labels)
        if prefix is None:
            tb_dict[loss_name] = loss_this.item()
        else:
            tb_dict[f'{prefix}/{loss_name}'] = loss_this.item()
        point_loss += loss_this
        return point_loss, tb_dict

    def get_loss(self, tb_dict=None, prefix=None):
        tb_dict = {} if tb_dict is None else tb_dict
        point_cls_labels = self.forward_ret_dict[self.gt_seg_cls_label_key].view(-1).long()
        point_cls_preds = self.forward_ret_dict['pred_seg_cls_logits'].view(-1, self.num_class)
        point_loss_cls, tb_dict_1 = self.get_cls_loss(point_cls_preds, point_cls_labels,)

        loss_gen = torch.Tensor([0.0]).cuda()
        if self.add_generator and self.forward_ret_dict['total_cdis']>0:
            loss_gen = self.forward_ret_dict['total_cdis']*0.001

        tb_dict_2 = {} 
        
        if self.ins_cls:
            ins_cls_labels = self.forward_ret_dict['ins_cls_label'].view(-1).long()
            ins_cls_preds = self.forward_ret_dict['ins_cls_pred'].view(-1, self.n_ins_cls)
            loss_ins_cls, _ = self.get_cls_loss(ins_cls_preds, ins_cls_labels,index=1)
        else:
            loss_ins_cls = torch.Tensor([0.0]).cuda()

        point_loss = point_loss_cls + self.generator_weight * loss_gen +  self.ins_cls_weight * loss_ins_cls 

        tb_dict.update(tb_dict_1)
        tb_dict.update(tb_dict_2)
        
        tb_dict['loss_recons'] = loss_gen.item()
        tb_dict['loss_seg'] = point_loss_cls.item()
        tb_dict['loss_ins_cls'] = loss_ins_cls.item()

        iou_stats, _ = self.get_iou_statistics()
        ups, downs = iou_stats[0]['ups'], iou_stats[0]['downs']
        for iou_stat in iou_stats[1:]:
            ups += iou_stat['ups']
            downs += iou_stat['downs']
        ious = ups / torch.clamp(downs, min=1.0)
        for i in range(self.num_class):
            if downs[i] > 0:
                if prefix is None:
                    tb_dict.update({f'per_class/IoU_{i}': ious[i]})
                else:
                    tb_dict.update({f'{prefix}/per_class/IoU_{i}': ious[i]})
        if prefix is None:
            tb_dict.update({f'IoU_FG': ups[1:5].sum()/torch.clamp(downs[1:5].sum(), min=1.0),
                            f'IoU_BG': ups[5:].sum()/torch.clamp(downs[5:].sum(), min=1.0),
                            })
            tb_dict.update({f'mIoU': ious.mean()})
            tb_dict.update({f'loss': point_loss.item()})
        else:
            tb_dict.update({f'{prefix}/IoU_FG': ups[1:5].sum()/torch.clamp(downs[1:5].sum(), min=1.0),
                            f'{prefix}/IoU_BG': ups[5:].sum()/torch.clamp(downs[5:].sum(), min=1.0),
                            })
            tb_dict.update({f'{prefix}/mIoU': ious.mean()})
            tb_dict.update({f'{prefix}/loss': point_loss.item()})
        
        return point_loss, tb_dict

    def get_iou_statistics(self):
        pred_dicts = self.get_evaluation_results()
        iou_dicts = []
        iou_dict = dict(
            ups=None,
            downs=None,
        )
        for pred_dict in pred_dicts:
            pred_labels = pred_dict['point_wise']['pred_segmentation_label']
            gt_labels = pred_dict['point_wise']['gt_segmentation_label']
            ups = pred_labels.new_zeros(self.num_class)
            downs = pred_labels.new_zeros(self.num_class)
            pred_labels[gt_labels == 0] = 0
            for cls in range(self.num_class):
                pred_mask = pred_labels == cls
                gt_mask = gt_labels == cls
                ups[cls] = (pred_mask & gt_mask).sum()
                downs[cls] = (pred_mask | gt_mask).sum()
            
            iou_dict['ups'] = ups if iou_dict['ups'] is None else iou_dict['ups'] + ups
            iou_dict['downs'] = downs if iou_dict['downs'] is None else iou_dict['downs'] + downs
            iou_dicts.append(
                dict(
                    ups = ups,
                    downs = downs
                )
            )
        return iou_dicts, iou_dict

    def get_evaluation_results(self):
        pred_logits = self.forward_ret_dict['pred_seg_cls_logits']
        pred_scores = torch.sigmoid(pred_logits)
        batch_idx = self.forward_ret_dict['batch_idx']
        pred_dicts = []
        point_bxyz = self.forward_ret_dict['point_bxyz']

        for i in range(self.forward_ret_dict['batch_size']):
            bs_mask = batch_idx == i
            point_xyz = point_bxyz[bs_mask, 1:4]
            pred_confidences, pred_labels = pred_scores[bs_mask].max(-1)
            gt_labels = self.forward_ret_dict[self.gt_seg_cls_label_key][bs_mask]
            #ins_labels = self.forward_ret_dict['instance_label_back'][bs_mask]
            ins_labels = torch.zeros(gt_labels.shape)
            valid_mask = (gt_labels >= 0)
            pred_labels = pred_labels[valid_mask]
            gt_labels = gt_labels[valid_mask]
            record_dict = dict(
                point_wise=dict(
                    gt_segmentation_label=gt_labels,
                    pred_segmentation_label=pred_labels,
                    point_xyz=point_xyz,
                    instance_label_back=ins_labels,
                ),
                object_wise=dict(),
                voxel_wise=dict(),
                scene_wise=dict(
                    num_seg_class=self.num_class,
                    ins_cls_label=self.forward_ret_dict['ins_cls_label'],
                ),
            )
            pred_dicts.append(record_dict)
        return pred_dicts

    def assign_targets(self, target_assigner_cfg, batch_dict):
        ref_label = batch_dict[target_assigner_cfg["REF_SEGMENTATION_LABEL"]]
        ref_bxyz = batch_dict[target_assigner_cfg["REF_POINT_BXYZ"]]
        query_bxyz = batch_dict[target_assigner_cfg["QUERY_POINT_BXYZ"]]
        query_label_key = target_assigner_cfg["QUERY_SEGMENTATION_LABEL"]

        radius = target_assigner_cfg["RADIUS"]
        er, eq = self.radius_graph(ref_bxyz, query_bxyz, radius, 1, sort_by_dist=True)

        query_label = ref_label.new_full(query_bxyz.shape[:1], 0) # by default, assuming class 0 is ignored
        query_label[eq] = ref_label[er]
        
        batch_dict[query_label_key] = query_label

    def forward(self, batch_dict):
        """
        Args:
            batch_dict:
                batch_size:
                point_features: (N1 + N2 + N3 + ..., C) or (B, N, C)
                point_features_before_fusion: (N1 + N2 + N3 + ..., C)
                point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
                point_labels (optional): (N1 + N2 + N3 + ...)
                gt_boxes (optional): (B, M, 8)
        Returns:
            batch_dict:
                point_cls_scores: (N1 + N2 + N3 + ..., 1)
                point_part_offset: (N1 + N2 + N3 + ..., 3)
        """

        point_features = batch_dict[self.point_feature_key]
        point_pred_logits = self.cls_layers(point_features)  # (total_points, num_class)
        
        point_pred_scores = torch.sigmoid(point_pred_logits)
        ret_dict = {
            'pred_seg_cls_logits': point_pred_logits,
        }

        ret_dict['pred_seg_cls_confidences'], ret_dict['pred_seg_cls_labels'] = point_pred_scores.max(dim=-1)
        batch_dict.update(ret_dict)

        if 1:
        #if self.training:
            # instance_label: instances labels for seg label 1-7 
            # instance_label_back: instances labels for seg label 1-13

            #instance_label_or = batch_dict['voxel_instance_label'] - 1
            if not 'voxel_instance_label_back' in batch_dict.keys():
                batch_dict['voxel_instance_label_back'] = batch_dict['voxel_instance_label'] 
                batch_dict['instance_label_back'] = batch_dict['instance_label']

            instance_label_or = batch_dict['voxel_instance_label_back'] - 1
            pt_instance_label_or = batch_dict['instance_label_back']

            segmentation_label = batch_dict['voxel_segmentation_label']
            batch_dim = batch_dict[self.batch_key][:, 0]
            instance_label = instance_label_or.clone()
            pt_instance_label = pt_instance_label_or.clone() 
            
            # add an offset for instance labels in different data samples 
            last_num = 0
            for b_i in range(batch_dict['batch_size']):
                ins_mask_tmp = (batch_dim==b_i) & (instance_label_or>0)
                pt_ins_mask_tmp = (batch_dict['point_bxyz'][:,0]==b_i) & (pt_instance_label_or>0)
                if ins_mask_tmp.sum()==0:
                    continue 
                instance_label[ins_mask_tmp] += last_num
                pt_instance_label[pt_ins_mask_tmp] += last_num
                last_num = instance_label[ins_mask_tmp].max() +1
        
            ins_mask = instance_label>=0
            cls_input = point_features[ins_mask].clone()

        if self.multi_scale:
        #if self.multi_scale and self.training:
            query_bxyz = batch_dict['spconv_unet_up_bcenter5'][ins_mask]
            #unet_features = [batch_dict['spconv_unet_up_feat4'], batch_dict['spconv_unet_up_feat3'], batch_dict['spconv_unet_up_feat2'], \
            #                  batch_dict['spconv_unet_up_feat1']]
            unet_features = [batch_dict['spconv_unet_up_feat4'], batch_dict['spconv_unet_up_feat3'], batch_dict['spconv_unet_up_feat2']]

            #for ri, ref_bxyz in enumerate([batch_dict['spconv_unet_up_bcenter4'], batch_dict['spconv_unet_up_bcenter3'], \
            #        batch_dict['spconv_unet_up_bcenter2'], batch_dict['spconv_unet_up_bcenter1']]):
            for ri, ref_bxyz in enumerate([batch_dict['spconv_unet_up_bcenter4'], batch_dict['spconv_unet_up_bcenter3'], \
                    batch_dict['spconv_unet_up_bcenter2'],]):
                if ref_bxyz.shape[0]>0 and query_bxyz.shape[0]>0:
                    e_ref, e_query = self.graph(ref_bxyz, query_bxyz)
                    ref_feat = unet_features[ri]
                    query_feat = ref_feat[e_ref]
                else:
                    query_feat = torch.zeros(cls_input.shape[0], unet_features[ri].shape[1]).cuda()
                cls_input = torch.cat([cls_input, query_feat], dim=1)

        if self.append_points:
        #if self.append_points and self.training:
            query_bxyz = batch_dict['spconv_unet_up_bcenter5'][ins_mask]
            if self.ref_all:
                ref_bxyz = batch_dict['point_bxyz']            
            else:
                ref_bxyz = batch_dict['point_bxyz'][pt_instance_label>=0]
            

            if ref_bxyz.shape[0]>0 and query_bxyz.shape[0]>0:
                e_ref, e_query = self.ins_graph(ref_bxyz, query_bxyz)
            
                appended_points = ref_bxyz[e_ref][:, 1:].reshape(query_bxyz.shape[0], 32, 3)
                dis = torch.sqrt((torch.square(appended_points - query_bxyz[:, 1:].unsqueeze(1))).sum(2))
                dis_mask = (dis<0.5).int() 
                appended_points = appended_points - query_bxyz[:, 1:].unsqueeze(1)
                appended_points = appended_points * dis_mask.unsqueeze(-1)
                appended_points = appended_points.reshape(query_bxyz.shape[0], 32*3)
            else:
                appended_points = torch.zeros(cls_input.shape[0], 32*3).cuda()

            cls_input = torch.cat([cls_input, appended_points], dim=1)
        
        if self.ins_cls:
        #if self.training and self.ins_cls:
            point_embeddings = self.embed_layers(cls_input) 
            point_label = point_pred_scores.max(-1)[1]
            if 1:
                masked_embeddings = point_embeddings 
                instance_label_few = instance_label[ins_mask]
                
                if self.global_fea_op=='mean':
                    ins_embeddings = scatter(masked_embeddings, instance_label_few,dim=0, reduce='mean')
                else:
                    ins_embeddings = scatter(masked_embeddings, instance_label_few,dim=0, reduce='max')

                if self.append_rawpoint_features:
                    if ins_embeddings.shape[0]!=raw_pt_embeddings.shape[0]:
                        print('###########', batch_dict['instance_label_back'].max(), batch_dict['voxel_instance_label_back'].max())
                    ins_embeddings = torch.cat((ins_embeddings, raw_pt_embeddings[:ins_embeddings.shape[0]]), dim=1)

                ins_cls_pred = self.ins_cls_layers(ins_embeddings)
                ins_cls_pred = ins_cls_pred[torch.unique(instance_label_few)]
                
                gt_ins_cls_label = scatter(segmentation_label[ins_mask].float(), instance_label_few, dim=0, reduce='mean')
                gt_ins_cls_label = torch.round(gt_ins_cls_label)
                gt_ins_cls_label = gt_ins_cls_label[torch.unique(instance_label_few)]
                
                if self.ins_cls_mapping is not None:
                    gt_ins_cls_label = self.ins_cls_mapping[gt_ins_cls_label.long()]
                
                if self.ins_cls_chosen is not None:
                    cls_mask = gt_ins_cls_label==1
                    for cls_i in self.ins_cls_chosen:
                        cls_mask = cls_mask | (gt_ins_cls_label==cls_i) 
                    gt_ins_cls_label = gt_ins_cls_label[cls_mask] - self.ins_cls_chosen[0] + 1 
                    gt_ins_cls_label[gt_ins_cls_label<0] = 0 
                else:
                    cls_mask = (gt_ins_cls_label>0) & (gt_ins_cls_label<=self.n_ins_cls)
                    gt_ins_cls_label = gt_ins_cls_label[cls_mask] - 1 
                
                #cls_mask = (gt_ins_cls_label>=0) & (gt_ins_cls_label<self.n_ins_cls)
                #gt_ins_cls_label = gt_ins_cls_label[cls_mask] ######################
                ins_cls_pred = ins_cls_pred[cls_mask]
                batch_dict['ins_cls_pred'] = ins_cls_pred
                batch_dict['ins_cls_label'] = gt_ins_cls_label
         
        #if 0:
        if self.add_generator:
            """
            seg_label_b = ret_dict['pred_seg_cls_labels'].cpu().numpy()
            points_b = batch_dict[self.batch_key].cpu().numpy()
            instance_label_b = np.zeros(batch_dict['voxel_instance_label'].shape, np.int32)

            cur_num = 1            
            ins_cls_list = [1,2,3,4]
            
            for b_i in range(batch_dict['batch_size']):
                b_ids = torch.nonzero(batch_dict[self.batch_key][:, 0]==b_i , as_tuple=True)[0]
                b_ids = b_ids.cpu().numpy() 
                seg_label_b_i = seg_label_b[b_ids]
                points_b_i = points_b[b_ids]

                for ii in ins_cls_list:
                    seg_mask = seg_label_b_i==ii
                    seg_ids = (seg_mask > 0).nonzero()[0]
    
                    if len(seg_ids)<10:
                        continue
    
                    masked_points = points_b_i[seg_ids, :]
                    num_points = masked_points.shape[0]
    
                    tree = NN(n_neighbors=10).fit(masked_points)
                    dists, indices = tree.kneighbors(masked_points)
                    e0 = np.arange(num_points).repeat(10)
                    e1 = indices.reshape(-1)
                    
                    dis_thres = 1.0
                    mask = dists.reshape(-1) < dis_thres
                    #mask = dists.reshape(-1) / (prange + 1e-6) < dist_th
                    e0, e1 = e0[mask], e1[mask]
    
                    graph = csr_matrix((np.ones_like(e0), (e0, e1)), shape=(num_points, num_points))
                    n_components, labels = connected_components(graph, directed=False)
                    chosen_labels = np.unique(labels) 
                    if len(chosen_labels)>4:
                        chosen_ids = np.random.permutation(len(chosen_labels))[:3]
                        chosen_labels = chosen_labels[chosen_ids]
                    for lb in chosen_labels:
                        ids = seg_ids[(labels==lb).nonzero()[0]]
                        if len(ids)<100:
                            continue
                        instance_label_b[b_ids[ids]] = cur_num
                        cur_num += 1
    
            batch_dict['instance_label_pred'] = torch.from_numpy(instance_label_b).long().cuda()
            """
            if self.n_ins_cls == 13:
                ins_cls_list = [1,2,3,4] 
            else:
                ins_cls_list = [1,4,5]
            #cur_num = batch_dict['voxel_instance_label'].max()
            cur_num = instance_label.max() 
            total_cdis = 0.0 
            
            total_points_input = torch.zeros(len(ins_cls_list)*3, self.n_points_gen, self.n_channel_gen_input+1).cuda() 
            total_points_gt = torch.zeros(len(ins_cls_list)*3, self.n_points_gen, 3).cuda() 
            
            cls_cnt = torch.zeros(23)
            total_cnt = 0 
            rand_n = np.random.randint(0, 100)
            
            total_cnt_tmp1 = 0
            total_cnt_tmp2 = 0
            total_cnt_tmp3 = 0 

            for ii in range(1, cur_num):
                ins_mask_pred = instance_label==ii
                if ins_mask_pred.sum()==0:
                    continue
                ins_label_gt = torch.mode(instance_label[ins_mask_pred])[0]
                seg_label_gt_ii = torch.mode(batch_dict['voxel_segmentation_label'][ins_mask_pred])[0]

                if seg_label_gt_ii in ins_cls_list:
                    total_cnt_tmp1 +=1 
                if ins_mask_pred.sum()>200:
                    total_cnt_tmp2 += 1 
                if seg_label_gt_ii in ins_cls_list and ins_mask_pred.sum()>200:
                    total_cnt_tmp3 += 1 

            for ii in range(1, cur_num):
                #ins_mask_pred =  batch_dict['instance_label_pred']==ii 
                ins_mask_pred = instance_label==ii 
                if ins_mask_pred.sum()==0:
                    continue

                ins_label_gt = torch.mode(instance_label[ins_mask_pred])[0]
                seg_label_gt_ii = torch.mode(batch_dict['voxel_segmentation_label'][ins_mask_pred])[0]
                
                if total_cnt_tmp3>0:
                    if ins_mask_pred.sum()<200:
                        continue 
                    if seg_label_gt_ii not in ins_cls_list:
                        continue
                elif total_cnt_tmp1>0:
                    if seg_label_gt_ii not in ins_cls_list:
                        continue
                elif total_cnt_tmp2>0:
                    if ins_mask_pred.sum()<200:
                        continue
                else:
                    pass 
                if cls_cnt[seg_label_gt_ii]>=3:
                    continue 
                cls_cnt[seg_label_gt_ii] += 1 
                #print(cls_cnt[:5], total_cnt)

                points_gt = batch_dict[self.batch_key][instance_label==ins_label_gt][:, 1:]
                points_input = batch_dict[self.batch_key][ins_mask_pred][:, 1:]
                points_mean = points_input.mean(0).unsqueeze(0)
                points_input -= points_mean 
                points_gt -= points_mean 
                feature_input = point_features[ins_mask_pred]
                
                # mask point input 
                mask_idx = torch.randint(0, points_input.shape[0], size=(1,))[0]
                mask_center = points_input[mask_idx]
                radius = torch.rand(2)[0]*0.8
                mask_dis = torch.sqrt(torch.square(points_input - mask_center).sum(1))>radius
                if mask_dis.sum()<100:
                    mask_dis[:100]=True 
                points_input = points_input[mask_dis]
                feature_input = feature_input[mask_dis]
                
                if 0:
                    from ...utils.vis_utils import write_ply_color
                    import os
                    out_path = 'tmp_refine_vis/gen_points_kitti'
                    if not os.path.exists(out_path):
                        os.makedirs(out_path)
                    name = str(rand_n) + '_' + str(ii)
                    print(points_gt.shape, points_input.shape)
                    print(name, points_gt.max(0)[0], points_gt.min(0)[0])
                    write_ply_color(points_gt.cpu().numpy(), np.ones(points_gt.shape[0], np.int32)%20, os.path.join(out_path, '%s_gt_seg.ply'%(name)))
                    write_ply_color(points_input.cpu().numpy(), np.ones(points_input.shape[0], np.int32)%20, os.path.join(out_path, '%s_in_seg.ply'%(name)))
                    print('saved')

                gen_input = torch.cat((points_input, feature_input), 1)
                ids = torch.randperm(gen_input.shape[0])[:self.n_points_gen]
                gen_input = gen_input[ids, :]
                if points_gt.shape[0]<self.n_points_gen:
                    new_gt = torch.zeros(self.n_points_gen, 3)
                    n_gt = self.n_points_gen//points_gt.shape[0]
                    for i_gt in range(n_gt):
                        new_gt[i_gt*points_gt.shape[0]: (i_gt+1)*points_gt.shape[0]] = points_gt 

                    n_add = self.n_points_gen - n_gt * points_gt.shape[0] 
                    if n_add!=0:
                        add_ids = torch.randperm(points_gt.shape[0])[:n_add]
                        new_gt[-n_add:] = points_gt[add_ids]
                    points_gt = new_gt.clone()
                else:
                    ids = torch.randperm(points_gt.shape[0])[:self.n_points_gen]
                    points_gt = points_gt[ids]

                total_points_input[total_cnt, :gen_input.shape[0], :-1] = gen_input 
                total_points_input[total_cnt, :gen_input.shape[0], -1] = 1 
                total_points_gt[total_cnt, ] = points_gt 
                total_cnt += 1 
                if total_cnt>8:
                    break

                if 0:
                    from ...utils.vis_utils import write_ply_color
                    import os
                    out_path = 'tmp_refine_vis/gen_points_kitti1'
                    if not os.path.exists(out_path):
                        os.makedirs(out_path)
                    name = str(rand_n) + '_' + str(ii)
                    print(points_gt.shape, points_input.shape)
                    print(name, points_gt.max(0)[0], points_gt.min(0)[0])
                    write_ply_color(gen_input[:, :3].detach().cpu().numpy(), np.ones(gen_input.shape[0], np.int32)%20, os.path.join(out_path, '%s_in_seg.ply'%(name)))
                    write_ply_color(points_gt.cpu().numpy(), np.ones(points_gt.shape[0], np.int32)%20, os.path.join(out_path, '%s_gt_seg.ply'%(name)))
                    print('saved')
            
            total_points_input = total_points_input[:total_cnt]
            total_points_gt = total_points_gt[:total_cnt]

            if total_cnt ==1:
                total_points_input = total_points_input.repeat(2, 1, 1)
                total_points_gt = total_points_gt.repeat(2, 1, 1)
            
            if total_cnt>0:
                gen_latent = self.gen_encoder(total_points_input) 
                gen_output = self.gen_decoder(gen_latent, total_points_input)
                total_cdis = self.chamferDist(gen_output, total_points_gt, bidirectional=True)
            else:
                
                total_points_input = torch.zeros(2, self.n_points_gen, self.n_channel_gen_input+1).cuda()
                total_points_gt = torch.zeros(2, self.n_points_gen, 3).cuda()
                gen_latent = self.gen_encoder(total_points_input)
                gen_output = self.gen_decoder(gen_latent, total_points_input)
                total_cdis = self.chamferDist(gen_output, total_points_gt, bidirectional=True)
                #total_cdis = 0.0
            
            """
            if total_cnt >= 2:
                total_points_input = total_points_input[:total_cnt]
                total_points_gt = total_points_gt[:total_cnt]
                gen_latent = self.gen_encoder(total_points_input)
                gen_output = self.gen_decoder(gen_latent, total_points_input) 
                total_cdis = self.chamferDist(gen_output, total_points_gt, bidirectional=True)  
            else:
                 
                #total_points_input = torch.zeros(4, self.n_points_gen, self.n_channel_gen_input+1).cuda()
                #total_points_gt = torch.zeros(4, self.n_points_gen, 3).cuda()
                #gen_latent = self.gen_encoder(total_points_input)
                #gen_output = self.gen_decoder(gen_latent, total_points_input)
                #total_cdis = self.chamferDist(gen_output, total_points_gt, bidirectional=True)
            """
            batch_dict['total_cdis'] = total_cdis 
        
        else:
            total_points_input = torch.zeros(4, self.n_points_gen, self.n_channel_gen_input+1).cuda()
            total_points_gt = torch.zeros(4, self.n_points_gen, 3).cuda()
            gen_latent = self.gen_encoder(total_points_input)
            gen_output = self.gen_decoder(gen_latent, total_points_input)

            total_cdis = self.chamferDist(gen_output, total_points_gt, bidirectional=True)
            batch_dict['total_cdis'] = total_cdis


        if self.target_assigner_cfg is not None:
            self.assign_targets(self.target_assigner_cfg, batch_dict)

        if self.gt_seg_cls_label_key in batch_dict:
            ret_dict[self.gt_seg_cls_label_key] = batch_dict[self.gt_seg_cls_label_key]
        
        ret_dict['batch_idx'] = batch_dict[self.batch_key][:, 0].round().long()
        ret_dict['point_bxyz'] = batch_dict[self.batch_key]
        ret_dict['instance_label_back'] = batch_dict['instance_label_back']
        
        if self.assign_to_point and (not self.training):
            # assign pred_seg_cls_labels to points
            ref_bxyz = batch_dict[self.batch_key]
            ref_labels = ret_dict['pred_seg_cls_labels']
            query_bxyz = batch_dict['point_bxyz']

            e_ref, e_query = self.graph(ref_bxyz, query_bxyz)
            new_ret_dict = {}
            for key in ret_dict.keys():
                new_ret_dict[key] = scatter(ret_dict[key][e_ref], e_query, dim=0,
                                            dim_size=query_bxyz.shape[0], reduce='max')

            new_ret_dict['point_bxyz'] = batch_dict['point_bxyz']
            ret_dict = new_ret_dict
        ret_dict['batch_size'] = batch_dict['batch_size']
        if self.ins_cls:
            ret_dict['ins_cls_label'] = batch_dict['ins_cls_label']
            ret_dict['ins_cls_pred'] = batch_dict['ins_cls_pred']
        elif self.training:
            ret_dict['ins_cls_pred'] = torch.zeros(1,7).cuda()
            ret_dict['ins_cls_label'] = torch.zeros(1).cuda()
        if 'total_cdis' in batch_dict.keys():
            ret_dict['total_cdis'] = batch_dict['total_cdis']

        self.forward_ret_dict = ret_dict
        return batch_dict


