import torch
import torch.nn.functional as F
import os
import glob
import tqdm
from torch.nn.utils import clip_grad_norm_
import copy
import numpy as np
import wandb
import random
from multiprocessing import Manager

from pcdet.utils import common_utils, commu_utils
from pcdet.utils import self_training_utils
from pcdet.config import cfg
from pcdet.models.model_utils.dsnorm import set_ds_source, set_ds_target, set_ds_mid
from pcdet.models import load_data_to_gpu
from .train_utils import save_checkpoint, checkpoint_state
from pcdet.datasets.augmentor.augmentor_utils import get_points_in_box
from pcdet.utils.box_utils import remove_points_in_boxes3d, enlarge_box3d, \
    boxes3d_kitti_lidar_to_fakelidar, boxes_to_corners_3d
from pcdet.ops.iou3d_nms import iou3d_nms_utils
from .optimization import build_optimizer, build_scheduler

TARGET_PSEUDO_LABELS = Manager().dict()
TARGET_PSEUDO_LABELS_TMP = {}
GLOBAL_PROTOTYPES = None
GLOBAL_PROTOTYPES_INITIALIZED = False

GLOBAL_PS_BOXES = None
GLOBAL_PS_POINTS = None

GLOBAL_BATCH_DICT_TEACHER = None

