import torch
import os
import glob
import tqdm
import numpy as np
import torch.distributed as dist
from pcdet.config import cfg
from pcdet.models import load_data_to_gpu
from pcdet.utils import common_utils, commu_utils, memory_ensemble_utils, dynamic_threshold
import pickle as pkl
import re
from multiprocessing import Manager
from pcdet.models.model_utils.dsnorm import set_ds_target
import copy
import wandb

PSEUDO_LABELS = {}
PSEUDO_LABELS = Manager().dict()

# TEACHER_PSEUDO_LABELS = {}
# TEACHER_PSEUDO_LABELS = Manager().dict()
# TMP_TEACHER_PSEUDO_LABELS = {}

NEW_PSEUDO_LABELS = {}

TMP_PSEUDO_LABELS = {}

# TARGET_PSEUDO_LABELS = {}
# TARGET_PSEUDO_LABELS = Manager().dict()
# if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('TARGET_LEARNING', None):
#     TARGET_PSEUDO_LABELS = {}
#     TARGET_PSEUDO_LABELS = Manager().dict()
# from self_training_utils_mine import vis
# def commu_all_gather():
#     commu_utils.synchronize()

#     if dist.is_initialized():
#         part_pseudo_labels_list = commu_utils.all_gather(TARGET_PSEUDO_LABELS)

#         new_pseudo_label_dict = {}
#         for pseudo_labels in part_pseudo_labels_list:
#             new_pseudo_label_dict.update(pseudo_labels)

#         TARGET_PSEUDO_LABELS.update(new_pseudo_label_dict)
#     return None

def check_already_exsit_pseudo_label(ps_label_dir, start_epoch):
    """
    if we continue training, use this to directly
    load pseudo labels from exsiting result pkl

    if exsit, load latest result pkl to PSEUDO LABEL
    otherwise, return false and

    Args:
        ps_label_dir: dir to save pseudo label results pkls.
        start_epoch: start epoc
    Returns:

    """
    # support init ps_label given by cfg
    if start_epoch == 0 and cfg.SELF_TRAIN.get('INIT_PS', None):
        if os.path.exists(cfg.SELF_TRAIN.INIT_PS):
            init_ps_label = pkl.load(open(cfg.SELF_TRAIN.INIT_PS, 'rb'))
            PSEUDO_LABELS.update(init_ps_label)

            if cfg.LOCAL_RANK == 0:
                ps_path = os.path.join(ps_label_dir, "ps_label_e0.pkl")
                with open(ps_path, 'wb') as f:
                    pkl.dump(PSEUDO_LABELS, f)

            return cfg.SELF_TRAIN.INIT_PS

    ps_label_list = glob.glob(os.path.join(ps_label_dir, 'ps_label_e*.pkl'))
    if len(ps_label_list) == 0:
        return

    ps_label_list.sort(key=os.path.getmtime, reverse=True)
    for cur_pkl in ps_label_list:
        num_epoch = re.findall('ps_label_e(.*).pkl', cur_pkl)
        assert len(num_epoch) == 1

        # load pseudo label and return
        if int(num_epoch[0]) <= start_epoch:
            latest_ps_label = pkl.load(open(cur_pkl, 'rb'))
            PSEUDO_LABELS.update(latest_ps_label)
            return cur_pkl

    return None


