import numpy as np
import copy
from .box_utils import remove_points_in_boxes3d, enlarge_box3d
from pcdet.ops.iou3d_nms import iou3d_nms_utils
from pcdet.config import cfg
from pcdet.datasets.augmentor.augmentor_utils import get_points_in_box
from pcdet.utils.box_utils import mask_boxes_outside_range_numpy

def process_batch_data(batch, idx):
    try:
        gt_box = batch['gt_boxes'][idx][:, :7].cpu().numpy()
    except:
        gt_box = batch['gt_boxes'][idx][:, :7]
    gt_box = gt_box[~np.all(gt_box == 0, axis=1)]  # 去除全为 0 的行

    try:
        points = batch['points'][batch['points'][:, 0] == idx][:, 1:].cpu().numpy()
    except:
        points = batch['points'][batch['points'][:, 0] == idx][:, 1:]

    return gt_box, points

def laser_mix_domain(target_batch, source_batch, target_loader):
    data_dict_list = []

    batch_size = min(target_batch['gt_boxes'].shape[0], source_batch['gt_boxes'].shape[0])

    pitch_angles = cfg.SELF_TRAIN.MID.get('PITCH', [-25, 3])
    num_areas = cfg.SELF_TRAIN.MID.get('NUM_AREAS', [3, 4, 5, 6])
    num_areas = np.random.choice(num_areas, size=1)[0]
    order = np.random.choice([0, 1], size=1)[0]

    lower_bound = np.deg2rad(pitch_angles[0]) + 1e-5
    upper_bound = np.deg2rad(pitch_angles[1]) - 1e-5
    pitch_angles = [lower_bound, upper_bound]

    for frame_idx in range(batch_size):
        mix_data = {}
        s_gt_box, single_pc_pnts = process_batch_data(source_batch, frame_idx)
        t_gt_box, target_points = process_batch_data(target_batch, frame_idx)

        # 提取 source 和 target 的点云
        source_obj_points = [get_points_in_box(single_pc_pnts, box[:7])[0] for box in s_gt_box]
        target_obj_points = [get_points_in_box(target_points, box[:7])[0] for box in t_gt_box]

        # 移除点云中的目标框点
        single_pc_pnts = remove_points_in_boxes3d(single_pc_pnts, enlarge_box3d(s_gt_box, extra_width=[1, 0.5, 0.5]))
        # single_pc_pnts = remove_points_in_boxes3d(single_pc_pnts, enlarge_box3d(t_gt_box, extra_width=[1, 0.5, 0.5]))
        # target_points = remove_points_in_boxes3d(target_points, enlarge_box3d(s_gt_box, extra_width=[1, 0.5, 0.5]))
        target_points = remove_points_in_boxes3d(target_points, enlarge_box3d(t_gt_box, extra_width=[1, 0.5, 0.5]))

        # 计算 pitch angle
        def compute_pitch(points, boxes):
            rho = np.sqrt(points[:, 0]**2 + points[:, 1]**2)
            pitch = np.arctan2(-1.8 + points[:, 2], rho)
            pitch = np.clip(pitch, pitch_angles[0] + 1e-5, pitch_angles[1] - 1e-5)

            rho_box = np.sqrt(boxes[:, 0]**2 + boxes[:, 1]**2)
            pitch_box = np.arctan2(-1.8 + boxes[:, 2], rho_box)
            pitch_box = np.clip(pitch_box, pitch_angles[0] + 1e-5, pitch_angles[1] - 1e-5)

            return pitch, pitch_box

        src_pitch, src_pitch_box = compute_pitch(single_pc_pnts, s_gt_box)
        tgt_pitch, tgt_pitch_box = compute_pitch(target_points, t_gt_box)

        # 划分区域并混合点云
        angle_list = np.linspace(pitch_angles[1], pitch_angles[0], num_areas + 1)
        mix_points, mix_boxes, mix_obj_points = [], [], []

        for i in range(num_areas):
            start_angle = angle_list[i + 1] # / 180 * np.pi
            end_angle = angle_list[i] # / 180 * np.pi

            if i % 2 == order:
                idx = (src_pitch > start_angle) & (src_pitch <= end_angle)
                mix_points.append(single_pc_pnts[idx])
                idx_b = (src_pitch_box > start_angle) & (src_pitch_box <= end_angle)
                mix_boxes.append(s_gt_box[idx_b])
                if idx_b.sum() > 0:
                    mix_obj_points.append(np.concatenate([arr for arr, flag in zip(source_obj_points, idx_b) if flag]))
            else:
                idx = (tgt_pitch > start_angle) & (tgt_pitch <= end_angle)
                mix_points.append(target_points[idx])
                idx_b = (tgt_pitch_box > start_angle) & (tgt_pitch_box <= end_angle)
                mix_boxes.append(t_gt_box[idx_b])
                if idx_b.sum() > 0:
                    mix_obj_points.append(np.concatenate([arr for arr, flag in zip(target_obj_points, idx_b) if flag]))

        # 拼接点云和检测框
        mix_points = np.concatenate(mix_points)
        mix_boxes = np.concatenate(mix_boxes)
        # mix_obj_points = np.concatenate(mix_obj_points) if len(mix_obj_points) > 0 else []
        # mix_points = np.concatenate([mix_points, mix_obj_points] if len(mix_obj_points) > 0 else mix_points)
        if len(mix_obj_points) > 0:
            mix_obj_points = np.concatenate(mix_obj_points)
            mix_points = np.concatenate([mix_points, mix_obj_points])
            # 否则 mix_points 保持原样（已在前面拼接过）


        # 准备数据字典
        data_dict = {
            'points': mix_points,
            'frame_id': target_batch['frame_id'][frame_idx],
            'gt_boxes': mix_boxes[:, :7]
        }
        if len(mix_boxes) < 0:
            data_dict = {
                'points': target_points,
                'frame_id': target_batch['frame_id'][frame_idx],
                'gt_boxes': t_gt_box[:, :7]
            }
        data_dict = target_loader.dataset.prepare_data_easy(data_dict=data_dict)
        data_dict_list.append(data_dict)
    # 定义需要保留的键
    keys_to_keep = {'points', 'frame_id', 'gt_boxes', 'non_sampled_gt_boxes_num', 'use_lead_xyz', 'voxels', 'voxel_coords', 'voxel_num_points'}
    # print('data_dict_list', data_dict_list)
    # 遍历列表中的每个字典，移除不需要的键
    for data in data_dict_list:
        # print(type(data))
        keys_to_remove = [key for key in data.keys() if key not in keys_to_keep]
        for key in keys_to_remove:
            del data[key]

    mid_batch = target_loader.dataset.collate_batch(data_dict_list)
    for key in target_batch.keys():
        if key not in mid_batch:
            mid_batch[key] = target_batch[key]
    return mid_batch

