import torch
import numpy as np
import time
from matplotlib.patches import Ellipse, Rectangle
import matplotlib.pyplot as plt
from typing import List, Tuple, Optional

def suggest_scale(bbox, scale_factor=0.8, image_size=(512,512)):
    """根据bbox大小推荐scale参数"""
    w, h = image_size
    xmin, ymin, xmax, ymax = bbox
    bbox_w = (xmax - xmin) / w  # 归一化宽度
    bbox_h = (ymax - ymin) / h  # 归一化高度
    # return round(max(bbox_w, bbox_h), 2)
    return round(max(bbox_w, bbox_h)*scale_factor, 2)
    # return round(max(bbox_w, bbox_h) * 1.5, 2)  # 略大于物体实际尺寸

def bbox_to_centers_and_scales_anisotropic(
    bboxes,                   # List[List[float]], each [xmin, ymin, xmax, ymax] in pixel
    image_size=(512, 512),    # (W, H)
    scale_factor=0.4,         # fraction of bbox size to turn into σ
    min_sigma=1e-3            # numerical lower‑bound
):
    """
    Convert pixel bboxes to normalized centers and anisotropic sigmas.
    Return:
        centers : Tensor [N, 2] in range [-1, 1]
        sigmas  : Tensor [N, 2] (σ_x, σ_y) in normalized coords
    """
    W, H = image_size
    bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
    # center in pixel coord
    cx = (bboxes[:, 0] + bboxes[:, 2]) * 0.5
    cy = (bboxes[:, 1] + bboxes[:, 3]) * 0.5
    # width / height → σ_x / σ_y in normalized coord [-1,1]
    bw = (bboxes[:, 2] - bboxes[:, 0]) / W     # [0, 1]
    bh = (bboxes[:, 3] - bboxes[:, 1]) / H
    sigma_x = torch.clamp(bw * scale_factor, min=min_sigma)
    sigma_y = torch.clamp(bh * scale_factor, min=min_sigma)

    # normalize centers to [-1,1]
    cx_norm = cx / (W - 1) * 2.0 - 1.0
    cy_norm = cy / (H - 1) * 2.0 - 1.0

    centers  = torch.stack([cx_norm, cy_norm], dim=-1)            # [N,2]
    sigmas   = torch.stack([sigma_x,  sigma_y ], dim=-1)          # [N,2]
    return centers, sigmas

def bbox_to_centers_and_scales(bboxes, scale_factor=0.8, image_size=(512, 512)):
    """将像素坐标系bbox转换为GMM中心点坐标
    Args:
        bboxes: List of [xmin, ymin, xmax, ymax]
        image_size: (width, height) 默认512x512
    Returns:
        centers: List of [x_center, y_center] (范围[-1,1])
    """
    w, h = image_size
    centers = []
    scales = []
    for bbox in bboxes:
        scales.append(suggest_scale(bbox, scale_factor, image_size))
        xmin, ymin, xmax, ymax = bbox
        x_center = ((xmin + xmax) / 2 / w) * 2 - 1  # 转换到[-1,1]
        y_center = ((ymin + ymax) / 2 / h) * 2 - 1
        centers.append([round(x_center, 2), round(y_center, 2)])
        
    return centers,scales

def get_centered_patch_and_coords(fixed_noise_partial, box_h, box_w):
    """
    返回patch和它在box区域内的起止坐标，只替换box中心区域，其余不动
    """
    _, C, H, W = fixed_noise_partial.shape
    device = fixed_noise_partial.device
    dtype = fixed_noise_partial.dtype

    # 计算patch实际大小
    patch_h = min(H, box_h)
    patch_w = min(W, box_w)

    # box内中心区域的起止
    box_patch_y1 = (box_h - patch_h) // 2
    box_patch_x1 = (box_w - patch_w) // 2
    box_patch_y2 = box_patch_y1 + patch_h
    box_patch_x2 = box_patch_x1 + patch_w

    # fixed_noise_partial内中心区域的起止
    fixed_patch_y1 = (H - patch_h) // 2
    fixed_patch_x1 = (W - patch_w) // 2
    fixed_patch_y2 = fixed_patch_y1 + patch_h
    fixed_patch_x2 = fixed_patch_x1 + patch_w

    patch = fixed_noise_partial[0, :, fixed_patch_y1:fixed_patch_y2, fixed_patch_x1:fixed_patch_x2]  # [C, patch_h, patch_w]
    return patch, box_patch_y1, box_patch_y2, box_patch_x1, box_patch_x2

