import torch
import torch.nn.functional as F
import os
import glob
import tqdm
from torch.nn.utils import clip_grad_norm_
import copy
import numpy as np
import wandb
import random
from multiprocessing import Manager

from pcdet.utils import common_utils, commu_utils
from pcdet.utils import self_training_utils, mix_domain_utils
from pcdet.config import cfg
from pcdet.models.model_utils.dsnorm import set_ds_source, set_ds_target, set_ds_mid
from pcdet.models import load_data_to_gpu
from .train_utils import save_checkpoint, checkpoint_state
from pcdet.datasets.augmentor.augmentor_utils import get_points_in_box
from pcdet.utils.box_utils import remove_points_in_boxes3d, enlarge_box3d, \
    boxes3d_kitti_lidar_to_fakelidar, boxes_to_corners_3d
from pcdet.ops.iou3d_nms import iou3d_nms_utils
from .optimization import build_optimizer, build_scheduler

TARGET_PSEUDO_LABELS = Manager().dict()
TARGET_PSEUDO_LABELS_TMP = {}
GLOBAL_PROTOTYPES = None
GLOBAL_PROTOTYPES_INITIALIZED = False

GLOBAL_PS_BOXES = None
GLOBAL_PS_POINTS = None

GLOBAL_BATCH_DICT_TEACHER = None

FEATURES_BANK = Manager().dict()
FEATURES_BANK_TMP_SRC = {}
FEATURES_BANK_TMP_TAR = {}

GLOBAL_SRC_BOXES = None
GLOBAL_SRC_POINTS = None
GLOBAL_SRC_BOX_SAMPLE_PROB = None