def check_aspect2D(crop_range, aspect_min):
    xy_aspect = np.min(crop_range[:2]) / np.max(crop_range[:2])
    return (xy_aspect >= aspect_min)

def cut_mix_domain(target_batch, source_batch, target_loader):
    data_dict_list = []

    batch_size = min(target_batch['gt_boxes'].shape[0], source_batch['gt_boxes'].shape[0])
    for frame_idx in range(batch_size):

        s_gt_box, single_pc_pnts = process_batch_data(source_batch, frame_idx)
        t_gt_box, target_points = process_batch_data(target_batch, frame_idx)

        # 提取 source 和 target 的点云
        source_obj_points = [get_points_in_box(single_pc_pnts, box[:7])[0] for box in s_gt_box]
        target_obj_points = [get_points_in_box(target_points, box[:7])[0] for box in t_gt_box]

        data_source = {
            'points': single_pc_pnts,
            'gt_boxes': s_gt_box
        }
        data_target = {
            'points': target_points,
            'gt_boxes': t_gt_box
        }
        cutmixed_data = inter_domain_point_cutmix(data_source, data_target, inc_method='cutmix')
        cutmixed_data['frame_id'] = target_batch['frame_id'][frame_idx]
        cutmixed_data = target_loader.dataset.prepare_data_easy(data_dict=cutmixed_data)
        data_dict_list.append(cutmixed_data)
    # 定义需要保留的键
    keys_to_keep = {'points', 'frame_id', 'gt_boxes', 'non_sampled_gt_boxes_num', 'use_lead_xyz', 'voxels', 'voxel_coords', 'voxel_num_points'}
    # print('data_dict_list', data_dict_list)
    # 遍历列表中的每个字典，移除不需要的键
    for data in data_dict_list:
        # print(type(data))
        keys_to_remove = [key for key in data.keys() if key not in keys_to_keep]
        for key in keys_to_remove:
            del data[key]

    mid_batch = target_loader.dataset.collate_batch(data_dict_list)
    for key in target_batch.keys():
        if key not in mid_batch:
            mid_batch[key] = target_batch[key]
    return mid_batch