def create_position_grid(image_size):
    """创建标准化坐标网格（范围[-1,1]）
    Args:
        image_size: (H,W) 或单个int（方形图像）
    Returns:
        grid: [H,W,2] 的张量，grid[y,x] = [x坐标, y坐标]
    """
    if isinstance(image_size, int):
        H = W = image_size
    else:
        H, W = image_size
    
    # 生成网格坐标（PyTorch的meshgrid与numpy相反）
    yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
    
    # 标准化到[-1,1]范围（兼容不同尺寸）
    grid = torch.stack([
        (xx.float() / (W-1)) * 2 - 1,  # x坐标归一化
        (yy.float() / (H-1)) * 2 - 1   # y坐标归一化
    ], dim=-1)
    
    return grid


def generate_anisotropic_gmm_noise(
    noise,               # Tensor [B,C,H,W]  original Gaussian noise ~N(0,1)
    centers,             # Tensor [N,2] in [-1,1]
    sigmas,              # Tensor [N,2] (σ_x, σ_y) in normalized coord
    weight=0.3,     # overall weight multiplier
    energy_norm=True     # keep global variance ≈ 1
):
    """
    Inject a sum of anisotropic Gaussian kernels into the noise volume.
    """
    dtype, device = noise.dtype, noise.device
    B, C, H, W = noise.shape

    # position grid once for all
    pos_grid = create_position_grid((H, W)).to(device=device, dtype=dtype)           # [H,W,2]
    pos_grid = pos_grid.view(1, 1, H, W, 2)                        # [1,1,H,W,2]
    # broadcast centers & sigmas
    centers = centers.to(device=device, dtype=dtype).view(-1, 1, 1, 1, 2)  # [N,1,1,1,2]
    sigmas  = sigmas .to(device=device, dtype=dtype).view(-1, 1, 1, 1, 2)  # [N,1,1,1,2]
    # sigmas = sigmas * 2.0 
    # compute per‑kernel Gaussian   exp( - 0.5 * ((x-µ)/σ)^2 + ((y-µ)/σ)^2 )
    exponent = -0.5 * ((pos_grid - centers) / sigmas).pow(2).sum(-1)       # [N,1,1,1]
    kernels  = torch.exp(exponent)                                          # [N,1,H,W]

    # weight each kernel inversely proportional to its area so that
    # total energy contribution is roughly constant across object size
    # (area ∝ σ_xσ_y).  Feel free to turn off if you want bigger objs louder.
    kernel_weights = torch.tensor(weight).to(device=device, dtype=dtype) # weight太大，导致亮点的出现
    kernels = kernels * kernel_weights                                      # [N,1,H,W]

    # sum over objects and broadcast to channels & batch
    kernel_sum = kernels.sum(0)                                             # [1,H,W]
    kernel_sum = kernel_sum.expand(C, H, W)                                 # [C,H,W]
    kernel_sum = kernel_sum.unsqueeze(0).expand(B, -1, -1, -1)              # [B,C,H,W]

    # add and (optionally) re‑normalize variance
    injected = noise + kernel_sum
    if energy_norm:
        density   = (np.pi * sigmas[...,0] * sigmas[...,1]).view(-1) / 4.0
        var_mul   = 1.0 + (kernel_weights.view(-1)**2 * density).sum()
        injected  /= torch.sqrt(torch.as_tensor(var_mul,
                              dtype=dtype, device=device))
    # if energy_norm:
    #     # Estimate var increase analytically:
    #     energy_add = (kernels.pow(2).mean(dim=(-2,-1))).sum()   # E[G_i²] over grid
    #     var_mul = 1.0 + energy_add
    #     # density = 1 / (4 * np.pi * sigmas[...,0] * sigmas[...,1])
    #     # var_mul = 1 + (kernel_weights.view(-1)**2 * density).sum()
    #     injected = injected / torch.sqrt(var_mul.cpu())

    return injected