def train_one_epoch_st(model, optimizer, source_reader, target_loader, model_func, lr_scheduler,
                       accumulated_iter, optim_cfg, rank, tbar, total_it_each_epoch,
                       dataloader_iter, tb_log=None, leave_pbar=False, ema_model=None, cur_epoch=None):
    if total_it_each_epoch == len(target_loader):
        dataloader_iter = iter(target_loader)

    if rank == 0:
        pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True)

    ps_bbox_nmeter = common_utils.NAverageMeter(len(cfg.CLASS_NAMES))
    ign_ps_bbox_nmeter = common_utils.NAverageMeter(len(cfg.CLASS_NAMES))
    loss_meter = common_utils.AverageMeter()
    st_loss_meter = common_utils.AverageMeter()

    cls_loss_meter = common_utils.AverageMeter()
    reg_loss_meter = common_utils.AverageMeter()

    bda_loss_meter = common_utils.AverageMeter()

    mid_loss_meter = common_utils.AverageMeter()

    disp_dict = {}                                                                                                                                        

    aggregation_interval = cfg.SELF_TRAIN.get('AGGREGATION_INTERVAL', None)  # 配置聚合间隔
    last_model_state = model.state_dict() if aggregation_interval else None  # 用于存储上一次的模型状态

    for cur_it in range(total_it_each_epoch):
        lr_scheduler.step(accumulated_iter)

        try:
            cur_lr = float(optimizer.lr)
        except:
            cur_lr = optimizer.param_groups[0]['lr']

        if tb_log is not None:
            tb_log.add_scalar('meta_data/learning_rate', cur_lr, accumulated_iter)

        model.train()
        if ema_model is not None:
            ema_model.train()

        optimizer.zero_grad()

        if cfg.SELF_TRAIN.SRC.USE_DATA:
            source_batch = source_reader.read_data()
        
        if cfg.SELF_TRAIN.TAR.USE_DATA:
            try:
                target_batch = next(dataloader_iter)
            except StopIteration:
                dataloader_iter = iter(target_loader)
                target_batch = next(dataloader_iter)
                print('new iters')
        
        if cfg.SELF_TRAIN.get('MT', None):
            batch_target1, batch_target2 = target_batch # s,t

            # add tag for target domain
            # batch_target1['batch_type'] = 'target'
            # batch_target2['batch_type'] = 'target'
            
            ema_model.train()
            if cfg.SELF_TRAIN.get('DSNORM', None):
                ema_model.apply(set_ds_target)
            
            if cfg.SELF_TRAIN.get('MT_CONSISTENCY', None) or cfg.SELF_TRAIN.get('DTS', None):
                # forward teacher model first
                from pcdet.models import load_data_to_gpu
                load_data_to_gpu(batch_target2)
                batch_target2 = ema_model(batch_target2, return_batch_dict=True)
                batch_dict_teacher = {}

                for key in ['rois_mt', 'roi_head_features_mt', 'roi_scores_mt', 'roi_iou_scores_mt', 'roi_cls_preds_mt']:
                    batch_dict_teacher[key] = batch_target2[key].detach().clone()

                global GLOBAL_BATCH_DICT_TEACHER
                GLOBAL_BATCH_DICT_TEACHER = batch_dict_teacher

            target_batch = batch_target1


        '''训练source数据，计算loss，更新梯度'''
        if cfg.SELF_TRAIN.SRC.USE_DATA and cfg.SELF_TRAIN.SRC.get('USE_GRAD', None):
            if cfg.SELF_TRAIN.get('DSNORM', None):
                model.apply(set_ds_source)
            model.train()
            loss, tb_dict, disp_dict, batch_dict_source = model_func(model, source_batch)
            # loss = cfg.SELF_TRAIN.SRC.get('LOSS_WEIGHT', 1.0) * loss
            loss_src = loss
            loss.backward(retain_graph=True)
            loss_meter.update(loss.item())
            disp_dict.update({'loss': "{:.3f}({:.3f})".format(loss_meter.val, loss_meter.avg)})
            # print('source raw start')

        if not cfg.SELF_TRAIN.SRC.get('USE_GRAD', None):
            optimizer.zero_grad()

        
        if cfg.SELF_TRAIN.get('COMPLE_AUG', None):
            data_dict_list = []
            batch_size = target_batch['gt_boxes'].shape[0]
            def process_batch_data(batch, idx):
                try:
                    gt_box = batch['gt_boxes'][idx].cpu().numpy()
                except:
                    gt_box = batch['gt_boxes'][idx]
                gt_box = gt_box[~np.all(gt_box == 0, axis=1)]  # 去除全为 0 的行

                try:
                    points = batch['points'][batch['points'][:, 0] == idx][:, 1:].cpu().numpy()
                except:
                    points = batch['points'][batch['points'][:, 0] == idx][:, 1:]

                try:
                    unreliable_gt_boxes = batch['unreliable_gt_boxes'][idx].cpu().numpy()
                except:
                    unreliable_gt_boxes = batch['unreliable_gt_boxes'][idx]

                unreliable_gt_boxes = unreliable_gt_boxes[~np.all(unreliable_gt_boxes == 0, axis=1)]  # 去除全为 0 的行

                # try:
                #     gt_labels = batch['gt_labels'][idx].cpu().numpy()
                # except:
                #     gt_labels = batch['gt_labels'][idx]
                
                # try:
                #     gt_scores = batch['gt_scores'][idx].cpu().numpy()
                # except:
                #     gt_scores = batch['gt_scores'][idx]

                # try:
                #     gt_names = batch['gt_names'][idx].cpu().numpy()
                # except:
                #     gt_names = batch['gt_names'][idx]

                return gt_box, points, unreliable_gt_boxes
            
            def add_sampled_ps_boxes_to_scene(data_dict, sampled_ps_boxes,
                                            sampled_ps_pnts, sampled_classes):

                gt_boxes = data_dict['gt_boxes']
                # gt_names = data_dict['gt_names']
                points = data_dict['points']
                
                obj_points = sampled_ps_pnts

                new_sampled_ps_pnts = []
                for k in range(sampled_ps_boxes.shape[0]):

                    # 缩放 ps box
                    old_gt_boxes = sampled_ps_boxes
                    new_gt_boxes = gt_boxes
                    scale_lwh = new_gt_boxes[k, 3:6] / old_gt_boxes[k, 3:6]

                    cur_obj_points = obj_points[k]
                    cur_obj_points = np.array(cur_obj_points)


                    obj_center, lwh, ry = old_gt_boxes[k, 0:3], old_gt_boxes[k, 3:6], old_gt_boxes[k, 6]

                    new_obj_center, new_lwh, new_ry = new_gt_boxes[k, 0:3], new_gt_boxes[k, 3:6], new_gt_boxes[k, 6]

                    # relative coordinates
                    cur_obj_points -= obj_center
                    cur_obj_points = common_utils.rotate_points_along_z(np.expand_dims(cur_obj_points, axis=0), -ry).squeeze(0)
                    cur_obj_points *= scale_lwh
                    cur_obj_points = common_utils.rotate_points_along_z(np.expand_dims(cur_obj_points, axis=0), new_ry).squeeze(0)

                    # calculate new object center to avoid object float over the road
                    # new_lwh = lwh * scale_lwh
                    new_obj_center[2] += (new_lwh[2] - lwh[2]) / 2
                    cur_obj_points += new_obj_center
                    # points[point_masks > 0] = obj_points
                    # gt_boxes[k, 3:6] = new_lwh
                    # object_scale_noise[k] = scale_noises[k][try_idx]
                    new_sampled_ps_pnts.append(cur_obj_points)
                
                # ps_pnts = new_sampled_ps_pnts


                ps_pnts = np.concatenate(new_sampled_ps_pnts, axis=0)
                extra_dim = points.shape[-1] - 3
                zero_dim = np.zeros((ps_pnts.shape[0], extra_dim)).reshape(
                    ps_pnts.shape[0], extra_dim)
                ps_pnts = np.concatenate((ps_pnts, zero_dim), axis=1)

                # points = box_utils.remove_points_in_boxes3d(points, sampled_ps_boxes)

                points = np.concatenate([ps_pnts[:, :points.shape[-1]], points], axis=0)

                return points

                # 提取 target 的点云
                # target_sobj_points = [get_points_in_box(target_points, box[:7])[0] for box in t_gt_box]

            import pickle as pkl
            # Load PS Objects Labels
            # ps_box_list = glob.glob(os.path.join(cfg.SELF_TRAIN.PS_SAMPLING.PS_OBJECT_PATH, 'ps_box_e*.pkl'))
            # ps_pnt_list = glob.glob(os.path.join(cfg.SELF_TRAIN.PS_SAMPLING.PS_OBJECT_PATH,'ps_point_e*.pkl'))

            # ps_box_list.sort(key=os.path.getmtime, reverse=True)
            # with open(ps_box_list[0], 'rb') as f:
            #     ps_boxes = pkl.load(f)
            # ps_pnt_list.sort(key=os.path.getmtime, reverse=True)
            # with open(ps_pnt_list[0], 'rb') as f:
            #     ps_points = pkl.load(f)
            global GLOBAL_PS_BOXES, GLOBAL_PS_POINTS
            ps_boxes = GLOBAL_PS_BOXES
            ps_points = GLOBAL_PS_POINTS
            
            for frame_idx in range(batch_size):
                mix_data = {}
                
                t_gt_box, target_points, unreliable_gt_boxes = process_batch_data(target_batch, frame_idx)
                T_pos = np.array(cfg.SELF_TRAIN.SCORE_THRESH)
                T_neg = np.array(cfg.SELF_TRAIN.NEG_THRESH)

                t_gt_box_unreliable = unreliable_gt_boxes[:, :7]
                t_gt_labels_unreliable = np.abs(unreliable_gt_boxes[:, -2]).astype(np.int32)
                t_gt_scores_unreliable = unreliable_gt_boxes[:, -1]

                # 移除点云中的目标框点
                target_points = remove_points_in_boxes3d(target_points, enlarge_box3d(t_gt_box_unreliable, extra_width=[1, 0.5, 0.5]))

                comple_box = []
                comple_class = []

                for box, box_score, box_label in zip(t_gt_box_unreliable, t_gt_scores_unreliable, t_gt_labels_unreliable):
                    prob_replace = (box_score - T_neg[box_label - 1]) / (T_pos[box_label - 1] - T_neg[box_label - 1])

                    if np.random.rand() < prob_replace:
                        comple_box.append(box)
                        comple_class.append(box_label)

                comple_box = np.array(comple_box)
                if len(comple_box) < 1:
                    # data_dict_list.append(target_batch)
                    data_dict = {'points': target_points, 'frame_id': target_batch['frame_id'][frame_idx], 'gt_boxes': t_gt_box}
                    
                    data_dict = target_loader.dataset.prepare_data_easy(data_dict=data_dict)
                    # data_dict = target_loader.dataset.prepare_data_easy(data_dict=data_dict)
                    data_dict_list.append(data_dict)
                    continue
                comple_class = np.array(comple_class)
                comple_box = comple_box[~np.all(comple_box == 0, axis=1)]  # 去除全为 0 的行
                comple_box = np.concatenate([comple_box, comple_class.reshape(-1, 1)], axis=1)

                # 把 comple_box 里的 box 按照最后一项（类别）进行排序
                comple_box = comple_box[comple_box[:, -1].argsort()]

                t_gt_box = np.concatenate([t_gt_box, comple_box], axis=0)
                # t_gt_box = np.concatenate([t_gt_box, comple_box], axis=0)

                class_names = cfg.CLASS_NAMES
                # 统计 comple_class 里面的各个类别数量
                comple_class = np.array(comple_class)
                comple_class_dict = {class_names[i]: np.sum(comple_class == (i + 1)) for i in range(len(class_names))}
                


                sample_groups = {}
                sample_class_num = {}
                for class_name, sample_num in comple_class_dict.items():
                    if class_name not in cfg.CLASS_NAMES:
                        continue
                    sample_class_num[class_name] = sample_num
                    sample_groups[class_name] = {
                        'sample_num': sample_num
                    }

                tmp_data_dict = {}
                tmp_data_dict['gt_boxes'] = comple_box
                # tmp_data_dict['gt_names'] = t_gt_names_unreliable.astype(str)
                tmp_data_dict['points'] = target_points
                # tmp_data_dict['gt_scores'] = t_gt_scores_unreliable

                # gt_boxes = t_gt_box
                # gt_names = data_dict['gt_names'].astype(str)
                # existed_boxes = gt_boxes
                sampled_class, total_sampled_pnts = [], []
                existed_boxes = None

                # sample_groups = cfg.SELF_TRAIN.PS_SAMPLING.SAMPLE_GROUPS
                for class_name, sample_group in sample_groups.items():

                    if int(sample_group['sample_num']) > 0:
                        sample_idx = np.random.randint(len(ps_boxes[class_name]), size=int(sample_group['sample_num']))
                        sampled_boxes = np.stack([ps_boxes[class_name][i] for i in sample_idx]).astype(np.float32)[:,:7] # (ps box num, 7)
                        sampled_pnts = [ps_points[class_name][i] for i in sample_idx]

                        valid_mask = np.ones(sampled_boxes.shape[0], dtype=bool)
                        valid_sampled_boxes = sampled_boxes[valid_mask]
                        existed_boxes = np.concatenate((existed_boxes, valid_sampled_boxes), axis=0) if existed_boxes is not None else valid_sampled_boxes

                        # sampled classes
                        sampled_class.extend([class_name for i in range(len(valid_mask))])

                        # sampled points
                        for mask_idx in range(valid_mask.shape[0]):
                            total_sampled_pnts.append(sampled_pnts[mask_idx])

                sampled_gt_boxes = existed_boxes
                if sampled_gt_boxes.shape[0] > 0:
                    points = add_sampled_ps_boxes_to_scene(tmp_data_dict,sampled_gt_boxes, total_sampled_pnts, sampled_class)
                
                # target_points = np.concatenate([points, target_points], axis=0)
                target_points = points
                data_dict = {'points': target_points, 'frame_id': target_batch['frame_id'][frame_idx], 'gt_boxes': t_gt_box}
                    
                data_dict = target_loader.dataset.prepare_data_easy(data_dict=data_dict)
                # data_dict = target_loader.dataset.prepare_data_easy(data_dict=data_dict)
                data_dict_list.append(data_dict)
                # return data_dict
            data_dict = target_loader.dataset.collate_batch(data_dict_list)
            for key in target_batch.keys():
                if key not in data_dict.keys():
                    data_dict[key] = target_batch[key]
            target_batch = data_dict

        if cfg.SELF_TRAIN.get('MID', None) and source_batch['gt_boxes'].shape[0] == target_batch['gt_boxes'].shape[0]:
            data_dict_list = []
            batch_size = min(target_batch['gt_boxes'].shape[0], source_batch['gt_boxes'].shape[0])

            pitch_angles = cfg.SELF_TRAIN.MID.get('PITCH', [-25, 3])
            num_areas = cfg.SELF_TRAIN.MID.get('NUM_AREAS', [3, 4, 5, 6])
            num_areas = np.random.choice(num_areas, size=1)[0]
            order = np.random.choice([0, 1], size=1)[0]

            lower_bound = np.deg2rad(pitch_angles[0]) + 1e-5
            upper_bound = np.deg2rad(pitch_angles[1]) - 1e-5
            pitch_angles = [lower_bound, upper_bound]


            def process_batch_data(batch, idx):
                try:
                    gt_box = batch['gt_boxes'][idx][:, :7].cpu().numpy()
                except:
                    gt_box = batch['gt_boxes'][idx][:, :7]
                gt_box = gt_box[~np.all(gt_box == 0, axis=1)]  # 去除全为 0 的行

                try:
                    points = batch['points'][batch['points'][:, 0] == idx][:, 1:].cpu().numpy()
                except:
                    points = batch['points'][batch['points'][:, 0] == idx][:, 1:]

                return gt_box, points

            for frame_idx in range(batch_size):
                mix_data = {}
                s_gt_box, single_pc_pnts = process_batch_data(source_batch, frame_idx)
                t_gt_box, target_points = process_batch_data(target_batch, frame_idx)

                # 提取 source 和 target 的点云
                source_obj_points = [get_points_in_box(single_pc_pnts, box[:7])[0] for box in s_gt_box]
                target_obj_points = [get_points_in_box(target_points, box[:7])[0] for box in t_gt_box]

                # 移除点云中的目标框点
                single_pc_pnts = remove_points_in_boxes3d(single_pc_pnts, enlarge_box3d(s_gt_box, extra_width=[1, 0.5, 0.5]))
                # single_pc_pnts = remove_points_in_boxes3d(single_pc_pnts, enlarge_box3d(t_gt_box, extra_width=[1, 0.5, 0.5]))
                # target_points = remove_points_in_boxes3d(target_points, enlarge_box3d(s_gt_box, extra_width=[1, 0.5, 0.5]))
                target_points = remove_points_in_boxes3d(target_points, enlarge_box3d(t_gt_box, extra_width=[1, 0.5, 0.5]))

                # 计算 pitch angle
                def compute_pitch(points, boxes):
                    rho = np.sqrt(points[:, 0]**2 + points[:, 1]**2)
                    pitch = np.arctan2(-1.8 + points[:, 2], rho)
                    pitch = np.clip(pitch, pitch_angles[0] + 1e-5, pitch_angles[1] - 1e-5)

                    rho_box = np.sqrt(boxes[:, 0]**2 + boxes[:, 1]**2)
                    pitch_box = np.arctan2(-1.8 + boxes[:, 2], rho_box)
                    pitch_box = np.clip(pitch_box, pitch_angles[0] + 1e-5, pitch_angles[1] - 1e-5)

                    return pitch, pitch_box

                src_pitch, src_pitch_box = compute_pitch(single_pc_pnts, s_gt_box)
                tgt_pitch, tgt_pitch_box = compute_pitch(target_points, t_gt_box)

                # 划分区域并混合点云
                angle_list = np.linspace(pitch_angles[1], pitch_angles[0], num_areas + 1)
                mix_points, mix_boxes, mix_obj_points = [], [], []

                for i in range(num_areas):
                    start_angle = angle_list[i + 1] # / 180 * np.pi
                    end_angle = angle_list[i] # / 180 * np.pi

                    if i % 2 == order:
                        idx = (src_pitch > start_angle) & (src_pitch <= end_angle)
                        mix_points.append(single_pc_pnts[idx])
                        idx_b = (src_pitch_box > start_angle) & (src_pitch_box <= end_angle)
                        mix_boxes.append(s_gt_box[idx_b])
                        if idx_b.sum() > 0:
                            mix_obj_points.append(np.concatenate([arr for arr, flag in zip(source_obj_points, idx_b) if flag]))
                    else:
                        idx = (tgt_pitch > start_angle) & (tgt_pitch <= end_angle)
                        mix_points.append(target_points[idx])
                        idx_b = (tgt_pitch_box > start_angle) & (tgt_pitch_box <= end_angle)
                        mix_boxes.append(t_gt_box[idx_b])
                        if idx_b.sum() > 0:
                            mix_obj_points.append(np.concatenate([arr for arr, flag in zip(target_obj_points, idx_b) if flag]))

                # 拼接点云和检测框
                mix_points = np.concatenate(mix_points)
                mix_boxes = np.concatenate(mix_boxes)
                # mix_obj_points = np.concatenate(mix_obj_points) if len(mix_obj_points) > 0 else []
                # mix_points = np.concatenate([mix_points, mix_obj_points] if len(mix_obj_points) > 0 else mix_points)
                if len(mix_obj_points) > 0:
                    mix_obj_points = np.concatenate(mix_obj_points)
                    mix_points = np.concatenate([mix_points, mix_obj_points])
                    # 否则 mix_points 保持原样（已在前面拼接过）


                # 准备数据字典
                data_dict = {
                    'points': mix_points,
                    'frame_id': target_batch['frame_id'][frame_idx],
                    'gt_boxes': mix_boxes[:, :7]
                }
                if len(mix_boxes) < 0:
                    data_dict = {
                        'points': target_points,
                        'frame_id': target_batch['frame_id'][frame_idx],
                        'gt_boxes': t_gt_box[:, :7]
                    }
                data_dict = target_loader.dataset.prepare_data_easy(data_dict=data_dict)
                data_dict_list.append(data_dict)
        
            # 定义需要保留的键
            keys_to_keep = {'points', 'frame_id', 'gt_boxes', 'non_sampled_gt_boxes_num', 'use_lead_xyz', 'voxels', 'voxel_coords', 'voxel_num_points'}
            # print('data_dict_list', data_dict_list)
            # 遍历列表中的每个字典，移除不需要的键
            for data in data_dict_list:
                # print(type(data))
                keys_to_remove = [key for key in data.keys() if key not in keys_to_keep]
                for key in keys_to_remove:
                    del data[key]

            mid_batch = target_loader.dataset.collate_batch(data_dict_list)
            for key in target_batch.keys():
                if key not in mid_batch:
                    mid_batch[key] = target_batch[key]
            
            # if cfg.SELF_TRAIN.get('DSNORM', None):
            #     model.apply(set_ds_target)
            if cfg.SELF_TRAIN.get('DSNORM', None):
                model.apply(set_ds_mid)
            model.train()
            mid_loss, mid_tb_dict, mid_disp_dict, batch_dict_mid = model_func(model, mid_batch)
            mid_loss.backward(retain_graph=True)
            mid_loss_meter.update(mid_loss.item())
            disp_dict.update({'mid_loss': "{:.3f}({:.3f})".format(mid_loss_meter.val, mid_loss_meter.avg)})
            
            
            
            # loss_mid, tb_dict, disp_dict, batch_dict_source = model_func(model, mid_batch)
            # loss_mid = cfg.SELF_TRAIN.MID.get('LOSS_WEIGHT', 1.0) * loss_mid
            # loss_src = loss
            # loss.backward(retain_graph=True)
            # loss_meter.update(loss.item())
            # disp_dict.update({'loss': "{:.3f}({:.3f})".format(loss_meter.val, loss_meter.avg)})
            # print('source raw start')

            # mid_tb_dict = {}
            # mid_tb_dict.update({'mid_loss': mid_loss_meter.val})
            # mid_disp_dict = {}
            # mid_disp_dict.update(mid_loss_i)

        if cfg.SELF_TRAIN.TAR.USE_DATA:
            
            # '''batch ADABN stage'''
            # if cfg.SELF_TRAIN.get('ADABN', None) and cfg.get('MT', None) == None:
            #     target_batch_raw = copy.deepcopy(target_batch)
            #     load_data_to_gpu(target_batch_raw)
            #     if ema_model is not None:
            #         ema_model.train()
            #         if cfg.SELF_TRAIN.get('DSNORM', None):
            #             ema_model.apply(set_ds_target)
            #         with torch.no_grad():
            #             batch_dict_teacher = ema_model(target_batch_raw, return_batch_dict=True)
            #     else:
            #         model.train()
            #         if cfg.SELF_TRAIN.get('DSNORM', None):
            #             model.apply(set_ds_target)
            #         with torch.no_grad():
            #             batch_dict_teacher = model(target_batch_raw, return_batch_dict=True)

            if cfg.SELF_TRAIN.get('DSNORM', None):
                model.apply(set_ds_target)
            model.train()
            
            '''将源域 obj 注入 target域'''
            if cfg.SELF_TRAIN.get('OBJECT', None) and source_batch['gt_boxes'].shape[0] == target_batch['gt_boxes'].shape[0]:
                data_dict_list = []
                batch_size = min(target_batch['gt_boxes'].shape[0], source_batch['gt_boxes'].shape[0])

                for idx in range(batch_size):
                    mix_data = {}
                    try:
                        s_gt_box = source_batch['gt_boxes'][idx].cpu().numpy()
                    except:
                        s_gt_box = source_batch['gt_boxes'][idx]
                    try:
                        t_gt_box = target_batch['gt_boxes'][idx].cpu().numpy()
                    except:
                        t_gt_box = target_batch['gt_boxes'][idx]
                    # t_gt_box = batch_pred_dict[idx]['pred_boxes'].cpu().numpy()
                    # t_gt_box = np.concatenate([t_gt_box, t_gt_labels.reshape(-1, 1)], axis=1)
                    # t_gt_scores = batch_pred_dict[idx]['pred_scores'].cpu().numpy()
                    # gt_mask = t_gt_scores > cfg.SELF_TRAIN.SCORE_THRESH
                    # t_gt_box = t_gt_box[gt_mask]

                    try:
                        single_pc_pnts = \
                            source_batch['points'][
                                source_batch['points'][:, 0] == idx][:, 1:].cpu().numpy()
                    except:
                        single_pc_pnts = \
                            source_batch['points'][
                                source_batch['points'][:, 0] == idx][:, 1:]

                    num_source_box_points = []
                    # combine s_gt_box and PC points
                    ps_pnts_to_sample_src = None
                    for box in s_gt_box:
                        points_in_box, mask = get_points_in_box(single_pc_pnts, box[:7])
                        ps_pnts_to_sample_src = points_in_box if ps_pnts_to_sample_src is None else np.concatenate([ps_pnts_to_sample_src, points_in_box])
                        num_source_box_points.append(points_in_box.shape[0])
                    num_source_box_points = np.array(num_source_box_points)

                    try:
                        target_points = target_batch['points'][target_batch['points'][:, 0] == idx][:, 1:].cpu().numpy()
                    except:
                        target_points = target_batch['points'][target_batch['points'][:, 0] == idx][:, 1:]
                    
                    num_target_box_points = []
                    # combine t_gt_box and PC points
                    ps_pnts_to_sample = None
                    for box in t_gt_box:
                        points_in_box, mask = get_points_in_box(target_points, box[:7])
                        ps_pnts_to_sample = points_in_box if ps_pnts_to_sample is None else np.concatenate([ps_pnts_to_sample, points_in_box])
                        num_target_box_points.append(points_in_box.shape[0])
                    num_target_box_points = np.array(num_target_box_points)
                    
                    if cfg.SELF_TRAIN.OBJECT.get('HARD', None) and cfg.SELF_TRAIN.OBJECT.HARD:
                        # mean_target_box_points = np.mean(num_target_box_points)
                        threshold = cfg.SELF_TRAIN.OBJECT.threshold
                        mean_target_box_points = np.percentile(num_target_box_points, threshold)

                        hard_source_box_idx = np.where(num_source_box_points < mean_target_box_points)[0]

                        if len(hard_source_box_idx) < 5:
                            hard_source_box_idx = np.argsort(num_source_box_points)[:1]
                        
                        max_num = cfg.SELF_TRAIN.OBJECT.get('max', 10)
                        if len(hard_source_box_idx) > max_num:
                            hard_source_box_idx = np.random.choice(hard_source_box_idx, max_num, replace=False)
                        
                        s_gt_box = [s_gt_box[i] for i in hard_source_box_idx]
                        s_gt_box = np.array(s_gt_box)

                        from pcdet.ops.iou3d_nms import iou3d_nms_utils


                        s_gt_box, _ = common_utils.check_numpy_to_torch(s_gt_box)
                        t_gt_box, _ = common_utils.check_numpy_to_torch(t_gt_box)
                        iou_matrix = iou3d_nms_utils.boxes_bev_iou_cpu(s_gt_box[:, :7], t_gt_box[:, :7]).cpu().numpy()
                        s_gt_box = s_gt_box.cpu().numpy()
                        t_gt_box = t_gt_box.cpu().numpy()
                        related_boxes = iou_matrix > 0.01
                        # 选出 与 target gt_box 无关的 source gt_box
                        s_gt_box = s_gt_box[np.sum(related_boxes, axis=1) == 0]
                        s_gt_box = np.array(s_gt_box)

                    ps_pnts_to_sample = None
                    for box in s_gt_box:
                        points_in_box, mask = get_points_in_box(single_pc_pnts, box[:7])
                        ps_pnts_to_sample = points_in_box if ps_pnts_to_sample is None else np.concatenate([ps_pnts_to_sample, points_in_box])
                    
                    target_points = remove_points_in_boxes3d(target_points, enlarge_box3d(s_gt_box[:, :7], extra_width=[1, 0.5, 0.5]))

                    try:
                        target_points = np.concatenate([target_points, ps_pnts_to_sample])
                    except:
                        pass
                    t_gt_box = np.concatenate([t_gt_box, s_gt_box])
                    # 将 新的points和gt_boxes 和 target_batch 原有的key结合
                    data_dict = {'points': target_points, 'frame_id': target_batch['frame_id'][idx], 'gt_boxes': t_gt_box[:, :7]}#, 'gt_names': gt_names}
                    
                    # if not (cfg.SELF_TRAIN.get('CONSISTENCY', None) and cfg.SELF_TRAIN.CONSISTENCY.get('REGION', None)):
                    data_dict = target_loader.dataset.prepare_data_easy(data_dict=data_dict)
                    data_dict_list.append(data_dict)
            
                target_batch_new = target_loader.dataset.collate_batch(data_dict_list)
                for key in target_batch.keys():
                    if key not in target_batch_new:
                        target_batch_new[key] = target_batch[key]
                target_batch = target_batch_new
            
        
            if cfg.SELF_TRAIN.get('DSNORM', None):
                if ema_model is not None:
                    ema_model.apply(set_ds_target)
                model.apply(set_ds_target)
            
             # parameters for save pseudo label on the fly
            if cfg.SELF_TRAIN.get('MT_CONSISTENCY', None):
                st_loss, st_tb_dict, st_disp_dict, batch_dict_target = model_func(model, target_batch, batch_dict_teacher)
            # elif cfg.SELF_TRAIN.get('TARGET_LEARNING', None) or cfg.SELF_TRAIN.get('BDA', None):
            #     st_loss, st_tb_dict, st_disp_dict, target_batch_dict = model_func(model, target_batch, batch_dict_teacher=pseudo_label_target if cfg.SELF_TRAIN.get('TARGET_LEARNING', None) or cfg.SELF_TRAIN.get('BDA', None) else None)
            elif cfg.SELF_TRAIN.get('DTS', None):
                batch_dict_teacher = GLOBAL_BATCH_DICT_TEACHER
                st_loss, st_tb_dict, st_disp_dict, batch_dict_target = model_func(model, target_batch, batch_dict_teacher=batch_dict_teacher)
            else:
                st_loss, st_tb_dict, st_disp_dict, batch_dict_target = model_func(model, target_batch)                      

            st_loss = cfg.SELF_TRAIN.TAR.get('LOSS_WEIGHT', 1.0) * st_loss
            st_loss_meter.update(st_loss.item())

            # if cfg.SELF_TRAIN.get('Triplet', None):
            #     batch_dict_source = batch_dict_source.detach()
            #     batch_dict_target = batch_dict_target
            #     source_features = batch_dict_source['roi_head_features']
            #     source_labels = batch_dict_source['roi_labels']
            #     target_features = batch_dict_target['roi_head_features']
            #     target_labels = batch_dict_target['roi_labels']
            # import torch
            # import torch.nn.functional as F


            if cfg.SELF_TRAIN.get('Triplet', None):
                # 提取特征和标签
                source_features = batch_dict_source['roi_head_features']  # (B, N_s, D)
                source_labels = batch_dict_source['roi_labels']           # (B, N_s)
                target_features = batch_dict_target['roi_head_features']  # (B, N_t, D)
                target_labels = batch_dict_target['roi_labels']           # (B, N_t)

                batch_size, num_target_rois, feat_dim = target_features.shape
                margin = 1.0
                triplet_loss = 0.0
                valid_count = 0

                lambda_loss_triplet = cfg.SELF_TRAIN.Triplet.get('lambda_loss_triplet', 0.1)
                target_lambda_loss_inter = cfg.SELF_TRAIN.Triplet.get('target_lambda_loss_inter', 1.0)
                target_lambda_loss_intra = cfg.SELF_TRAIN.Triplet.get('target_lambda_loss_intra', 1.0)
                source_lambda_loss_inter = cfg.SELF_TRAIN.Triplet.get('source_lambda_loss_inter', 1.0)
                source_lambda_loss_intra = cfg.SELF_TRAIN.Triplet.get('source_lambda_loss_intra', 1.0)

                if source_features.shape[0] == target_features.shape[0]:
                    for b in range(batch_size):
                        src_feat_b = source_features[b]  # (N_s, D)
                        src_label_b = source_labels[b]    # (N_s,)
                        tgt_feat_b = target_features[b]   # (N_t, D)
                        tgt_label_b = target_labels[b]    # (N_t,)

                        num_target = tgt_feat_b.shape[0]
                        num_source = src_feat_b.shape[0]

                        # 处理目标域锚点
                        for i in range(num_target):
                            anchor_feat = tgt_feat_b[i]
                            anchor_label = tgt_label_b[i].item()

                            # 跨域三元组 (目标锚点, 源正样本, 目标负样本)
                            src_same_mask = (src_label_b == anchor_label).clone()
                            if src_same_mask.sum() > 0:
                                src_pos = src_feat_b[src_same_mask][torch.argmax(torch.norm(src_feat_b[src_same_mask] - anchor_feat, dim=1))]
                                tgt_diff_mask = (tgt_label_b != anchor_label)
                                if tgt_diff_mask.sum() > 0:
                                    tgt_neg = tgt_feat_b[tgt_diff_mask][torch.argmin(torch.norm(tgt_feat_b[tgt_diff_mask] - anchor_feat, dim=1))]
                                    loss = F.relu(torch.norm(anchor_feat - src_pos) - torch.norm(anchor_feat - tgt_neg) + margin)
                                    triplet_loss += target_lambda_loss_inter * loss
                                    valid_count += 1

                            # 目标域内三元组 (目标锚点, 目标正样本, 目标负样本)
                            tgt_same_mask = (tgt_label_b == anchor_label).clone()
                            tgt_same_mask[i] = False
                            if tgt_same_mask.sum() > 0:
                                tgt_pos = tgt_feat_b[tgt_same_mask][torch.argmax(torch.norm(tgt_feat_b[tgt_same_mask] - anchor_feat, dim=1))]
                                tgt_diff_mask = (tgt_label_b != anchor_label)
                                if tgt_diff_mask.sum() > 0:
                                    tgt_neg = tgt_feat_b[tgt_diff_mask][torch.argmin(torch.norm(tgt_feat_b[tgt_diff_mask] - anchor_feat, dim=1))]
                                    loss = F.relu(torch.norm(anchor_feat - tgt_pos) - torch.norm(anchor_feat - tgt_neg) + margin)
                                    triplet_loss += target_lambda_loss_intra * loss
                                    valid_count += 1

                        # 处理源域锚点
                        for i in range(num_source):
                            anchor_feat = src_feat_b[i]
                            anchor_label = src_label_b[i].item()

                            # 源域内三元组 (源锚点, 源正样本, 源负样本)
                            src_same_mask = (src_label_b == anchor_label).clone()
                            src_same_mask[i] = False
                            if src_same_mask.sum() > 0:
                                src_pos = src_feat_b[src_same_mask][torch.argmax(torch.norm(src_feat_b[src_same_mask] - anchor_feat, dim=1))]
                                src_diff_mask = (src_label_b != anchor_label)
                                if src_diff_mask.sum() > 0:
                                    src_neg = src_feat_b[src_diff_mask][torch.argmin(torch.norm(src_feat_b[src_diff_mask] - anchor_feat, dim=1))]
                                    loss = F.relu(torch.norm(anchor_feat - src_pos) - torch.norm(anchor_feat - src_neg) + margin)
                                    triplet_loss += source_lambda_loss_intra * loss
                                    valid_count += 1

                            # 跨域三元组 (源锚点, 源正样本, 目标负样本)
                            # src_same_mask = (src_label_b == anchor_label).clone()
                            # if src_same_mask.sum() > 0:
                            #     src_pos = src_feat_b[src_same_mask][torch.argmax(torch.norm(src_feat_b[src_same_mask] - anchor_feat, dim=1))]
                            #     tgt_diff_mask = (tgt_label_b != anchor_label)
                            #     if tgt_diff_mask.sum() > 0:
                            #         tgt_neg = tgt_feat_b[tgt_diff_mask][torch.argmin(torch.norm(tgt_feat_b[tgt_diff_mask] - anchor_feat, dim=1))]
                            #         loss = F.relu(torch.norm(anchor_feat - src_pos) - torch.norm(anchor_feat - tgt_neg) + margin)
                            #         triplet_loss += source_lambda_loss_inter * loss
                            #         valid_count += 1
                            # 跨域三元组 (源锚点, 目标正样本, 源负样本)
                            tgt_same_mask = (tgt_label_b == anchor_label).clone()
                            if tgt_same_mask.sum() > 0:
                                tgt_pos = tgt_feat_b[tgt_same_mask][torch.argmax(torch.norm(tgt_feat_b[tgt_same_mask] - anchor_feat, dim=1))]
                                src_diff_mask = (src_label_b != anchor_label)
                                if src_diff_mask.sum() > 0:
                                    src_neg = src_feat_b[src_diff_mask][torch.argmin(torch.norm(src_feat_b[src_diff_mask] - anchor_feat, dim=1))]
                                    loss = F.relu(torch.norm(anchor_feat - tgt_pos) - torch.norm(anchor_feat - src_neg) + margin)
                                    triplet_loss += source_lambda_loss_inter * loss
                                    valid_count += 1


                    # 归一化损失
                    if valid_count > 0:
                        triplet_loss /= valid_count
                    else:
                        triplet_loss = 0.0

                    st_loss += lambda_loss_triplet * triplet_loss

            if cfg.SELF_TRAIN.get('ProtoAlign', None):
                
                from sklearn.mixture import GaussianMixture
                from sklearn.cluster import KMeans
                from scipy.optimize import linear_sum_assignment
                import torch.distributed as dist

                ###########################
                # 1. 对齐方法（匹配算法）
                ###########################

                def match_prototypes_greedy(old_prototypes, new_prototypes):
                    """
                    利用贪心算法对齐新旧原型。
                    
                    参数：
                    old_prototypes: Tensor，形状 (M, D)
                    new_prototypes: Tensor，形状 (M, D)
                    
                    返回：
                    aligned_new: Tensor，形状 (M, D)，使得 old_prototypes[i] 对应 aligned_new[i]
                    """
                    M = old_prototypes.shape[0]
                    dist_matrix = torch.cdist(old_prototypes, new_prototypes, p=2)  # (M, M)
                    aligned_new = torch.zeros_like(new_prototypes)
                    assigned_new = set()
                    for i in range(M):
                        min_val = float('inf')
                        min_idx = -1
                        for j in range(M):
                            if j in assigned_new:
                                continue
                            d = dist_matrix[i, j].item()
                            if d < min_val:
                                min_val = d
                                min_idx = j
                        assigned_new.add(min_idx)
                        aligned_new[i] = new_prototypes[min_idx]
                    return aligned_new

                def match_prototypes_hungarian(old_prototypes, new_prototypes):
                    """
                    利用匈牙利算法对齐新旧原型。
                    
                    参数：
                    old_prototypes: Tensor，形状 (M, D)
                    new_prototypes: Tensor，形状 (M, D)
                    
                    返回：
                    aligned_new: Tensor，形状 (M, D)
                    """
                    dist_matrix = torch.cdist(old_prototypes, new_prototypes, p=2).detach().cpu().numpy()
                    row_ind, col_ind = linear_sum_assignment(dist_matrix)
                    aligned_new = new_prototypes[col_ind]
                    return aligned_new

                def match_prototypes(old_prototypes, new_prototypes, alignment_method="greedy"):
                    """
                    根据配置选择对齐方法。
                    
                    参数：
                    alignment_method: "greedy" 或 "hungarian"
                    """
                    if alignment_method == "hungarian":
                        return match_prototypes_hungarian(old_prototypes, new_prototypes)
                    else:
                        return match_prototypes_greedy(old_prototypes, new_prototypes)

                ###########################
                # 2. 原型更新方法
                ###########################

                def update_prototypes_recluster(features, global_prototype, M, config, device='cuda'):
                    """
                    利用重新聚类方式更新原型。
                    
                    参数：
                    features: Tensor，形状 (N, D)（某类别当前 batch 特征）
                    global_prototype: Tensor，形状 (M, D)（旧原型）
                    M: int，该类别的原型数量
                    config: dict，包含配置项，如：
                            config["alignment_method"]: "greedy" 或 "hungarian"
                            config["coff"]: 移动平均更新系数（如 0.2）
                    device: 设备
                    
                    返回：
                    updated: Tensor，形状 (M, D)
                    """
                    if features.shape[0] == 0:
                        return global_prototype
                    
                    # 1) 本地做 GMM / KMeans 聚类
                    features_np = features.detach().cpu().numpy()
                    try:
                        gmm = GaussianMixture(n_components=M, covariance_type='diag', reg_covar=1e-6, random_state=0)
                        gmm.fit(features_np)
                        centroids = gmm.means_
                    except Exception as e:
                        print(f"[Rank {dist.get_rank() if dist.is_initialized() else 0}] GMM 聚类失败: {e}. 使用 KMeans 替代。")
                        kmeans = KMeans(n_clusters=M, init='k-means++', max_iter=300, n_init=10, random_state=0)
                        kmeans.fit(features_np)
                        centroids = kmeans.cluster_centers_
                    local_new_prototypes = torch.tensor(centroids, dtype=features.dtype, device=device)

                    # 2) 分布式环境下，对所有 rank 的局部 new_prototypes 做 All-Reduce 平均
                    if dist.is_initialized():
                        # 先 clone 一份，不要直接操作 local_new_prototypes，避免梯度或其他问题
                        proto_sum = local_new_prototypes.detach().clone()
                        dist.all_reduce(proto_sum, op=dist.ReduceOp.SUM)
                        world_size = dist.get_world_size()
                        global_new_prototypes = proto_sum / world_size
                    else:
                        global_new_prototypes = local_new_prototypes

                    global_prototype = global_prototype.to(device)
                    global_new_prototypes = global_new_prototypes.to(device)

                    # 如果 global_prototype 为全 0，则直接用 global_new_prototypes
                    if torch.allclose(global_prototype, torch.zeros_like(global_prototype)):
                        updated = global_new_prototypes
                    else:
                        # 3) 与旧原型对齐并更新
                        alignment_method = config.get("alignment_method", "greedy")
                        aligned_new = match_prototypes(global_prototype, global_new_prototypes, alignment_method=alignment_method)

                        coff = config.get("coff", 0.9)
                        updated = coff * global_prototype + (1 - coff) * aligned_new
                    return updated

                def update_prototypes_old_gmm(features, global_prototype, config, device='cuda'):
                    """
                    利用旧全局原型对新特征进行推理，在线更新原型。
                    
                    参数：
                    features: Tensor，形状 (N, D)
                    global_prototype: Tensor，形状 (M, D)
                    device: 设备
                    
                    思路：每个新特征计算与各原型的距离，分配到最近原型，然后用移动平均更新该原型。
                    """
                    M = global_prototype.shape[0]
                    if features.shape[0] == 0:
                        return global_prototype
                    dist = torch.cdist(features, global_prototype, p=2)  # (N, M)
                    assignments = torch.argmin(dist, dim=1)  # (N,)
                    updated_prototype = global_prototype.clone()
                    coff = config.get("coff", 0.9)  # 从配置获取，默认0.2

                    # 本地统计：每个原型要更新的“增量向量”和“计数”
                    # 例如，针对原型 m，如果分到它的特征均值是 mean_feature，则本地更新量 = coff * (mean_feature - old_proto)
                    local_increments = torch.zeros_like(global_prototype)  # (M, D)
                    counts = torch.zeros(M, dtype=torch.float32, device=device)

                    for m in range(M):
                        idx = (assignments == m)
                        if idx.sum() > 0:
                            mean_feature = features[idx].mean(dim=0)
                            increment = (1 - coff) * (mean_feature - global_prototype[m])
                            local_increments[m] = increment
                            counts[m] = 1.0  # 表示本 rank 的这个原型有数据更新

                    import torch.distributed as dist
                    # 分布式：汇总增量和计数
                    if dist.is_initialized():
                        dist.all_reduce(local_increments, op=dist.ReduceOp.SUM)
                        dist.all_reduce(counts, op=dist.ReduceOp.SUM)

                    # 平均处理：若某些 rank 没数据，则 counts[m] 会小一些，可做加权；下方仅做简单“若有数据则做更新”的示例
                    for m in range(M):
                        if counts[m] > 0:
                            updated_prototype[m] = global_prototype[m] + local_increments[m] / counts[m]

                    return updated_prototype

                ###########################
                # 3. 数据收集函数（支持非分布式和分布式）
                ###########################

                def gather_all_data(local_features, local_labels):
                    """
                    收集所有进程的特征和标签数据。如果不在分布式环境下，直接返回本地数据。
                    
                    参数：
                    local_features: Tensor, (N, D)
                    local_labels: Tensor, (N,)
                    
                    返回：
                    all_features: Tensor, (N_total, D)
                    all_labels:   Tensor, (N_total,)
                    """
                    if not dist.is_initialized():
                        return local_features, local_labels

                    world_size = dist.get_world_size()
                    # 将每个进程的数据转为列表
                    local_feats_list = local_features.cpu().tolist()
                    local_labels_list = local_labels.cpu().tolist()
                    feats_list = [None for _ in range(world_size)]
                    labels_list = [None for _ in range(world_size)]
                    dist.all_gather_object(feats_list, local_feats_list)
                    dist.all_gather_object(labels_list, local_labels_list)
                    flat_feats = []
                    flat_labels = []
                    for sublist in feats_list:
                        flat_feats.extend(sublist)
                    for sublist in labels_list:
                        flat_labels.extend(sublist)
                    all_features = torch.tensor(flat_feats, dtype=local_features.dtype, device=local_features.device)
                    all_labels = torch.tensor(flat_labels, dtype=local_labels.dtype, device=local_labels.device)
                    return all_features, all_labels

                ###########################
                # 4. 全局原型更新接口（更新所有类别原型）
                ###########################

                def update_global_prototypes(features, labels, global_prototypes, config, device='cuda'):
                    """
                    针对所有类别更新全局原型。
                    参数：
                    features: Tensor, (N_total, D)
                    labels: Tensor, (N_total,)
                    global_prototypes: dict, 格式：{ cls: {'prototypes': Tensor (M, D)} }
                    config: dict, 配置项：
                            config["update_method"]: "recluster" 或 "old_gmm"（默认 "recluster"）
                            config["alignment_method"]: "greedy" 或 "hungarian"（仅对 recluster 有效）
                            config["coff"]: 移动平均更新系数
                    device: 设备
                    返回：
                    updated_global: dict, 更新后的全局原型字典
                    """
                    # updated_global = {}
                    # for cls in global_prototypes.keys():
                    #     cls_mask = (labels == cls)
                    #     if cls_mask.sum() == 0:
                    #         updated_global[cls] = global_prototypes[cls]
                    #         continue
                    #     cls_features = features[cls_mask]  # (N_cls, D)
                    #     M = global_prototypes[cls]['prototypes'].shape[0]
                    #     update_method = config.get("update_method", "recluster")
                    #     if update_method == "old_gmm":
                    #         new_proto = update_prototypes_old_gmm(cls_features, global_prototypes[cls]['prototypes'], config, device=device)
                    #     else:
                    #         new_proto = update_prototypes_recluster(cls_features, global_prototypes[cls]['prototypes'], M, config, device=device)
                    #     updated_global[cls] = {'prototypes': new_proto}
                    # return updated_global
                    updated_global = {}
                    for cls in global_prototypes.keys():
                        cls_mask = (labels == cls)
                        # 仅取本进程上该类的本地特征
                        cls_features = features[cls_mask]  # (N_local_cls, D)

                        old_proto = global_prototypes[cls]['prototypes']
                        M = old_proto.shape[0]

                        update_method = config.get("update_method", "recluster")
                        if update_method == "old_gmm":
                            new_proto = update_prototypes_old_gmm(cls_features, old_proto, config, device=device)
                        else:
                            new_proto = update_prototypes_recluster(cls_features, old_proto, M, config, device=device)
                        updated_global[cls] = {'prototypes': new_proto}

                    return updated_global

                ###########################
                # 5. 广播函数：同步全局原型字典
                ###########################

                def broadcast_global_prototypes(global_prototypes, src=0):
                    """
                    同步全局原型字典到所有进程。如果不在分布式环境下，直接返回对象。
                    
                    参数：
                    global_prototypes: Python 对象（例如 dict）
                    src: 源进程编号
                    返回：
                    global_prototypes: 所有进程一致的全局原型字典
                    """
                    if not dist.is_initialized():
                        return global_prototypes
                    obj_list = [global_prototypes]
                    dist.broadcast_object_list(obj_list, src=src)
                    return obj_list[0]

                ###########################
                # 6. 示例更新接口（DDP 或非DDP）
                ###########################

                def example_update(global_prototypes, source_features, source_labels, target_features, target_labels, config, device='cuda'):
                    """
                    更新全局原型的示例函数，支持分布式和非分布式。
                    
                    参数：
                    source_features: Tensor, (B, N_s, D)
                    source_labels:   Tensor, (B, N_s)
                    target_features: Tensor, (B, N_t, D)
                    target_labels:   Tensor, (B, N_t)
                    global_prototypes: dict, 全局原型字典
                    config: dict, 配置项
                    device: 设备
                    返回：
                    updated_global_prototypes: dict, 更新后的全局原型字典
                    """
                    # 1) 整理本地数据
                    B_s, N_s, D = source_features.shape
                    B_t, N_t, _ = target_features.shape
                    src_feats = source_features.view(-1, D)
                    src_lbls  = source_labels.view(-1)
                    tgt_feats = target_features.view(-1, D)
                    tgt_lbls  = target_labels.view(-1)
                    local_feats = torch.cat([src_feats, tgt_feats], dim=0)
                    local_lbls = torch.cat([src_lbls, tgt_lbls], dim=0)

                    # 2) 每个进程本地更新 global_prototypes (函数内部会 all_reduce 求均值)
                    updated_global_prototypes = update_global_prototypes(local_feats, local_lbls, global_prototypes, config, device=device)

                    # 3) 更新全局变量
                    global GLOBAL_PROTOTYPES
                    GLOBAL_PROTOTYPES = updated_global_prototypes

                    return updated_global_prototypes

                ###########################
                # 7. DDP 离线初始化全局原型（Offline Initialization）
                ###########################

                # def read_source_target_batches(cfg, source_reader, target_loader, dataloader_iter):
                #     """
                #     根据配置读取源域和目标域数据，返回 (source_batch, target_batch, dataloader_iter)。
                #     这里假设返回的数据为字典格式，其中包含 'roi_head_features' 和 'roi_labels'
                #     """
                #     source_batch = None
                #     target_batch = None
                #     if cfg.SELF_TRAIN.SRC.USE_DATA:
                #         source_batch = source_reader.read_data()
                #     if cfg.SELF_TRAIN.TAR.USE_DATA:
                #         try:
                #             target_batch = next(dataloader_iter)
                #         except StopIteration:
                #             dataloader_iter = iter(target_loader)
                #             target_batch = next(dataloader_iter)
                #             print('new iters')
                #     return source_batch, target_batch, dataloader_iter

                # def offline_initialize_prototypes_dict_ddp(feature_bank, prototype_nums, device='cuda'):
                #     """
                #     离线初始化每个类别的原型，适用于 DDP 环境。
                #     各进程先构建 feature_bank（格式：{cls: Tensor(N, D)}），然后通过 all_gather_object 收集所有进程的 feature_bank，
                #     rank 0 合并后利用 KMeans 进行聚类初始化，得到全局原型字典，再广播给所有进程。
                #     """
                #     if not dist.is_initialized():
                #         # 非分布式情况下，直接在本进程内初始化
                #         merged_bank = feature_bank
                #     else:
                #         world_size = dist.get_world_size()
                #         local_bank = feature_bank  # 每个进程自己的字典
                #         bank_list = [None for _ in range(world_size)]
                #         dist.all_gather_object(bank_list, local_bank)
                #         merged_bank = {}
                #         for bank in bank_list:
                #             for cls, feats in bank.items():
                #                 if cls in merged_bank:
                #                     merged_bank[cls] = torch.cat([merged_bank[cls], feats.cpu()], dim=0)
                #                 else:
                #                     merged_bank[cls] = feats.cpu()
                #     # rank 0 进行离线聚类初始化
                #     if dist.is_initialized():
                #         rank = dist.get_rank()
                #     else:
                #         rank = 0
                #     if rank == 0:
                #         global_prototypes = {}
                #         for cls, feats in merged_bank.items():
                #             n_components = prototype_nums.get(cls, 1)
                #             feats_np = feats.numpy()
                #             kmeans = KMeans(n_clusters=n_components, init='k-means++', max_iter=300, n_init=10, random_state=42)
                #             kmeans.fit(feats_np)
                #             centroids = kmeans.cluster_centers_
                #             centroids_tensor = torch.tensor(centroids, dtype=feats.dtype, device=device)
                #             global_prototypes[cls] = {'prototypes': centroids_tensor}
                #     else:
                #         global_prototypes = None
                #     global_prototypes = broadcast_global_prototypes(global_prototypes, src=0)
                #     return global_prototypes

                # def ddp_offline_initialize(cfg, source_reader, target_loader, dataloader_iter, device='cuda'):
                #     """
                #     离线初始化全局原型：
                #     1. 各进程读取源域和目标域数据，提取 ROI 特征和标签；
                #     2. 按类别构建 feature_bank（字典格式 {cls: Tensor(N, D)}）；
                #     3. 调用 offline_initialize_prototypes_dict_ddp 进行全局初始化；
                #     4. 返回全局原型和更新后的 dataloader_iter。
                #     """
                #     source_batch, target_batch, dataloader_iter = read_source_target_batches(cfg, source_reader, target_loader, dataloader_iter)
                #     if source_batch is None or target_batch is None:
                #         raise ValueError("必须同时读取到源域和目标域数据进行初始化")
                #     src_feats = source_batch['roi_head_features']  # (B, N, D)
                #     src_lbls = source_batch['roi_labels']           # (B, N)
                #     tgt_feats = target_batch['roi_head_features']   # (B, N, D)
                #     tgt_lbls = target_batch['roi_labels']            # (B, N)
                #     B_s, N_s, D = src_feats.shape
                #     B_t, N_t, _ = tgt_feats.shape
                #     src_feats = src_feats.view(-1, D)
                #     src_lbls  = src_lbls.view(-1)
                #     tgt_feats = tgt_feats.view(-1, D)
                #     tgt_lbls  = tgt_lbls.view(-1)
                #     all_feats = torch.cat([src_feats, tgt_feats], dim=0)
                #     all_lbls  = torch.cat([src_lbls, tgt_lbls], dim=0)
                #     # 构建 feature_bank：按照类别划分
                #     feature_bank = {}
                #     for feat, lbl in zip(all_feats, all_lbls):
                #         cls = int(lbl.item())
                #         if cls in feature_bank:
                #             feature_bank[cls] = torch.cat([feature_bank[cls], feat.unsqueeze(0)], dim=0)
                #         else:
                #             feature_bank[cls] = feat.unsqueeze(0)
                #     prototype_nums = cfg.PROTOTYPE_NUMS  # 例如 {0: 3, 1: 2, ...}
                #     global_prototypes = offline_initialize_prototypes_dict_ddp(feature_bank, prototype_nums, device=device)
                #     return global_prototypes, dataloader_iter

                def random_initialize_prototypes_dict(prototype_nums, D, device='cuda'):
                    """
                    随机初始化全局原型，同时考虑 DDP 环境下的一致性。
                    
                    参数：
                    prototype_nums: dict，格式 { cls: num_prototypes }，例如 {0: 3, 1: 2, ...}
                    D: int，特征维度
                    device: 设备（例如 'cuda' 或 'cpu'）
                    
                    返回：
                    global_prototypes: dict，每个类别的全局原型字典，
                                        格式为 { cls: {'prototypes': Tensor (num_prototypes, D)} }
                    """
                    # 若不在分布式环境下，直接生成
                    if not dist.is_initialized():
                        global_prototypes = {}
                        for cls, num in prototype_nums.items():
                            prototypes = torch.zeros(num, D, device=device)
                            global_prototypes[cls] = {'prototypes': prototypes}
                        return global_prototypes

                    # 分布式情况下，仅由 rank 0 生成
                    rank = dist.get_rank()
                    if rank == 0:
                        global_prototypes = {}
                        for cls, num in prototype_nums.items():
                            prototypes = torch.randn(num, D, device='cpu', requires_grad=False)
                            # global_prototypes[cls]['prototypes'] = torch.randn(num, D, device=device, requires_grad=False)
                            global_prototypes[cls] = {'prototypes': prototypes}
                    else:
                        global_prototypes = None

                    global_prototypes = broadcast_global_prototypes(global_prototypes, src=0)
                    for cls in global_prototypes:
                        global_prototypes[cls]['prototypes'] = global_prototypes[cls]['prototypes'].to(device=device)
                    return global_prototypes

                def initialize_global_prototypes(init_method, prototype_nums, D, device='cuda', **kwargs):
                    """
                    根据指定的初始化方法初始化全局原型，并确保只初始化一次（无论DDP还是非DDP）。
                    
                    参数：
                    init_method: str，初始化方法，支持 'random' 或 'offline'
                    prototype_nums: dict，格式 { cls: num_prototypes }
                    D: int，特征维度
                    device: 设备
                    kwargs: 若采用 offline 初始化，可传入 offline 数据等参数
                    返回：
                    global_prototypes: dict，全局原型字典
                    """
                    global GLOBAL_PROTOTYPES, GLOBAL_PROTOTYPES_INITIALIZED

                    # 如果已经初始化，直接返回
                    if GLOBAL_PROTOTYPES_INITIALIZED:
                        return GLOBAL_PROTOTYPES

                    if init_method == 'random':
                        prototypes = random_initialize_prototypes_dict(prototype_nums, D, device=device)
                    elif init_method == 'offline':
                        # 此处调用 offline 初始化（例如 ddp_offline_initialize），假设 offline 初始化函数返回 (global_prototypes, dataloader_iter)
                        # 注意：offline初始化需要离线数据，您可以根据实际需求实现 offline 初始化方案
                        # 例如：
                        # prototypes, _ = ddp_offline_initialize(cfg, source_reader, target_loader, dataloader_iter, device=device)
                        raise NotImplementedError("offline 初始化方法请根据实际需求实现")
                    else:
                        raise ValueError("Invalid init_method. Use 'random' or 'offline'.")

                    GLOBAL_PROTOTYPES = prototypes
                    GLOBAL_PROTOTYPES_INITIALIZED = True
                    return GLOBAL_PROTOTYPES
                
                def compute_feature_to_prototype_loss(features, labels, global_prototypes, 
                                      margin_inter=0.5, 
                                      lambda_intra=1.0, lambda_inter=1.0):
                    """
                    计算特征与原型之间的损失，包含两部分：
                    1. Intra-class 对齐损失：鼓励样本与其所属类别中最近的原型距离尽可能小。
                    2. Inter-class 分离损失：鼓励样本与其他类别原型的距离大于 margin_inter。
                    
                    参数：
                    features: Tensor, shape (N, D)，所有样本的特征。
                    labels: Tensor, shape (N,)，每个样本的类别标签（整数）。
                    global_prototypes: dict, 格式为 { cls: {'prototypes': Tensor (M, D)} }，表示各类别的原型。
                    margin_inter: float，非本类原型的距离阈值，若距离低于该值则产生惩罚。
                    lambda_intra: float，Intra-class 损失权重。
                    lambda_inter: float，Inter-class 损失权重。
                    
                    返回：
                    total_loss: 标量 Tensor，为特征–原型损失总和。
                    """
                    if not dist.is_initialized():
                        device = torch.device('cuda')
                    else:
                        # local_rank = dist.get_rank()
                        device = features.device
                    # feats_c 
                    # device = features.device
                    loss_intra = torch.tensor(0.0, device=device)  # 对齐损失 torch.zeros(1).float().cuda()[0]
                    loss_inter = torch.tensor(0.0, device=device)  # 分离损失
                    count_intra = 0
                    count_inter = 0

                    # 遍历各类别，对每个类别计算损失
                    for cls, proto_dict in global_prototypes.items():
                        # 获取该类别的样本
                        cls_mask = (labels == cls)
                        if cls_mask.sum() == 0:
                            continue
                        feats_c = features[cls_mask]  # (n_c, D)
                        # prototypes = proto_dict['prototypes']  # (M, D)
                        prototypes = proto_dict['prototypes']  # 确保两者在同一设备
                        # 计算样本与该类别所有原型之间的欧氏距离，结果 shape (n_c, M)
                        dists = torch.cdist(feats_c, prototypes, p=2)
                        # 对于每个样本，取最小距离
                        min_dists, _ = torch.min(dists, dim=1)
                        # Intra-class loss：鼓励每个样本尽可能接近其最近的原型
                        loss_intra += torch.mean(min_dists ** 2)
                        count_intra += 1

                        # Inter-class loss：对于该类别的样本，计算与其他类别原型的距离惩罚
                        # 注意：这里只对该类别样本与其他类别的原型计算惩罚
                        for other_cls, other_proto_dict in global_prototypes.items():
                            if other_cls == cls:
                                continue
                            prototypes_other = other_proto_dict['prototypes']  # (M_other, D)
                            dists_other = torch.cdist(feats_c, prototypes_other, p=2)  # (n_c, M_other)
                            # 如果距离低于 margin_inter，则产生损失
                            penalty = F.relu(margin_inter - dists_other)
                            loss_inter += torch.mean(penalty ** 2)
                            count_inter += 1

                    if count_intra > 0:
                        loss_intra = loss_intra / count_intra
                    else:
                        loss_intra = torch.tensor(0.0, device=device)
                    if count_inter > 0:
                        loss_inter = loss_inter / count_inter
                    else:
                        loss_inter = torch.tensor(0.0, device=device)

                    total_loss = lambda_intra * loss_intra + lambda_inter * loss_inter
                    return total_loss
                # 提取特征和标签
                source_features = batch_dict_source['roi_head_features']  # (B, N_s, D)
                source_labels = batch_dict_source['roi_labels']           # (B, N_s)
                target_features = batch_dict_target['roi_head_features']  # (B, N_t, D)
                target_labels = batch_dict_target['roi_labels']           # (B, N_t)
                prototype_nums = cfg.SELF_TRAIN.ProtoAlign.get('PROTOTYPE_NUMS', [3, 1, 1])
                prototype_nums = {i + 1: prototype_nums[i] for i in range(len(prototype_nums))}
                D = source_features.shape[-1]
                device = source_features.device
                config = cfg.SELF_TRAIN.ProtoAlign.get('CONFIG', {
                                        "update_method": "recluster",        # 可选："recluster"（重新聚类更新）或 "old_gmm"（利用旧 GMM 推理在线更新）
                                        "alignment_method": "hungarian",     # 对于 "recluster" 方法可选："greedy" 或 "hungarian"
                                        "coff": 0.9,                         # 移动平均更新系数
                                        'init_method': 'random'              # 可选 'random' 或 'offline'
                                    })
                global_prototypes = initialize_global_prototypes(config['init_method'], prototype_nums, D, device=device)
                global_prototypes = example_update(global_prototypes, source_features, source_labels, target_features, target_labels, config, device=device)
                
                # 2. 计算特征–原型对齐损失
                # 将当前 batch 内的源域和目标域的 ROI 特征与标签合并，作为计算损失的数据
                B_s, N_s, D = source_features.shape
                B_t, N_t, _ = target_features.shape
                # 重塑为二维 Tensor： (B*N, D)
                src_feats = source_features.view(-1, D)
                src_lbls  = source_labels.view(-1)
                tgt_feats = target_features.view(-1, D)
                tgt_lbls  = target_labels.view(-1)
                all_feats = torch.cat([src_feats, tgt_feats], dim=0)
                all_lbls = torch.cat([src_lbls, tgt_lbls], dim=0)
                
                # 配置特征–原型损失参数：
                margin_inter = 0.5    # 非本类原型距离阈值
                lambda_intra = 1.0    # Intra-class 损失权重（这里用每个样本与其所属类别最近原型距离的平方）
                lambda_inter = 1.0    # Inter-class 损失权重
                
                proto_loss = compute_feature_to_prototype_loss(all_feats, all_lbls, global_prototypes,
                                                                margin_inter=margin_inter,
                                                                lambda_intra=lambda_intra,
                                                                lambda_inter=lambda_inter)
                
                # 3. 组合所有损失（这里假设检测损失包括 loss_source 和 st_loss）
                # total_loss = loss_source + st_loss + proto_loss
                lambda_proto = cfg.SELF_TRAIN.ProtoAlign.get('LAMBDA', 1.0)
                st_loss = st_loss + lambda_proto * proto_loss

            # '''训练source数据，计算loss，更新梯度'''
            # if (cfg.SELF_TRAIN.SRC.USE_DATA and cfg.SELF_TRAIN.SRC.get('USE_GRAD', None)):
            #     st_loss = st_loss + loss_src

            # if cfg.SELF_TRAIN.get('MID', None):
            #     st_loss = st_loss + mid_loss

            torch.autograd.set_detect_anomaly(True)

            st_loss.backward()

            if 'pos_ps_bbox' in target_batch.keys():
                # count number of used ps bboxes in this batch
                pos_pseudo_bbox = target_batch['pos_ps_bbox'].mean(dim=0).cpu().numpy()
                ign_pseudo_bbox = target_batch['ign_ps_bbox'].mean(dim=0).cpu().numpy()
            else:
                pos_pseudo_bbox = torch.zeros(3)
                ign_pseudo_bbox = torch.zeros(3)
            ps_bbox_nmeter.update(pos_pseudo_bbox.tolist())
            ign_ps_bbox_nmeter.update(ign_pseudo_bbox.tolist())
            pos_ps_result = ps_bbox_nmeter.aggregate_result()
            ign_ps_result = ign_ps_bbox_nmeter.aggregate_result()

            st_tb_dict = common_utils.add_prefix_to_dict(st_tb_dict, 'st_')
            disp_dict.update(common_utils.add_prefix_to_dict(st_disp_dict, 'st_'))
            disp_dict.update({'st_loss': "{:.3f}({:.3f})".format(st_loss_meter.val, st_loss_meter.avg),
                              'pos_ps_box': pos_ps_result,
                              'ign_ps_box': ign_ps_result})

            if cfg.SELF_TRAIN.get('MID', None):
                mid_tb_dict = common_utils.add_prefix_to_dict(mid_tb_dict, 'mid_')
                disp_dict.update(common_utils.add_prefix_to_dict(mid_disp_dict, 'mid_'))
                disp_dict.update({'mid_loss': "{:.3f}({:.3f})".format(mid_loss_meter.val, mid_loss_meter.avg)})

            if cfg.SELF_TRAIN.get('TARGET_LEARNING', None):
                cls_tb_dict = common_utils.add_prefix_to_dict(cls_tb_dict, 'cls_')
                disp_dict.update(common_utils.add_prefix_to_dict(cls_disp_dict, 'cls_'))
                disp_dict.update({'cls_loss': "{:.3f}({:.3f})".format(cls_loss_meter.val, cls_loss_meter.avg)})

                reg_tb_dict = common_utils.add_prefix_to_dict(reg_tb_dict, 'reg_')
                disp_dict.update(common_utils.add_prefix_to_dict(reg_disp_dict, 'reg_'))
                disp_dict.update({'reg_loss': "{:.3f}({:.3f})".format(reg_loss_meter.val, reg_loss_meter.avg)})


            if cfg.SELF_TRAIN.get('BDA', None):
                bda_tb_dict = common_utils.add_prefix_to_dict(st_tb_dict, 'bda_')
                disp_dict.update(common_utils.add_prefix_to_dict(st_disp_dict, 'bda_'))
                disp_dict.update({'bda_loss': "{:.3f}({:.3f})".format(bda_loss_meter.val, bda_loss_meter.avg)})
            
        clip_grad_norm_(model.parameters(), optim_cfg.GRAD_NORM_CLIP)
        # torch.cuda.synchronize()  # 等待所有GPU完成反向传播
        optimizer.step()
        accumulated_iter += 1


        # 检查是否需要执行模型聚合
        if aggregation_interval and (accumulated_iter % aggregation_interval == 0):
            if last_model_state is not None:
                ema_alpha = cfg.SELF_TRAIN.WEAK_MODEL_AGGREGATION.EMA_ALPHA
                current_model_state = model.state_dict()
                for key in last_model_state.keys():
                    last_model_state[key] = (
                        ema_alpha * last_model_state[key]
                        + (1 - ema_alpha) * current_model_state[key]
                    )
                model.load_state_dict(last_model_state)

        if ema_model is not None:
            update_ema_variables(model, ema_model, model_cfg=ema_model.model_cfg, cur_iter=accumulated_iter)
            # ema_rampup_start, ema_start = cfg.SELF_TRAIN.EMA_LEARNING.EMA_EPOCH
            # assert ema_rampup_start <= ema_start
            # epoch_id = cur_epoch
            # if epoch_id < ema_rampup_start:
            #     pass
            # elif (epoch_id >= ema_rampup_start) and (epoch_id < ema_start):
            #     # if accumulated_iter % ssl_cfg.TEACHER.NUM_ITERS_PER_UPDATE == 0:
            #         # if dist:
            #             #if rank == 0:
            #     update_ema_variables(student_model.module.onepass, teacher_model.module.onepass, ssl_cfg.TEACHER.RAMPUP_EMA_MOMENTUM, accumulated_iter)
            #         # else:
            #             # update_ema_variables(student_model, teacher_model, ssl_cfg.TEACHER.RAMPUP_EMA_MOMENTUM, accumulated_iter)
            # elif epoch_id >= ema_start:
            #     if accumulated_iter % ssl_cfg.TEACHER.NUM_ITERS_PER_UPDATE == 0:
            #         if dist:
            #             #if rank == 0:
            #             update_ema_variables_with_fixed_momentum(student_model.module.onepass, teacher_model.module.onepass, ssl_cfg.TEACHER.EMA_MOMENTUM)
            #         else:
            #             update_ema_variables_with_fixed_momentum(student_model, teacher_model, ssl_cfg.TEACHER.EMA_MOMENTUM)
            # else:
            #     raise Exception('Impossible condition for EMA update')
        # log to console and tensorboard
        if rank == 0:
            pbar.update()
            pbar.set_postfix(dict(total_it=accumulated_iter, pos_ps_box=pos_ps_result,
                                  ign_ps_box=ign_ps_result))
            tbar.set_postfix(disp_dict)
            tbar.refresh()

            if tb_log is not None:
                tb_log.add_scalar('meta_data/learning_rate', cur_lr, accumulated_iter)
                wandb.log({'meta_data/learning_rate': cur_lr},
                          step=int(accumulated_iter))
                if cfg.SELF_TRAIN.SRC.USE_DATA and cfg.SELF_TRAIN.SRC.get('USE_GRAD', None):
                    tb_log.add_scalar('train/loss', loss, accumulated_iter)
                    for key, val in tb_dict.items():
                        tb_log.add_scalar('train/' + key, val, accumulated_iter)
                        wandb.log({key: val}, step=int(accumulated_iter))

                if cfg.SELF_TRAIN.TAR.USE_DATA:
                    tb_log.add_scalar('train/st_loss', st_loss, accumulated_iter)
                    for key, val in st_tb_dict.items():
                        tb_log.add_scalar('train/' + key, val, accumulated_iter)
                        wandb.log({key: val}, step=int(accumulated_iter))


    if rank == 0:
        pbar.close()
        for i, class_names in enumerate(target_loader.dataset.class_names):
            tb_log.add_scalar(
                'ps_box/pos_%s' % class_names, ps_bbox_nmeter.meters[i].avg, cur_epoch)
            tb_log.add_scalar(
                'ps_box/ign_%s' % class_names, ign_ps_bbox_nmeter.meters[i].avg, cur_epoch)
            wandb.log({'ps_box/pos_%s' % class_names:
                           ps_bbox_nmeter.meters[i].avg})
            wandb.log({'ps_box/ign_%s' % class_names:
                           ign_ps_bbox_nmeter.meters[i].avg})    

    return accumulated_iter