def inter_domain_point_cutmix(data_source, data_target, inc_method):
    """
    Random crop a range in data_source, replace the points in this area with that in
    data_target, and boxes
    Args:
        data_source (_type_): an input sample point cloud, assert Waymo
        data_target (_type_): another input sample point cloud, assert NuScenes
        pc_range (_type_): point cloud range
    """
    pc_range = [-75.2, -75.2, -2, 75.2, 75.2, 4]
    pc_range = np.array(pc_range, dtype=np.float32)
    cutmixed_data = copy.deepcopy(data_target)
    cutmixed_data['points'] = []
    assert len(cutmixed_data['points']) == 0, 'new generated cutmixed_data should contain 0 point before inter_domain point_cutmix!'
    cutmixed_data['gt_boxes'] = []
    assert len(cutmixed_data['gt_boxes']) == 0, 'new generated cutmixed_data should contain 0 gt_box before inter_domain point_cutmix!'
    
    range_xy = pc_range[3:5] - pc_range[0:2]

    crop_range = 0.5 + (np.random.rand(2) * 0.5)

    loop_count = 0
    while not check_aspect2D(crop_range, 0.75):
        loop_count += 1
        crop_range = 0.5 + (np.random.rand(2) * 0.5)
        if loop_count > 100:
            break

    while True:
        new_range = range_xy * crop_range / 2.0
        sample_center = data_source['points'][np.random.choice(len(data_source['points'])), 0:3]
        # print(sample_center)
        max_xy = sample_center[:2] + new_range
        min_xy = sample_center[:2] - new_range
        
        upper_idx_source = np.sum((data_source['points'][:, :2] < max_xy).astype(np.int32), 1) == 2
        lower_idx_source = np.sum((data_source['points'][:, :2] > min_xy).astype(np.int32), 1) == 2
        
        upper_idx_target = np.sum((data_target['points'][:, :2] < max_xy).astype(np.int32), 1) == 2
        lower_idx_target = np.sum((data_target['points'][:, :2] > min_xy).astype(np.int32), 1) == 2
        
        inside_region_point_idx_source = ((upper_idx_source) & (lower_idx_source))
        outside_region_point_idx_source = ~inside_region_point_idx_source
        
        inside_region_point_idx_target = ((upper_idx_target) & (lower_idx_target))
        outside_region_point_idx_target = ~inside_region_point_idx_target
        
        # avoid nus having too few points
        if (np.sum(inside_region_point_idx_target) > 10000):
            break
        
    # Old version: loop to add the points from source and target domain to the final data dict.
    # for i in range(len(inside_region_point_idx_target)):
    #     if inside_region_point_idx_target[i]:
    #         cutmixed_data['points'].append(data_target['points'][i, :])

    # for i in range(len(outside_region_point_idx_source)):
    #     if outside_region_point_idx_source[i]:
    #         cutmixed_data['points'].append(data_source['points'][i, :])

    # New version: broadcast to add the points from source and target domain to the final data dict.
    cutmixed_data['points'].extend(data_target['points'][inside_region_point_idx_target, :])
    cutmixed_data['points'].extend(data_source['points'][outside_region_point_idx_source, :])

    assert len(cutmixed_data['points']) != 0, 'new generated cutmixed_data should contain more than 0 point after inter_domain point_cutmix!'

    region_range = [min_xy[0], min_xy[1], pc_range[2], max_xy[0], max_xy[1], pc_range[5]]
    inside_region_gt_boxes_mask_source = mask_boxes_outside_range_numpy(data_source['gt_boxes'], region_range, min_num_corners=1)
    inside_region_gt_boxes_mask_target = mask_boxes_outside_range_numpy(data_target['gt_boxes'], region_range, min_num_corners=1)

    inside_region_gt_boxes_source = data_source['gt_boxes'][inside_region_gt_boxes_mask_source]
    inside_region_gt_boxes_target = data_target['gt_boxes'][inside_region_gt_boxes_mask_target]
    outside_region_gt_boxes_source = data_source['gt_boxes'][~inside_region_gt_boxes_mask_source]
    outside_region_gt_boxes_target = data_target['gt_boxes'][~inside_region_gt_boxes_mask_target]

    cutmixed_data['gt_boxes'].extend(outside_region_gt_boxes_source)
    cutmixed_data['gt_boxes'].extend(inside_region_gt_boxes_target)

    cutmixed_data['points'] = np.array(cutmixed_data['points'])
    cutmixed_data['gt_boxes'] = np.array(cutmixed_data['gt_boxes'])

    return cutmixed_data



# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# Written by Aoran Xiao, 09:43 2022/03/05
# Wish for world peace!

# Modified by Yecheol Kim,
# Convert label section formatted in Seamantic KITTI to make it compatible 
# with NuScenes

import copy

import numpy as np

from pcdet.utils import box_utils
from pcdet.ops.roiaware_pool3d import roiaware_pool3d_utils
from pcdet.ops.iou3d_nms import iou3d_nms_utils

def sig_polar(x):
    # change range 0~1 -> -1~1, hyperparameter alpha:6
    return 1/(1+np.exp(-6*(x*2-1)))

def is_overlap(x1, x2, x3, x4):
    if x2 < x1:
        x1, x2 = x2, x1
    if x4 < x3:
        x3, x4 = x4, x3
    
    if x2 < x3 or x4 < x1: 
        return False
    else:
        return True

def swap(pt1, pt2, start_angle, end_angle, label1, label2, inc_method='center', use_pitch=False):
    # calculate horizontal angel for each center of gt bbox
    n_label1, n_label2 = label1.shape[0], label2.shape[0]
    if inc_method == 'center':
        yaw1 = -np.arctan2(label1[:, 1], label1[:, 0])
        yaw2 = -np.arctan2(label2[:, 1], label2[:, 0])
        idx1 = np.where((yaw1>start_angle) & (yaw1<end_angle))
        idx2 = np.where((yaw2>start_angle) & (yaw2<end_angle))
    elif inc_method == 'corner' or inc_method == 'corner_del':
        corner1 = box_utils.boxes_to_corners_3d(label1)[:, :, :2]
        corner2 = box_utils.boxes_to_corners_3d(label2)[:, :, :2]
        yaw1 = -np.arctan2(corner1[:, :, 1], corner1[:, :, 0])
        yaw2 = -np.arctan2(corner2[:, :, 1], corner2[:, :, 0])
        idx1 = np.any((yaw1>start_angle) & (yaw1<end_angle), axis=1)
        idx2 = np.all((yaw2>start_angle) & (yaw2<end_angle), axis=1)
        if inc_method == 'corner_del':
            idx1_a = np.all((yaw1>start_angle) & (yaw1<end_angle), axis=1)
            idx2_a = np.any((yaw2>start_angle) & (yaw2<end_angle), axis=1)
            # when only part of corners are wiht in region remove points in bbox
            pt1 = box_utils.remove_points_in_boxes3d(pt1, label1[idx1 != idx1_a][:, :7])
            pt2 = box_utils.remove_points_in_boxes3d(pt2, label2[idx2 != idx2_a][:, :7])
        idx1, idx2 = np.where(idx1), np.where(idx2)
    else:
        NotImplementedError()
    comp1 = np.setdiff1d(np.arange(n_label1), idx1[0])
    label1_out = label1[comp1]
    label1_out = np.concatenate((label1_out, label2[idx2]))

    # calculate horizontal angle for each point
    yaw1 = -np.arctan2(pt1[:, 1], pt1[:, 0])
    yaw2 = -np.arctan2(pt2[:, 1], pt2[:, 0])

    # select points in sector
    idx1 = np.where((yaw1>start_angle) & (yaw1<end_angle))
    if use_pitch:
        dis1 = np.sqrt(np.power(pt1[:, 0], 2) + np.power(pt1[:, 1], 2))
        dis2 = np.sqrt(np.power(pt2[:, 0], 2) + np.power(pt2[:, 1], 2))
        mask1 = dis1 > 1
        mask2 = dis2 > 1
        pitch1 = -np.arctan2(pt1[:, 2], dis1)
        pitch2 = -np.arctan2(pt2[:, 2], dis2)
        pitch1_min, pitch1_max = pitch1[mask1].min(), pitch1[mask1].max()
        idx2_1 = np.where(((yaw2<start_angle) | (yaw2>end_angle)) & ((pitch2 < pitch1_min) | (pitch2 > pitch1_max)) & mask2)
        idx2_2 = np.where((yaw2>start_angle) & (yaw2<end_angle))
        #  if idx2_1[0].shape[0]:
            #  print('pitch: %d' % idx2_1[0].shape[0])
        idx2 = (np.hstack((idx2_1[0], idx2_2[0])))
    else:
        idx2 = np.where((yaw2>start_angle) & (yaw2<end_angle))

    # swap
    pt1_out = np.delete(pt1, idx1, axis=0)
    pt1_out2 = np.concatenate((pt1_out, pt2[idx2]))


    return pt1_out2, label1_out