def generate_anisotropic_gmm_noise_bbox_normalized(
        noise,                # [B,C,H,W]  ~𝒩(0,1)
        centers,              # [N,2]  in [-1,1]
        sigmas,               # [N,2]  σ_x, σ_y
        weight = 0.3,    # 基准强度
        beta = 0.4,    # 面积衰减指数
        energy_norm = True
):
    dtype, device = noise.dtype, noise.device
    B, C, H, W = noise.shape

    # ---------- broadcast -------------------------------------------
    pos_grid = create_position_grid((H, W)).to(device, dtype)      # [H,W,2]
    pos_grid = pos_grid.view(1, 1, H, W, 2)                        # [1,1,H,W,2]

    centers = centers.to(device, dtype).view(-1,1,1,1,2)           # [N,1,1,1,2]
    sigmas  = sigmas .to(device, dtype).view(-1,1,1,1,2)           # [N,1,1,1,2]
    # sigmas = sigmas * 2.0 # care

    # ---------- per‑kernel Gaussian --------------------------------
    kernels = torch.exp(-0.5 * ((pos_grid - centers)/sigmas).pow(2).sum(-1))  # [N,1,H,W]

    # ---------- size‑adaptive weight -------------------------------
    area = (sigmas[...,0] * sigmas[...,1]).view(-1)            # [N]
    # print(f"area: {area}")
    
    A_abs = 0.008
    if len(area[area < A_abs]) > 0: 
        print(f"小于{A_abs}的box: {len(area[area < A_abs])}")
    w_i = torch.where(
        area >= A_abs,
        torch.full_like(area, weight),
        weight * (A_abs / (area + 1e-6)).pow(beta)
    )
    w_i = w_i.clamp(min=0.5 * weight,  # 最低 0.06
        max=1.6 * weight)  # 最高 0.48      

    # broadcast 乘到 kernels
    kernels = kernels * w_i.view(-1, 1, 1, 1)  
    kernel_sum = kernels.sum(0)                   # [1,H,W]  sum over objs
    kernel_sum = kernel_sum.expand(C,H,W).unsqueeze(0)  # [B,C,H,W]

    injected = noise + kernel_sum

    # ---------- global var renorm ----------------------------------
    if energy_norm:
        density   = (np.pi * sigmas[...,0] * sigmas[...,1]).view(-1) / 4.0
        var_mul   = 1.0 + (w_i**2 * density).sum()
        injected  /= torch.sqrt(torch.as_tensor(var_mul,
                              dtype=dtype, device=device))

    return injected


