import torch
import os
import glob
import tqdm
import numpy as np
import torch.distributed as dist
from torch.nn.utils import clip_grad_norm_
from pcdet.utils import common_utils, commu_utils
from pcdet.utils import self_training_utils
from pcdet.config import cfg
from pcdet.models.model_utils.dsnorm import set_ds_source, set_ds_target

from .train_utils import save_checkpoint, checkpoint_state

from multiprocessing import Manager

# if cfg.get('SELF_TRAIN', None) and cfg.SELF_TRAIN.get('TARGET_LEARNING', None):
# TARGET_PSEUDO_LABELS = {}
TARGET_PSEUDO_LABELS = Manager().dict()

TARGET_PSEUDO_LABELS_TMP = {}
# TARGET_PSEUDO_LABELS_TMP = Manager().dict()

# TARGET_PSEUDO_LABELS_TMP = {}
# from pcdet.utils.self_training_utils import TARGET_PSEUDO_LABELS, commu_all_gather

def train_one_epoch_st(model, optimizer, source_reader, target_loader, model_func, lr_scheduler,
                       accumulated_iter, optim_cfg, rank, tbar, total_it_each_epoch,
                       dataloader_iter, tb_log=None, leave_pbar=False, ema_model=None, cur_epoch=None):
    if total_it_each_epoch == len(target_loader):
        dataloader_iter = iter(target_loader)

    if rank == 0:
        pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True)

    ps_bbox_nmeter = common_utils.NAverageMeter(len(cfg.CLASS_NAMES))
    ign_ps_bbox_nmeter = common_utils.NAverageMeter(len(cfg.CLASS_NAMES))
    loss_meter = common_utils.AverageMeter()
    st_loss_meter = common_utils.AverageMeter()
    
    cls_loss_meter = common_utils.AverageMeter()
    reg_loss_meter = common_utils.AverageMeter()

    bda_loss_meter = common_utils.AverageMeter()

    disp_dict = {}

    aggregation_interval = cfg.SELF_TRAIN.get('AGGREGATION_INTERVAL', None)  # 配置聚合间隔
    last_model_state = model.state_dict() if aggregation_interval else None  # 用于存储上一次的模型状态

    for cur_it in range(total_it_each_epoch):
        lr_scheduler.step(accumulated_iter)

        try:
            cur_lr = float(optimizer.lr)
        except:
            cur_lr = optimizer.param_groups[0]['lr']

        if tb_log is not None:
            tb_log.add_scalar('meta_data/learning_rate', cur_lr, accumulated_iter)

        model.train()

        optimizer.zero_grad()
        if cfg.SELF_TRAIN.SRC.USE_DATA:
            # forward source data with labels
            source_batch = source_reader.read_data()

            if cfg.SELF_TRAIN.get('DSNORM', None):
                model.apply(set_ds_source)

            if cfg.SELF_TRAIN.SRC.get('SEP_LOSS_WEIGHTS', None):
                source_batch['SEP_LOSS_WEIGHTS'] = cfg.SELF_TRAIN.SRC.SEP_LOSS_WEIGHTS

            loss, tb_dict, disp_dict, _ = model_func(model, source_batch)
            loss = cfg.SELF_TRAIN.SRC.get('LOSS_WEIGHT', 1.0) * loss
            loss.backward()
            loss_meter.update(loss.item())
            disp_dict.update({'loss': "{:.3f}({:.3f})".format(loss_meter.val, loss_meter.avg)})

            if not cfg.SELF_TRAIN.SRC.get('USE_GRAD', None):
                optimizer.zero_grad()

        if cfg.SELF_TRAIN.TAR.USE_DATA:
            try:
                target_batch = next(dataloader_iter)
            except StopIteration:
                dataloader_iter = iter(target_loader)
                target_batch = next(dataloader_iter)
                print('new iters')

            if cfg.SELF_TRAIN.get('DSNORM', None):
                model.apply(set_ds_target)

            if 'MT' in cfg.DATA_CONFIG_TAR.DATASET:
                batch_target1, batch_target2 = target_batch # s,t
                target_batch = batch_target1


            if cfg.SELF_TRAIN.TAR.get('SEP_LOSS_WEIGHTS', None):
                target_batch['SEP_LOSS_WEIGHTS'] = cfg.SELF_TRAIN.TAR.SEP_LOSS_WEIGHTS

            # construct current batch Target Pseudo Labels
            if cfg.SELF_TRAIN.get('TARGET_LEARNING', None):
                pseudo_label_target = []
                for index in range(target_batch['frame_id'].shape[0]):
                    if target_batch['frame_id'][index] in TARGET_PSEUDO_LABELS and TARGET_PSEUDO_LABELS[target_batch['frame_id'][index]]['gt_boxes'].shape[0] > 0:
                        pseudo_label_target.append(TARGET_PSEUDO_LABELS[target_batch['frame_id'][index]])
                    # else:
                    #     pseudo_label_target.append(None)
                if len(pseudo_label_target) < target_batch['frame_id'].shape[0]:
                    pseudo_label_target = None
            
            if cfg.SELF_TRAIN.get('BDA', None):
                import copy
                pseudo_label_target = copy.deepcopy(target_batch)
            # for debug target pseudo label only
            # pseudo_label_target = None
            # parameters for save pseudo label on the fly
            st_loss, st_tb_dict, st_disp_dict, target_batch_dict = model_func(model, target_batch, batch_dict_teacher=pseudo_label_target if cfg.SELF_TRAIN.get('TARGET_LEARNING', None) or cfg.SELF_TRAIN.get('BDA', None) else None)

            # if cfg.SELF_TRAIN.get('BDA', None):
            #     bda_st_loss, bda_st_tb_dict, bda_st_disp_dict, bda_batch_dict_target = model_func(model, target_batch, batch_dict_teacher)

            # construct Target Pseudo Labels
            if cfg.SELF_TRAIN.get('TARGET_LEARNING', None):
                from pcdet.models.model_utils import model_nms_utils
                # teacher box in train mode
                if 'MT' in cfg.DATA_CONFIG_TAR.DATASET:
                    model.train()
                    if cfg.SELF_TRAIN.get('DSNORM', None):
                        model.apply(set_ds_target)
                    # forward teacher model first
                    from pcdet.models import load_data_to_gpu
                    load_data_to_gpu(batch_target2)
                    batch_target2 = model(batch_target2, return_batch_dict=True)

                    batch_dict_teacher = batch_target2
                else:
                    batch_dict_teacher = target_batch_dict
                    batch_dict_teacher['rois_mt'][:, :, :7], gt_assignment_list = common_utils.reverse_augmentation(batch_dict_teacher['rois_mt'][:, :, :7], batch_dict_teacher)
                batch_size = batch_dict_teacher['rois_mt'].shape[0]
                pred_dicts = []
                for index in range(batch_size): 
                    pred_box_a = batch_dict_teacher['rois_mt'][index,:,:7]#.cpu().numpy()     #[x, y, z, dx, dy, dz, heading]
                    cls_scores_a = batch_dict_teacher['roi_scores_mt'][index]#.cpu().numpy() 
                    iou_scores_a = batch_dict_teacher['roi_iou_scores_mt'][index]#.detach().cpu().numpy()
                    cls_preds = batch_dict_teacher['roi_cls_preds_mt'][index]#.detach().cpu().numpy()
                    iou_scores_a = torch.sigmoid(iou_scores_a)
                    cls_scores_a = torch.sigmoid(cls_scores_a)
                    iou_scores_a = iou_scores_a.squeeze()
                    cls_scores_a = cls_scores_a.squeeze()

                    # cls_preds = torch.sigmoid(cls_preds)
                    
                    label_preds_a = batch_dict_teacher['roi_labels_mt'][index]#.detach().cpu().numpy()

                    if cfg.MODEL.POST_PROCESSING.NMS_CONFIG.get('SCORE_WEIGHTS', None):
                        # weight_iou = cfg.SELF_TRAIN.TARGET_LEARNING.SCORE_WEIGHTS.iou
                        # weight_cls = cfg.SELF_TRAIN.TARGET_LEARNING.SCORE_WEIGHTS.cls
                        weight_iou = cfg.MODEL.POST_PROCESSING.NMS_CONFIG.SCORE_WEIGHTS.iou
                        weight_cls = cfg.MODEL.POST_PROCESSING.NMS_CONFIG.SCORE_WEIGHTS.cls
                    if cfg.MODEL.POST_PROCESSING.NMS_CONFIG.get('SCORE_TYPE', None) == 'iou' or \
                            cfg.MODEL.POST_PROCESSING.NMS_CONFIG.get('SCORE_TYPE', None) is None:
                        cls_scores_a = iou_scores_a
                    elif cfg.MODEL.POST_PROCESSING.NMS_CONFIG.SCORE_TYPE == 'cls':
                        cls_scores_a = cls_scores_a
                    elif cfg.MODEL.POST_PROCESSING.NMS_CONFIG.SCORE_TYPE == 'hybrid_iou_cls':
                        assert weight_iou + weight_cls == 1
                        cls_scores_a = weight_iou * iou_scores_a + weight_cls * cls_scores_a
                    else:
                        raise NotImplementedError
                    selected_a, selected_scores_a = model_nms_utils.class_agnostic_nms(
                        box_scores=cls_scores_a, box_preds=pred_box_a,
                        nms_config=cfg.MODEL.POST_PROCESSING.NMS_CONFIG,
                        score_thresh=cfg.MODEL.POST_PROCESSING.NMS_CONFIG.NMS_THRESH
                    )

                    pred_box_a = pred_box_a[selected_a].cpu().numpy()
                    cls_scores_a = selected_scores_a.detach().cpu().numpy() 
                    label_preds_a = label_preds_a[selected_a].detach().cpu().numpy()

                    cls_preds_a = cls_preds[selected_a].detach().cpu().numpy()

                    # 是否对teacher_box使用阈值筛选
                    labels_remove_scores = np.array(cfg.SELF_TRAIN.SCORE_THRESH)[label_preds_a - 1]
                    remain_mask = cls_scores_a >= labels_remove_scores
                    label_preds_a = label_preds_a[remain_mask]
                    cls_scores_a = cls_scores_a[remain_mask]
                    pred_box_a = pred_box_a[remain_mask]
                    cls_preds_a = cls_preds_a[remain_mask]

                    zero_mask = (pred_box_a[:, 3:6] !=0).all(1)
                    pred_box_a = pred_box_a[zero_mask]
                    cls_scores_a = cls_scores_a[zero_mask]
                    label_preds_a = label_preds_a[zero_mask]
                    cls_preds_a = cls_preds_a[zero_mask]

                    gt_box = np.concatenate((pred_box_a,
                        label_preds_a.reshape(-1, 1),
                        cls_scores_a.reshape(-1, 1)), axis=1)

                    # record_dict = {
                    #     'pred_boxes': pred_box_a,
                    #     'pred_scores': cls_scores_a,
                    #     'pred_labels': label_preds_a,
                    #     'pred_cls_preds': cls_preds_a
                    # }
                    gt_infos = {
                        'gt_boxes': gt_box,
                        'cls_scores': cls_scores_a,
                        'iou_scores': cls_scores_a,
                        'cls_preds': cls_preds_a,
                        'memory_counter': np.zeros(gt_box.shape[0])
                    }
                    if batch_dict_teacher['frame_id'][index] not in TARGET_PSEUDO_LABELS or TARGET_PSEUDO_LABELS[batch_dict_teacher['frame_id'][index]] is None or TARGET_PSEUDO_LABELS[batch_dict_teacher['frame_id'][index]].get('gt_boxes', None) is None or (TARGET_PSEUDO_LABELS[batch_dict_teacher['frame_id'][index]]['gt_boxes'].shape[0] == 0):
                        TARGET_PSEUDO_LABELS_TMP[batch_dict_teacher['frame_id'][index]] = gt_infos
                    else:
                        from pcdet.ops.iou3d_nms import iou3d_nms_utils
                        def consistency_ensemble(gt_infos_a, gt_infos_b, memory_ensemble_cfg):
                            """
                            Args:
                                gt_infos_a:
                                    gt_boxes: (N, 9) [x, y, z, dx, dy, dz, heading, label, scores]  in LiDAR for previous pseudo boxes
                                    cls_scores: (N)
                                    iou_scores: (N)
                                    memory_counter: (N)
                                    cls_preds: (N)

                                gt_infos_b:
                                    gt_boxes: (M, 9) [x, y, z, dx, dy, dz, heading, label, scores]  in LiDAR for current pseudo boxes
                                    cls_scores: (M)
                                    iou_scores: (M)
                                    memory_counter: (M)
                                    cls_preds: (M)

                                memory_ensemble_cfg:

                            Returns:
                                gt_infos:
                                    gt_boxes: (K, 9) [x, y, z, dx, dy, dz, heading, label, scores]  in LiDAR for merged pseudo boxes
                                    cls_scores: (K)
                                    iou_scores: (K)
                                    memory_counter: (K)
                                    cls_preds: (K)
                            """
                            gt_box_a, _ = common_utils.check_numpy_to_torch(gt_infos_a['gt_boxes'])
                            gt_box_b, _ = common_utils.check_numpy_to_torch(gt_infos_b['gt_boxes'])
                            gt_box_a, gt_box_b = gt_box_a.cuda(), gt_box_b.cuda()

                            new_gt_box = gt_infos_a['gt_boxes']
                            new_cls_scores = gt_infos_a['cls_scores']
                            new_iou_scores = gt_infos_a['iou_scores']
                            new_cls_preds = gt_infos_a['cls_preds']
                            new_memory_counter = gt_infos_a['memory_counter']

                            # if gt_box_b or gt_box_a don't have any predictions
                            if gt_box_b.shape[0] == 0:
                                gt_infos_a['memory_counter'] += 1
                                return gt_infos_a
                            elif gt_box_a.shape[0] == 0:
                                return gt_infos_b

                            # get ious
                            iou_matrix = iou3d_nms_utils.boxes_iou3d_gpu(gt_box_a[:, :7], gt_box_b[:, :7]).cpu()

                            ious, match_idx = torch.max(iou_matrix, dim=1)
                            ious, match_idx = ious.numpy(), match_idx.numpy()
                            gt_box_a, gt_box_b = gt_box_a.cpu().numpy(), gt_box_b.cpu().numpy()

                            match_pairs_idx = np.concatenate((
                                np.array(list(range(gt_box_a.shape[0]))).reshape(-1, 1),
                                match_idx.reshape(-1, 1)), axis=1)

                            #########################################################
                            # filter matched pair boxes by IoU
                            # if matching succeeded, use boxes with higher confidence
                            #########################################################

                            iou_mask = (ious >= memory_ensemble_cfg.IOU_THRESH)

                            matching_selected = match_pairs_idx[iou_mask]
                            gt_box_selected_a = gt_box_a[matching_selected[:, 0]]
                            gt_box_selected_b = gt_box_b[matching_selected[:, 1]]

                            # assign boxes with higher confidence
                            score_mask = gt_box_selected_a[:, 8] < gt_box_selected_b[:, 8]
                            # if memory_ensemble_cfg.get('WEIGHTED', None):
                            #     weight = gt_box_selected_a[:, 8] / (gt_box_selected_a[:, 8] + gt_box_selected_b[:, 8])
                            #     min_scores = np.minimum(gt_box_selected_a[:, 8], gt_box_selected_b[:, 8])
                            #     max_scores = np.maximum(gt_box_selected_a[:, 8], gt_box_selected_b[:, 8])
                            #     weighted_score = weight * (max_scores - min_scores) + min_scores
                            #     new_gt_box[matching_selected[:, 0], :7] = weight.reshape(-1, 1) * gt_box_selected_a[:, :7] + \
                            #                                             (1 - weight.reshape(-1, 1)) * gt_box_selected_b[:, :7]
                            #     new_gt_box[matching_selected[:, 0], 8] = weighted_score
                            # else:
                            # new_gt_box[matching_selected[score_mask, 0], :] = gt_box_selected_b[score_mask, :]

                            # 将 gt_box_selected_a 中的 cls_preds 与 gt_box_selected_b 中的 cls_preds 进行 ema 更新
                            cls_preds_a = gt_infos_a['cls_preds'][matching_selected[:, 0]]
                            cls_preds_b = gt_infos_b['cls_preds'][matching_selected[:, 1]]

                            new_cls_preds[matching_selected[:, 0]] = cfg.SELF_TRAIN.TARGET_LEARNING.EMA_ALPHA * cls_preds_a + (1 - cfg.SELF_TRAIN.TARGET_LEARNING.EMA_ALPHA) * cls_preds_b
                            
                            # 第一种选择，直接使用新的 matched_gt_box
                            # new_gt_box[matching_selected[:, 0]] = gt_box_selected_b
                            # 第二种选择，将新的 matched_gt_box 与原来的 gt_box 进行 EMA 融合
                            new_gt_box[matching_selected[:, 0], :7] = cfg.SELF_TRAIN.TARGET_LEARNING.EMA_ALPHA * gt_box_selected_a[:, :7] + \
                                                                    (1 - cfg.SELF_TRAIN.TARGET_LEARNING.EMA_ALPHA) * gt_box_selected_b[:, :7]
                            
                            # 第三种选择，选择score较高的gt_box
                            # new_gt_box[matching_selected[score_mask, 0], :] = gt_box_selected_b[score_mask, :]

                            # 第四种选择，将新的 matched_gt_box 与原来的 gt_box 进行按 weight 融合
                            #     weight = gt_box_selected_a[:, 8] / (gt_box_selected_a[:, 8] + gt_box_selected_b[:, 8])
                            #     min_scores = np.minimum(gt_box_selected_a[:, 8], gt_box_selected_b[:, 8])
                            #     max_scores = np.maximum(gt_box_selected_a[:, 8], gt_box_selected_b[:, 8])
                            #     weighted_score = weight * (max_scores - min_scores) + min_scores
                            #     new_gt_box[matching_selected[:, 0], :7] = weight.reshape(-1, 1) * gt_box_selected_a[:, :7] + \
                            #                                             (1 - weight.reshape(-1, 1)) * gt_box_selected_b[:, :7]
                            #     new_gt_box[matching_selected[:, 0], 8] = weighted_score
                            


                            # 选择 score 较高的用的
                            if gt_infos_a['cls_scores'] is not None:
                                new_cls_scores[matching_selected[score_mask, 0]] = gt_infos_b['cls_scores'][
                                    matching_selected[score_mask, 1]]
                            if gt_infos_a['iou_scores'] is not None:
                                new_iou_scores[matching_selected[score_mask, 0]] = gt_infos_b['iou_scores'][
                                    matching_selected[score_mask, 1]]
                            if gt_infos_a['cls_preds'] is not None:
                                new_cls_preds[matching_selected[score_mask, 0]] = gt_infos_b['cls_preds'][
                                    matching_selected[score_mask, 1]]
                            
                            # for matching pairs, clear the ignore counter
                            new_memory_counter[matching_selected[:, 0]] = 0

                            #######################################################
                            # If previous bboxes disappeared: ious <= 0.1
                            #######################################################
                            disappear_idx = (ious < memory_ensemble_cfg.IOU_THRESH).nonzero()[0]

                            # if memory_ensemble_cfg.get('MEMORY_VOTING', None) and memory_ensemble_cfg.MEMORY_VOTING.ENABLED:
                            new_memory_counter[disappear_idx] += 1
                            # ignore gt_boxes that ignore_count == IGNORE_THRESH
                            ignore_mask = new_memory_counter >= memory_ensemble_cfg.MEMORY_VOTING.IGNORE_THRESH
                            new_gt_box[ignore_mask, 7] = -1

                            # remove gt_boxes that ignore_count >= RM_THRESH
                            remain_mask = new_memory_counter < memory_ensemble_cfg.MEMORY_VOTING.RM_THRESH
                            # remain_mask = new_memory_counter < 1
                            new_gt_box = new_gt_box[remain_mask]
                            new_memory_counter = new_memory_counter[remain_mask]
                            if gt_infos_a['cls_scores'] is not None:
                                new_cls_scores = new_cls_scores[remain_mask]
                            if gt_infos_a['iou_scores'] is not None:
                                new_iou_scores = new_iou_scores[remain_mask]
                            if gt_infos_a['cls_preds'] is not None:
                                new_cls_preds = new_cls_preds[remain_mask]

                            # Add new appear boxes
                            ious_b2a, match_idx_b2a = torch.max(iou_matrix, dim=0)
                            ious_b2a, match_idx_b2a = ious_b2a.numpy(), match_idx_b2a.numpy()

                            newboxes_idx = (ious_b2a < memory_ensemble_cfg.IOU_THRESH).nonzero()[0]
                            if newboxes_idx.shape[0] != 0:
                                new_gt_box = np.concatenate((new_gt_box, gt_infos_b['gt_boxes'][newboxes_idx, :]), axis=0)
                                if gt_infos_a['cls_scores'] is not None:
                                    new_cls_scores = np.concatenate((new_cls_scores, gt_infos_b['cls_scores'][newboxes_idx]), axis=0)
                                if gt_infos_a['iou_scores'] is not None:
                                    new_iou_scores = np.concatenate((new_iou_scores, gt_infos_b['iou_scores'][newboxes_idx]), axis=0)
                                
                                new_cls_preds = np.concatenate((new_cls_preds, gt_infos_b['cls_preds'][newboxes_idx]), axis=0)
                                new_memory_counter = np.concatenate((new_memory_counter, gt_infos_b['memory_counter'][newboxes_idx]), axis=0)

                            new_gt_infos = {
                                'gt_boxes': new_gt_box,
                                'cls_scores': new_cls_scores if gt_infos_a['cls_scores'] is not None else None,
                                'iou_scores': new_iou_scores if gt_infos_a['iou_scores'] is not None else None,
                                'cls_preds': new_cls_preds,
                                'memory_counter': new_memory_counter
                            }

                            return new_gt_infos

                        # memory_ensemble_cfg = cfg.SELF_TRAIN.TARGET_LEARNING.get('MEMORY_ENSEMBLE', None)
                        memory_ensemble_cfg = cfg.SELF_TRAIN.MEMORY_ENSEMBLE
                        # ensemble_func = getattr(memory_ensemble_utils, cfg.SELF_TRAIN.TARGET_LEARNING.MEMORY_ENSEMBLE.NAME)
                        ensemble_func = consistency_ensemble

                        gt_infos_a = TARGET_PSEUDO_LABELS[batch_dict_teacher['frame_id'][index]]
                        gt_infos_b = gt_infos

                        # if there are multiple classes
                        classes_a = np.unique(np.abs(gt_infos_a['gt_boxes'][:, -2]))
                        classes_b = np.unique(np.abs(gt_infos_b['gt_boxes'][:, -2]))

                        n_classes = max(classes_a.shape[0], classes_b.shape[0])
                        if n_classes == 0:
                            merged_infos = gt_infos_a
                            # return gt_infos_a
                        
                        # single category case
                        # if n_classes == 1:
                        #     return ensemble_func(gt_infos_a, gt_infos_b, memory_ensemble_cfg)

                        # for multi class case
                        merged_infos = {}
                        for i in np.union1d(classes_a, classes_b):
                            mask_a = np.abs(gt_infos_a['gt_boxes'][:, -2]) == i
                            gt_infos_a_i = common_utils.mask_dict(gt_infos_a, mask_a)

                            mask_b = np.abs(gt_infos_b['gt_boxes'][:, -2]) == i
                            gt_infos_b_i = common_utils.mask_dict(gt_infos_b, mask_b)

                            gt_infos = ensemble_func(gt_infos_a_i, gt_infos_b_i, memory_ensemble_cfg)
                            merged_infos = common_utils.concatenate_array_inside_dict(merged_infos, gt_infos)
                        
                        TARGET_PSEUDO_LABELS_TMP[batch_dict_teacher['frame_id'][index]] = merged_infos
                # pred_dicts.append(record_dict)
            # batch_pred_dict = pred_dicts

            st_loss = cfg.SELF_TRAIN.TAR.get('LOSS_WEIGHT', 1.0) * st_loss
            st_loss.backward()
            st_loss_meter.update(st_loss.item())

            try:
                cls_loss_item = st_tb_dict['cls_loss']
            except:
                cls_loss_item = 0
            try:
                reg_loss_item = st_tb_dict['reg_loss']
            except:
                reg_loss_item = 0
            
            cls_loss_meter.update(cls_loss_item)
            reg_loss_meter.update(reg_loss_item)

            cls_tb_dict = {}
            cls_tb_dict.update({'cls_loss': cls_loss_item})
            reg_tb_dict = {}
            reg_tb_dict.update({'reg_loss': reg_loss_item})

            cls_disp_dict = {}
            reg_disp_dict = {}

            if cfg.SELF_TRAIN.get('BDA', None):
                try:
                    bda_loss_item = st_tb_dict['bda_loss']
                except:
                    bda_loss_item = 0
                bda_tb_dict = {}
                bda_tb_dict.update({'bda_loss': bda_loss_item})
                bda_disp_dict = {}
                bda_loss_meter.update(bda_loss_item)

            # count number of used ps bboxes in this batch
            pos_pseudo_bbox = target_batch['pos_ps_bbox'].mean(dim=0).cpu().numpy()
            ign_pseudo_bbox = target_batch['ign_ps_bbox'].mean(dim=0).cpu().numpy()
            ps_bbox_nmeter.update(pos_pseudo_bbox.tolist())
            ign_ps_bbox_nmeter.update(ign_pseudo_bbox.tolist())
            pos_ps_result = ps_bbox_nmeter.aggregate_result()
            ign_ps_result = ign_ps_bbox_nmeter.aggregate_result()

            st_tb_dict = common_utils.add_prefix_to_dict(st_tb_dict, 'st_')
            disp_dict.update(common_utils.add_prefix_to_dict(st_disp_dict, 'st_'))
            disp_dict.update({'st_loss': "{:.3f}({:.3f})".format(st_loss_meter.val, st_loss_meter.avg),
                              'pos_ps_box': pos_ps_result,
                              'ign_ps_box': ign_ps_result})

            if cfg.SELF_TRAIN.get('TARGET_LEARNING', None):
                cls_tb_dict = common_utils.add_prefix_to_dict(cls_tb_dict, 'cls_')
                disp_dict.update(common_utils.add_prefix_to_dict(cls_disp_dict, 'cls_'))
                disp_dict.update({'cls_loss': "{:.3f}({:.3f})".format(cls_loss_meter.val, cls_loss_meter.avg)})

                reg_tb_dict = common_utils.add_prefix_to_dict(reg_tb_dict, 'reg_')
                disp_dict.update(common_utils.add_prefix_to_dict(reg_disp_dict, 'reg_'))
                disp_dict.update({'reg_loss': "{:.3f}({:.3f})".format(reg_loss_meter.val, reg_loss_meter.avg)})


            if cfg.SELF_TRAIN.get('BDA', None):
                bda_tb_dict = common_utils.add_prefix_to_dict(st_tb_dict, 'bda_')
                disp_dict.update(common_utils.add_prefix_to_dict(st_disp_dict, 'bda_'))
                disp_dict.update({'bda_loss': "{:.3f}({:.3f})".format(bda_loss_meter.val, bda_loss_meter.avg)})

        clip_grad_norm_(model.parameters(), optim_cfg.GRAD_NORM_CLIP)
        optimizer.step()
        accumulated_iter += 1

        # 检查是否需要执行模型聚合
        if aggregation_interval and (accumulated_iter % aggregation_interval == 0):
            if last_model_state is not None:
                ema_alpha = cfg.SELF_TRAIN.WEAK_MODEL_AGGREGATION.EMA_ALPHA
                current_model_state = model.state_dict()
                for key in last_model_state.keys():
                    last_model_state[key] = (
                        ema_alpha * last_model_state[key]
                        + (1 - ema_alpha) * current_model_state[key]
                    )
                model.load_state_dict(last_model_state)

        # log to console and tensorboard
        if rank == 0:
            pbar.update()
            pbar.set_postfix(dict(total_it=accumulated_iter, pos_ps_box=pos_ps_result,
                                  ign_ps_box=ign_ps_result))
            tbar.set_postfix(disp_dict)
            tbar.refresh()

            if tb_log is not None:
                tb_log.add_scalar('meta_data/learning_rate', cur_lr, accumulated_iter)
                if cfg.SELF_TRAIN.SRC.USE_DATA:
                    tb_log.add_scalar('train/loss', loss, accumulated_iter)
                    for key, val in tb_dict.items():
                        tb_log.add_scalar('train/' + key, val, accumulated_iter)
                if cfg.SELF_TRAIN.TAR.USE_DATA:
                    tb_log.add_scalar('train/st_loss', st_loss, accumulated_iter)
                    for key, val in st_tb_dict.items():
                        tb_log.add_scalar('train/' + key, val, accumulated_iter)
    if rank == 0:
        pbar.close()
        for i, class_names in enumerate(target_loader.dataset.class_names):
            tb_log.add_scalar(
                'ps_box/pos_%s' % class_names, ps_bbox_nmeter.meters[i].avg, cur_epoch)
            tb_log.add_scalar(
                'ps_box/ign_%s' % class_names, ign_ps_bbox_nmeter.meters[i].avg, cur_epoch)

    return accumulated_iter