def swap_with_range(pt1, pt2, start_angle, end_angle, label1, label2, pc_range):
    # pick distance threshold
    dis_th = np.random.random() * pc_range[3]

    # calculate horizontal angle and distance for each point
    yaw1 = -np.arctan2(pt1[:, 1], pt1[:, 0])
    yaw2 = -np.arctan2(pt2[:, 1], pt2[:, 0])
    dis1 = np.sqrt(pt1[:, 0] ** 2 + pt1[:, 1] ** 2)
    dis2 = np.sqrt(pt2[:, 0] ** 2 + pt2[:, 1] ** 2)

    # select points in sector
    idx1 = (yaw1>start_angle) & (yaw1<end_angle)
    idx2 = (yaw2>start_angle) & (yaw2<end_angle)
    if dis_th / pc_range[3] > 0.4:
        idx1 = np.logical_and(idx1, (dis1 < dis_th))
        idx2 = np.logical_and(idx2, (dis2 < dis_th))
    else:
        idx1 = np.logical_and(idx1, (dis1 > dis_th))
        idx2 = np.logical_and(idx2, (dis2 > dis_th))
    idx1, idx2 = np.where(idx1), np.where(idx2)

    # swap
    pt1_out = np.delete(pt1, idx1, axis=0)
    pt1_out = np.concatenate((pt1_out, pt2[idx2]))
    pt2_out = np.delete(pt2, idx2, axis=0)
    pt2_out = np.concatenate((pt2_out, pt1[idx1]))

    # calculate horizontal angel for each center of gt bbox
    n_label1, n_label2 = label1.shape[0], label2.shape[0]
    yaw1 = -np.arctan2(label1[:, 1], label1[:, 0])
    yaw2 = -np.arctan2(label2[:, 1], label2[:, 0])
    dis1 = np.sqrt(label1[:, 0] ** 2 + label1[:, 1] ** 2)
    dis2 = np.sqrt(label2[:, 0] ** 2 + label2[:, 1] ** 2)
    idx1 = (yaw1>start_angle) & (yaw1<end_angle)
    idx2 = (yaw2>start_angle) & (yaw2<end_angle)
    if dis_th / pc_range[3] > 0.4:
        idx1 = np.logical_and(idx1, (dis1 < dis_th))
        idx2 = np.logical_and(idx2, (dis2 < dis_th))
    else:
        idx1 = np.logical_and(idx1, (dis1 > dis_th))
        idx2 = np.logical_and(idx2, (dis2 > dis_th))
    idx1, idx2 = np.where(idx1), np.where(idx2)

    comp1 = np.setdiff1d(np.arange(n_label1), idx1[0])
    label1_out = label1[comp1]
    label1_out = np.concatenate((label1_out, label2[idx2]))
    comp2 = np.setdiff1d(np.arange(n_label2), idx2[0])
    label2_out = label2[comp2]
    label2_out = np.concatenate((label2_out, label1[idx1]))

    return pt1_out, label1_out
    
