import torch
import os
import glob
import tqdm
import numpy as np
import torch.distributed as dist
from pcdet.config import cfg
from pcdet.models import load_data_to_gpu
from pcdet.utils import common_utils, commu_utils, memory_ensemble_utils, dynamic_threshold
import pickle as pkl
import re
from multiprocessing import Manager
from pcdet.models.model_utils.dsnorm import set_ds_target, set_ds_source
import copy
import wandb
from pcdet.models.model_utils.model_nms_utils import class_agnostic_nms

PSEUDO_LABELS = {}
PSEUDO_LABELS = Manager().dict()

NEW_PSEUDO_LABELS = {}

TMP_PSEUDO_LABELS = {}

FEATURES_BANK = Manager().dict()
FEATURES_BANK_TMP_SRC = {}
FEATURES_BANK_TMP_TAR = {}

# def filter_pseudo_label_by_score(rois, roi_scores, roi_labels, roi_head_features):
#     """
#     Filter pseudo label by score threshold and nms.
#     Args:
#         rois: pseudo label boxes
#         roi_scores: pseudo label scores
#         roi_labels: pseudo label labels
#     """

#     roi_scores = torch.sigmoid(roi_scores)
#     batch_size = rois.shape[0]
#     pred_dicts = []
#     nms_config = cfg.MODEL.ROI_HEAD.NMS_CONFIG.TEST
#     score_thresh = np.array(cfg.SELF_TRAIN.SCORE_THRESH)
#     score_thresh = torch.tensor(score_thresh).to(roi_scores.device)
#     neg_thresh = np.array(cfg.SELF_TRAIN.NEG_THRESH)
#     neg_thresh = torch.tensor(neg_thresh).to(roi_scores.device)
#     for index in range(batch_size):
#         batch_mask = index
#         mask = roi_scores[batch_mask] > score_thresh[roi_labels[batch_mask] - 1]
#         box_preds = rois[batch_mask][mask]
#         cur_roi_scores = roi_scores[batch_mask][mask]
#         cur_roi_labels = roi_labels[batch_mask][mask]
#         cur_roi_features = roi_head_features[batch_mask][mask]
#         selected, selected_scores = class_agnostic_nms(
#             box_scores=cur_roi_scores, box_preds=box_preds, nms_config=nms_config
#         )
#         background_mask = roi_scores[batch_mask] <= neg_thresh[roi_labels[batch_mask] - 1]
#         neg_features = roi_head_features[batch_mask][background_mask]
#         pred_dicts.append({
#             'pred_boxes': box_preds[selected],
#             'pred_scores': cur_roi_scores[selected],
#             'pred_labels': cur_roi_labels[selected],
#             'pred_head_features': cur_roi_features[selected],
#             'neg_features': neg_features
#         })

#     rois = torch.cat([pred_dict['pred_boxes'] for pred_dict in pred_dicts], dim=0)
#     roi_scores = torch.cat([pred_dict['pred_scores'] for pred_dict in pred_dicts], dim=0)
#     roi_labels = torch.cat([pred_dict['pred_labels'] for pred_dict in pred_dicts], dim=0)
#     roi_head_features = torch.cat([pred_dict['pred_head_features'] for pred_dict in pred_dicts], dim=0)
#     neg_features = torch.cat([pred_dict['neg_features'] for pred_dict in pred_dicts], dim=0)

#     return rois, roi_scores, roi_labels, roi_head_features, neg_features
def filter_pseudo_label_by_score(rois, roi_scores, roi_labels, roi_head_features):
    """
    Filter pseudo label by score threshold and nms.
    Args:
        rois: pseudo label boxes
        roi_scores: pseudo label scores
        roi_labels: pseudo label labels
    """

    roi_scores = torch.sigmoid(roi_scores)
    batch_size = rois.shape[0]
    pred_dicts = []
    nms_config = cfg.MODEL.ROI_HEAD.NMS_CONFIG.TEST
    score_thresh = np.array(cfg.SELF_TRAIN.SCORE_THRESH)
    score_thresh = torch.tensor(score_thresh).to(roi_scores.device)
    neg_thresh = np.array(cfg.SELF_TRAIN.NEG_THRESH)
    neg_thresh = torch.tensor(neg_thresh).to(roi_scores.device)
    for index in range(batch_size):
        batch_mask = index
        mask = roi_scores[batch_mask] > score_thresh[roi_labels[batch_mask] - 1]
        box_preds = rois[batch_mask][mask]
        cur_roi_scores = roi_scores[batch_mask][mask]
        cur_roi_labels = roi_labels[batch_mask][mask]
        cur_roi_features = roi_head_features[batch_mask][mask]
        selected, selected_scores = class_agnostic_nms(
            box_scores=cur_roi_scores, box_preds=box_preds, nms_config=nms_config
        )
        if cfg.SELF_TRAIN.ProtoAlign.get('use_backgroud', False):
            background_mask = roi_scores[batch_mask] <= neg_thresh[roi_labels[batch_mask] - 1]
            neg_features = roi_head_features[batch_mask][background_mask]
            neg_labels = roi_labels[batch_mask][background_mask]
            neg_features_return = torch.tensor([]).to(roi_head_features.device)
            # neg_features 按类别随机采样与正样本数量相同
            if neg_features.size(0) > 0:
                for cls in range(1, len(cfg.CLASS_NAMES) + 1):
                    # mask = (roi_labels[batch_mask] == cls)
                    mask = (cur_roi_labels[selected] == cls)
                    if mask.sum() == 0:
                        continue
                    neg_mask = (neg_labels == cls)
                    if neg_mask.sum() == 0:
                        continue
                    neg_mask = neg_mask.nonzero().squeeze(1)
                    neg_mask = neg_mask[torch.randperm(neg_mask.size(0))[:mask.sum()]]
                    neg_features_return = neg_features[neg_mask] if neg_features_return.size(0) == 0 else torch.cat([neg_features_return, neg_features[neg_mask]], dim=0)
                    

            pred_dicts.append({
                'pred_boxes': box_preds[selected],
                'pred_scores': cur_roi_scores[selected],
                'pred_labels': cur_roi_labels[selected],
                'pred_head_features': cur_roi_features[selected],
                'neg_features': neg_features_return
            })
        else:
            pred_dicts.append({
                'pred_boxes': box_preds[selected],
                'pred_scores': cur_roi_scores[selected],
                'pred_labels': cur_roi_labels[selected],
                'pred_head_features': cur_roi_features[selected],
                'neg_features': torch.tensor([]).to(roi_head_features.device)
            })

    rois = torch.cat([pred_dict['pred_boxes'] for pred_dict in pred_dicts], dim=0)
    roi_scores = torch.cat([pred_dict['pred_scores'] for pred_dict in pred_dicts], dim=0)
    roi_labels = torch.cat([pred_dict['pred_labels'] for pred_dict in pred_dicts], dim=0)
    roi_head_features = torch.cat([pred_dict['pred_head_features'] for pred_dict in pred_dicts], dim=0)
    neg_features = torch.cat([pred_dict['neg_features'] for pred_dict in pred_dicts], dim=0)

    return rois, roi_scores, roi_labels, roi_head_features, neg_features