def train_model_st(model, optimizer, source_loader, target_loader, model_func, lr_scheduler, optim_cfg,
                   start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, ps_label_dir,
                   source_sampler=None, target_sampler=None, lr_warmup_scheduler=None, ckpt_save_interval=1,
                   max_ckpt_save_num=50, merge_all_iters_to_one_epoch=False, logger=None, ema_model=None):
    accumulated_iter = start_iter
    source_reader = common_utils.DataReader(source_loader, source_sampler)
    source_reader.construct_iter()

    # for continue training.
    # if already exist generated pseudo label result
    ps_pkl = self_training_utils.check_already_exsit_pseudo_label(ps_label_dir, start_epoch)
    if ps_pkl is not None:
        logger.info('==> Loading pseudo labels from {}'.format(ps_pkl))

    # for continue training
    if cfg.SELF_TRAIN.get('PROG_AUG', None) and cfg.SELF_TRAIN.PROG_AUG.ENABLED and \
        start_epoch > 0:
        for cur_epoch in range(start_epoch):
            if cur_epoch in cfg.SELF_TRAIN.PROG_AUG.UPDATE_AUG:
                target_loader.dataset.data_augmentor.re_prepare(
                    augmentor_configs=None, intensity=cfg.SELF_TRAIN.PROG_AUG.SCALE)


    if cfg.SELF_TRAIN.get('WEAK_MODEL_AGGREGATION', None):
        last_model_state = None  

    with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True,
                     leave=(rank == 0)) as tbar:
        total_it_each_epoch = len(target_loader)
        if merge_all_iters_to_one_epoch:
            assert hasattr(target_loader.dataset, 'merge_all_iters_to_one_epoch')
            target_loader.dataset.merge_all_iters_to_one_epoch(merge=True, epochs=total_epochs)
            total_it_each_epoch = len(target_loader) // max(total_epochs, 1)

        dataloader_iter = iter(target_loader)
        for cur_epoch in tbar:
            if target_sampler is not None:
                target_sampler.set_epoch(cur_epoch)
                source_reader.set_cur_epoch(cur_epoch)

            # train one epoch
            if lr_warmup_scheduler is not None and cur_epoch < optim_cfg.WARMUP_EPOCH:
                cur_scheduler = lr_warmup_scheduler
            else:
                cur_scheduler = lr_scheduler

            # update pseudo label
            if (cur_epoch in cfg.SELF_TRAIN.UPDATE_PSEUDO_LABEL) or \
                    ((cur_epoch % cfg.SELF_TRAIN.UPDATE_PSEUDO_LABEL_INTERVAL == 0)
                     and cur_epoch != 0):
                target_loader.dataset.eval()
                self_training_utils.save_pseudo_label_epoch(
                    model, target_loader, rank,
                    leave_pbar=True, ps_label_dir=ps_label_dir, cur_epoch=cur_epoch
                )
                
                # if cfg.SELF_TRAIN.get('TARGET_LEARNING', None):
                #     if cur_epoch != 0:
                #         commu_utils.synchronize()
                #         self_training_utils.load_target(TARGET_PSEUDO_LABELS)
                #         commu_utils.synchronize()
                
                target_loader.dataset.train()
            
            # curriculum data augmentation
            if cfg.SELF_TRAIN.get('PROG_AUG', None) and cfg.SELF_TRAIN.PROG_AUG.ENABLED and \
                (cur_epoch in cfg.SELF_TRAIN.PROG_AUG.UPDATE_AUG):
                target_loader.dataset.data_augmentor.re_prepare(
                    augmentor_configs=None, intensity=cfg.SELF_TRAIN.PROG_AUG.SCALE)

            accumulated_iter = train_one_epoch_st(
                model, optimizer, source_reader, target_loader, model_func,
                lr_scheduler=cur_scheduler,
                accumulated_iter=accumulated_iter, optim_cfg=optim_cfg,
                rank=rank, tbar=tbar, tb_log=tb_log,
                leave_pbar=(cur_epoch + 1 == total_epochs),
                total_it_each_epoch=total_it_each_epoch,
                dataloader_iter=dataloader_iter, ema_model=ema_model, cur_epoch=cur_epoch
            )
            
            # if cfg.SELF_TRAIN.get('WEAK_MODEL_AGGREGATION', None):
            #     aggregation_interval = cfg.SELF_TRAIN.WEAK_MODEL_AGGREGATION.AGGREGATION_INTERVAL
            #     if aggregation_interval and (accumulated_iter % aggregation_interval == 0):
            #         if last_model_state is None:
            #             last_model_state = model.state_dict()
            #         else:
            #             ema_alpha = cfg.SELF_TRAIN.WEAK_MODEL_AGGREGATION.EMA_ALPHA
            #             for key in last_model_state.keys():
            #                 # last_model_state[key] = ema_alpha * last_model_state[key] + (1 - ema_alpha) model.state_dict()[key]
            #                 # new_model_state[key] = alpha * last_model_state[key] + (1 - alpha) * current_model_state[key]
            #                 last_model_state[key] = ema_alpha * last_model_state[key] + (1 - ema_alpha) * model.state_dict()[key]
            #             model.load_state_dict(last_model_state)

            commu_utils.synchronize()

            if cfg.SELF_TRAIN.get('TARGET_LEARNING', None):
                if dist.is_initialized():
                    part_pseudo_labels_list = commu_utils.all_gather(TARGET_PSEUDO_LABELS_TMP)

                    new_pseudo_label_dict = {}
                    for pseudo_labels in part_pseudo_labels_list:
                        new_pseudo_label_dict.update(pseudo_labels)

                    TARGET_PSEUDO_LABELS_TMP.update(new_pseudo_label_dict)
                
                commu_utils.synchronize()
                TARGET_PSEUDO_LABELS.clear()
                TARGET_PSEUDO_LABELS.update(TARGET_PSEUDO_LABELS_TMP)
                TARGET_PSEUDO_LABELS_TMP.clear()
                commu_utils.synchronize()
            # self_training_utils.load_target(TARGET_PSEUDO_LABELS)
            # commu_utils.synchronize()

            # save trained model
            trained_epoch = cur_epoch + 1
            if trained_epoch % ckpt_save_interval == 0 and rank == 0:

                ckpt_list = glob.glob(str(ckpt_save_dir / 'checkpoint_epoch_*.pth'))
                ckpt_list.sort(key=os.path.getmtime)

                if ckpt_list.__len__() >= max_ckpt_save_num:
                    for cur_file_idx in range(0, len(ckpt_list) - max_ckpt_save_num + 1):
                        os.remove(ckpt_list[cur_file_idx])

                ckpt_name = ckpt_save_dir / ('checkpoint_epoch_%d' % trained_epoch)
                state = checkpoint_state(model, optimizer, trained_epoch, accumulated_iter)

                save_checkpoint(state, filename=ckpt_name)