def rotate_copy(pts, labels, Omega, labels2):
    labels_inst = labels

    # rotate-copy
    pts_copy = []
    labels_copy = []
    labels_exist = [labels2]
    for omega_j in Omega:
        # rotate box
        rot_mat = np.array([[np.cos(omega_j),
                             np.sin(omega_j), 0],
                            [-np.sin(omega_j),
                             np.cos(omega_j), 0], [0, 0, 1]])
        new_labels_inst = copy.deepcopy(labels_inst)
        new_labels_inst[:, :3] = np.dot(labels_inst[:, :3], rot_mat)
        new_labels_inst[:, 6] += omega_j

        # check overlap between existing boxes
        overlap = iou3d_nms_utils.boxes_bev_iou_cpu(np.concatenate(labels_exist, axis=0)[:, :7],
                                                    new_labels_inst[:, :7])
        overlap_mask = overlap.sum(0) == 0
        new_labels_inst = new_labels_inst[overlap_mask]

        labels_copy.append(new_labels_inst)
        labels_exist.append(new_labels_inst)

        # extract points
        point_masks = roiaware_pool3d_utils.points_in_boxes_cpu(pts[:, :3], 
                                                                labels_inst[overlap_mask][:, :7])
        point_masks = point_masks.sum(0) != 0
        pts_inst = pts[point_masks]
        new_pt = np.zeros_like(pts_inst)
        new_pt[:, :3] = np.dot(pts_inst[:, :3], rot_mat)
        # new_pt[:, 3] = pts_inst[:, 3]
        pts_copy.append(new_pt)

    pts_copy = np.concatenate(pts_copy, axis=0)
    labels_copy = np.concatenate(labels_copy, axis=0)
    return pts_copy, labels_copy

def polarmix(pts1, labels1, pts2, labels2, swap_range, Omega, pc_range, polar_dis, inc_method,
             use_pitch):
    """
    Args:
        pts1: source domain points
        labels1: source domain labels (gt bbox)
        pts2: target domain points
        labels2: target domain labels (gt bbox)
        alpha (float): start angle
        beta (float): end angle 
        Omega (List[float]): Instance-level rotate-pasting angles

    Return:
        pts_out: polarmix points
        labels_out: polarmix labels
    """
    pts_out, labels_out = pts1, labels1
    # swapping
    if np.random.random() < 1.0:
        for i in range(len(swap_range)):
            if polar_dis == 'FULL':
                pts_out, labels_out = swap(pts_out, pts2, 
                                           start_angle=swap_range[i][0],
                                           end_angle=swap_range[i][1],
                                           label1=labels_out, label2=labels2,
                                           inc_method=inc_method,
                                           use_pitch=use_pitch
                                           )
            elif polar_dis == 'RAND':
                pts_out, labels_out = swap_with_range(pts_out, pts2, 
                                                      start_angle=swap_range[i][0],
                                                      end_angle=swap_range[i][1],
                                                      label1=labels_out, label2=labels2, 
                                                      pc_range=pc_range, polar_dis=polar_dis)
        #  nus_vis(pts_out, labels_out, 'vis_1.png')
        #  nus_vis(pts1, labels1, 'vis_ori.png')
        #  nus_vis(pts2, labels2, 'vis_trg.png')
        #  print('PolarMix swep')

    # rotate-pasting
    if np.random.random() < 1.0:
        # rotate-copy
        pts_copy, labels_copy = rotate_copy(pts2, labels2, Omega, labels_out)
        # paste
        #  nus_vis(pts_out, labels_out, 'vis_1.png')
        pts_out = box_utils.remove_points_in_boxes3d(pts_out, labels_copy[:, :7])
        #  nus_vis(pts_out, labels_out, 'vis_2.png')
        pts_out = np.concatenate((pts_out, pts_copy), axis=0)
        labels_out = np.concatenate((labels_out, labels_copy), axis=0)
        #  nus_vis(pts_out, labels_out, 'vis_3.png')
        #  print('PolarMix rotate-pasting')

    return pts_out, labels_out