def update_ema_variables(model, ema_model, model_cfg=None, cur_iter=0):
    assert model_cfg is not None

    multiplier = 1.0

    alpha = model_cfg['EMA_MODEL_ALPHA']
    alpha = 1 - multiplier*(1-alpha)
    if cfg.SELF_TRAIN.get('HSSDA_EMA', None):
        ema_keep_rate = 0.999
        change_global_step = 1000
        if cur_iter < change_global_step:
            keep_rate = (ema_keep_rate - 0.8) / change_global_step * cur_iter + 0.8
        else:
            keep_rate = ema_keep_rate
        alpha = keep_rate
        # Use the true average until the exponential average is more correct
        # alpha = min(1 - 1 / (cur_iter + 1), alpha)
        # for ema_param, param in zip(self.pv_rcnn_ema.parameters(), self.pv_rcnn.parameters()):
        #     ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        # ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

    model_named_buffers = model.module.named_buffers() if hasattr(model, 'module') else model.named_buffers()
    for emabf, bf in zip(ema_model.named_buffers(), model_named_buffers):
        emaname, emavalue = emabf
        name, value = bf
        assert emaname == name, 'name not equal:{} , {}'.format(emaname, name)
        # if 'running_mean' in name or 'running_var' in name:
        if cfg.SELF_TRAIN.get('EMA_BN', None):
            # alpha = 0.95
            emavalue.data = emavalue.data * alpha + value.data * (1 - alpha)
        elif cfg.SELF_TRAIN.get('EMA_COPY', None):
            emavalue.data = value.data