def generate_uniform_gmm_noise(noise, centers, scale, weight):
    """生成带统一高斯偏移的噪声
    Args:
        shape: 输出噪声形状 [B,C,H,W]
        centers: 中心点列表 [[x1,y1],...]（坐标范围[-1,1]）
        scale: 高斯核尺度（控制物体大小）
        weight: 噪声强度系数
    Returns:
        noise: 混合噪声 [B,C,H,W]
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    shape = tuple(noise.shape)
    
    # 创建坐标网格（兼容任意图像尺寸）
    H, W = shape[-2:]
    pos_grid = create_position_grid((H, W)).to(device)  # [H,W,2]
    kernel_sum = torch.zeros_like(noise[0,0])  # [H,W]
    for center in centers:
        # 计算高斯核（使用广播机制）
        center = torch.tensor(center, device=device).view(1,1,2)
        kernel = torch.exp(-((pos_grid - center).pow(2).sum(-1)) / (2 * scale**2))
        # 叠加到所有通道
        kernel_sum += kernel
    
    # 精确能量归一化
    N = len(centers)
    alpha = 1/(4*np.pi*scale**2)  # 高斯核的二阶矩理论值
    normalization = torch.sqrt(torch.tensor(1 + N * weight**2 * alpha))
    
    noise = (noise + weight * kernel_sum.unsqueeze(0).unsqueeze(0)) / normalization
    # 能量归一化（保持总体方差稳定）
    return noise

def visualize_gmm_components(noise,centers,save_dir=f"/openseg_blob/zhaoyaqi/Count-FLUX/basecode_flux/kernel_noise/",img_name_prefix=""):
    """可视化GMM的各个高斯核"""
    import matplotlib.pyplot as plt
    import numpy as np
    noise = noise.float()
    plt.figure(figsize=(6, 6), dpi=150)
    # 设置全局样式
    plt.rcParams.update({
        'figure.facecolor': 'white',
        'axes.facecolor': 'white',
        'axes.grid': False,
        'font.family': 'sans-serif'
    })
    
    plt.imshow(noise[0,0].cpu().numpy(), 
              cmap='RdYlBu_r', 
              vmin=-2, 
              vmax=2)
    cbar = plt.colorbar(label='Noise Value')
    cbar.ax.tick_params(labelsize=10)
    cbar.set_label('Noise Value', size=11, weight='bold')
    # 在图像四周添加小边距
    plt.margins(0.1)
    for center in centers:
        px = int((center[0] + 1) / 2 * noise.shape[-1])
        py = int((center[1] + 1) / 2 * noise.shape[-2])
        plt.plot(px, py, 'r*', markersize=15, markeredgewidth=2)
        plt.text(px+15, py, f"{center}", 
                color='white', fontsize=11, fontweight='bold',
                bbox=dict(facecolor='red', alpha=0.7, edgecolor='none', pad=1))
    plt.title("Enhanced Combined Noise with Centers", pad=20, fontsize=13, fontweight='bold')
    plt.axis('off')
    image_name = f"{img_name_prefix}_train_flux_visual_gmm_components_{len(centers)}_{time.strftime('%Y%m%d_%H%M%S')}"
    plt.savefig(f"{save_dir}/{image_name}.png", bbox_inches='tight', dpi=150)


def visualize_gmm_components_ellipse(
    noise, centers, sigmas, bbox_info, save_dir="/openseg_blob/zhaoyaqi/Count-FLUX/basecode_flux/kernel_noise/", title="Elliptical GMM Noise"
):
    """可视化带椭圆高斯核的噪声分布
    Args
        noise   : Tensor [B,C,H,W] 噪声张量
        centers : Tensor/List [[cx,cy], ...] 中心点坐标 (范围[-1,1])
        sigmas  : Tensor/List [[σx,σy], ...] 各向异性的sigma值
        bbox_info : List [[xmin,ymin,xmax,ymax], ...] 边界框信息
    """
    noise = noise.float().cpu()
    H, W = noise.shape[-2:]
    
    # 设置全局样式
    plt.figure(figsize=(8, 8), dpi=150)
    plt.rcParams.update({
        'figure.facecolor': 'white',
        'axes.facecolor': 'white',
        'axes.grid': False,
        'font.family': 'sans-serif'
    })
    
    # 绘制噪声背景
    plt.imshow(noise[0, 0], cmap="RdYlBu_r", vmin=-2, vmax=2)
    cbar = plt.colorbar(label='Noise Value')
    cbar.ax.tick_params(labelsize=10)
    cbar.set_label('Noise Value', size=11, weight='bold')
    
    # 使用不同颜色绘制每个高斯核的椭圆和边界框
    colors = plt.cm.rainbow(np.linspace(0, 1, len(centers)))
    for i, ((cx, cy), (sx, sy), bbox, color) in enumerate(zip(centers, sigmas, bbox_info, colors)):
        # 中心点从[-1,1]转换到像素坐标
        px = (cx + 1) * 0.5 * (W - 1)
        py = (cy + 1) * 0.5 * (H - 1)
        
        # sigma从归一化尺度转换到像素尺度（直径 = 4σ）
        ex = 4 * sx * (W - 1)  # width
        ey = 4 * sy * (H - 1)  # height
        
        # 绘制椭圆
        ell = Ellipse(
            (px, py), width=ex, height=ey,
            edgecolor=color, facecolor="none", 
            lw=2, alpha=0.9
        )
        plt.gca().add_patch(ell)
        
        # 绘制中心点和标签
        plt.plot(px, py, "*", color=color, ms=12, mew=2)
        plt.text(px+10, py+10, f"K{i}\n(σx={sx:.2f}, σy={sy:.2f})", 
                color=color, fontsize=8, fontweight='bold',
                bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=1))
        
        # 绘制边界框
        x1, y1, x2, y2 = bbox
        # 将原始分辨率的坐标缩放到当前图像大小
        x1 = x1 * W / 1024
        y1 = y1 * H / 1024
        x2 = x2 * W / 1024
        y2 = y2 * H / 1024
        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                           fill=False, edgecolor=color, 
                           linestyle='--', linewidth=2, alpha=0.7)
        plt.gca().add_patch(rect)

    plt.title(title, fontsize=13, pad=10, fontweight='bold')
    plt.axis("off")
    plt.tight_layout()
    
    # 保存图像
    fname = f"{save_dir}/gmm_ellipses_{len(centers)}_{time.strftime('%Y%m%d_%H%M%S')}.png"
    plt.savefig(fname, bbox_inches="tight", dpi=150)
    print(f"[Saved] {fname}")
    plt.close()


def box_area(box):
    """Compute area of a single box [x1, y1, x2, y2]."""
    return max(0.0, box[2] - box[0]) * max(0.0, box[3] - box[1])

def inter_matrix(boxes):
    """NxN pair‑wise intersection areas."""
    b = np.asarray(boxes, dtype=np.float32)
    x1 = np.maximum(b[:, None, 0], b[None, :, 0])
    y1 = np.maximum(b[:, None, 1], b[None, :, 1])
    x2 = np.minimum(b[:, None, 2], b[None, :, 2])
    y2 = np.minimum(b[:, None, 3], b[None, :, 3])
    return np.clip(x2 - x1, 0, None) * np.clip(y2 - y1, 0, None)

def iou_inter_matrix(boxes):
    """Return IoU, intersection_area, areas."""
    b = np.asarray(boxes, dtype=np.float32)
    areas = np.array([box_area(bb) for bb in b], dtype=np.float32)

    # pair‑wise intersection
    x1 = np.maximum(b[:, None, 0], b[None, :, 0])
    y1 = np.maximum(b[:, None, 1], b[None, :, 1])
    x2 = np.minimum(b[:, None, 2], b[None, :, 2])
    y2 = np.minimum(b[:, None, 3], b[None, :, 3])
    inter = np.clip(x2-x1, 0, None) * np.clip(y2-y1, 0, None)

    union = areas[:,None] + areas[None,:] - inter
    iou   = inter / (union + 1e-6)
    return iou, inter, areas

def iou_matrix(boxes):
    """Return NxN IoU matrix."""
    n = len(boxes)
    ious = np.zeros((n, n), dtype=np.float32)
    areas = np.array([box_area(b) for b in boxes])

    for i in range(n):
        x1 = np.maximum(boxes[i][0], boxes[:, 0])
        y1 = np.maximum(boxes[i][1], boxes[:, 1])
        x2 = np.minimum(boxes[i][2], boxes[:, 2])
        y2 = np.minimum(boxes[i][3], boxes[:, 3])

        inter = np.clip(x2 - x1, 0, None) * np.clip(y2 - y1, 0, None)
        union = areas[i] + areas - inter
        ious[i] = inter / (union + 1e-6)
    return ious, areas


def filter_boxes_iou(boxes, iou_thresh=0.5):
    """
    Remove background‑like overlapped boxes while keeping smaller, subject‑like ones.
    Args:
        boxes:  (N,4) list/ndarray  [x1,y1,x2,y2]
        iou_thresh: float, IoU above which suppression happens
    Returns:
        kept_boxes: list of kept boxes [x1,y1,x2,y2]
    """
    boxes = np.asarray(boxes, dtype=np.float32)
    ious, areas = iou_matrix(boxes)

    # sort ascending by area
    order = areas.argsort()
    keep, suppressed = [], np.zeros(len(boxes), dtype=bool)

    for idx in order:
        if suppressed[idx]:
            continue
        keep.append(idx)
        # suppress larger boxes with IoU > threshold
        overlaps = (ious[idx] > iou_thresh) & (areas > areas[idx])
        suppressed = suppressed | overlaps
    
    # 返回保留的边界框列表而不是索引
    kept_boxes = boxes[keep].tolist()
    return kept_boxes

def keep_subject_boxes_iou_contain(
    boxes,
    iou_thresh: float | None = 0.5,
    contain_thresh: float | None = 0.7,
):
    """
    Keep small / subject‑like boxes.
    Drop a larger box if:
        1) IoU with a smaller one ≥ iou_thresh      (set None or >=2 to disable)
        2) IoA (inter / smaller_area) ≥ contain_thresh
           (handles large‑contains‑small)            (set None or >=2 to disable)
    Returns:
        tuple: (kept_boxes, kept_indices)
            - kept_boxes: List of kept boxes in original order
            - kept_indices: List of indices of kept boxes in original order
    """
    boxes = np.asarray(boxes, dtype=np.float32)
    iou, inter, areas = iou_inter_matrix(boxes)

    # treat "None" or threshold >= 2  as "turn off this rule"
    use_iou      = (iou_thresh      is not None) and (iou_thresh      < 2.0)
    use_contain  = (contain_thresh  is not None) and (contain_thresh  < 2.0)

    order   = areas.argsort()                    # small → large
    keep    = []
    removed = np.zeros(len(boxes), bool)

    for idx in order:
        if removed[idx]:
            continue
        keep.append(idx)

        larger = (areas > areas[idx]) & (~removed)
        if not larger.any():
            continue

        # ---------------- IoU rule ------------------------------------
        if use_iou:
            cond_iou = (iou[idx] >= iou_thresh) & larger
        else:
            cond_iou = np.zeros_like(larger)

        # ---------------- IoA rule ------------------------------------
        if use_contain:
            ioa           = inter[idx] / (areas[idx] + 1e-6)
            cond_contain  = (ioa >= contain_thresh) & larger
        else:
            cond_contain = np.zeros_like(larger)

        removed |= (cond_iou | cond_contain)

    # 按原始顺序返回保留的boxes和索引
    keep = sorted(keep)  # 对保留的索引排序，以恢复原始顺序
    return boxes[keep].tolist(), keep

def visualize_box_filtering_comparison(
    original_boxes,
    filtered_boxes,
    image_size=(512, 512),
    img=None,
    save_dir="/openseg_blob/zhaoyaqi/Count-FLUX/basecode_flux/kernel_noise/",
    title="Box Filtering Comparison"
):
    """可视化边界框过滤前后的对比
    Args:
        original_boxes: List[[xmin,ymin,xmax,ymax], ...] 原始边界框
        filtered_boxes: List[[xmin,ymin,xmax,ymax], ...] 过滤后的边界框
        image_size: (W,H) 图像尺寸
        img: 可选的背景图像 H×W×3 uint8
        save_dir: 保存路径
    """
    W, H = image_size
    
    # 创建两个子图
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8), dpi=150)
    
    # 设置全局样式
    plt.rcParams.update({
        'figure.facecolor': 'white',
        'axes.facecolor': 'white',
        'axes.grid': False,
        'font.family': 'sans-serif'
    })
    
    # 为所有原始框分配唯一的颜色
    colors = plt.cm.rainbow(np.linspace(0, 1, len(original_boxes)))
    color_dict = {tuple(box): color for box, color in zip(original_boxes, colors)}
    
    # 绘制左侧子图（原始框）
    if img is not None:
        ax1.imshow(img)
    else:
        ax1.imshow(np.ones((H, W, 3), dtype=np.float32))  # 白色背景
    ax1.set_title("Original Boxes", fontsize=13, pad=10, fontweight='bold')
    
    for i, box in enumerate(original_boxes):
        x1, y1, x2, y2 = box
        w, h = x2 - x1, y2 - y1
        color = color_dict[tuple(box)]
        rect = Rectangle((x1, y1), w, h,
                        fill=False, edgecolor=color,
                        linestyle='-', linewidth=2)
        ax1.add_patch(rect)
        # 添加标签
        ax1.text(x1, y1-5, f"Box {i}", color=color, 
                fontsize=10, fontweight='bold',
                bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=1))
    
    # 绘制右侧子图（过滤后的框）
    if img is not None:
        ax2.imshow(img)
    else:
        ax2.imshow(np.ones((H, W, 3), dtype=np.float32))  # 白色背景
    ax2.set_title("Filtered Boxes", fontsize=13, pad=10, fontweight='bold')
    
    # 绘制被过滤掉的框（用虚线表示）
    for i, box in enumerate(original_boxes):
        if not any(np.array_equal(box, fb) for fb in filtered_boxes):
            x1, y1, x2, y2 = box
            w, h = x2 - x1, y2 - y1
            color = color_dict[tuple(box)]
            rect = Rectangle((x1, y1), w, h,
                           fill=False, edgecolor=color,
                           linestyle='--', linewidth=1, alpha=0.5)
            ax2.add_patch(rect)
    
    # 绘制保留的框（用实线表示）
    for i, box in enumerate(filtered_boxes):
        x1, y1, x2, y2 = box
        w, h = x2 - x1, y2 - y1
        color = color_dict[tuple(box)]
        rect = Rectangle((x1, y1), w, h,
                        fill=False, edgecolor=color,
                        linestyle='-', linewidth=2)
        ax2.add_patch(rect)
        # 添加标签
        ax2.text(x1, y1-5, f"Box {i}", color=color, 
                fontsize=10, fontweight='bold',
                bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=1))
    
    # 设置坐标轴
    for ax in [ax1, ax2]:
        ax.set_xlim(0, W)
        ax.set_ylim(H, 0)  # 翻转y轴
        ax.axis('off')
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    
    # 保存图像
    fname = f"{save_dir}/box_filtering_comparison_{time.strftime('%Y%m%d_%H%M%S')}.png"
    plt.savefig(fname, bbox_inches="tight", dpi=150)
    print(f"[Saved] {fname}")
    plt.close()

# Example
if __name__ == "__main__":
    # 示例边界框
    original_boxes = [
        [10, 10, 200, 200],   # 大背景框
        [50, 50, 120, 160],   # 目标1
        [15, 15, 180, 190],   # 另一个大重叠框
        [400, 100, 480, 200]  # 目标2
    ]
    filtered_boxes = keep_subject_boxes_iou_contain(original_boxes, iou_thresh=0.5, contain_thresh=0.7)
    visualize_box_filtering_comparison(original_boxes, filtered_boxes)