def inter_domain_point_polarmix(data_dict_source, data_dict_target, polarmix_rot_copy_num, polarmix_degree,
                                train_percent, update_methods, pc_range, polar_dis, inc_method, use_pitch):
    if isinstance(polarmix_degree, float):
        p_degree = [polarmix_degree, polarmix_degree]
    elif isinstance(polarmix_degree, list):
        if len(polarmix_degree) == 1:
            p_degree = [polarmix_degree[0], polarmix_degree[0]]
        else:
            p_degree = [polarmix_degree[0], polarmix_degree[1]]

    swap_range = []
    for update_method in update_methods:
        if update_method == 'FIX':
            prand_degree = p_degree[0]
        elif update_method == 'RAND':
            prand_degree = np.random.uniform(p_degree[0], p_degree[1])
        elif update_method == 'ASC':
            prand_degree = p_degree[0] + (p_degree[1] - p_degree[0]) * train_percent
        elif update_method == 'ASC_SIG':
            prand_degree = p_degree[0] + (p_degree[1] - p_degree[0]) * sig_polar(train_percent)
        elif update_method == 'DESC':
            prand_degree = p_degree[1] - (p_degree[1] - p_degree[0]) * train_percent
            
        num_swap = len(swap_range)
        for _ in range(100):
            swap_st = (np.random.random() * 2 - 1) * np.pi # -pi ~ pi
            ov_flag = False
            for i in range(num_swap):
                ov_flag = is_overlap(swap_range[i][0], swap_range[i][1], swap_st, swap_st + prand_degree)
                if ov_flag:
                    break
            if ov_flag == False:
                swap_range.append([swap_st, swap_st + prand_degree])
                break

        num_swap = len(swap_range)
        for i in range(num_swap):
            if swap_range[i][1] > np.pi:
                swap_range.append([-np.pi, swap_range[i][1]-(np.pi*2)])
                swap_range[i][1] = np.pi

    Omega = [0, np.random.random() * np.pi * 2 / 3, (np.random.random() + 1) * np.pi * 2 / 3]  # x3
    Omega = Omega[:polarmix_rot_copy_num]
    pts_out, labels_out = polarmix(data_dict_source['points'],
                                   data_dict_source['gt_boxes'],
                                   data_dict_target['points'],
                                   data_dict_target['gt_boxes'],
                                   swap_range, Omega, pc_range, polar_dis, inc_method, use_pitch
                                   )
    cutmixed_data = copy.deepcopy(data_dict_target)
    cutmixed_data['points'] = pts_out
    cutmixed_data['gt_boxes'] = labels_out

    return cutmixed_data

def polar_mix_domain(target_batch, source_batch, target_loader):
    data_dict_list = []

    batch_size = min(target_batch['gt_boxes'].shape[0], source_batch['gt_boxes'].shape[0])
    for frame_idx in range(batch_size):

        s_gt_box, single_pc_pnts = process_batch_data(source_batch, frame_idx)
        t_gt_box, target_points = process_batch_data(target_batch, frame_idx)

        data_dict_source = {
            'points': single_pc_pnts,
            'gt_boxes': s_gt_box
        }
        data_dict_target = {
            'points': target_points,
            'gt_boxes': t_gt_box
        }
        polarmix_rot_copy_num = 1
        polarmix_degree = 1.570796
        train_percent = 0.3
        update_methods = ['FIX', 'FIX', 'FIX']
        pc_range = [-75.2, -75.2, -2, 75.2, 75.2, 4]
        polar_dis = 'FULL'
        inc_method = 'corner_del'
        use_pitch = False
        cutmixed_data = inter_domain_point_polarmix(data_dict_source, data_dict_target, polarmix_rot_copy_num, polarmix_degree,
                                train_percent, update_methods, pc_range, polar_dis, inc_method, use_pitch)
        cutmixed_data['frame_id'] = target_batch['frame_id'][frame_idx]
        cutmixed_data = target_loader.dataset.prepare_data_easy(data_dict=cutmixed_data)
        data_dict_list.append(cutmixed_data)
    # 定义需要保留的键
    keys_to_keep = {'points', 'frame_id', 'gt_boxes', 'non_sampled_gt_boxes_num', 'use_lead_xyz', 'voxels', 'voxel_coords', 'voxel_num_points'}
    # print('data_dict_list', data_dict_list)
    # 遍历列表中的每个字典，移除不需要的键
    for data in data_dict_list:
        # print(type(data))
        keys_to_remove = [key for key in data.keys() if key not in keys_to_keep]
        for key in keys_to_remove:
            del data[key]

    mid_batch = target_loader.dataset.collate_batch(data_dict_list)
    for key in target_batch.keys():
        if key not in mid_batch:
            mid_batch[key] = target_batch[key]
    return mid_batch


import numpy as np
import copy

def shuffle_points(data_dict):
    points = data_dict['points']
    shuffle_idx = np.random.permutation(points.shape[0])
    points = points[shuffle_idx]
    data_dict['points'] = points
    return data_dict