def train_model_st(model, optimizer, source_loader, target_loader, model_func, lr_scheduler, optim_cfg,
                   start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, ps_label_dir,
                   source_sampler=None, target_sampler=None, lr_warmup_scheduler=None, ckpt_save_interval=1,
                   max_ckpt_save_num=50, merge_all_iters_to_one_epoch=False, logger=None, ema_model=None, 
                   source_loader_detect=None, source_sampler_detect=None, source_model=None, dist=None, pretrained=None):
    accumulated_iter = start_iter
    source_reader = common_utils.DataReader(source_loader, source_sampler)
    source_reader.construct_iter()

    # for continue training.
    # if already exist generated pseudo label result
    ps_pkl = self_training_utils.check_already_exsit_pseudo_label(ps_label_dir, start_epoch)
    if ps_pkl is not None:
        logger.info('==> Loading pseudo labels from {}'.format(ps_pkl))

    # for continue training
    if cfg.SELF_TRAIN.get('PROG_AUG', None) and cfg.SELF_TRAIN.PROG_AUG.ENABLED and \
        start_epoch > 0:
        for cur_epoch in range(start_epoch):
            if cur_epoch in cfg.SELF_TRAIN.PROG_AUG.UPDATE_AUG:
                target_loader.dataset.data_augmentor.re_prepare(
                    augmentor_configs=None, intensity=cfg.SELF_TRAIN.PROG_AUG.SCALE)


    with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True,
                     leave=(rank == 0)) as tbar:
        total_it_each_epoch = len(target_loader)
        if merge_all_iters_to_one_epoch:
            assert hasattr(target_loader.dataset, 'merge_all_iters_to_one_epoch')
            target_loader.dataset.merge_all_iters_to_one_epoch(merge=True, epochs=total_epochs)
            total_it_each_epoch = len(target_loader) // max(total_epochs, 1)

        dataloader_iter = iter(target_loader)

        for cur_epoch in tbar:
            if target_sampler is not None:
                target_sampler.set_epoch(cur_epoch)
                source_reader.set_cur_epoch(cur_epoch)

            # train one epoch
            if lr_warmup_scheduler is not None and cur_epoch < optim_cfg.WARMUP_EPOCH:
                cur_scheduler = lr_warmup_scheduler
            else:
                cur_scheduler = lr_scheduler

            # update pseudo label
            if (cur_epoch in cfg.SELF_TRAIN.UPDATE_PSEUDO_LABEL) or \
                    ((cur_epoch % cfg.SELF_TRAIN.UPDATE_PSEUDO_LABEL_INTERVAL == 0)
                     and cur_epoch != 0):
                target_loader.dataset.eval()
                self_training_utils.save_pseudo_label_epoch(
                    ema_model if ema_model else model, target_loader, rank,
                    leave_pbar=True, ps_label_dir=ps_label_dir, cur_epoch=cur_epoch)
                
                if cfg.SELF_TRAIN.get('COMPLE_AUG', None):
                    import pickle as pkl
                    # Load PS Objects Labels
                    ps_box_list = glob.glob(os.path.join(cfg.SELF_TRAIN.PS_SAMPLING.PS_OBJECT_PATH, 'ps_box_e*.pkl'))
                    ps_pnt_list = glob.glob(os.path.join(cfg.SELF_TRAIN.PS_SAMPLING.PS_OBJECT_PATH,'ps_point_e*.pkl'))

                    ps_box_list.sort(key=os.path.getmtime, reverse=True)
                    with open(ps_box_list[0], 'rb') as f:
                        ps_boxes = pkl.load(f)
                    ps_pnt_list.sort(key=os.path.getmtime, reverse=True)
                    with open(ps_pnt_list[0], 'rb') as f:
                        ps_points = pkl.load(f)

                    global GLOBAL_PS_BOXES, GLOBAL_PS_POINTS
                    GLOBAL_PS_BOXES = ps_boxes
                    GLOBAL_PS_POINTS = ps_points
                
                if cfg.SELF_TRAIN.get('LOAD_OPTIMIZER_AFTER_PSEUDO_LABELING', None):
                    for g in optimizer.param_groups:
                        g['lr'] = cfg.OPTIMIZATION.LR
                    lr_scheduler, lr_warmup_scheduler = build_scheduler(
                        optimizer,
                        total_iters_each_epoch=total_it_each_epoch,
                        total_epochs=cfg.OPTIMIZATION.NUM_EPOCHS,
                        last_epoch=0, optim_cfg=cfg.OPTIMIZATION
                    )
                    if lr_warmup_scheduler is not None and \
                            cur_epoch < optim_cfg.WARMUP_EPOCH:
                        cur_scheduler = lr_warmup_scheduler
                    else:
                        cur_scheduler = lr_scheduler
                    new_accumulated_iter = 1

                if cfg.SELF_TRAIN.get('OBJECT', None) and cfg.SELF_TRAIN.OBJECT.get('rdu', None):
                    cfg.SELF_TRAIN.OBJECT.max = cfg.SELF_TRAIN.OBJECT.max - 1

            

                if cfg.SELF_TRAIN.get('PLS', None):
                    print('**********************Update Scale**********************')   
                    print('mean (std)|length     |width      |height')
                    print(cfg.DATA_CONFIG.DATA_AUGMENTOR.AUG_CONFIG_LIST[1]['SIZE_RES'])

                target_loader.dataset.train()

                if cfg.SELF_TRAIN.get('ADABN_EPOCH', None):
                    model.train()
                    if ema_model is not None:
                        ema_model.train()
                    if cfg.SELF_TRAIN.get('DSNORM', None):
                        model.apply(set_ds_target)
                    if ema_model is not None:
                        ema_model.apply(set_ds_target)
                    for cur_it in range(total_it_each_epoch):
                        try:
                            target_batch = next(dataloader_iter)
                        except StopIteration:
                            dataloader_iter = iter(target_loader)
                            target_batch = next(dataloader_iter)
                            print('new iters')
                        with torch.no_grad():
                            load_data_to_gpu(target_batch)
                            _ = model(target_batch)[0]
                    if ema_model is not None:
                        update_ema_variables(model, ema_model, model_cfg=ema_model.model_cfg)
            
            # curriculum data augmentation
            if cfg.SELF_TRAIN.get('PROG_AUG', None) and cfg.SELF_TRAIN.PROG_AUG.ENABLED and \
                (cur_epoch in cfg.SELF_TRAIN.PROG_AUG.UPDATE_AUG):
                target_loader.dataset.data_augmentor.re_prepare(
                    augmentor_configs=None, intensity=cfg.SELF_TRAIN.PROG_AUG.SCALE)

            accumulated_iter = train_one_epoch_st(
                model, optimizer, source_reader, target_loader, model_func,
                lr_scheduler=cur_scheduler,
                accumulated_iter=accumulated_iter, optim_cfg=optim_cfg,
                rank=rank, tbar=tbar, tb_log=tb_log,
                leave_pbar=(cur_epoch + 1 == total_epochs),
                total_it_each_epoch=total_it_each_epoch,
                dataloader_iter=dataloader_iter, ema_model=ema_model, cur_epoch=cur_epoch
            )

            commu_utils.synchronize()

            if cfg.SELF_TRAIN.get('TARGET_LEARNING', None):
                if dist.is_initialized():
                    part_pseudo_labels_list = commu_utils.all_gather(TARGET_PSEUDO_LABELS_TMP)

                    new_pseudo_label_dict = {}
                    for pseudo_labels in part_pseudo_labels_list:
                        new_pseudo_label_dict.update(pseudo_labels)

                    TARGET_PSEUDO_LABELS_TMP.update(new_pseudo_label_dict)
                
                commu_utils.synchronize()
                TARGET_PSEUDO_LABELS.clear()
                TARGET_PSEUDO_LABELS.update(TARGET_PSEUDO_LABELS_TMP)
                TARGET_PSEUDO_LABELS_TMP.clear()
                commu_utils.synchronize()

            # save trained model
            trained_epoch = cur_epoch + 1
            if trained_epoch % ckpt_save_interval == 0 and rank == 0:

                ckpt_list = glob.glob(str(ckpt_save_dir / 'checkpoint_epoch_*.pth'))
                ckpt_list.sort(key=os.path.getmtime)

                if ckpt_list.__len__() >= max_ckpt_save_num:
                    for cur_file_idx in range(0, len(ckpt_list) - max_ckpt_save_num + 1):
                        os.remove(ckpt_list[cur_file_idx])

                ckpt_name = ckpt_save_dir / ('checkpoint_epoch_%d' % trained_epoch)
                state = checkpoint_state(model, optimizer, trained_epoch, accumulated_iter)

                save_checkpoint(state, filename=ckpt_name)