def post_data_augmentation(input_dict, src_loader):
    """
    Data augmentation for pseudo label generation.
    """
    data_dict_list = []
    batch_size = input_dict['gt_boxes'].shape[0]

    
    for frame_idx in range(batch_size):
        gt_classes = input_dict['gt_boxes'][frame_idx][:, -1]
        data_dict = {
            'points': input_dict['points'][input_dict['points'][:, 0] == frame_idx][:, 1:],
            'frame_id': input_dict['frame_id'][frame_idx],
            'gt_boxes': input_dict['gt_boxes'][frame_idx][:, :7],
            # 'gt_names': input_dict['gt_names'][frame_idx],
            'gt_classes': input_dict['gt_boxes'][frame_idx][:, -1],
            'beam_labels': input_dict['beam_labels'][input_dict['beam_labels'][:, 0] == frame_idx][:, 1:].squeeze(1) if 'beam_labels' in input_dict else None,
        }
        for key in list(data_dict.keys()):
            if data_dict[key] is None:
                data_dict.pop(key)
        if 'beam_labels' in data_dict:
            assert len(data_dict['beam_labels']) == len(data_dict['points'])
            beam_labels = data_dict['beam_labels'].astype(np.int32)
            data_dict['beam_labels'] = beam_labels
        data_dict = src_loader.dataset.prepare_data_proto_align(data_dict=data_dict)
        data_dict_list.append(data_dict)
    # # 定义需要保留的键 
    # keys_to_keep = {'points', 'frame_id', 'gt_boxes', 'non_sampled_gt_boxes_num', 'use_lead_xyz', 'voxels', 'voxel_coords', 'voxel_num_points'}
    # # print('data_dict_list', data_dict_list)
    # # 遍历列表中的每个字典，移除不需要的键
    # for data in data_dict_list:
    #     # print(type(data))
    #     keys_to_remove = [key for key in data.keys() if key not in keys_to_keep]
    #     for key in keys_to_remove:
    #         del data[key]
    new_batch = src_loader.dataset.collate_batch(data_dict_list)
    for key in input_dict.keys():
        if key not in new_batch.keys():
            new_batch[key] = input_dict[key]
    return new_batch

def check_already_exsit_pseudo_label(ps_label_dir, start_epoch):
    """
    if we continue training, use this to directly
    load pseudo labels from exsiting result pkl

    if exsit, load latest result pkl to PSEUDO LABEL
    otherwise, return false and

    Args:
        ps_label_dir: dir to save pseudo label results pkls.
        start_epoch: start epoc
    Returns:

    """
    # support init ps_label given by cfg
    if start_epoch == 0 and cfg.SELF_TRAIN.get('INIT_PS', None):
        if os.path.exists(cfg.SELF_TRAIN.INIT_PS):
            init_ps_label = pkl.load(open(cfg.SELF_TRAIN.INIT_PS, 'rb'))
            PSEUDO_LABELS.update(init_ps_label)

            if cfg.LOCAL_RANK == 0:
                ps_path = os.path.join(ps_label_dir, "ps_label_e0.pkl")
                with open(ps_path, 'wb') as f:
                    pkl.dump(PSEUDO_LABELS, f)

            return cfg.SELF_TRAIN.INIT_PS

    ps_label_list = glob.glob(os.path.join(ps_label_dir, 'ps_label_e*.pkl'))
    if len(ps_label_list) == 0:
        return

    ps_label_list.sort(key=os.path.getmtime, reverse=True)
    for cur_pkl in ps_label_list:
        num_epoch = re.findall('ps_label_e(.*).pkl', cur_pkl)
        assert len(num_epoch) == 1

        # load pseudo label and return
        if int(num_epoch[0]) <= start_epoch:
            latest_ps_label = pkl.load(open(cur_pkl, 'rb'))
            PSEUDO_LABELS.update(latest_ps_label)
            return cur_pkl

    return None