def intra_domain_point_mixup(data_dict_1, data_dict_2, alpha=None):
    new_data_dict = copy.deepcopy(data_dict_1)
    
    new_data_dict['points'] = []
    new_data_dict['gt_boxes'] = []
    
    lam = np.random.beta(alpha, alpha)
    
    data_dict_1 = shuffle_points(data_dict_1)
    data_dict_2 = shuffle_points(data_dict_2)

    new_data_dict['points'] = np.concatenate((data_dict_1['points'][:int(data_dict_1['points'].shape[0] * lam)], 
                                              data_dict_2['points'][:int(data_dict_2['points'].shape[0] * (1 - lam))]), axis=0)
    new_data_dict['gt_boxes'] = np.concatenate((data_dict_1['gt_boxes'], data_dict_2['gt_boxes']), axis=0)
    
    return new_data_dict

# collision detection
def intra_domain_point_mixup_cd(data_dict_1, data_dict_2, alpha=None):
    new_data_dict = copy.deepcopy(data_dict_1)

    new_data_dict['points'] = []
    new_data_dict['gt_boxes'] = []

    lam = np.random.beta(alpha, alpha)

    valid_boxes = data_dict_2['gt_boxes']
    try:
        # collision detection
        iou = iou3d_nms_utils.boxes_bev_iou_cpu(data_dict_1['gt_boxes'][:, 0:7], data_dict_2['gt_boxes'][:, 0:7])
        valid_mask = (iou.max(axis=0) == 0).nonzero()[0]
        invalid_mask = (iou.max(axis=0) > 0).nonzero()[0]
        valid_boxes = data_dict_2['gt_boxes'][valid_mask]
        invalid_boxes = data_dict_2['gt_boxes'][invalid_mask]
        assert len(valid_boxes) + len(invalid_boxes) == len(data_dict_2['gt_boxes'])

        cur_mask = None
        for box in invalid_boxes:
            points_in_box, mask = get_points_in_box(data_dict_2['points'], box)
            if cur_mask is not None:
                cur_mask = cur_mask & ~mask
            else:
                cur_mask = ~mask

        if cur_mask is not None:
            data_dict_2['points'] = data_dict_2['points'][cur_mask]
        # end collision detection
    except:
        pass

    data_dict_1 = shuffle_points(data_dict_1)
    data_dict_2 = shuffle_points(data_dict_2)

    new_data_dict['points'] = np.concatenate((data_dict_1['points'][:int(data_dict_1['points'].shape[0] * lam)], 
                                              data_dict_2['points'][:int(data_dict_2['points'].shape[0] * (1 - lam))]), axis=0)
    new_data_dict['gt_boxes'] = np.concatenate((data_dict_1['gt_boxes'], valid_boxes), axis=0)

    return new_data_dict

def mixup_mix_domain(target_batch, source_batch, target_loader):
    data_dict_list = []

    batch_size = min(target_batch['gt_boxes'].shape[0], source_batch['gt_boxes'].shape[0])
    for frame_idx in range(batch_size):

        s_gt_box, single_pc_pnts = process_batch_data(source_batch, frame_idx)
        t_gt_box, target_points = process_batch_data(target_batch, frame_idx)

        data_dict_1 = {
            'points': single_pc_pnts,
            'gt_boxes': s_gt_box
        }
        data_dict_2 = {
            'points': target_points,
            'gt_boxes': t_gt_box
        }
        # MIX_TYPE: mixup
        # COLLISION_DETECTION: True
        # MIXUP_PROB: 0.6
        # GT_PROB: 0.3
        # MIXUP_TYPE: gt+ps_gt+ps
        # ALPHA: 2
        alpha = 2
        cutmixed_data = intra_domain_point_mixup_cd(data_dict_1, data_dict_2, alpha=alpha)
        cutmixed_data['frame_id'] = target_batch['frame_id'][frame_idx]
        cutmixed_data = target_loader.dataset.prepare_data_easy(data_dict=cutmixed_data)
        data_dict_list.append(cutmixed_data)
    # 定义需要保留的键
    keys_to_keep = {'points', 'frame_id', 'gt_boxes', 'non_sampled_gt_boxes_num', 'use_lead_xyz', 'voxels', 'voxel_coords', 'voxel_num_points'}
    # print('data_dict_list', data_dict_list)
    # 遍历列表中的每个字典，移除不需要的键
    for data in data_dict_list:
        # print(type(data))
        keys_to_remove = [key for key in data.keys() if key not in keys_to_keep]
        for key in keys_to_remove:
            del data[key]

    mid_batch = target_loader.dataset.collate_batch(data_dict_list)
    for key in target_batch.keys():
        if key not in mid_batch:
            mid_batch[key] = target_batch[key]
    return mid_batch