def train_one_epoch_st(model, optimizer, source_reader, target_loader, model_func, lr_scheduler,
                       accumulated_iter, optim_cfg, rank, tbar, total_it_each_epoch,
                       dataloader_iter, tb_log=None, leave_pbar=False, ema_model=None, cur_epoch=None):
    if total_it_each_epoch == len(target_loader):
        dataloader_iter = iter(target_loader)

    if rank == 0:
        pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True)

    ps_bbox_nmeter = common_utils.NAverageMeter(len(cfg.CLASS_NAMES))
    ign_ps_bbox_nmeter = common_utils.NAverageMeter(len(cfg.CLASS_NAMES))
    loss_meter = common_utils.AverageMeter()
    st_loss_meter = common_utils.AverageMeter()

    cls_loss_meter = common_utils.AverageMeter()
    reg_loss_meter = common_utils.AverageMeter()

    bda_loss_meter = common_utils.AverageMeter()

    mid_loss_meter = common_utils.AverageMeter()

    if cfg.SELF_TRAIN.get('ProtoAlign', None):
        proto_loss_meter = common_utils.AverageMeter()

    disp_dict = {}                                                                                                                                        

    aggregation_interval = cfg.SELF_TRAIN.get('AGGREGATION_INTERVAL', None)  # 配置聚合间隔
    last_model_state = model.state_dict() if aggregation_interval else None  # 用于存储上一次的模型状态

    for cur_it in range(total_it_each_epoch):
        lr_scheduler.step(accumulated_iter)

        try:
            cur_lr = float(optimizer.lr)
        except:
            cur_lr = optimizer.param_groups[0]['lr']

        if tb_log is not None:
            tb_log.add_scalar('meta_data/learning_rate', cur_lr, accumulated_iter)

        model.train()
        if ema_model is not None:
            ema_model.train()

        optimizer.zero_grad()

        if cfg.SELF_TRAIN.SRC.USE_DATA:
            source_batch = source_reader.read_data()
            source_batch['domain'] = 'source'
        
        if cfg.SELF_TRAIN.TAR.USE_DATA:
            try:
                target_batch = next(dataloader_iter)
            except StopIteration:
                dataloader_iter = iter(target_loader)
                target_batch = next(dataloader_iter)
                print('new iters')
        
        if cfg.SELF_TRAIN.get('MT', None):
            batch_target1, batch_target2 = target_batch # s,t

            # add tag for target domain
            # batch_target1['batch_type'] = 'target'
            # batch_target2['batch_type'] = 'target'
            
            ema_model.train()
            if cfg.SELF_TRAIN.get('DSNORM', None):
                ema_model.apply(set_ds_target)
                model.apply(set_ds_target)
            
            if cfg.SELF_TRAIN.get('MT_CONSISTENCY', None) or cfg.SELF_TRAIN.get('DTS', None):
                # forward teacher model first
                # from pcdet.models import load_data_to_gpu
                load_data_to_gpu(batch_target2)
                batch_target2 = ema_model(batch_target2, return_batch_dict=True)
                batch_dict_teacher = {}

                for key in ['rois_mt', 'roi_head_features_mt', 'roi_scores_mt', 'roi_cls_preds_mt']:
                    batch_dict_teacher[key] = batch_target2[key].detach().clone()
                
                batch_dict_teacher['iou_scores_mt'] = batch_dict_teacher['roi_scores_mt']
                batch_dict_teacher['cls_scores_mt'] = batch_dict_teacher['roi_scores_mt']

                global GLOBAL_BATCH_DICT_TEACHER
                GLOBAL_BATCH_DICT_TEACHER = batch_dict_teacher

            target_batch = batch_target1

            '''batch ADABN stage'''
            if cfg.SELF_TRAIN.get('ADABN', None):
                target_batch_raw = copy.deepcopy(target_batch)
                load_data_to_gpu(target_batch_raw)
                if ema_model is not None:
                    ema_model.train()
                    if cfg.SELF_TRAIN.get('DSNORM', None):
                        ema_model.apply(set_ds_target)
                    with torch.no_grad():
                        batch_dict_teacher = ema_model(target_batch_raw, return_batch_dict=True)
                else:
                    model.train()
                    if cfg.SELF_TRAIN.get('DSNORM', None):
                        model.apply(set_ds_target)
                    with torch.no_grad():
                        batch_dict_teacher = model(target_batch_raw, return_batch_dict=True)
                        
        target_batch['domain'] = 'target'


        # if cfg.SELF_TRAIN.get('COMPLE_AUG_SRC', None):
        # if cfg.SELF_TRAIN.get('SRC_COMPLE_AUG', None) and np.random.rand() < cfg.SELF_TRAIN.SRC_COMPLE_AUG.get('SRC_PROB', 0.5):
        #     # global GLOBAL_SRC_BOXES, GLOBAL_SRC_POINTS
        #     global GLOBAL_SRC_BOXES, GLOBAL_SRC_POINTS, GLOBAL_SRC_BOX_SAMPLE_PROB
        #     src_boxes = GLOBAL_SRC_BOXES
        #     src_points = GLOBAL_SRC_POINTS
        #     old_gt_boxes = src_boxes[sampled_classes[k]][np.random.randint(len(src_boxes[sampled_classes[k]]))]
        #     old_gt_boxes = np.array(old_gt_boxes)
        #     cur_obj_points = src_points[sampled_classes[k]][np.random.randint(len(src_points[sampled_classes[k]]))]
        #     src_boxes_sample_prob = GLOBAL_SRC_BOX_SAMPLE_PROB
        #     cur_class = sampled_classes[k]
        #     # 根据 box 采样概率采样 box
        #     box_idx = np.random.choice(len(src_boxes[cur_class]), p=src_boxes_sample_prob[cur_class])
        #     old_gt_boxes = src_boxes[cur_class][box_idx]
        #     cur_obj_points = src_points[cur_class][box_idx]
        #     min_num = np.min(src_boxes_sample_prob[cur_class]) * 2

        #     if src_boxes_sample_prob[cur_class][box_idx] > min_num:
        #         # 被采样到的采样概率降低
        #         src_boxes_sample_prob[cur_class][box_idx] = src_boxes_sample_prob[cur_class][box_idx] / 2
        #         src_boxes_sample_prob[cur_class] = src_boxes_sample_prob[cur_class] / np.sum(src_boxes_sample_prob[cur_class])
        #     GLOBAL_SRC_BOX_SAMPLE_PROB = src_boxes_sample_prob

            # old_gt_boxes = src_boxes[sampled_classes[k]][np.random.randint(len(src_boxes[sampled_classes[k]]))]
            # old_gt_boxes = np.array(old_gt_boxes)
            # cur_obj_points = src_points[sampled_classes[k]][np.random.randint(len(src_points[sampled_classes[k]]))]            
            # pass

        '''训练source数据，计算loss，更新梯度'''
        if cfg.SELF_TRAIN.SRC.USE_DATA and cfg.SELF_TRAIN.SRC.get('USE_GRAD', None):
            if cfg.SELF_TRAIN.get('DSNORM', None):
                model.apply(set_ds_source)
            model.train()
            loss, tb_dict, disp_dict, batch_dict_source = model_func(model, source_batch)
            loss = cfg.SELF_TRAIN.SRC.get('LOSS_WEIGHT', 1.0) * loss
            loss_src = loss
            # loss.backward(retain_graph=True)
            loss.backward()
            
            if cfg.SELF_TRAIN.get('ProtoAlign', None):
                loss_name_list = ['loss_proto', 'loss_feat', 'loss_feat_intra', 'loss_feat_inter', 'loss_align', 'loss_align_intra', 'loss_align_inter', 'loss_bg']
                for key in tb_dict.keys():
                    if key in loss_name_list:
                        if key == 'loss_proto':
                            lamda_proto = cfg.SELF_TRAIN.ProtoAlign.get('lambda_proto', 1.0)
                            loss -= lamda_proto * tb_dict[key]
                        disp_dict.update({key: "{:.3f}".format(tb_dict[key])})
            loss_meter.update(loss.item())
            disp_dict.update({'loss': "{:.3f}({:.3f})".format(loss_meter.val, loss_meter.avg)})
                

            # print('source raw start')

        if not cfg.SELF_TRAIN.SRC.get('USE_GRAD', None):
            optimizer.zero_grad()
        
        if cfg.SELF_TRAIN.get('COMPLE_AUG', None):
            data_dict_list = []
            batch_size = target_batch['gt_boxes'].shape[0]
            def process_batch_data(batch, idx):
                try:
                    gt_box = batch['gt_boxes'][idx].cpu().numpy()
                except:
                    gt_box = batch['gt_boxes'][idx]
                gt_box = gt_box[~np.all(gt_box == 0, axis=1)]  # 去除全为 0 的行

                try:
                    points = batch['points'][batch['points'][:, 0] == idx][:, 1:].cpu().numpy()
                except:
                    points = batch['points'][batch['points'][:, 0] == idx][:, 1:]

                try:
                    unreliable_gt_boxes = batch['unreliable_gt_boxes'][idx].cpu().numpy()
                except:
                    unreliable_gt_boxes = batch['unreliable_gt_boxes'][idx]

                unreliable_gt_boxes = unreliable_gt_boxes[~np.all(unreliable_gt_boxes == 0, axis=1)]  # 去除全为 0 的行

                return gt_box, points, unreliable_gt_boxes
            
            def add_sampled_ps_boxes_to_scene(data_dict, sampled_ps_boxes,
                                            sampled_ps_pnts, sampled_classes):

                gt_boxes = data_dict['gt_boxes']
                # gt_names = data_dict['gt_names']
                points = data_dict['points']
                
                obj_points = sampled_ps_pnts

                new_sampled_ps_pnts = []
                for k in range(sampled_ps_boxes.shape[0]):

                    # 缩放 ps box
                    old_gt_boxes = sampled_ps_boxes[k]
                    cur_obj_points = obj_points[k]
                    if cfg.SELF_TRAIN.get('SRC_COMPLE_AUG', None) and np.random.rand() < cfg.SELF_TRAIN.SRC_COMPLE_AUG.get('SRC_PROB', 0.5):
                        global GLOBAL_SRC_BOXES, GLOBAL_SRC_POINTS, GLOBAL_SRC_BOX_SAMPLE_PROB
                        src_boxes = GLOBAL_SRC_BOXES
                        src_points = GLOBAL_SRC_POINTS
                        src_boxes_sample_prob = GLOBAL_SRC_BOX_SAMPLE_PROB
                        cur_class = sampled_classes[k]
                        if len(src_boxes[cur_class]) > 0:
                            # 根据 box 采样概率采样 box
                            box_idx = np.random.choice(len(src_boxes[cur_class]), p=src_boxes_sample_prob[cur_class])
                            old_gt_boxes = src_boxes[cur_class][box_idx]
                            cur_obj_points = src_points[cur_class][box_idx]
                            min_num = np.min(src_boxes_sample_prob[cur_class]) * 2

                            if src_boxes_sample_prob[cur_class][box_idx] > min_num:
                                # 被采样到的采样概率降低
                                src_boxes_sample_prob[cur_class][box_idx] = src_boxes_sample_prob[cur_class][box_idx] / 2
                                src_boxes_sample_prob[cur_class] = src_boxes_sample_prob[cur_class] / np.sum(src_boxes_sample_prob[cur_class])
                            GLOBAL_SRC_BOX_SAMPLE_PROB = src_boxes_sample_prob

                            # old_gt_boxes = src_boxes[sampled_classes[k]][np.random.randint(len(src_boxes[sampled_classes[k]]))]
                            # old_gt_boxes = np.array(old_gt_boxes)
                            # cur_obj_points = src_points[sampled_classes[k]][np.random.randint(len(src_points[sampled_classes[k]]))]
                        

                    cur_obj_points = np.array(cur_obj_points)
                    new_gt_boxes = gt_boxes
                    scale_lwh = new_gt_boxes[k, 3:6] / old_gt_boxes[3:6]



                    obj_center, lwh, ry = old_gt_boxes[0:3], old_gt_boxes[3:6], old_gt_boxes[6]

                    new_obj_center, new_lwh, new_ry = new_gt_boxes[k, 0:3], new_gt_boxes[k, 3:6], new_gt_boxes[k, 6]

                    # relative coordinates
                    cur_obj_points -= obj_center
                    cur_obj_points = common_utils.rotate_points_along_z(np.expand_dims(cur_obj_points, axis=0), -ry).squeeze(0)
                    cur_obj_points *= scale_lwh
                    cur_obj_points = common_utils.rotate_points_along_z(np.expand_dims(cur_obj_points, axis=0), new_ry).squeeze(0)

                    # calculate new object center to avoid object float over the road
                    # new_lwh = lwh * scale_lwh
                    new_obj_center[2] += (new_lwh[2] - lwh[2]) / 2
                    cur_obj_points += new_obj_center
                    # points[point_masks > 0] = obj_points
                    # gt_boxes[k, 3:6] = new_lwh
                    # object_scale_noise[k] = scale_noises[k][try_idx]
                    new_sampled_ps_pnts.append(cur_obj_points)
                
                # ps_pnts = new_sampled_ps_pnts


                ps_pnts = np.concatenate(new_sampled_ps_pnts, axis=0)
                extra_dim = points.shape[-1] - 3
                zero_dim = np.zeros((ps_pnts.shape[0], extra_dim)).reshape(
                    ps_pnts.shape[0], extra_dim)
                ps_pnts = np.concatenate((ps_pnts, zero_dim), axis=1)

                # points = box_utils.remove_points_in_boxes3d(points, sampled_ps_boxes)

                points = np.concatenate([ps_pnts[:, :points.shape[-1]], points], axis=0)

                return points

                # 提取 target 的点云
                # target_sobj_points = [get_points_in_box(target_points, box[:7])[0] for box in t_gt_box]

            import pickle as pkl

            global GLOBAL_PS_BOXES, GLOBAL_PS_POINTS
            ps_boxes = GLOBAL_PS_BOXES
            ps_points = GLOBAL_PS_POINTS

            if cfg.SELF_TRAIN.get('SRC_COMPLE_AUG', None):
                source_boxes_total = {}
                source_points_total = {}
                source_boxes_sample_prob_total = {}
                for cls in cfg.CLASS_NAMES:
                    source_boxes_total[cls] = []
                    source_points_total[cls] = []
                    source_boxes_sample_prob_total[cls] = []
                source_batch_size = source_batch['gt_boxes'].shape[0]
                for frame_idx in range(source_batch_size):
                    try:
                        s_gt_box = source_batch['gt_boxes'][frame_idx].cpu().numpy()
                    except:
                        s_gt_box = source_batch['gt_boxes'][frame_idx]
                    s_gt_box = s_gt_box[~np.all(s_gt_box == 0, axis=1)]  # 去除全为 0 的行
                    try:
                        source_points = source_batch['points'][source_batch['points'][:, 0] == frame_idx][:, 1:].cpu().numpy()
                    except:
                        source_points = source_batch['points'][source_batch['points'][:, 0] == frame_idx][:, 1:]
                    for box in s_gt_box:
                        points_in_box, mask = get_points_in_box(source_points, box[:7])
                        cls = cfg.CLASS_NAMES[int(box[-1]) - 1]
                        source_boxes_total[cls].append(box)
                        source_points_total[cls].append(points_in_box)
                        source_boxes_sample_prob_total[cls].append(len(points_in_box))
                
                # source_boxes_sample_prob_total 为每个类别的点云数量的反比例
                for cls in cfg.CLASS_NAMES:
                    source_boxes_sample_prob_total[cls] = np.array(source_boxes_sample_prob_total[cls], dtype=np.float64)
                    source_boxes_sample_prob_total[cls] = 1 / (source_boxes_sample_prob_total[cls] + 1e-6)
                    source_boxes_sample_prob_total[cls] = source_boxes_sample_prob_total[cls] / np.sum(source_boxes_sample_prob_total[cls])
                    

                global GLOBAL_SRC_BOXES, GLOBAL_SRC_POINTS, GLOBAL_SRC_BOX_SAMPLE_PROB
                GLOBAL_SRC_BOXES = source_boxes_total
                GLOBAL_SRC_POINTS = source_points_total
                GLOBAL_SRC_BOX_SAMPLE_PROB = source_boxes_sample_prob_total
                

            for frame_idx in range(batch_size):
                mix_data = {}


                
                t_gt_box, target_points, unreliable_gt_boxes = process_batch_data(target_batch, frame_idx)
                T_pos = np.array(cfg.SELF_TRAIN.SCORE_THRESH)
                T_neg = np.array(cfg.SELF_TRAIN.NEG_THRESH)

                t_gt_box_unreliable = unreliable_gt_boxes[:, :7]
                t_gt_labels_unreliable = np.abs(unreliable_gt_boxes[:, -2]).astype(np.int32)
                t_gt_scores_unreliable = unreliable_gt_boxes[:, -1]

                # 移除点云中的目标框点
                target_points = remove_points_in_boxes3d(target_points, enlarge_box3d(t_gt_box_unreliable, extra_width=[1, 0.5, 0.5]))


                comple_box = []
                comple_class = []

                if cfg.SELF_TRAIN.COMPLE_AUG != True and cfg.SELF_TRAIN.COMPLE_AUG.get('Balance', None):
                    # 保持各类别数量平衡
                    num_classes = len(cfg.CLASS_NAMES)
                    num_per_class = {}
                    for i in range(num_classes):
                        num_per_class[str(i + 1)] = np.sum(t_gt_labels_unreliable == i + 1)

                    min_num = min(num_per_class.values())
                    num_per_class = {k: min_num for k, v in num_per_class.items()}
                    # 按照score进行排序，取前面的一部分
                    t_gt_scores_unreliable = np.array(t_gt_scores_unreliable)
                    t_gt_labels_unreliable = np.array(t_gt_labels_unreliable)
                    sort_idx = np.argsort(t_gt_scores_unreliable)[::-1]
                    t_gt_box_unreliable = t_gt_box_unreliable[sort_idx]
                    t_gt_labels_unreliable = t_gt_labels_unreliable[sort_idx]
                    t_gt_scores_unreliable = t_gt_scores_unreliable[sort_idx]
                    for i in range(num_classes):
                        idx = np.where(t_gt_labels_unreliable == i + 1)[0][:min_num]
                        for cur_box, cur_class in zip(t_gt_box_unreliable[idx], t_gt_labels_unreliable[idx]):
                            comple_box.append(cur_box)
                            comple_class.append(cur_class)


                else:
                    for box, box_score, box_label in zip(t_gt_box_unreliable, t_gt_scores_unreliable, t_gt_labels_unreliable):
                        prob_replace = (box_score - T_neg[box_label - 1]) / (T_pos[box_label - 1] - T_neg[box_label - 1])

                        if np.random.rand() < prob_replace:
                            comple_box.append(box)
                            comple_class.append(box_label)
                


                comple_box = np.array(comple_box)
                if len(comple_box) < 1:
                    # data_dict_list.append(target_batch)
                    data_dict = {'points': target_points, 'frame_id': target_batch['frame_id'][frame_idx], 'gt_boxes': t_gt_box}
                    
                    data_dict = target_loader.dataset.prepare_data_easy(data_dict=data_dict)
                    # data_dict = target_loader.dataset.prepare_data_easy(data_dict=data_dict)
                    data_dict_list.append(data_dict)
                    continue
                comple_class = np.array(comple_class)
                comple_box = comple_box[~np.all(comple_box == 0, axis=1)]  # 去除全为 0 的行
                comple_box = np.concatenate([comple_box, comple_class.reshape(-1, 1)], axis=1)

                # 把 comple_box 里的 box 按照最后一项（类别）进行排序
                comple_box = comple_box[comple_box[:, -1].argsort()]

                t_gt_box = np.concatenate([t_gt_box, comple_box], axis=0)
                # t_gt_box = np.concatenate([t_gt_box, comple_box], axis=0)

                class_names = cfg.CLASS_NAMES
                # 统计 comple_class 里面的各个类别数量
                comple_class = np.array(comple_class)
                comple_class_dict = {class_names[i]: np.sum(comple_class == (i + 1)) for i in range(len(class_names))}

                sample_groups = {}
                sample_class_num = {}
                for class_name, sample_num in comple_class_dict.items():
                    if class_name not in cfg.CLASS_NAMES:
                        continue
                    sample_class_num[class_name] = sample_num
                    sample_groups[class_name] = {
                        'sample_num': sample_num
                    }

                tmp_data_dict = {}
                tmp_data_dict['gt_boxes'] = comple_box
                # tmp_data_dict['gt_names'] = t_gt_names_unreliable.astype(str)
                tmp_data_dict['points'] = target_points
                # tmp_data_dict['gt_scores'] = t_gt_scores_unreliable

                # gt_boxes = t_gt_box
                # gt_names = data_dict['gt_names'].astype(str)
                # existed_boxes = gt_boxes
                sampled_class, total_sampled_pnts = [], []
                existed_boxes = None

                # sample_groups = cfg.SELF_TRAIN.PS_SAMPLING.SAMPLE_GROUPS
                for class_name, sample_group in sample_groups.items():

                    if int(sample_group['sample_num']) > 0:
                        sample_idx = np.random.randint(len(ps_boxes[class_name]), size=int(sample_group['sample_num']))
                        sampled_boxes = np.stack([ps_boxes[class_name][i] for i in sample_idx]).astype(np.float32)[:,:7] # (ps box num, 7)
                        sampled_pnts = [ps_points[class_name][i] for i in sample_idx]

                        valid_mask = np.ones(sampled_boxes.shape[0], dtype=bool)
                        valid_sampled_boxes = sampled_boxes[valid_mask]
                        existed_boxes = np.concatenate((existed_boxes, valid_sampled_boxes), axis=0) if existed_boxes is not None else valid_sampled_boxes

                        # sampled classes
                        sampled_class.extend([class_name for i in range(len(valid_mask))])

                        # sampled points
                        for mask_idx in range(valid_mask.shape[0]):
                            total_sampled_pnts.append(sampled_pnts[mask_idx])

                sampled_gt_boxes = existed_boxes
                if sampled_gt_boxes.shape[0] > 0:
                    points = add_sampled_ps_boxes_to_scene(tmp_data_dict,sampled_gt_boxes, total_sampled_pnts, sampled_class)
                
                # target_points = np.concatenate([points, target_points], axis=0)
                target_points = points
                data_dict = {'points': target_points, 'frame_id': target_batch['frame_id'][frame_idx], 'gt_boxes': t_gt_box}
                    
                data_dict = target_loader.dataset.prepare_data_easy(data_dict=data_dict)
                # data_dict = target_loader.dataset.prepare_data_easy(data_dict=data_dict)
                data_dict_list.append(data_dict)
                # return data_dict
            data_dict = target_loader.dataset.collate_batch(data_dict_list)
            for key in target_batch.keys():
                if key not in data_dict.keys():
                    data_dict[key] = target_batch[key]
            target_batch = data_dict

        if cfg.SELF_TRAIN.get('MID', None) and source_batch['gt_boxes'].shape[0] == target_batch['gt_boxes'].shape[0]:
            mix_method = cfg.SELF_TRAIN.MID.get('MIX_METHOD', 'laser')
            if mix_method == 'laser':
                mid_batch = mix_domain_utils.laser_mix_domain(source_batch, target_batch, target_loader)
            elif mix_method == 'cut':
                mid_batch = mix_domain_utils.cut_mix_domain(source_batch, target_batch, target_loader)
            elif mix_method == 'polar':
                mid_batch = mix_domain_utils.polar_mix_domain(source_batch, target_batch, target_loader)

            if cfg.SELF_TRAIN.get('DSNORM', None):
                model.apply(set_ds_mid)
            model.train()
            mid_loss, mid_tb_dict, mid_disp_dict, batch_dict_mid = model_func(model, mid_batch)
            mid_loss.backward(retain_graph=True)
            mid_loss_meter.update(mid_loss.item())
            disp_dict.update({'mid_loss': "{:.3f}({:.3f})".format(mid_loss_meter.val, mid_loss_meter.avg)})
            
            
            
            # loss_mid, tb_dict, disp_dict, batch_dict_source = model_func(model, mid_batch)
            # loss_mid = cfg.SELF_TRAIN.MID.get('LOSS_WEIGHT', 1.0) * loss_mid
            # loss_src = loss
            # loss.backward(retain_graph=True)
            # loss_meter.update(loss.item())
            # disp_dict.update({'loss': "{:.3f}({:.3f})".format(loss_meter.val, loss_meter.avg)})
            # print('source raw start')

            # mid_tb_dict = {}
            # mid_tb_dict.update({'mid_loss': mid_loss_meter.val})
            # mid_disp_dict = {}
            # mid_disp_dict.update(mid_loss_i)

        if cfg.SELF_TRAIN.TAR.USE_DATA:
           
        
            if cfg.SELF_TRAIN.get('DSNORM', None):
                if ema_model is not None:
                    ema_model.apply(set_ds_target)
                model.apply(set_ds_target)
                model.train()
            
             # parameters for save pseudo label on the fly
            if cfg.SELF_TRAIN.get('MT_CONSISTENCY', None):
                st_loss, st_tb_dict, st_disp_dict, batch_dict_target = model_func(model, target_batch, batch_dict_teacher)
            # elif cfg.SELF_TRAIN.get('TARGET_LEARNING', None) or cfg.SELF_TRAIN.get('BDA', None):
            #     st_loss, st_tb_dict, st_disp_dict, target_batch_dict = model_func(model, target_batch, batch_dict_teacher=pseudo_label_target if cfg.SELF_TRAIN.get('TARGET_LEARNING', None) or cfg.SELF_TRAIN.get('BDA', None) else None)
            elif cfg.SELF_TRAIN.get('DTS', None):
                batch_dict_teacher = GLOBAL_BATCH_DICT_TEACHER
                st_loss, st_tb_dict, st_disp_dict, batch_dict_target = model_func(model, target_batch, batch_dict_teacher=batch_dict_teacher)
            else:
                st_loss, st_tb_dict, st_disp_dict, batch_dict_target = model_func(model, target_batch)                      

            st_loss = cfg.SELF_TRAIN.TAR.get('LOSS_WEIGHT', 1.0) * st_loss
            if cfg.SELF_TRAIN.get('ProtoAlign', None):
                loss_name_list = ['loss_proto', 'loss_feat', 'loss_feat_intra', 'loss_feat_inter', 'loss_align', 'loss_align_intra', 'loss_align_inter', 'loss_bg']
                for key in st_tb_dict.keys():
                    if key in loss_name_list:
                        if key == 'loss_proto':
                            lamda_proto = cfg.SELF_TRAIN.ProtoAlign.get('lambda_proto', 1.0)
                            st_loss -= lamda_proto * st_tb_dict[key]
                        st_disp_dict.update({key: "{:.3f}".format(st_tb_dict[key])})
            st_loss_meter.update(st_loss.item())


            if cfg.SELF_TRAIN.get('Triplet', None):
                # 提取特征和标签
                source_features = batch_dict_source['roi_head_features']  # (B, N_s, D)
                source_labels = batch_dict_source['roi_labels']           # (B, N_s)
                target_features = batch_dict_target['roi_head_features']  # (B, N_t, D)
                target_labels = batch_dict_target['roi_labels']           # (B, N_t)

                batch_size, num_target_rois, feat_dim = target_features.shape
                margin = 1.0
                triplet_loss = 0.0
                valid_count = 0

                lambda_loss_triplet = cfg.SELF_TRAIN.Triplet.get('lambda_loss_triplet', 0.1)
                target_lambda_loss_inter = cfg.SELF_TRAIN.Triplet.get('target_lambda_loss_inter', 1.0)
                target_lambda_loss_intra = cfg.SELF_TRAIN.Triplet.get('target_lambda_loss_intra', 1.0)
                source_lambda_loss_inter = cfg.SELF_TRAIN.Triplet.get('source_lambda_loss_inter', 1.0)
                source_lambda_loss_intra = cfg.SELF_TRAIN.Triplet.get('source_lambda_loss_intra', 1.0)

                if source_features.shape[0] == target_features.shape[0]:
                    for b in range(batch_size):
                        src_feat_b = source_features[b]  # (N_s, D)
                        src_label_b = source_labels[b]    # (N_s,)
                        tgt_feat_b = target_features[b]   # (N_t, D)
                        tgt_label_b = target_labels[b]    # (N_t,)

                        num_target = tgt_feat_b.shape[0]
                        num_source = src_feat_b.shape[0]

                        # 处理目标域锚点
                        for i in range(num_target):
                            anchor_feat = tgt_feat_b[i]
                            anchor_label = tgt_label_b[i].item()

                            # 跨域三元组 (目标锚点, 源正样本, 目标负样本)
                            src_same_mask = (src_label_b == anchor_label).clone()
                            if src_same_mask.sum() > 0:
                                src_pos = src_feat_b[src_same_mask][torch.argmax(torch.norm(src_feat_b[src_same_mask] - anchor_feat, dim=1))]
                                tgt_diff_mask = (tgt_label_b != anchor_label)
                                if tgt_diff_mask.sum() > 0:
                                    tgt_neg = tgt_feat_b[tgt_diff_mask][torch.argmin(torch.norm(tgt_feat_b[tgt_diff_mask] - anchor_feat, dim=1))]
                                    loss = F.relu(torch.norm(anchor_feat - src_pos) - torch.norm(anchor_feat - tgt_neg) + margin)
                                    triplet_loss += target_lambda_loss_inter * loss
                                    valid_count += 1

                            # 目标域内三元组 (目标锚点, 目标正样本, 目标负样本)
                            tgt_same_mask = (tgt_label_b == anchor_label).clone()
                            tgt_same_mask[i] = False
                            if tgt_same_mask.sum() > 0:
                                tgt_pos = tgt_feat_b[tgt_same_mask][torch.argmax(torch.norm(tgt_feat_b[tgt_same_mask] - anchor_feat, dim=1))]
                                tgt_diff_mask = (tgt_label_b != anchor_label)
                                if tgt_diff_mask.sum() > 0:
                                    tgt_neg = tgt_feat_b[tgt_diff_mask][torch.argmin(torch.norm(tgt_feat_b[tgt_diff_mask] - anchor_feat, dim=1))]
                                    loss = F.relu(torch.norm(anchor_feat - tgt_pos) - torch.norm(anchor_feat - tgt_neg) + margin)
                                    triplet_loss += target_lambda_loss_intra * loss
                                    valid_count += 1

                        # 处理源域锚点
                        for i in range(num_source):
                            anchor_feat = src_feat_b[i]
                            anchor_label = src_label_b[i].item()

                            # 源域内三元组 (源锚点, 源正样本, 源负样本)
                            src_same_mask = (src_label_b == anchor_label).clone()
                            src_same_mask[i] = False
                            if src_same_mask.sum() > 0:
                                src_pos = src_feat_b[src_same_mask][torch.argmax(torch.norm(src_feat_b[src_same_mask] - anchor_feat, dim=1))]
                                src_diff_mask = (src_label_b != anchor_label)
                                if src_diff_mask.sum() > 0:
                                    src_neg = src_feat_b[src_diff_mask][torch.argmin(torch.norm(src_feat_b[src_diff_mask] - anchor_feat, dim=1))]
                                    loss = F.relu(torch.norm(anchor_feat - src_pos) - torch.norm(anchor_feat - src_neg) + margin)
                                    triplet_loss += source_lambda_loss_intra * loss
                                    valid_count += 1

                            # 跨域三元组 (源锚点, 目标正样本, 源负样本)
                            tgt_same_mask = (tgt_label_b == anchor_label).clone()
                            if tgt_same_mask.sum() > 0:
                                tgt_pos = tgt_feat_b[tgt_same_mask][torch.argmax(torch.norm(tgt_feat_b[tgt_same_mask] - anchor_feat, dim=1))]
                                src_diff_mask = (src_label_b != anchor_label)
                                if src_diff_mask.sum() > 0:
                                    src_neg = src_feat_b[src_diff_mask][torch.argmin(torch.norm(src_feat_b[src_diff_mask] - anchor_feat, dim=1))]
                                    loss = F.relu(torch.norm(anchor_feat - tgt_pos) - torch.norm(anchor_feat - src_neg) + margin)
                                    triplet_loss += source_lambda_loss_inter * loss
                                    valid_count += 1


                    # 归一化损失
                    if valid_count > 0:
                        triplet_loss /= valid_count
                    else:
                        triplet_loss = 0.0

                    st_loss += lambda_loss_triplet * triplet_loss

            # '''训练source数据，计算loss，更新梯度'''
            # if (cfg.SELF_TRAIN.SRC.USE_DATA and cfg.SELF_TRAIN.SRC.get('USE_GRAD', None)):
            #     st_loss = st_loss + loss_src

            # if cfg.SELF_TRAIN.get('MID', None):
            #     st_loss = st_loss + mid_loss

            torch.autograd.set_detect_anomaly(True)

            st_loss.backward()

            if 'pos_ps_bbox' in target_batch.keys():
                # count number of used ps bboxes in this batch
                pos_pseudo_bbox = target_batch['pos_ps_bbox'].mean(dim=0).cpu().numpy()
                ign_pseudo_bbox = target_batch['ign_ps_bbox'].mean(dim=0).cpu().numpy()
            else:
                pos_pseudo_bbox = torch.zeros(3)
                ign_pseudo_bbox = torch.zeros(3)
            ps_bbox_nmeter.update(pos_pseudo_bbox.tolist())
            ign_ps_bbox_nmeter.update(ign_pseudo_bbox.tolist())
            pos_ps_result = ps_bbox_nmeter.aggregate_result()
            ign_ps_result = ign_ps_bbox_nmeter.aggregate_result()

            st_tb_dict = common_utils.add_prefix_to_dict(st_tb_dict, 'st_')
            disp_dict.update(common_utils.add_prefix_to_dict(st_disp_dict, 'st_'))
            disp_dict.update({'st_loss': "{:.3f}({:.3f})".format(st_loss_meter.val, st_loss_meter.avg),
                              'pos_ps_box': pos_ps_result,
                              'ign_ps_box': ign_ps_result})

            if cfg.SELF_TRAIN.get('MID', None):
                mid_tb_dict = common_utils.add_prefix_to_dict(mid_tb_dict, 'mid_')
                disp_dict.update(common_utils.add_prefix_to_dict(mid_disp_dict, 'mid_'))
                disp_dict.update({'mid_loss': "{:.3f}({:.3f})".format(mid_loss_meter.val, mid_loss_meter.avg)})

            if cfg.SELF_TRAIN.get('TARGET_LEARNING', None):
                cls_tb_dict = common_utils.add_prefix_to_dict(cls_tb_dict, 'cls_')
                disp_dict.update(common_utils.add_prefix_to_dict(cls_disp_dict, 'cls_'))
                disp_dict.update({'cls_loss': "{:.3f}({:.3f})".format(cls_loss_meter.val, cls_loss_meter.avg)})

                reg_tb_dict = common_utils.add_prefix_to_dict(reg_tb_dict, 'reg_')
                disp_dict.update(common_utils.add_prefix_to_dict(reg_disp_dict, 'reg_'))
                disp_dict.update({'reg_loss': "{:.3f}({:.3f})".format(reg_loss_meter.val, reg_loss_meter.avg)})


            if cfg.SELF_TRAIN.get('BDA', None):
                bda_tb_dict = common_utils.add_prefix_to_dict(st_tb_dict, 'bda_')
                disp_dict.update(common_utils.add_prefix_to_dict(st_disp_dict, 'bda_'))
                disp_dict.update({'bda_loss': "{:.3f}({:.3f})".format(bda_loss_meter.val, bda_loss_meter.avg)})
            
        clip_grad_norm_(model.parameters(), optim_cfg.GRAD_NORM_CLIP)
        # torch.cuda.synchronize()  # 等待所有GPU完成反向传播
        optimizer.step()
        accumulated_iter += 1


        # 检查是否需要执行模型聚合
        if aggregation_interval and (accumulated_iter % aggregation_interval == 0):
            if last_model_state is not None:
                ema_alpha = cfg.SELF_TRAIN.WEAK_MODEL_AGGREGATION.EMA_ALPHA
                current_model_state = model.state_dict()
                for key in last_model_state.keys():
                    last_model_state[key] = (
                        ema_alpha * last_model_state[key]
                        + (1 - ema_alpha) * current_model_state[key]
                    )
                model.load_state_dict(last_model_state)

        if ema_model is not None:
            update_ema_variables(model, ema_model, model_cfg=ema_model.model_cfg, cur_iter=accumulated_iter)
            # ema_rampup_start, ema_start = cfg.SELF_TRAIN.EMA_LEARNING.EMA_EPOCH
            # assert ema_rampup_start <= ema_start
            # epoch_id = cur_epoch
            # if epoch_id < ema_rampup_start:
            #     pass
            # elif (epoch_id >= ema_rampup_start) and (epoch_id < ema_start):
            #     # if accumulated_iter % ssl_cfg.TEACHER.NUM_ITERS_PER_UPDATE == 0:
            #         # if dist:
            #             #if rank == 0:
            #     update_ema_variables(student_model.module.onepass, teacher_model.module.onepass, ssl_cfg.TEACHER.RAMPUP_EMA_MOMENTUM, accumulated_iter)
            #         # else:
            #             # update_ema_variables(student_model, teacher_model, ssl_cfg.TEACHER.RAMPUP_EMA_MOMENTUM, accumulated_iter)
            # elif epoch_id >= ema_start:
            #     if accumulated_iter % ssl_cfg.TEACHER.NUM_ITERS_PER_UPDATE == 0:
            #         if dist:
            #             #if rank == 0:
            #             update_ema_variables_with_fixed_momentum(student_model.module.onepass, teacher_model.module.onepass, ssl_cfg.TEACHER.EMA_MOMENTUM)
            #         else:
            #             update_ema_variables_with_fixed_momentum(student_model, teacher_model, ssl_cfg.TEACHER.EMA_MOMENTUM)
            # else:
            #     raise Exception('Impossible condition for EMA update')

        if cfg.SELF_TRAIN.get('ProtoAlign', None) and cfg.SELF_TRAIN.ProtoAlign.get('use_gmm', None):
            import torch.distributed as dist
            # 获取原型管理器
            if dist.is_initialized():
                prototype_manager = model.module.prototype_manager
            else:
                prototype_manager = model.prototype_manager

            # 遍历所有类别进行参数同步
            for cls_str in prototype_manager.gmm_prototypes.keys():
                # 同步均值参数
                mean_tensor = prototype_manager.gmm_prototypes[cls_str].data
                dist.all_reduce(mean_tensor, op=dist.ReduceOp.SUM)
                mean_tensor /= dist.get_world_size()
                prototype_manager.gmm_prototypes[cls_str].data.copy_(mean_tensor)

                # 同步协方差参数
                cov_tensor = prototype_manager.gmm_covariances[cls_str].data
                dist.all_reduce(cov_tensor, op=dist.ReduceOp.SUM)
                cov_tensor /= dist.get_world_size()
                prototype_manager.gmm_covariances[cls_str].data.copy_(cov_tensor)

                # 同步权重参数
                weight_tensor = prototype_manager.gmm_weights[cls_str].data
                dist.all_reduce(weight_tensor, op=dist.ReduceOp.SUM)
                weight_tensor /= dist.get_world_size()
                prototype_manager.gmm_weights[cls_str].data.copy_(weight_tensor)
            # new_weight = new_weight.cpu().numpy()
            # new_weight = new_weight / np.sum(new_weight)
            # new_weight = np.round(new_weight, decimals=3)
            # new_weight[-1] = 1 - np.sum(new_weight[:-1])
            # new_weight = torch.tensor(new_weight, device=device)
            
        if cfg.SELF_TRAIN.get('DSNORM', None):
            # for emabf, bf in zip(ema_model.named_buffers(), model_named_buffers):
            #     emaname, emavalue = emabf
            #     name, value = bf
            #     assert emaname == name, 'name not equal:{} , {}'.format(emaname, name)
            #     # if 'running_mean' in name or 'running_var' in name:
            #     if cfg.SELF_TRAIN.get('EMA_BN', None):
            #         # alpha = 0.95
            #         emavalue.data = emavalue.data * alpha + value.data * (1 - alpha)
            #     elif cfg.SELF_TRAIN.get('EMA_COPY', None):
            #         emavalue.data = value.data
            # 手动同步BN层
            import torch.distributed as dist
            if dist.is_initialized():
                for bf in model.named_buffers():
                    name, value = bf
                    if 'running_mean' in name or 'running_var' in name:
                        data = value.data.clone()
                        dist.all_reduce(data, op=dist.ReduceOp.SUM)
                        data /= dist.get_world_size()
                        value.data.copy_(data)


        # log to console and tensorboard
        if rank == 0:
            pbar.update()
            pbar.set_postfix(dict(total_it=accumulated_iter, pos_ps_box=pos_ps_result,
                                  ign_ps_box=ign_ps_result))
            tbar.set_postfix(disp_dict)
            tbar.refresh()

            if tb_log is not None:
                tb_log.add_scalar('meta_data/learning_rate', cur_lr, accumulated_iter)
                wandb.log({'meta_data/learning_rate': cur_lr},
                          step=int(accumulated_iter))
                if cfg.SELF_TRAIN.SRC.USE_DATA and cfg.SELF_TRAIN.SRC.get('USE_GRAD', None):
                    tb_log.add_scalar('train/loss', loss, accumulated_iter)
                    for key, val in tb_dict.items():
                        tb_log.add_scalar('train/' + key, val, accumulated_iter)
                        wandb.log({key: val}, step=int(accumulated_iter))

                if cfg.SELF_TRAIN.TAR.USE_DATA:
                    tb_log.add_scalar('train/st_loss', st_loss, accumulated_iter)
                    for key, val in st_tb_dict.items():
                        tb_log.add_scalar('train/' + key, val, accumulated_iter)
                        wandb.log({key: val}, step=int(accumulated_iter))


    if rank == 0:
        pbar.close()
        for i, class_names in enumerate(target_loader.dataset.class_names):
            tb_log.add_scalar(
                'ps_box/pos_%s' % class_names, ps_bbox_nmeter.meters[i].avg, cur_epoch)
            tb_log.add_scalar(
                'ps_box/ign_%s' % class_names, ign_ps_bbox_nmeter.meters[i].avg, cur_epoch)
            wandb.log({'ps_box/pos_%s' % class_names:
                           ps_bbox_nmeter.meters[i].avg})
            wandb.log({'ps_box/ign_%s' % class_names:
                           ign_ps_bbox_nmeter.meters[i].avg})    

    return accumulated_iter