def save_pseudo_label_epoch(model, val_loader, source_loader, rank, leave_pbar, ps_label_dir, cur_epoch, train_model):
    """
    Generate pseudo label with given model.

    Args:
        model: model to predict result for pseudo label
        val_loader: data_loader to predict pseudo label
        rank: process rank
        leave_pbar: tqdm bar controller
        ps_label_dir: dir to save pseudo label
        cur_epoch
    """
    val_dataloader_iter = iter(val_loader)
    total_it_each_epoch = len(val_loader)

    src_dataloader_iter = iter(source_loader)
    total_it_each_epoch_src = len(source_loader)

    if dist.is_initialized():
        prototype_manager = train_model.module.prototype_manager
    else:
        prototype_manager = train_model.prototype_manager

    if rank == 0:
        pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar,
                         desc='generate_ps_e%d' % cur_epoch, dynamic_ncols=True)

    pos_ps_nmeter = common_utils.NAverageMeter(len(cfg.CLASS_NAMES))
    ign_ps_nmeter = common_utils.NAverageMeter(len(cfg.CLASS_NAMES))

    if cfg.SELF_TRAIN.get('DSNORM', None):
        model.apply(set_ds_target)

    model.eval()
    if rank == 0:
        print('==> Start collecting features for prototype alignment')

    for cur_it in range(total_it_each_epoch):
        try:
            target_batch = next(val_dataloader_iter)
        except StopIteration:
            target_dataloader_iter = iter(val_loader)
            target_batch = next(target_dataloader_iter)

        # generate gt_boxes for target_batch and update model weights
        with torch.no_grad():
            if isinstance(target_batch, list):
                target_batch = target_batch[-1]
                # target_batch = target_batch[0]
            

            # target_batch = post_data_augmentation(target_batch, val_loader) # for RBRS and ROS
            target_batch['collect_features_only'] = True
            if cfg.SELF_TRAIN.get('DSNORM', None):
                model.apply(set_ds_target)

            load_data_to_gpu(target_batch)
            pred_dicts, ret_dict, target_batch_dict = model(target_batch)

            # collect tgt features for prototype initialization
            target_roi_feats = target_batch_dict['roi_head_features'] # (B, N, D)
            target_roi_labels = target_batch_dict['roi_labels'] # (B, N)
            target_roi_scores = target_batch_dict['roi_scores'] # (B, N)
            target_roi_boxes = target_batch_dict['rois'] # (B, N, 7)
            ## TODO: filter features scores > threshold
            target_roi_boxes, target_roi_scores, target_roi_labels, target_roi_feats, target_neg_feats = filter_pseudo_label_by_score(
                target_roi_boxes, target_roi_scores, target_roi_labels, target_roi_feats
            )
            for cls in range(len(cfg.CLASS_NAMES)):
                if str(cls + 1) not in FEATURES_BANK_TMP_TAR:
                    FEATURES_BANK_TMP_TAR[str(cls + 1)] = target_roi_feats[target_roi_labels == (cls + 1)].cpu()
                else:
                    FEATURES_BANK_TMP_TAR[str(cls + 1)] = torch.cat([FEATURES_BANK_TMP_TAR[str(cls + 1)], target_roi_feats[target_roi_labels == (cls + 1)].cpu()], dim=0)
            if str(len(cfg.CLASS_NAMES) + 1) not in FEATURES_BANK_TMP_TAR:
                FEATURES_BANK_TMP_TAR[str(len(cfg.CLASS_NAMES) + 1)] = target_neg_feats.cpu()
            else:
                FEATURES_BANK_TMP_TAR[str(len(cfg.CLASS_NAMES) + 1)] = torch.cat([FEATURES_BANK_TMP_TAR[str(len(cfg.CLASS_NAMES) + 1)], target_neg_feats.cpu()], dim=0)

        pos_ps_batch_nmeters, ign_ps_batch_nmeters = save_pseudo_label_batch(
            target_batch, pred_dicts=pred_dicts,
            need_update=(cfg.SELF_TRAIN.get('MEMORY_ENSEMBLE', None) and
                         cfg.SELF_TRAIN.MEMORY_ENSEMBLE.ENABLED and
                         cur_epoch > 0),
            val_loader=val_loader,
            model=model
        )

        # log to console and tensorboard
        pos_ps_nmeter.update(pos_ps_batch_nmeters)
        ign_ps_nmeter.update(ign_ps_batch_nmeters)
        pos_ps_result = pos_ps_nmeter.aggregate_result()
        ign_ps_result = ign_ps_nmeter.aggregate_result()

        disp_dict = {'pos_ps_box': pos_ps_result, 'ign_ps_box': ign_ps_result}

        if rank == 0:
            pbar.update()
            pbar.set_postfix(disp_dict)
            pbar.refresh()

        # collect src features for prototype initialization
        # source_batch = next(src_dataloader_iter)
        try:
            source_batch = next(src_dataloader_iter)
        except StopIteration:
            src_dataloader_iter = iter(source_loader)
            source_batch = next(src_dataloader_iter)
        source_batch = post_data_augmentation(source_batch, source_loader) # for RBRS and ROS
        source_batch['collect_features_only'] = True
        
        if cfg.SELF_TRAIN.get('DSNORM', None):
            model.apply(set_ds_source)
        with torch.no_grad():
            load_data_to_gpu(source_batch)
            source_batch_dict = model(source_batch)[-1]
            source_roi_feats = source_batch_dict['roi_head_features'] # (B, N, D)
            source_roi_labels = source_batch_dict['roi_labels'] # (B, N)
            source_roi_scores = source_batch_dict['roi_scores'] # (B, N)
            source_roi_boxes = source_batch_dict['rois'] # (B, N, 7)
            source_roi_boxes, source_roi_scores, source_roi_labels, source_roi_feats, source_neg_feats = filter_pseudo_label_by_score(
                source_roi_boxes, source_roi_scores, source_roi_labels, source_roi_feats
            )
            ## TODO: filter features scores > threshold
            for cls in range(len(cfg.CLASS_NAMES)):
                if str(cls + 1) not in FEATURES_BANK_TMP_SRC:
                    FEATURES_BANK_TMP_SRC[str(cls + 1)] = source_roi_feats[source_roi_labels == (cls + 1)].cpu()
                else:
                    FEATURES_BANK_TMP_SRC[str(cls + 1)] = torch.cat([FEATURES_BANK_TMP_SRC[str(cls + 1)], source_roi_feats[source_roi_labels == (cls + 1)].cpu()], dim=0)
            if str(len(cfg.CLASS_NAMES) + 1) not in FEATURES_BANK_TMP_SRC:
                FEATURES_BANK_TMP_SRC[str(len(cfg.CLASS_NAMES) + 1)] = source_neg_feats.cpu()
            else:
                FEATURES_BANK_TMP_SRC[str(len(cfg.CLASS_NAMES) + 1)] = torch.cat([FEATURES_BANK_TMP_SRC[str(len(cfg.CLASS_NAMES) + 1)], source_neg_feats.cpu()], dim=0)
    if rank == 0:
        pbar.close()

        if cfg.SELF_TRAIN.get('PROGRESSIVE_SAMPLING', None) and cfg.SELF_TRAIN.PROGRESSIVE_SAMPLING.ENABLE and cur_epoch != cfg.OPTIMIZATION.NUM_EPOCHS:
            gt_reduce = cfg.SELF_TRAIN.PROGRESSIVE_SAMPLING.GT_REDUCE
            ps_grow = cfg.SELF_TRAIN.PROGRESSIVE_SAMPLING.PS_GROW
            if cfg.SELF_TRAIN.get('PS_SAMPLING', None):
                for k in cfg.SELF_TRAIN.PS_SAMPLING.SAMPLE_GROUPS:
                    cfg.SELF_TRAIN.PS_SAMPLING.SAMPLE_GROUPS[k] += ps_grow
            if cfg.DATA_CONFIG_TAR.DATA_AUGMENTOR.AUG_CONFIG_LIST[0].NAME == 'gt_sampling' and \
                    'gt_sampling' not in cfg.DATA_CONFIG_TAR.DATA_AUGMENTOR.DISABLE_AUG_LIST:
                new_sample_groups = []
                for i in cfg.DATA_CONFIG_TAR.DATA_AUGMENTOR.AUG_CONFIG_LIST[0].SAMPLE_GROUPS:
                    old_sample_num = int(i.split(":")[-1])
                    if old_sample_num - gt_reduce < 0:
                        new_sample_num = str(0)
                    else:
                        new_sample_num = str(int(i.split(":")[-1])-gt_reduce)
                    new_i = i.split(":")[0] + ':' + new_sample_num
                    new_sample_groups.append(new_i)
                cfg.DATA_CONFIG_TAR.DATA_AUGMENTOR.AUG_CONFIG_LIST[0].SAMPLE_GROUPS = new_sample_groups

    # debug for dataloader error
    # for cur_it in range(total_it_each_epoch, total_it_each_epoch_src):
    #     source_batch = next(src_dataloader_iter)

    # merge src and tgt features for prototype initialization
    commu_utils.synchronize()
    if dist.is_initialized():
        part_pseudo_labels_list = commu_utils.all_gather(FEATURES_BANK_TMP_TAR)

        new_pseudo_label_dict = {}
        for cls in range(len(cfg.CLASS_NAMES)):
            new_pseudo_label_dict[str(cls + 1)] = torch.cat([part_pseudo_labels_list[i][str(cls + 1)] for i in range(len(part_pseudo_labels_list))], dim=0)
        new_pseudo_label_dict[str(len(cfg.CLASS_NAMES) + 1)] = torch.cat([part_pseudo_labels_list[i][str(len(cfg.CLASS_NAMES) + 1)] for i in range(len(part_pseudo_labels_list))], dim=0)
        FEATURES_BANK_TMP_TAR.update(new_pseudo_label_dict)

        part_pseudo_labels_list_src = commu_utils.all_gather(FEATURES_BANK_TMP_SRC)
        new_pseudo_label_dict_src = {}
        for cls in range(len(cfg.CLASS_NAMES)):
            new_pseudo_label_dict_src[str(cls + 1)] = torch.cat([part_pseudo_labels_list_src[i][str(cls + 1)] for i in range(len(part_pseudo_labels_list_src))], dim=0)
        new_pseudo_label_dict_src[str(len(cfg.CLASS_NAMES) + 1)] = torch.cat([part_pseudo_labels_list_src[i][str(len(cfg.CLASS_NAMES) + 1)] for i in range(len(part_pseudo_labels_list_src))], dim=0)
        FEATURES_BANK_TMP_SRC.update(new_pseudo_label_dict_src)
    
        # del new_pseudo_label_dict, part_pseudo_labels_list, new_pseudo_label_dict_src, part_pseudo_labels_list_src
        # torch.cuda.empty_cache()
        # import gc
        # gc.collect()
    commu_utils.synchronize()

    FEATURES_BANK.clear()
    FEATURES_BANK['source'] = FEATURES_BANK_TMP_SRC
    FEATURES_BANK['target'] = FEATURES_BANK_TMP_TAR
    FEATURES_BANK_TMP_SRC.clear()
    FEATURES_BANK_TMP_TAR.clear()
    commu_utils.synchronize()
    if rank == 0:

        print('==> Finish collecting features for prototype alignment')
        if not cfg.SELF_TRAIN.ProtoAlign.get('use_backgroud', False):
            # remove background features
            if str(len(cfg.CLASS_NAMES) + 1) in FEATURES_BANK['source']:
                FEATURES_BANK['source'].pop(str(len(cfg.CLASS_NAMES) + 1))
            if str(len(cfg.CLASS_NAMES) + 1) in FEATURES_BANK['target']:
                FEATURES_BANK['target'].pop(str(len(cfg.CLASS_NAMES) + 1))
        
        prototype_manager.gmm_update(FEATURES_BANK)
    commu_utils.synchronize()

    def broadcast_prototype_manager(prototype_manager):
        # 遍历每个类别，对 prototypes, covariances, weights 进行广播
        for key in prototype_manager.gmm_prototypes.keys():
            # 广播 prototypes[key]
            dist.broadcast(prototype_manager.gmm_prototypes[key].data, src=0)
            # 广播 covariances[key]
            dist.broadcast(prototype_manager.gmm_covariances[key].data, src=0)
            # 广播 weights[key]
            dist.broadcast(prototype_manager.gmm_weights[key].data, src=0)

            dist.broadcast(prototype_manager.gmm_precisions_cholesky_[key].data, src=0)

        # if cfg.SELF_TRAIN.ProtoAlign.use_kmeans:
        #     for key in prototype_manager.kmeans_prototypes.keys():
        #         dist.broadcast(prototype_manager.kmeans_prototypes[key].data, src=0)
        #         dist.broadcast(prototype_manager.kmeans__n_threads[key].data, src=0)
                

    if dist.is_initialized():
        if rank == 0:
            print("Broadcasting prototype manager parameters from rank 0")
        commu_utils.synchronize()
        broadcast_prototype_manager(prototype_manager)
        commu_utils.synchronize()
    # commu_utils.synchronize()

    # def average_prototype_manager(prototype_manager):
    #     """
    #     计算所有进程上的 GMM 参数的平均值，使所有进程共享相同的参数，而不是仅使用 rank=0 的参数
    #     """
    #     for key in prototype_manager.gmm_prototypes.keys():
    #         # 计算所有进程上的均值
    #         dist.all_reduce(prototype_manager.gmm_prototypes[key].data, op=dist.ReduceOp.SUM)
    #         prototype_manager.gmm_prototypes[key].data /= dist.get_world_size()

    #         dist.all_reduce(prototype_manager.gmm_covariances[key].data, op=dist.ReduceOp.SUM)
    #         prototype_manager.gmm_covariances[key].data /= dist.get_world_size()

    #         dist.all_reduce(prototype_manager.gmm_weights[key].data, op=dist.ReduceOp.SUM)
    #         prototype_manager.gmm_weights[key].data /= dist.get_world_size()

    #         dist.all_reduce(prototype_manager.gmm_precisions_cholesky_[key].data, op=dist.ReduceOp.SUM)
    #         prototype_manager.gmm_precisions_cholesky_[key].data /= dist.get_world_size()

    # if dist.is_initialized():
    #     if dist.get_rank() == 0:
    #         print("Averaging prototype manager parameters across all processes")
    #     commu_utils.synchronize()
    #     average_prototype_manager(prototype_manager)
    #     commu_utils.synchronize()
    prototype_manager.init = True
    FEATURES_BANK.clear()
    FEATURES_BANK_TMP_SRC.clear()
    FEATURES_BANK_TMP_TAR.clear()
    torch.cuda.empty_cache()  # 清理 GPU 缓存
    # commu_utils.synchronize()
    # del src_dataloader_iter, val_dataloader_iter
    # import gc
    # gc.collect()
    # commu_utils.synchronize()
    if rank == 0:
        print('==> Finish broadcasting prototype manager parameters')
        print('==> Prototype Feature Counts: {}'.format(prototype_manager.feature_counts))

    gather_and_dump_pseudo_label_result(rank, ps_label_dir, cur_epoch)