def save_pseudo_label_epoch(model, val_loader, rank, leave_pbar, ps_label_dir, cur_epoch):
    """
    Generate pseudo label with given model.

    Args:
        model: model to predict result for pseudo label
        val_loader: data_loader to predict pseudo label
        rank: process rank
        leave_pbar: tqdm bar controller
        ps_label_dir: dir to save pseudo label
        cur_epoch
    """
    val_dataloader_iter = iter(val_loader)
    total_it_each_epoch = len(val_loader)

    if rank == 0:
        pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar,
                         desc='generate_ps_e%d' % cur_epoch, dynamic_ncols=True)

    pos_ps_nmeter = common_utils.NAverageMeter(len(cfg.CLASS_NAMES))
    ign_ps_nmeter = common_utils.NAverageMeter(len(cfg.CLASS_NAMES))

    if cfg.SELF_TRAIN.get('DSNORM', None):
        model.apply(set_ds_target)

    model.eval()

    if cfg.SELF_TRAIN.get('DYNAMIC_THRESHOLD', None):
        target_dataloader_iter = iter(val_loader)
        total_it_each_epoch = len(val_loader)
        commu_utils.synchronize()
        pred_scores=pred_labels = np.array([])
        for cur_it in range(total_it_each_epoch):
            target_batch = next(target_dataloader_iter)

            # generate gt_boxes for target_batch and update model weights
            with torch.no_grad():
                load_data_to_gpu(target_batch)
                pred_dicts, ret_dict = model(target_batch)
            # batch_size里面所有的score都放到一起
            batch_size = len(pred_dicts)
            for b_idx in range(batch_size):
                if 'pred_scores' in pred_dicts[b_idx]:
                    pred_scores = np.concatenate((pred_scores, pred_dicts[b_idx]['pred_scores'].detach().cpu().numpy()))
                if 'pred_labels' in pred_dicts[b_idx]:
                    pred_labels = np.concatenate((pred_labels, pred_dicts[b_idx]['pred_labels'].detach().cpu().numpy()))
        
        TMP_PSEUDO_LABELS['pred_labels'] = pred_labels
        TMP_PSEUDO_LABELS['pred_scores'] = pred_scores
        commu_utils.synchronize()
        if dist.is_initialized():
            part_pseudo_labels_list = commu_utils.all_gather(TMP_PSEUDO_LABELS)

            new_pseudo_label_dict = {}
            for pseudo_labels in part_pseudo_labels_list:
                for key in pseudo_labels.keys():
                    if key not in new_pseudo_label_dict.keys():
                        new_pseudo_label_dict[key] = pseudo_labels[key]
                    else:
                        new_pseudo_label_dict[key] = np.concatenate((new_pseudo_label_dict[key], pseudo_labels[key]))

            TMP_PSEUDO_LABELS.update(new_pseudo_label_dict)
        
        pred_scores = TMP_PSEUDO_LABELS['pred_scores']
        pred_labels = TMP_PSEUDO_LABELS['pred_labels']
        threshold_pool = {'Car': [0.99, 0.99]}
        all_final_score = {}
        class_names = val_loader.dataset.class_names
        for category in threshold_pool.keys():
            all_final_score[category] = []
        for i in range(len(pred_labels)):
            category = class_names[int(pred_labels[i]) - 1]
            if pred_labels[i] > 0:
                all_final_score[category].append(pred_scores[i])

        for category in all_final_score.keys():
            if len(all_final_score[category]) > 2:
                if cfg.SELF_TRAIN.DYNAMIC_THRESHOLD.METHOD == 'jnb':
                    threshold_pool[category][0], threshold_pool[category][1] = dynamic_threshold.get_jnb_threshold(all_final_score[category])
                elif cfg.SELF_TRAIN.DYNAMIC_THRESHOLD.METHOD == 'gmm':
                    threshold_pool[category][0], threshold_pool[category][1] = dynamic_threshold.get_gmm_threshold(all_final_score[category])
                elif cfg.SELF_TRAIN.DYNAMIC_THRESHOLD.METHOD == 'bmm':
                    threshold_pool[category][0], threshold_pool[category][1] = dynamic_threshold.get_bmm_threshold(all_final_score[category])
                elif cfg.SELF_TRAIN.DYNAMIC_THRESHOLD.METHOD == 'topk':
                    threshold = cfg.SELF_TRAIN.DYNAMIC_THRESHOLD.THRESHOLD * 100
                    threshold_pool[category][0] = np.percentile(all_final_score[category], threshold)
                    threshold_pool[category][1] = np.percentile(all_final_score[category], threshold)
        print("threshold_pool: ")
        print(threshold_pool)
        for category in class_names:
            cfg.SELF_TRAIN.NEG_THRESH = [threshold_pool[category][1]]
            cfg.SELF_TRAIN.SCORE_THRESH = [threshold_pool[category][0]]

        commu_utils.synchronize()
        TMP_PSEUDO_LABELS.clear()
        torch.cuda.empty_cache()        

    # if cfg.SELF_TRAIN.get('CONSISTENCY', None):
    #     TMP_TEACHER_PSEUDO_LABELS.clear()
    #     TEACHER_PSEUDO_LABELS.clear()

    for cur_it in range(total_it_each_epoch):
        try:
            target_batch = next(val_dataloader_iter)
        except StopIteration:
            target_dataloader_iter = iter(val_loader)
            target_batch = next(target_dataloader_iter)

        # generate gt_boxes for target_batch and update model weights
        with torch.no_grad():
            load_data_to_gpu(target_batch)
            if isinstance(target_batch, list):
                target_batch = target_batch[-1]
            
            # if cfg.SELF_TRAIN.get('CONSISTENCY', None):
            #     target_batch_raw = copy.deepcopy(target_batch)

            #     batch_dict = model(target_batch_raw, return_batch_dict=True)
                
                
            #     batch_size = target_batch['gt_boxes'].shape[0]
            #     for b_idx in range(batch_size):
            #         batch_dict_teacher = {}
            #         for key in ['rois_mt', 'roi_head_features_mt', 'roi_scores_mt', 'roi_iou_scores_mt']:
            #             batch_dict_teacher[key] = batch_dict[key].detach().cpu().numpy()
            #         TMP_TEACHER_PSEUDO_LABELS[target_batch_raw['frame_id'][b_idx]] = batch_dict_teacher

            pred_dicts, ret_dict = model(target_batch)

        pos_ps_batch_nmeters, ign_ps_batch_nmeters = save_pseudo_label_batch(
            target_batch, pred_dicts=pred_dicts,
            need_update=(cfg.SELF_TRAIN.get('MEMORY_ENSEMBLE', None) and
                         cfg.SELF_TRAIN.MEMORY_ENSEMBLE.ENABLED and
                         cur_epoch > 0),
            val_loader=val_loader,
            model=model
        )

        # log to console and tensorboard
        pos_ps_nmeter.update(pos_ps_batch_nmeters)
        ign_ps_nmeter.update(ign_ps_batch_nmeters)
        pos_ps_result = pos_ps_nmeter.aggregate_result()
        ign_ps_result = ign_ps_nmeter.aggregate_result()

        disp_dict = {'pos_ps_box': pos_ps_result, 'ign_ps_box': ign_ps_result}

        if rank == 0:
            pbar.update()
            pbar.set_postfix(disp_dict)
            pbar.refresh()

    if rank == 0:
        pbar.close()

        if cfg.SELF_TRAIN.get('PROGRESSIVE_SAMPLING', None) and cfg.SELF_TRAIN.PROGRESSIVE_SAMPLING.ENABLE and cur_epoch != cfg.OPTIMIZATION.NUM_EPOCHS:
            gt_reduce = cfg.SELF_TRAIN.PROGRESSIVE_SAMPLING.GT_REDUCE
            ps_grow = cfg.SELF_TRAIN.PROGRESSIVE_SAMPLING.PS_GROW
            if cfg.SELF_TRAIN.get('PS_SAMPLING', None):
                for k in cfg.SELF_TRAIN.PS_SAMPLING.SAMPLE_GROUPS:
                    cfg.SELF_TRAIN.PS_SAMPLING.SAMPLE_GROUPS[k] += ps_grow
            if cfg.DATA_CONFIG_TAR.DATA_AUGMENTOR.AUG_CONFIG_LIST[0].NAME == 'gt_sampling' and \
                    'gt_sampling' not in cfg.DATA_CONFIG_TAR.DATA_AUGMENTOR.DISABLE_AUG_LIST:
                new_sample_groups = []
                for i in cfg.DATA_CONFIG_TAR.DATA_AUGMENTOR.AUG_CONFIG_LIST[0].SAMPLE_GROUPS:
                    old_sample_num = int(i.split(":")[-1])
                    if old_sample_num - gt_reduce < 0:
                        new_sample_num = str(0)
                    else:
                        new_sample_num = str(int(i.split(":")[-1])-gt_reduce)
                    new_i = i.split(":")[0] + ':' + new_sample_num
                    new_sample_groups.append(new_i)
                cfg.DATA_CONFIG_TAR.DATA_AUGMENTOR.AUG_CONFIG_LIST[0].SAMPLE_GROUPS = new_sample_groups


    gather_and_dump_pseudo_label_result(rank, ps_label_dir, cur_epoch)