def update_ema_variables(model, ema_model, model_cfg=None, cur_iter=0):
    assert model_cfg is not None

    multiplier = 1.0

    alpha = model_cfg['EMA_MODEL_ALPHA']
    alpha = 1 - multiplier*(1-alpha)
    if cfg.SELF_TRAIN.get('HSSDA_EMA', None):
        ema_keep_rate = 0.999
        change_global_step = 1000
        if cur_iter < change_global_step:
            keep_rate = (ema_keep_rate - 0.8) / change_global_step * cur_iter + 0.8
        else:
            keep_rate = ema_keep_rate
        alpha = keep_rate
        # Use the true average until the exponential average is more correct
        # alpha = min(1 - 1 / (cur_iter + 1), alpha)
        # for ema_param, param in zip(self.pv_rcnn_ema.parameters(), self.pv_rcnn.parameters()):
        #     ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        # ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

    model_named_buffers = model.module.named_buffers() if hasattr(model, 'module') else model.named_buffers()
    for emabf, bf in zip(ema_model.named_buffers(), model_named_buffers):
        emaname, emavalue = emabf
        name, value = bf
        assert emaname == name, 'name not equal:{} , {}'.format(emaname, name)
        # if 'running_mean' in name or 'running_var' in name:
        if cfg.SELF_TRAIN.get('EMA_BN', None):
            # alpha = 0.95
            emavalue.data = emavalue.data * alpha + value.data * (1 - alpha)
        elif cfg.SELF_TRAIN.get('EMA_COPY', None):
            emavalue.data = value.data