def gather_and_dump_pseudo_label_result(rank, ps_label_dir, cur_epoch):
    commu_utils.synchronize()

    if dist.is_initialized():
        part_pseudo_labels_list = commu_utils.all_gather(NEW_PSEUDO_LABELS)

        new_pseudo_label_dict = {}
        for pseudo_labels in part_pseudo_labels_list:
            new_pseudo_label_dict.update(pseudo_labels)

        NEW_PSEUDO_LABELS.update(new_pseudo_label_dict)

    if cfg.SELF_TRAIN.get('PS_SAMPLING',None) and \
                    cfg.SELF_TRAIN.PS_SAMPLING.ENABLE:

        cfg.SELF_TRAIN.PS_SAMPLING.PS_OBJECT_PATH = ps_label_dir
    # dump new pseudo label to given dir
    if rank == 0:
        ps_path = os.path.join(ps_label_dir, "ps_label_e{}.pkl".format(cur_epoch))
        with open(ps_path, 'wb') as f:
            pkl.dump(NEW_PSEUDO_LABELS, f)

        if cfg.SELF_TRAIN.get('PS_SAMPLING',None) and \
                        cfg.SELF_TRAIN.PS_SAMPLING.ENABLE:

            cfg.SELF_TRAIN.PS_SAMPLING.PS_OBJECT_PATH = ps_label_dir
            ps_point_path = os.path.join(ps_label_dir,
                                        "ps_point_e{}.pkl".format(cur_epoch))
            ps_box_path = os.path.join(ps_label_dir,
                                        "ps_box_e{}.pkl".format(cur_epoch))
            class_names = cfg.CLASS_NAMES
            ps_boxes_dict = {class_name: [] for class_name in class_names}
            ps_points_dict = {class_name: [] for class_name in class_names}

            if cfg.SELF_TRAIN.get('DIVERSITY_SAMPLING', None):
                ps_related_counts_dict = {class_name: [] for class_name in class_names}

            for frame_id, frame in NEW_PSEUDO_LABELS.items():
                ps_boxes = frame['gt_boxes']
                if 'quality_mask' in frame.keys():
                    quality_mask = frame['quality_mask'].astype(np.bool_)
                else:
                    import numpy as np
                    quality_mask = np.ones(ps_boxes.shape[0]).astype(np.bool_)
                pos_ps_mask = ps_boxes[:, -2] > 0
                quality_mask = quality_mask[pos_ps_mask]
                ps_boxes = ps_boxes[pos_ps_mask]
                ps_labels = ps_boxes[:, -2]
                ps_scores = ps_boxes[:, -1]
                if cfg.SELF_TRAIN.get('DIVERSITY_SAMPLING', None):
                    related_boxes_count = frame['related_box_count'][pos_ps_mask]

                ps_points_list = []
                for mask_idx in range(pos_ps_mask.shape[0]):
                    if pos_ps_mask[mask_idx]:
                        ps_points_list.append(frame['gt_points'][mask_idx])

                remain_mask = quality_mask

                remain_ps_points_list = []
                for mask_idx in range(remain_mask.shape[0]):
                    if remain_mask[mask_idx]:
                        remain_ps_points_list.append(ps_points_list[mask_idx])

                ps_boxes = ps_boxes[remain_mask]
                ps_labels = ps_labels[remain_mask]
                if cfg.SELF_TRAIN.get('DIVERSITY_SAMPLING', None):
                    related_boxes_count = related_boxes_count[remain_mask]

                # ps_points = ps_points[remain_mask]
                ps_names = [class_names[int(class_id - 1)]
                            for class_id in ps_labels.tolist()]
                for inx in range(len(ps_names)):
                    ps_boxes_dict[ps_names[inx]].append(ps_boxes[inx])
                    ps_points_dict[ps_names[inx]].append(remain_ps_points_list[inx])
                    if cfg.SELF_TRAIN.get('DIVERSITY_SAMPLING', None):
                        ps_related_counts_dict[ps_names[inx]].append(related_boxes_count[inx])

                """ Print how many ps-labels for sampling """
            print('Distribution of PS-Labels for sampling at epoch {}:'.format(cur_epoch))
            for cls in ps_boxes_dict.keys():
                print('# of ps labels of class {} is {}'.format(cls, len(ps_boxes_dict[cls])))
                key = 'PS-Label before div-sampling/' + cls
                wandb.log({key: len(ps_boxes_dict[cls])})

            """------------------ OBC-based Diversity Down-sampling ------------------"""
            if cfg.SELF_TRAIN.get('DIVERSITY_SAMPLING', None):
                ps_diverse_path = os.path.join(ps_label_dir, "ps_diverse_e{}.pkl".format(cur_epoch))
                for class_name in cfg.SELF_TRAIN.DIVERSITY_SAMPLING.SAMPLE_CLASSES:
                    ps_boxes = np.stack(ps_boxes_dict[class_name])
                    related_box_counts = np.stack(ps_related_counts_dict[class_name])
                    import scipy.stats
                    kde = scipy.stats.gaussian_kde(related_box_counts.T)
                    p = 1 / kde.pdf(related_box_counts.T)
                    p /= np.sum(p)
                    sample_size = int(len(related_box_counts) / cfg.SELF_TRAIN.DIVERSITY_SAMPLING.DOWNSAMPLE_RATE)
                    sample_idx = np.random.choice(np.arange(len(related_box_counts)), size=sample_size, replace=False, p=p)
                    sampled_boxes = ps_boxes[sample_idx].astype(np.float32)[:, :7]  # (ps box num, 7)
                    sampled_pnts = [ps_points_dict[class_name][i] for i in sample_idx]

                    ps_boxes_dict[class_name] = sampled_boxes
                    ps_points_dict[class_name] = sampled_pnts
                    ps_related_counts_dict[class_name]= related_box_counts[sample_idx]

                with open(ps_diverse_path, 'wb') as f:
                    pkl.dump(ps_related_counts_dict, f)
                """ Print how many ps-labels for sampling """
                print('Distribution of PS-Labels for DIVERSE sampling at epoch {}:'.format(cur_epoch))
                for cls in ps_boxes_dict.keys():
                    print('# of ps labels of class {} is {}'.format(cls, len(ps_boxes_dict[cls])))
                    key = 'PS-Label after div-sampling/' + cls
                    wandb.log({key: len(ps_boxes_dict[cls])})


            with open(ps_box_path, 'wb') as f:
                pkl.dump(ps_boxes_dict, f)
            with open(ps_point_path, 'wb') as f:
                pkl.dump(ps_points_dict, f)

    if cfg.SELF_TRAIN.get('BDA', None):
        
        old_dist = cfg.SELF_TRAIN.BDA.DIST
        distribution = {}
        distribution[0] = 0
        distribution[1] = 0
        distribution[2] = 0            


        import pickle
        
        # pkl_path = os.path.join(ps_label_dir, "ps_label_e{}.pkl".format(cur_epoch))
        # with open(pkl_path, 'rb') as f:
        data = NEW_PSEUDO_LABELS
        for key, value in data.items():
            pred_box = value['gt_boxes']
            for box in pred_box:
                if box[-2] > -1:
                    distribution[box[-2] - 1] += 1

        # if sum(old_dist) < 1:
        #     cfg.SELF_TRAIN.BDA.DIST = distribution.tolist()
        # else:
        #     cfg.SELF_TRAIN.BDA.DIST = (np.array(old_dist) + np.array(list(distribution.values()))).tolist()
        import numpy as np
        distribution = np.array([i for i in distribution.values()])



        # 平滑项
        epsilon = 1e-6

        # 平滑并归一化 target_distribution
        distribution = (distribution + epsilon) / (np.sum(distribution) + epsilon * len(distribution))
        
        old_dist = np.array(old_dist)
        alpha = cfg.SELF_TRAIN.BDA.ALPHA
        distribution = alpha * old_dist + (1 - alpha) * distribution
    
        cfg.SELF_TRAIN.BDA.DIST = list(distribution)


            # cur_save_path = os.path.join(ps_label_dir, "ps_label_e{}_distribution.txt".format(cur_epoch))
            # with open(cur_save_path, 'a+') as f:
            #     # f.write("distribution of pseudo label\n")
            #     f.write(str(distribution))
            

    commu_utils.synchronize()

    if cfg.SELF_TRAIN.get('PLS', None):
        stat_subjects = ["length", "width", "height"]

        import numpy as np
        import pickle
        old_stats = np.array([cfg.SELF_TRAIN.PLS[dim] for dim in stat_subjects])
        
        ps_path = os.path.join(ps_label_dir, "ps_label_e{}.pkl".format(cur_epoch))
        with open(ps_path, 'rb') as f:
            data = pickle.load(f)
        gt_boxes_list = []
        for i in range(len(data)):
            gt_boxes_list.extend(list(data.values())[i]['gt_boxes'][:, 3:6])
        # l, w, h
        import numpy as np
        gt_boxes_avg = np.mean(gt_boxes_list, axis=0)
        new_stats = gt_boxes_avg.reshape(1,3)
        new_scale = new_stats - old_stats
        if 'normalize_object_size' in cfg.DATA_CONFIG.DATA_AUGMENTOR.DISABLE_AUG_LIST:
            cfg.DATA_CONFIG.DATA_AUGMENTOR.DISABLE_AUG_LIST.remove('normalize_object_size')
            if 'random_object_scaling' not in cfg.DATA_CONFIG.DATA_AUGMENTOR.DISABLE_AUG_LIST:
                cfg.DATA_CONFIG.DATA_AUGMENTOR.DISABLE_AUG_LIST.append('random_object_scaling')
        cfg.DATA_CONFIG.DATA_AUGMENTOR.AUG_CONFIG_LIST[1]['SIZE_RES'] = new_scale.squeeze().tolist()
    
    commu_utils.synchronize()
    PSEUDO_LABELS.clear()
    PSEUDO_LABELS.update(NEW_PSEUDO_LABELS)
    NEW_PSEUDO_LABELS.clear()
    # if cfg.SELF_TRAIN.get('CONSISTENCY', None):
    #     TMP_TEACHER_PSEUDO_LABELS.clear()
        