def gather_and_dump_pseudo_label_result(rank, ps_label_dir, cur_epoch):
    commu_utils.synchronize()

    if dist.is_initialized():
        part_pseudo_labels_list = commu_utils.all_gather(NEW_PSEUDO_LABELS)

        new_pseudo_label_dict = {}
        for pseudo_labels in part_pseudo_labels_list:
            new_pseudo_label_dict.update(pseudo_labels)

        NEW_PSEUDO_LABELS.update(new_pseudo_label_dict)
        
        # if cfg.SELF_TRAIN.get('CONSISTENCY', None):

        # # if dist.is_initialized():
        #     part_pseudo_labels_list = commu_utils.all_gather(TMP_TEACHER_PSEUDO_LABELS)

        #     new_pseudo_label_dict = {}
        #     for pseudo_labels in part_pseudo_labels_list:
        #         new_pseudo_label_dict.update(pseudo_labels)

        #     # new_pseudo_label_dict = {}
        #     # for pseudo_labels in part_pseudo_labels_list:
        #     #     for key in pseudo_labels.keys():
        #     #         if key not in new_pseudo_label_dict.keys():
        #     #             new_pseudo_label_dict[key] = pseudo_labels[key]
        #     #         else:
        #     #             new_pseudo_label_dict[key] = np.concatenate((new_pseudo_label_dict[key], pseudo_labels[key]))

        #     TEACHER_PSEUDO_LABELS.update(new_pseudo_label_dict)

    if cfg.SELF_TRAIN.get('PS_SAMPLING',None) and \
                    cfg.SELF_TRAIN.PS_SAMPLING.ENABLE:

        cfg.SELF_TRAIN.PS_SAMPLING.PS_OBJECT_PATH = ps_label_dir
    # dump new pseudo label to given dir
    if rank == 0:
        ps_path = os.path.join(ps_label_dir, "ps_label_e{}.pkl".format(cur_epoch))
        with open(ps_path, 'wb') as f:
            pkl.dump(NEW_PSEUDO_LABELS, f)

        if cfg.SELF_TRAIN.get('PS_SAMPLING',None) and \
                        cfg.SELF_TRAIN.PS_SAMPLING.ENABLE:

            cfg.SELF_TRAIN.PS_SAMPLING.PS_OBJECT_PATH = ps_label_dir
            ps_point_path = os.path.join(ps_label_dir,
                                        "ps_point_e{}.pkl".format(cur_epoch))
            ps_box_path = os.path.join(ps_label_dir,
                                        "ps_box_e{}.pkl".format(cur_epoch))
            class_names = cfg.CLASS_NAMES
            ps_boxes_dict = {class_name: [] for class_name in class_names}
            ps_points_dict = {class_name: [] for class_name in class_names}

            if cfg.SELF_TRAIN.get('DIVERSITY_SAMPLING', None):
                ps_related_counts_dict = {class_name: [] for class_name in class_names}

            for frame_id, frame in NEW_PSEUDO_LABELS.items():
                ps_boxes = frame['gt_boxes']
                if 'quality_mask' in frame.keys():
                    quality_mask = frame['quality_mask'].astype(np.bool_)
                else:
                    import numpy as np
                    quality_mask = np.ones(ps_boxes.shape[0]).astype(np.bool_)
                pos_ps_mask = ps_boxes[:, -2] > 0
                quality_mask = quality_mask[pos_ps_mask]
                ps_boxes = ps_boxes[pos_ps_mask]
                ps_labels = ps_boxes[:, -2]
                ps_scores = ps_boxes[:, -1]
                if cfg.SELF_TRAIN.get('DIVERSITY_SAMPLING', None):
                    related_boxes_count = frame['related_box_count'][pos_ps_mask]

                ps_points_list = []
                for mask_idx in range(pos_ps_mask.shape[0]):
                    if pos_ps_mask[mask_idx]:
                        ps_points_list.append(frame['gt_points'][mask_idx])

                remain_mask = quality_mask

                remain_ps_points_list = []
                for mask_idx in range(remain_mask.shape[0]):
                    if remain_mask[mask_idx]:
                        remain_ps_points_list.append(ps_points_list[mask_idx])

                ps_boxes = ps_boxes[remain_mask]
                ps_labels = ps_labels[remain_mask]
                if cfg.SELF_TRAIN.get('DIVERSITY_SAMPLING', None):
                    related_boxes_count = related_boxes_count[remain_mask]

                # ps_points = ps_points[remain_mask]
                ps_names = [class_names[int(class_id - 1)]
                            for class_id in ps_labels.tolist()]
                for inx in range(len(ps_names)):
                    ps_boxes_dict[ps_names[inx]].append(ps_boxes[inx])
                    ps_points_dict[ps_names[inx]].append(remain_ps_points_list[inx])
                    if cfg.SELF_TRAIN.get('DIVERSITY_SAMPLING', None):
                        ps_related_counts_dict[ps_names[inx]].append(related_boxes_count[inx])

                """ Print how many ps-labels for sampling """
            print('Distribution of PS-Labels for sampling at epoch {}:'.format(cur_epoch))
            for cls in ps_boxes_dict.keys():
                print('# of ps labels of class {} is {}'.format(cls, len(ps_boxes_dict[cls])))
                key = 'PS-Label before div-sampling/' + cls
                wandb.log({key: len(ps_boxes_dict[cls])})

            """------------------ OBC-based Diversity Down-sampling ------------------"""
            if cfg.SELF_TRAIN.get('DIVERSITY_SAMPLING', None):
                ps_diverse_path = os.path.join(ps_label_dir, "ps_diverse_e{}.pkl".format(cur_epoch))
                for class_name in cfg.SELF_TRAIN.DIVERSITY_SAMPLING.SAMPLE_CLASSES:
                    ps_boxes = np.stack(ps_boxes_dict[class_name])
                    related_box_counts = np.stack(ps_related_counts_dict[class_name])
                    kde = scipy.stats.gaussian_kde(related_box_counts.T)
                    p = 1 / kde.pdf(related_box_counts.T)
                    p /= np.sum(p)
                    sample_size = int(len(related_box_counts) / cfg.SELF_TRAIN.DIVERSITY_SAMPLING.DOWNSAMPLE_RATE)
                    sample_idx = np.random.choice(np.arange(len(related_box_counts)), size=sample_size, replace=False, p=p)
                    sampled_boxes = ps_boxes[sample_idx].astype(np.float32)[:, :7]  # (ps box num, 7)
                    sampled_pnts = [ps_points_dict[class_name][i] for i in sample_idx]

                    ps_boxes_dict[class_name] = sampled_boxes
                    ps_points_dict[class_name] = sampled_pnts
                    ps_related_counts_dict[class_name]= related_box_counts[sample_idx]

                with open(ps_diverse_path, 'wb') as f:
                    pkl.dump(ps_related_counts_dict, f)
                """ Print how many ps-labels for sampling """
                print('Distribution of PS-Labels for DIVERSE sampling at epoch {}:'.format(cur_epoch))
                for cls in ps_boxes_dict.keys():
                    print('# of ps labels of class {} is {}'.format(cls, len(ps_boxes_dict[cls])))
                    key = 'PS-Label after div-sampling/' + cls
                    wandb.log({key: len(ps_boxes_dict[cls])})


            with open(ps_box_path, 'wb') as f:
                pkl.dump(ps_boxes_dict, f)
            with open(ps_point_path, 'wb') as f:
                pkl.dump(ps_points_dict, f)

    if cfg.SELF_TRAIN.get('BDA', None):
        
        old_dist = cfg.SELF_TRAIN.BDA.DIST
        distribution = {}
        distribution[0] = 0
        distribution[1] = 0
        distribution[2] = 0            


        import pickle
        
        # pkl_path = os.path.join(ps_label_dir, "ps_label_e{}.pkl".format(cur_epoch))
        # with open(pkl_path, 'rb') as f:
        data = NEW_PSEUDO_LABELS
        for key, value in data.items():
            pred_box = value['gt_boxes']
            for box in pred_box:
                if box[-2] > -1:
                    distribution[box[-2] - 1] += 1

        # if sum(old_dist) < 1:
        #     cfg.SELF_TRAIN.BDA.DIST = distribution.tolist()
        # else:
        #     cfg.SELF_TRAIN.BDA.DIST = (np.array(old_dist) + np.array(list(distribution.values()))).tolist()
        import numpy as np
        distribution = np.array([i for i in distribution.values()])



        # 平滑项
        epsilon = 1e-6

        # 平滑并归一化 target_distribution
        distribution = (distribution + epsilon) / (np.sum(distribution) + epsilon * len(distribution))
        
        old_dist = np.array(old_dist)
        alpha = cfg.SELF_TRAIN.BDA.ALPHA
        distribution = alpha * old_dist + (1 - alpha) * distribution
    
        cfg.SELF_TRAIN.BDA.DIST = list(distribution)


            # cur_save_path = os.path.join(ps_label_dir, "ps_label_e{}_distribution.txt".format(cur_epoch))
            # with open(cur_save_path, 'a+') as f:
            #     # f.write("distribution of pseudo label\n")
            #     f.write(str(distribution))
            

    commu_utils.synchronize()

    if cfg.SELF_TRAIN.get('PLS', None):
        stat_subjects = ["length", "width", "height"]

        import numpy as np
        import pickle
        old_stats = np.array([cfg.SELF_TRAIN.PLS[dim] for dim in stat_subjects])
        
        ps_path = os.path.join(ps_label_dir, "ps_label_e{}.pkl".format(cur_epoch))
        with open(ps_path, 'rb') as f:
            data = pickle.load(f)
        gt_boxes_list = []
        for i in range(len(data)):
            gt_boxes_list.extend(list(data.values())[i]['gt_boxes'][:, 3:6])
        # l, w, h
        import numpy as np
        gt_boxes_avg = np.mean(gt_boxes_list, axis=0)
        new_stats = gt_boxes_avg.reshape(1,3)
        new_scale = new_stats - old_stats
        if 'normalize_object_size' in cfg.DATA_CONFIG.DATA_AUGMENTOR.DISABLE_AUG_LIST:
            cfg.DATA_CONFIG.DATA_AUGMENTOR.DISABLE_AUG_LIST.remove('normalize_object_size')
            if 'random_object_scaling' not in cfg.DATA_CONFIG.DATA_AUGMENTOR.DISABLE_AUG_LIST:
                cfg.DATA_CONFIG.DATA_AUGMENTOR.DISABLE_AUG_LIST.append('random_object_scaling')
        cfg.DATA_CONFIG.DATA_AUGMENTOR.AUG_CONFIG_LIST[1]['SIZE_RES'] = new_scale.squeeze().tolist()
    
    commu_utils.synchronize()
    PSEUDO_LABELS.clear()
    PSEUDO_LABELS.update(NEW_PSEUDO_LABELS)
    NEW_PSEUDO_LABELS.clear()
    # if cfg.SELF_TRAIN.get('CONSISTENCY', None):
    #     TMP_TEACHER_PSEUDO_LABELS.clear()
        