def train_model_st(model, optimizer, source_loader, target_loader, model_func, lr_scheduler, optim_cfg,
                   start_epoch, total_epochs, start_iter, rank, tb_log, ckpt_save_dir, ps_label_dir,
                   source_sampler=None, target_sampler=None, lr_warmup_scheduler=None, ckpt_save_interval=1,
                   max_ckpt_save_num=50, merge_all_iters_to_one_epoch=False, logger=None, ema_model=None, 
                   source_loader_detect=None, source_sampler_detect=None, source_model=None, dist=None, pretrained=None):
    accumulated_iter = start_iter
    source_reader = common_utils.DataReader(source_loader, source_sampler)
    source_reader.construct_iter()

    # for continue training.
    # if already exist generated pseudo label result
    ps_pkl = self_training_utils.check_already_exsit_pseudo_label(ps_label_dir, start_epoch)
    if ps_pkl is not None:
        logger.info('==> Loading pseudo labels from {}'.format(ps_pkl))

    # for continue training
    if cfg.SELF_TRAIN.get('PROG_AUG', None) and cfg.SELF_TRAIN.PROG_AUG.ENABLED and \
        start_epoch > 0:
        for cur_epoch in range(start_epoch):
            if cur_epoch in cfg.SELF_TRAIN.PROG_AUG.UPDATE_AUG:
                target_loader.dataset.data_augmentor.re_prepare(
                    augmentor_configs=None, intensity=cfg.SELF_TRAIN.PROG_AUG.SCALE)

    with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True,
                     leave=(rank == 0)) as tbar:
        total_it_each_epoch = len(target_loader)
        if merge_all_iters_to_one_epoch:
            assert hasattr(target_loader.dataset, 'merge_all_iters_to_one_epoch')
            target_loader.dataset.merge_all_iters_to_one_epoch(merge=True, epochs=total_epochs)
            total_it_each_epoch = len(target_loader) // max(total_epochs, 1)

        dataloader_iter = iter(target_loader)

        for cur_epoch in tbar:
            if target_sampler is not None:
                target_sampler.set_epoch(cur_epoch)
                source_reader.set_cur_epoch(cur_epoch)

            # train one epoch
            if lr_warmup_scheduler is not None and cur_epoch < optim_cfg.WARMUP_EPOCH:
                cur_scheduler = lr_warmup_scheduler
            else:
                cur_scheduler = lr_scheduler

            # update pseudo label
            if (cur_epoch in cfg.SELF_TRAIN.UPDATE_PSEUDO_LABEL) or \
                    ((cur_epoch % cfg.SELF_TRAIN.UPDATE_PSEUDO_LABEL_INTERVAL == 0)
                     and cur_epoch != 0):
            
                if cfg.SELF_TRAIN.get('ProtoAlign', None) and (cfg.SELF_TRAIN.ProtoAlign.get('use_gmm', None) or cfg.SELF_TRAIN.ProtoAlign.get('use_kmeans', None)):
                    from pcdet.utils import self_training_utils_proto
                    source_loader.dataset.eval()
                    target_loader.dataset.eval()
                    self_training_utils_proto.save_pseudo_label_epoch(
                    ema_model if ema_model else model, target_loader, source_loader, rank,
                    leave_pbar=True, ps_label_dir=ps_label_dir, cur_epoch=cur_epoch, train_model=model)
                else:
                    target_loader.dataset.eval()
                    self_training_utils.save_pseudo_label_epoch(
                        ema_model if ema_model else model, target_loader, rank,
                        leave_pbar=True, ps_label_dir=ps_label_dir, cur_epoch=cur_epoch)


                if cfg.SELF_TRAIN.get('COMPLE_AUG', None):
                    import pickle as pkl
                    # Load PS Objects Labels
                    ps_box_list = glob.glob(os.path.join(cfg.SELF_TRAIN.PS_SAMPLING.PS_OBJECT_PATH, 'ps_box_e*.pkl'))
                    ps_pnt_list = glob.glob(os.path.join(cfg.SELF_TRAIN.PS_SAMPLING.PS_OBJECT_PATH,'ps_point_e*.pkl'))

                    ps_box_list.sort(key=os.path.getmtime, reverse=True)
                    with open(ps_box_list[0], 'rb') as f:
                        ps_boxes = pkl.load(f)
                    ps_pnt_list.sort(key=os.path.getmtime, reverse=True)
                    with open(ps_pnt_list[0], 'rb') as f:
                        ps_points = pkl.load(f)

                    global GLOBAL_PS_BOXES, GLOBAL_PS_POINTS
                    GLOBAL_PS_BOXES = ps_boxes
                    GLOBAL_PS_POINTS = ps_points
                
                if cfg.SELF_TRAIN.get('LOAD_OPTIMIZER_AFTER_PSEUDO_LABELING', None):
                    for g in optimizer.param_groups:
                        g['lr'] = cfg.OPTIMIZATION.LR
                    lr_scheduler, lr_warmup_scheduler = build_scheduler(
                        optimizer,
                        total_iters_each_epoch=total_it_each_epoch,
                        total_epochs=cfg.OPTIMIZATION.NUM_EPOCHS,
                        last_epoch=0, optim_cfg=cfg.OPTIMIZATION
                    )
                    if lr_warmup_scheduler is not None and \
                            cur_epoch < optim_cfg.WARMUP_EPOCH:
                        cur_scheduler = lr_warmup_scheduler
                    else:
                        cur_scheduler = lr_scheduler
                    new_accumulated_iter = 1

                if cfg.SELF_TRAIN.get('OBJECT', None) and cfg.SELF_TRAIN.OBJECT.get('rdu', None):
                    cfg.SELF_TRAIN.OBJECT.max = cfg.SELF_TRAIN.OBJECT.max - 1

                if cfg.SELF_TRAIN.get('SRC_COMPLE_AUG', None):
                    if cfg.SELF_TRAIN.SRC_COMPLE_AUG.SRC_PROB > 0:
                        cfg.SELF_TRAIN.SRC_COMPLE_AUG.SRC_PROB = cfg.SELF_TRAIN.SRC_COMPLE_AUG.SRC_PROB - 0.05

                if cfg.SELF_TRAIN.get('ada_threshold', None):
                    cfg.SELF_TRAIN.SCORE_THRESH = (np.array(cfg.SELF_TRAIN.SCORE_THRESH) + 0.02).tolist()
            

                if cfg.SELF_TRAIN.get('PLS', None):
                    print('**********************Update Scale**********************')   
                    print('mean (std)|length     |width      |height')
                    print(cfg.DATA_CONFIG.DATA_AUGMENTOR.AUG_CONFIG_LIST[1]['SIZE_RES'])

                target_loader.dataset.train()

                if cfg.SELF_TRAIN.get('ADABN_EPOCH', None):
                    model.train()
                    if ema_model is not None:
                        ema_model.train()
                    if cfg.SELF_TRAIN.get('DSNORM', None):
                        model.apply(set_ds_target)
                    if ema_model is not None:
                        ema_model.apply(set_ds_target)
                    for cur_it in range(total_it_each_epoch):
                        try:
                            target_batch = next(dataloader_iter)
                        except StopIteration:
                            dataloader_iter = iter(target_loader)
                            target_batch = next(dataloader_iter)
                            print('new iters')
                        with torch.no_grad():
                            load_data_to_gpu(target_batch)
                            _ = model(target_batch)[0]
                    if ema_model is not None:
                        update_ema_variables(model, ema_model, model_cfg=ema_model.model_cfg)
            
            # curriculum data augmentation
            if cfg.SELF_TRAIN.get('PROG_AUG', None) and cfg.SELF_TRAIN.PROG_AUG.ENABLED and \
                (cur_epoch in cfg.SELF_TRAIN.PROG_AUG.UPDATE_AUG):
                target_loader.dataset.data_augmentor.re_prepare(
                    augmentor_configs=None, intensity=cfg.SELF_TRAIN.PROG_AUG.SCALE)

            accumulated_iter = train_one_epoch_st(
                model, optimizer, source_reader, target_loader, model_func,
                lr_scheduler=cur_scheduler,
                accumulated_iter=accumulated_iter, optim_cfg=optim_cfg,
                rank=rank, tbar=tbar, tb_log=tb_log,
                leave_pbar=(cur_epoch + 1 == total_epochs),
                total_it_each_epoch=total_it_each_epoch,
                dataloader_iter=dataloader_iter, ema_model=ema_model, cur_epoch=cur_epoch
            )

            commu_utils.synchronize()


            # save trained model
            trained_epoch = cur_epoch + 1
            if trained_epoch % ckpt_save_interval == 0 and rank == 0:

                ckpt_list = glob.glob(str(ckpt_save_dir / 'checkpoint_epoch_*.pth'))
                ckpt_list.sort(key=os.path.getmtime)

                if ckpt_list.__len__() >= max_ckpt_save_num:
                    for cur_file_idx in range(0, len(ckpt_list) - max_ckpt_save_num + 1):
                        os.remove(ckpt_list[cur_file_idx])

                ckpt_name = ckpt_save_dir / ('checkpoint_epoch_%d' % trained_epoch)
                state = checkpoint_state(model, optimizer, trained_epoch, accumulated_iter)

                save_checkpoint(state, filename=ckpt_name)