def save_pseudo_label_batch(input_dict,
                            pred_dicts=None,
                            need_update=True,
                            val_loader=None,
                            model=None):
    """
    Save pseudo label for give batch.
    If model is given, use model to inference pred_dicts,
    otherwise, directly use given pred_dicts.

    Args:
        input_dict: batch data read from dataloader
        pred_dicts: Dict if not given model.
            predict results to be generated pseudo label and saved
        need_update: Bool.
            If set to true, use consistency matching to update pseudo label
    """
    pos_ps_nmeter = common_utils.NAverageMeter(len(cfg.CLASS_NAMES))
    ign_ps_nmeter = common_utils.NAverageMeter(len(cfg.CLASS_NAMES))

    batch_size = len(pred_dicts)
    for b_idx in range(batch_size):
        pred_cls_scores = pred_iou_scores = None
        if 'pred_boxes' in pred_dicts[b_idx]:
            # Exist predicted boxes passing self-training score threshold
            pred_boxes = pred_dicts[b_idx]['pred_boxes'].detach().cpu().numpy()
            pred_labels = pred_dicts[b_idx]['pred_labels'].detach().cpu().numpy()
            pred_scores = pred_dicts[b_idx]['pred_scores'].detach().cpu().numpy()
            if 'pred_cls_scores' in pred_dicts[b_idx]:
                pred_cls_scores = pred_dicts[b_idx]['pred_cls_scores'].detach().cpu().numpy()
            if 'pred_iou_scores' in pred_dicts[b_idx]:
                pred_iou_scores = pred_dicts[b_idx]['pred_iou_scores'].detach().cpu().numpy()

            if cfg.SELF_TRAIN.get('NO_FILTER', None):
                gt_box = np.concatenate((pred_boxes,
                                         pred_labels.reshape(-1, 1),
                                         pred_scores.reshape(-1, 1)), axis=1)
            elif cfg.SELF_TRAIN.get('AUG', None):
                from pcdet.datasets.augmentor.augmentor_utils import global_scaling
                batch_points = \
                    input_dict['points'][
                        input_dict['points'][:, 0] == b_idx][:,1:].cpu().numpy()
                
                single_pc_pnts = batch_points.copy()

                src_uniform = np.array([1.00, 1.05])
                noise_scale = np.random.uniform(src_uniform[0], src_uniform[1])
                single_pc_pnts *= noise_scale

                
                # from pcdet.datasets.augmentor.data_augmentor import get_polar_image, label_point_cloud_beam
                
                config = cfg.DATA_CONFIG_TAR.DATA_AUGMENTOR.AUG_CONFIG_LIST[0]
                points = single_pc_pnts
                if 'beam_labels' in input_dict: # for waymo and kitti datasets
                    beam_label = input_dict['beam_labels'][input_dict['beam_labels'][:, 0] == b_idx][:, 1].cpu().numpy()
                    beam_label = beam_label.astype(np.int32)

                # else:
                #     polar_image = get_polar_image(points)
                #     beam_label = label_point_cloud_beam(polar_image, points, config['BEAM'])

                beam_mask = np.random.rand(config['BEAM']) < config['BEAM_PROB']
                points_mask = beam_mask[beam_label]
                single_pc_pnts = points[points_mask]
                # - NAME: random_beam_downsample
                #     BEAM: 64
                #     BEAM_PROB: 0.80
                #     PHI_THRESHOLD: 0.03
                data_dict = {'points': single_pc_pnts, 'frame_id': 0}
                tgt_dataset = val_loader.dataset
                data_dict = tgt_dataset.prepare_data(data_dict=data_dict)
                data_dict = tgt_dataset.collate_batch([data_dict])
                load_data_to_gpu(data_dict)
                with torch.no_grad():
                    batch_pred_dict = model(data_dict)[0][0]
                pred_boxes_aug = batch_pred_dict['pred_boxes'].cpu().numpy()
                
                pred_boxes_aug = pred_boxes_aug / noise_scale

                pred_labels_aug = batch_pred_dict['pred_labels'].cpu().numpy()
                pred_labels_aug = pred_labels_aug.astype(np.int32)
                pred_scores_aug = batch_pred_dict['pred_scores'].cpu().numpy()
                pred_cls_scores_aug = batch_pred_dict['pred_cls_scores'].cpu().numpy()
                pred_iou_scores_aug = batch_pred_dict['pred_iou_scores'].cpu().numpy()
                # pred_cls_scores_aug = pred_cls_scores_aug
                # pred_iou_scores_aug = pred_iou_scores_aug
                # gt_box = np.concatenate((pred_boxes_aug,
                #                          pred_labels_aug.reshape(-1, 1),
                #                          pred_scores_aug.reshape(-1, 1)), axis=1)
                
                # gt_infos = {
                #     'gt_boxes': gt_box,
                #     'cls_scores': pred_cls_scores_aug,
                #     'iou_scores': pred_iou_scores_aug,
                #     'memory_counter': np.zeros(gt_box.shape[0])
                # }
                middle_mask = pred_scores < 0.99
                middle_pred_boxes = pred_boxes[middle_mask]
                middle_pred_labels = pred_labels[middle_mask]
                middle_pred_scores = pred_scores[middle_mask]
                middle_pred_cls_scores = pred_cls_scores[middle_mask]
                middle_pred_iou_scores = pred_iou_scores[middle_mask]

                high_pred_boxes = pred_boxes[~middle_mask]
                high_pred_labels = pred_labels[~middle_mask]
                high_pred_scores = pred_scores[~middle_mask]
                high_pred_cls_scores = pred_cls_scores[~middle_mask]
                high_pred_iou_scores = pred_iou_scores[~middle_mask]

                gt_boxes_raw = middle_pred_boxes

                gt_boxes = pred_boxes_aug
                
                from pcdet.ops.iou3d_nms import iou3d_nms_utils


                gt_boxes, _ = common_utils.check_numpy_to_torch(gt_boxes)
                gt_boxes_raw, _ = common_utils.check_numpy_to_torch(gt_boxes_raw)
                iou_matrix = iou3d_nms_utils.boxes_bev_iou_cpu(gt_boxes_raw[:, :7],
                                                gt_boxes).cpu().numpy()
                gt_boxes = gt_boxes.cpu().numpy()
                gt_boxes_raw = gt_boxes_raw.cpu().numpy()
                related_boxes = iou_matrix > 0.9 #  (9, 1024)

                gt_boxes = gt_boxes_raw[related_boxes.any(axis=1)]
                gt_labels = middle_pred_labels[related_boxes.any(axis=1)]
                gt_scores = middle_pred_scores[related_boxes.any(axis=1)]
                pred_cls_scores = middle_pred_cls_scores[related_boxes.any(axis=1)]
                pred_iou_scores = middle_pred_iou_scores[related_boxes.any(axis=1)]

                gt_boxes = np.concatenate((gt_boxes, high_pred_boxes), axis=0)
                gt_labels = np.concatenate((gt_labels, high_pred_labels), axis=0)
                gt_scores = np.concatenate((gt_scores, high_pred_scores), axis=0)
                pred_cls_scores = np.concatenate((pred_cls_scores, high_pred_cls_scores), axis=0)
                pred_iou_scores = np.concatenate((pred_iou_scores, high_pred_iou_scores), axis=0)

                gt_box = np.concatenate((gt_boxes,
                                            gt_labels.reshape(-1, 1),
                                            gt_scores.reshape(-1, 1)), axis=1)

                # gt_box_raw = np.concatenate((pred_boxes,
                #                             pred_labels.reshape(-1, 1),
                #                             pred_scores.reshape(-1, 1)), axis=1)
                
                # gt_infos_raw = {
                #     'gt_boxes': gt_box_raw,
                #     'cls_scores': pred_cls_scores,
                #     'iou_scores': pred_iou_scores,
                #     'memory_counter': np.zeros(gt_box_raw.shape[0])
                # }
                # ensemble_func = getattr(memory_ensemble_utils, cfg.SELF_TRAIN.MEMORY_ENSEMBLE.NAME)
                # gt_infos = memory_ensemble_utils.memory_ensemble(
                #     gt_infos, gt_infos_raw,
                #     cfg.SELF_TRAIN.MEMORY_ENSEMBLE, ensemble_func
                # )       
            else:
                # remove boxes under negative threshold
                if cfg.SELF_TRAIN.get('NEG_THRESH', None):
                    labels_remove_scores = np.array(cfg.SELF_TRAIN.NEG_THRESH)[pred_labels - 1]
                    remain_mask = pred_scores >= labels_remove_scores
                    pred_labels = pred_labels[remain_mask]
                    pred_scores = pred_scores[remain_mask]
                    pred_boxes = pred_boxes[remain_mask]
                    if 'pred_cls_scores' in pred_dicts[b_idx]:
                        pred_cls_scores = pred_cls_scores[remain_mask]
                    if 'pred_iou_scores' in pred_dicts[b_idx]:
                        pred_iou_scores = pred_iou_scores[remain_mask]

                if not cfg.SELF_TRAIN.get('SOFT_LABEL', None):

                    labels_ignore_scores = np.array(cfg.SELF_TRAIN.SCORE_THRESH)[pred_labels - 1]
                    # ignore_mask = pred_scores < labels_ignore_scores
                    # pred_labels[ignore_mask] = -1
                    ignore_mask = np.logical_and(pred_scores < labels_ignore_scores, pred_labels > 0)
                    pred_labels[ignore_mask] = -pred_labels[ignore_mask]

                    unreliable_pred_boxes = pred_boxes[ignore_mask]
                    unreliable_pred_labels = pred_labels[ignore_mask]
                    unreliable_pred_scores = pred_scores[ignore_mask]

                # if cfg.SELF_TRAIN.get('IE_AUG', None) and len(pred_boxes) > 0:

                #     pred_boxes_pre_nms = pred_dicts[b_idx]['pred_boxes_pre_nms'].detach().cpu().numpy()
                #     pred_labels_pre_nms = pred_dicts[b_idx]['pred_labels_pre_nms'].detach().cpu().numpy()
                #     pred_scores_pre_nms = pred_dicts[b_idx]['pred_scores_pre_nms'].detach().cpu().numpy()
                #     if 'pred_cls_scores_pre_nms' in pred_dicts[b_idx]:
                #         pred_cls_scores_pre_nms = pred_dicts[b_idx]['pred_cls_scores_pre_nms'].detach().cpu().numpy()
                #     if 'pred_iou_scores_pre_nms' in pred_dicts[b_idx]:
                #         pred_iou_scores_pre_nms = pred_dicts[b_idx]['pred_iou_scores_pre_nms'].detach().cpu().numpy()
                    
                #     # find closest iou between pre_nms boxes and pred_boxes
                #     from pcdet.ops.iou3d_nms import iou3d_nms_utils
                #     pred_boxes_pre_nms, _ = common_utils.check_numpy_to_torch(pred_boxes_pre_nms)
                #     pred_boxes, _ = common_utils.check_numpy_to_torch(pred_boxes)
                #     iou_matrix = iou3d_nms_utils.boxes_bev_iou_cpu(pred_boxes_pre_nms[:, :7],
                #                                                     pred_boxes)
                #     # iou_matrix = iou_matrix.cpu().numpy                    pred_boxes_pre_nms = pred_boxes_pre_nms.cpu().numpy()
                #     # pred_boxes = pred_boxes.cpu().numpy()
                #     # pred_boxes_pre_nms = pred_boxes_pre_nms.cpu().numpy()

                #     # 删除 iou 大于 0.999 的 pred_box_pre_nms 和 iou_matrix
                #     high_idx = iou_matrix > 0.999
                #     pred_boxes_pre_nms = pred_boxes_pre_nms[~high_idx.any(axis=1)]

                #     iou_matrix = iou_matrix[~high_idx.any(axis=1)]

                #     # additional_proposals = []
                #     # 在剩下的 pred_box_pre_nms 选择 iou 最大的 pred_box_pre_nms
                #     if iou_matrix.shape[0] > 0:
                #         closest_idx = np.argmax(iou_matrix, axis=0)

                #         # 根据iou最大的 pred_box_pre_nms 和 pred_box 生成插值和外推提议
                #         closest_proposal = pred_boxes_pre_nms[closest_idx]
                #         inter_center = pred_boxes[:, :3] - (closest_proposal[:, :3] + pred_boxes[:, :3]) / 2
                #         inter_other = pred_boxes[:, 3:]
                #         inter_proposal = np.concatenate((inter_center, inter_other), axis=1)
                #         extra_center = pred_boxes[:, :3] + (closest_proposal[:, :3] + pred_boxes[:, :3]) / 2
                #         extra_other = pred_boxes[:, 3:]
                #         extra_proposal = np.concatenate((extra_center, extra_other), axis=1)
                #         # 生成最终的提议
                #         pred_boxes = np.concatenate((pred_boxes, inter_proposal, extra_proposal), axis=0)
                #         pred_labels = np.concatenate((pred_labels, pred_labels, pred_labels), axis=0)
                #         pred_scores = np.concatenate((pred_scores, pred_scores, pred_scores), axis=0)

                #         pred_cls_scores = np.concatenate((pred_cls_scores, pred_cls_scores, pred_cls_scores), axis=0)
                #         pred_iou_scores = np.concatenate((pred_iou_scores, pred_iou_scores, pred_iou_scores), axis=0)


                    
                
                # if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('SOFT_LABEL', None):
                #     hard_quality = pred_labels > labels_ignore_scores
                # hard_quality = remain_mask
                # pred_scores[remain_mask] = 1
                


                gt_box = np.concatenate((pred_boxes,
                                        pred_labels.reshape(-1, 1),
                                        pred_scores.reshape(-1, 1)), axis=1)
                
                unreliable_gt_box = np.concatenate((unreliable_pred_boxes,
                                        unreliable_pred_labels.reshape(-1, 1),
                                        unreliable_pred_scores.reshape(-1, 1)), axis=1)

        else:
            # no predicted boxes passes self-training score threshold
            gt_box = np.zeros((0, 9), dtype=np.float32)
            unreliable_gt_box = np.zeros((0, 9), dtype=np.float32)

        if pred_cls_scores is None:
            pred_cls_scores = pred_scores
        if pred_iou_scores is None:
            pred_iou_scores = pred_scores


        gt_infos = {
            'gt_boxes': gt_box,
            'unreliable_boxes': unreliable_gt_box,
            'cls_scores': pred_cls_scores,
            'iou_scores': pred_iou_scores,
            'memory_counter': np.zeros(gt_box.shape[0])
        }


        # # record pseudo label to pseudo label dict
        # if need_update:
        #     ensemble_func = getattr(memory_ensemble_utils, cfg.SELF_TRAIN.MEMORY_ENSEMBLE.NAME)
        #     gt_infos = ensemble_func(PSEUDO_LABELS[input_dict['frame_id'][b_idx]],
        #                              gt_infos, cfg.SELF_TRAIN.MEMORY_ENSEMBLE)
        # record pseudo label to pseudo label dict
        if need_update:
            unreliable_gt_box_bk = copy.deepcopy(unreliable_gt_box)
            gt_infos.pop('unreliable_boxes')

            # unreliable_ps_bk = copy.deepcopy(PSEUDO_LABELS[input_dict['frame_id'][b_idx]]['unreliable_boxes'])
            # PSEUDO_LABELS[input_dict['frame_id'][b_idx]].pop('unreliable_boxes')
            previous_gt_infos = copy.deepcopy(PSEUDO_LABELS[input_dict['frame_id'][b_idx]])
            previous_gt_infos.pop('unreliable_boxes')
            
            ensemble_func = getattr(memory_ensemble_utils, cfg.SELF_TRAIN.MEMORY_ENSEMBLE.NAME)
            gt_infos = memory_ensemble_utils.memory_ensemble(
                previous_gt_infos, gt_infos,
                cfg.SELF_TRAIN.MEMORY_ENSEMBLE, ensemble_func
            )

            gt_infos.update({'unreliable_boxes': unreliable_gt_box_bk})

        '''--------- Target ReD Sampling ---------'''
        gt_box = gt_infos['gt_boxes']
        gt_points = None
        if cfg.SELF_TRAIN.get('PS_SAMPLING',None) and \
                cfg.SELF_TRAIN.PS_SAMPLING.ENABLE:
            from pcdet.ops.roiaware_pool3d.roiaware_pool3d_utils import points_in_boxes_cpu
            gt_points = []
            batch_points = \
                input_dict['points'][
                    input_dict['points'][:,0]==b_idx][:,1:].cpu().numpy()
            internal_pnts_mask = \
                points_in_boxes_cpu(batch_points, gt_box[:, :7])
            for msk_idx in range(gt_box.shape[0]):
                gt_points.append(batch_points[internal_pnts_mask[msk_idx]==1])

        if cfg.SELF_TRAIN.get('PS_SAMPLING',  None) and cfg.SELF_TRAIN.PS_SAMPLING.ENABLE:
            gt_infos.update({'gt_points': gt_points})

        if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('SOFT_LABEL', None):
            labels_ignore_scores = np.array(cfg.SELF_TRAIN.SCORE_THRESH)[pred_labels - 1]
            remain_mask = pred_scores > labels_ignore_scores
            pred_scores[remain_mask] = 1
        # counter the number of ignore boxes for each class
        for i in range(ign_ps_nmeter.n):
            num_total_boxes = (np.abs(gt_infos['gt_boxes'][:, 7]) == (i+1)).sum()
            ign_ps_nmeter.update((gt_infos['gt_boxes'][:, 7] == -(i+1)).sum(), index=i)
            pos_ps_nmeter.update(num_total_boxes - ign_ps_nmeter.meters[i].val, index=i)

        NEW_PSEUDO_LABELS[input_dict['frame_id'][b_idx]] = gt_infos

    return pos_ps_nmeter, ign_ps_nmeter