def save_pseudo_label_batch(input_dict,
                            pred_dicts=None,
                            need_update=True,
                            val_loader=None,
                            model=None):
    """
    Save pseudo label for give batch.
    If model is given, use model to inference pred_dicts,
    otherwise, directly use given pred_dicts.

    Args:
        input_dict: batch data read from dataloader
        pred_dicts: Dict if not given model.
            predict results to be generated pseudo label and saved
        need_update: Bool.
            If set to true, use consistency matching to update pseudo label
    """
    pos_ps_nmeter = common_utils.NAverageMeter(len(cfg.CLASS_NAMES))
    ign_ps_nmeter = common_utils.NAverageMeter(len(cfg.CLASS_NAMES))

    batch_size = len(pred_dicts)
    for b_idx in range(batch_size):
        pred_cls_scores = pred_iou_scores = None
        if 'pred_boxes' in pred_dicts[b_idx]:
            # Exist predicted boxes passing self-training score threshold
            pred_boxes = pred_dicts[b_idx]['pred_boxes'].detach().cpu().numpy()
            pred_labels = pred_dicts[b_idx]['pred_labels'].detach().cpu().numpy()
            pred_scores = pred_dicts[b_idx]['pred_scores'].detach().cpu().numpy()
            if 'pred_cls_scores' in pred_dicts[b_idx]:
                pred_cls_scores = pred_dicts[b_idx]['pred_cls_scores'].detach().cpu().numpy()
            if 'pred_iou_scores' in pred_dicts[b_idx]:
                pred_iou_scores = pred_dicts[b_idx]['pred_iou_scores'].detach().cpu().numpy()

            if cfg.SELF_TRAIN.get('NO_FILTER', None):
                gt_box = np.concatenate((pred_boxes,
                                         pred_labels.reshape(-1, 1),
                                         pred_scores.reshape(-1, 1)), axis=1)
            elif cfg.SELF_TRAIN.get('AUG', None):
                from pcdet.datasets.augmentor.augmentor_utils import global_scaling
                batch_points = \
                    input_dict['points'][
                        input_dict['points'][:, 0] == b_idx][:,1:].cpu().numpy()
                
                single_pc_pnts = batch_points.copy()

                src_uniform = np.array([1.00, 1.05])
                noise_scale = np.random.uniform(src_uniform[0], src_uniform[1])
                single_pc_pnts *= noise_scale

                
                # from pcdet.datasets.augmentor.data_augmentor import get_polar_image, label_point_cloud_beam
                
                config = cfg.DATA_CONFIG_TAR.DATA_AUGMENTOR.AUG_CONFIG_LIST[0]
                points = single_pc_pnts
                if 'beam_labels' in input_dict: # for waymo and kitti datasets
                    beam_label = input_dict['beam_labels'][input_dict['beam_labels'][:, 0] == b_idx][:, 1].cpu().numpy()
                    beam_label = beam_label.astype(np.int32)

                # else:
                #     polar_image = get_polar_image(points)
                #     beam_label = label_point_cloud_beam(polar_image, points, config['BEAM'])

                beam_mask = np.random.rand(config['BEAM']) < config['BEAM_PROB']
                points_mask = beam_mask[beam_label]
                single_pc_pnts = points[points_mask]
                # - NAME: random_beam_downsample
                #     BEAM: 64
                #     BEAM_PROB: 0.80
                #     PHI_THRESHOLD: 0.03
                data_dict = {'points': single_pc_pnts, 'frame_id': 0}
                tgt_dataset = val_loader.dataset
                data_dict = tgt_dataset.prepare_data(data_dict=data_dict)
                data_dict = tgt_dataset.collate_batch([data_dict])
                load_data_to_gpu(data_dict)
                with torch.no_grad():
                    batch_pred_dict = model(data_dict)[0][0]
                pred_boxes_aug = batch_pred_dict['pred_boxes'].cpu().numpy()
                
                pred_boxes_aug = pred_boxes_aug / noise_scale

                pred_labels_aug = batch_pred_dict['pred_labels'].cpu().numpy()
                pred_labels_aug = pred_labels_aug.astype(np.int32)
                pred_scores_aug = batch_pred_dict['pred_scores'].cpu().numpy()
                pred_cls_scores_aug = batch_pred_dict['pred_cls_scores'].cpu().numpy()
                pred_iou_scores_aug = batch_pred_dict['pred_iou_scores'].cpu().numpy()
                # pred_cls_scores_aug = pred_cls_scores_aug
                # pred_iou_scores_aug = pred_iou_scores_aug
                # gt_box = np.concatenate((pred_boxes_aug,
                #                          pred_labels_aug.reshape(-1, 1),
                #                          pred_scores_aug.reshape(-1, 1)), axis=1)
                
                # gt_infos = {
                #     'gt_boxes': gt_box,
                #     'cls_scores': pred_cls_scores_aug,
                #     'iou_scores': pred_iou_scores_aug,
                #     'memory_counter': np.zeros(gt_box.shape[0])
                # }
                middle_mask = pred_scores < 0.99
                middle_pred_boxes = pred_boxes[middle_mask]
                middle_pred_labels = pred_labels[middle_mask]
                middle_pred_scores = pred_scores[middle_mask]
                middle_pred_cls_scores = pred_cls_scores[middle_mask]
                middle_pred_iou_scores = pred_iou_scores[middle_mask]

                high_pred_boxes = pred_boxes[~middle_mask]
                high_pred_labels = pred_labels[~middle_mask]
                high_pred_scores = pred_scores[~middle_mask]
                high_pred_cls_scores = pred_cls_scores[~middle_mask]
                high_pred_iou_scores = pred_iou_scores[~middle_mask]

                gt_boxes_raw = middle_pred_boxes

                gt_boxes = pred_boxes_aug
                
                from pcdet.ops.iou3d_nms import iou3d_nms_utils


                gt_boxes, _ = common_utils.check_numpy_to_torch(gt_boxes)
                gt_boxes_raw, _ = common_utils.check_numpy_to_torch(gt_boxes_raw)
                iou_matrix = iou3d_nms_utils.boxes_bev_iou_cpu(gt_boxes_raw[:, :7],
                                                gt_boxes).cpu().numpy()
                gt_boxes = gt_boxes.cpu().numpy()
                gt_boxes_raw = gt_boxes_raw.cpu().numpy()
                related_boxes = iou_matrix > 0.9 #  (9, 1024)

                gt_boxes = gt_boxes_raw[related_boxes.any(axis=1)]
                gt_labels = middle_pred_labels[related_boxes.any(axis=1)]
                gt_scores = middle_pred_scores[related_boxes.any(axis=1)]
                pred_cls_scores = middle_pred_cls_scores[related_boxes.any(axis=1)]
                pred_iou_scores = middle_pred_iou_scores[related_boxes.any(axis=1)]

                gt_boxes = np.concatenate((gt_boxes, high_pred_boxes), axis=0)
                gt_labels = np.concatenate((gt_labels, high_pred_labels), axis=0)
                gt_scores = np.concatenate((gt_scores, high_pred_scores), axis=0)
                pred_cls_scores = np.concatenate((pred_cls_scores, high_pred_cls_scores), axis=0)
                pred_iou_scores = np.concatenate((pred_iou_scores, high_pred_iou_scores), axis=0)

                gt_box = np.concatenate((gt_boxes,
                                            gt_labels.reshape(-1, 1),
                                            gt_scores.reshape(-1, 1)), axis=1)

                # gt_box_raw = np.concatenate((pred_boxes,
                #                             pred_labels.reshape(-1, 1),
                #                             pred_scores.reshape(-1, 1)), axis=1)
                
                # gt_infos_raw = {
                #     'gt_boxes': gt_box_raw,
                #     'cls_scores': pred_cls_scores,
                #     'iou_scores': pred_iou_scores,
                #     'memory_counter': np.zeros(gt_box_raw.shape[0])
                # }
                # ensemble_func = getattr(memory_ensemble_utils, cfg.SELF_TRAIN.MEMORY_ENSEMBLE.NAME)
                # gt_infos = memory_ensemble_utils.memory_ensemble(
                #     gt_infos, gt_infos_raw,
                #     cfg.SELF_TRAIN.MEMORY_ENSEMBLE, ensemble_func
                # )       
            else:
                # remove boxes under negative threshold
                if cfg.SELF_TRAIN.get('NEG_THRESH', None):
                    labels_remove_scores = np.array(cfg.SELF_TRAIN.NEG_THRESH)[pred_labels - 1]
                    remain_mask = pred_scores >= labels_remove_scores
                    pred_labels = pred_labels[remain_mask]
                    pred_scores = pred_scores[remain_mask]
                    pred_boxes = pred_boxes[remain_mask]
                    if 'pred_cls_scores' in pred_dicts[b_idx]:
                        pred_cls_scores = pred_cls_scores[remain_mask]
                    if 'pred_iou_scores' in pred_dicts[b_idx]:
                        pred_iou_scores = pred_iou_scores[remain_mask]

                if not cfg.SELF_TRAIN.get('SOFT_LABEL', None):

                    labels_ignore_scores = np.array(cfg.SELF_TRAIN.SCORE_THRESH)[pred_labels - 1]
                    # ignore_mask = pred_scores < labels_ignore_scores
                    # pred_labels[ignore_mask] = -1
                    ignore_mask = np.logical_and(pred_scores < labels_ignore_scores, pred_labels > 0)
                    pred_labels[ignore_mask] = -pred_labels[ignore_mask]

                    unreliable_pred_boxes = pred_boxes[ignore_mask]
                    unreliable_pred_labels = pred_labels[ignore_mask]
                    unreliable_pred_scores = pred_scores[ignore_mask]

                # if cfg.SELF_TRAIN.get('IE_AUG', None) and len(pred_boxes) > 0:

                #     pred_boxes_pre_nms = pred_dicts[b_idx]['pred_boxes_pre_nms'].detach().cpu().numpy()
                #     pred_labels_pre_nms = pred_dicts[b_idx]['pred_labels_pre_nms'].detach().cpu().numpy()
                #     pred_scores_pre_nms = pred_dicts[b_idx]['pred_scores_pre_nms'].detach().cpu().numpy()
                #     if 'pred_cls_scores_pre_nms' in pred_dicts[b_idx]:
                #         pred_cls_scores_pre_nms = pred_dicts[b_idx]['pred_cls_scores_pre_nms'].detach().cpu().numpy()
                #     if 'pred_iou_scores_pre_nms' in pred_dicts[b_idx]:
                #         pred_iou_scores_pre_nms = pred_dicts[b_idx]['pred_iou_scores_pre_nms'].detach().cpu().numpy()
                    
                #     # find closest iou between pre_nms boxes and pred_boxes
                #     from pcdet.ops.iou3d_nms import iou3d_nms_utils
                #     pred_boxes_pre_nms, _ = common_utils.check_numpy_to_torch(pred_boxes_pre_nms)
                #     pred_boxes, _ = common_utils.check_numpy_to_torch(pred_boxes)
                #     iou_matrix = iou3d_nms_utils.boxes_bev_iou_cpu(pred_boxes_pre_nms[:, :7],
                #                                                     pred_boxes)
                #     # iou_matrix = iou_matrix.cpu().numpy                    pred_boxes_pre_nms = pred_boxes_pre_nms.cpu().numpy()
                #     # pred_boxes = pred_boxes.cpu().numpy()
                #     # pred_boxes_pre_nms = pred_boxes_pre_nms.cpu().numpy()

                #     # 删除 iou 大于 0.999 的 pred_box_pre_nms 和 iou_matrix
                #     high_idx = iou_matrix > 0.999
                #     pred_boxes_pre_nms = pred_boxes_pre_nms[~high_idx.any(axis=1)]

                #     iou_matrix = iou_matrix[~high_idx.any(axis=1)]

                #     # additional_proposals = []
                #     # 在剩下的 pred_box_pre_nms 选择 iou 最大的 pred_box_pre_nms
                #     if iou_matrix.shape[0] > 0:
                #         closest_idx = np.argmax(iou_matrix, axis=0)

                #         # 根据iou最大的 pred_box_pre_nms 和 pred_box 生成插值和外推提议
                #         closest_proposal = pred_boxes_pre_nms[closest_idx]
                #         inter_center = pred_boxes[:, :3] - (closest_proposal[:, :3] + pred_boxes[:, :3]) / 2
                #         inter_other = pred_boxes[:, 3:]
                #         inter_proposal = np.concatenate((inter_center, inter_other), axis=1)
                #         extra_center = pred_boxes[:, :3] + (closest_proposal[:, :3] + pred_boxes[:, :3]) / 2
                #         extra_other = pred_boxes[:, 3:]
                #         extra_proposal = np.concatenate((extra_center, extra_other), axis=1)
                #         # 生成最终的提议
                #         pred_boxes = np.concatenate((pred_boxes, inter_proposal, extra_proposal), axis=0)
                #         pred_labels = np.concatenate((pred_labels, pred_labels, pred_labels), axis=0)
                #         pred_scores = np.concatenate((pred_scores, pred_scores, pred_scores), axis=0)

                #         pred_cls_scores = np.concatenate((pred_cls_scores, pred_cls_scores, pred_cls_scores), axis=0)
                #         pred_iou_scores = np.concatenate((pred_iou_scores, pred_iou_scores, pred_iou_scores), axis=0)


                    
                
                # if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('SOFT_LABEL', None):
                #     hard_quality = pred_labels > labels_ignore_scores
                # hard_quality = remain_mask
                # pred_scores[remain_mask] = 1
                


                gt_box = np.concatenate((pred_boxes,
                                        pred_labels.reshape(-1, 1),
                                        pred_scores.reshape(-1, 1)), axis=1)
                
                unreliable_gt_box = np.concatenate((unreliable_pred_boxes,
                                        unreliable_pred_labels.reshape(-1, 1),
                                        unreliable_pred_scores.reshape(-1, 1)), axis=1)

        else:
            # no predicted boxes passes self-training score threshold
            gt_box = np.zeros((0, 9), dtype=np.float32)
            unreliable_gt_box = np.zeros((0, 9), dtype=np.float32)

        if pred_cls_scores is None:
            pred_cls_scores = pred_scores
        if pred_iou_scores is None:
            pred_iou_scores = pred_scores


        gt_infos = {
            'gt_boxes': gt_box,
            'unreliable_boxes': unreliable_gt_box,
            'cls_scores': pred_cls_scores,
            'iou_scores': pred_iou_scores,
            'memory_counter': np.zeros(gt_box.shape[0])
        }


        # # record pseudo label to pseudo label dict
        # if need_update:
        #     ensemble_func = getattr(memory_ensemble_utils, cfg.SELF_TRAIN.MEMORY_ENSEMBLE.NAME)
        #     gt_infos = ensemble_func(PSEUDO_LABELS[input_dict['frame_id'][b_idx]],
        #                              gt_infos, cfg.SELF_TRAIN.MEMORY_ENSEMBLE)
        # record pseudo label to pseudo label dict
        if need_update:
            unreliable_gt_box_bk = copy.deepcopy(unreliable_gt_box)
            gt_infos.pop('unreliable_boxes')

            # unreliable_ps_bk = copy.deepcopy(PSEUDO_LABELS[input_dict['frame_id'][b_idx]]['unreliable_boxes'])
            # PSEUDO_LABELS[input_dict['frame_id'][b_idx]].pop('unreliable_boxes')
            previous_gt_infos = copy.deepcopy(PSEUDO_LABELS[input_dict['frame_id'][b_idx]])
            previous_gt_infos.pop('unreliable_boxes')
            
            ensemble_func = getattr(memory_ensemble_utils, cfg.SELF_TRAIN.MEMORY_ENSEMBLE.NAME)
            gt_infos = memory_ensemble_utils.memory_ensemble(
                previous_gt_infos, gt_infos,
                cfg.SELF_TRAIN.MEMORY_ENSEMBLE, ensemble_func
            )

            gt_infos.update({'unreliable_boxes': unreliable_gt_box_bk})

        '''--------- Target ReD Sampling ---------'''
        gt_box = gt_infos['gt_boxes']
        gt_points = None
        if cfg.SELF_TRAIN.get('PS_SAMPLING',None) and \
                cfg.SELF_TRAIN.PS_SAMPLING.ENABLE:
            from pcdet.ops.roiaware_pool3d.roiaware_pool3d_utils import points_in_boxes_cpu
            gt_points = []
            batch_points = \
                input_dict['points'][
                    input_dict['points'][:,0]==b_idx][:,1:].cpu().numpy()
            internal_pnts_mask = \
                points_in_boxes_cpu(batch_points, gt_box[:, :7])
            for msk_idx in range(gt_box.shape[0]):
                gt_points.append(batch_points[internal_pnts_mask[msk_idx]==1])

        if cfg.SELF_TRAIN.get('PS_SAMPLING',  None) and cfg.SELF_TRAIN.PS_SAMPLING.ENABLE:
            gt_infos.update({'gt_points': gt_points})

        if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('SOFT_LABEL', None):
            labels_ignore_scores = np.array(cfg.SELF_TRAIN.SCORE_THRESH)[pred_labels - 1]
            remain_mask = pred_scores > labels_ignore_scores
            pred_scores[remain_mask] = 1
        # counter the number of ignore boxes for each class
        for i in range(ign_ps_nmeter.n):
            num_total_boxes = (np.abs(gt_infos['gt_boxes'][:, 7]) == (i+1)).sum()
            ign_ps_nmeter.update((gt_infos['gt_boxes'][:, 7] == -(i+1)).sum(), index=i)
            pos_ps_nmeter.update(num_total_boxes - ign_ps_nmeter.meters[i].val, index=i)

        NEW_PSEUDO_LABELS[input_dict['frame_id'][b_idx]] = gt_infos

    return pos_ps_nmeter, ign_ps_nmeter


def load_ps_label(frame_id):
    """
    :param frame_id: file name of pseudo label
    :return gt_box: loaded gt boxes (N, 9) [x, y, z, w, l, h, ry, label, scores]
    """
    if frame_id in PSEUDO_LABELS:
        gt_box = PSEUDO_LABELS[frame_id]['gt_boxes']
        unreliable_gt_box = PSEUDO_LABELS[frame_id]['unreliable_boxes']
    else:
        raise ValueError('Cannot find pseudo label for frame: %s' % frame_id)

    return gt_box, unreliable_gt_box


def load_target(TARGET_PSEUDO_LABELS):
    for key in PSEUDO_LABELS.keys():
        if key not in TARGET_PSEUDO_LABELS.keys():
            if 'cls_preds' not in PSEUDO_LABELS[key].keys():
                tmp = PSEUDO_LABELS[key]
                tmp['cls_preds'] = np.zeros((PSEUDO_LABELS[key]['gt_boxes'].shape[0], 3))
                # PSEUDO_LABELS[key]['cls_preds'] = np.zeros((PSEUDO_LABELS[key]['gt_boxes'].shape[0], 1))
                PSEUDO_LABELS[key] = tmp
            TARGET_PSEUDO_LABELS[key] = PSEUDO_LABELS[key]
        # else:
        #     TARGET_PSEUDO_LABELS[key]['gt_boxes'] = np.concatenate(
        #         (TARGET_PSEUDO_LABELS[key]['gt_boxes'], PSEUDO_LABELS[key]['gt_boxes']), axis=0)
        #     TARGET_PSEUDO_LABELS[key]['cls_scores'] = np.concatenate(
        #         (TARGET_PSEUDO_LABELS[key]['cls_scores'], PSEUDO_LABELS[key]['cls_scores']), axis=0)
        #     TARGET_PSEUDO_LABELS[key]['iou_scores'] = np.concatenate(
        #         (TARGET_PSEUDO_LABELS[key]['iou_scores'], PSEUDO_LABELS[key]['iou_scores']), axis=0)
        #     TARGET_PSEUDO_LABELS[key]['memory_counter'] = np.concatenate(
        #         (TARGET_PSEUDO_LABELS[key]['memory_counter'], PSEUDO_LABELS[key]['memory_counter']), axis=0)
    PSEUDO_LABELS.clear()
    PSEUDO_LABELS.update(TARGET_PSEUDO_LABELS)