def load_ps_label(frame_id):
    """
    :param frame_id: file name of pseudo label
    :return gt_box: loaded gt boxes (N, 9) [x, y, z, w, l, h, ry, label, scores]
    """
    if frame_id in PSEUDO_LABELS:
        gt_box = PSEUDO_LABELS[frame_id]['gt_boxes']
        unreliable_gt_box = PSEUDO_LABELS[frame_id]['unreliable_boxes']
    else:
        raise ValueError('Cannot find pseudo label for frame: %s' % frame_id)

    return gt_box, unreliable_gt_box


def load_target(TARGET_PSEUDO_LABELS):
    for key in PSEUDO_LABELS.keys():
        if key not in TARGET_PSEUDO_LABELS.keys():
            if 'cls_preds' not in PSEUDO_LABELS[key].keys():
                tmp = PSEUDO_LABELS[key]
                tmp['cls_preds'] = np.zeros((PSEUDO_LABELS[key]['gt_boxes'].shape[0], 3))
                # PSEUDO_LABELS[key]['cls_preds'] = np.zeros((PSEUDO_LABELS[key]['gt_boxes'].shape[0], 1))
                PSEUDO_LABELS[key] = tmp
            TARGET_PSEUDO_LABELS[key] = PSEUDO_LABELS[key]
        # else:
        #     TARGET_PSEUDO_LABELS[key]['gt_boxes'] = np.concatenate(
        #         (TARGET_PSEUDO_LABELS[key]['gt_boxes'], PSEUDO_LABELS[key]['gt_boxes']), axis=0)
        #     TARGET_PSEUDO_LABELS[key]['cls_scores'] = np.concatenate(
        #         (TARGET_PSEUDO_LABELS[key]['cls_scores'], PSEUDO_LABELS[key]['cls_scores']), axis=0)
        #     TARGET_PSEUDO_LABELS[key]['iou_scores'] = np.concatenate(
        #         (TARGET_PSEUDO_LABELS[key]['iou_scores'], PSEUDO_LABELS[key]['iou_scores']), axis=0)
        #     TARGET_PSEUDO_LABELS[key]['memory_counter'] = np.concatenate(
        #         (TARGET_PSEUDO_LABELS[key]['memory_counter'], PSEUDO_LABELS[key]['memory_counter']), axis=0)
    PSEUDO_LABELS.clear()
    PSEUDO_LABELS.update(TARGET_PSEUDO_LABELS)
