import torch
import torch.nn.functional as F
import numpy as np
import logging
from collections import OrderedDict

from defense.robosac import cal_robosac_consensus, sample_agents
from defense.mdag_grouping import (
    mdag_grouping_strategy,
    extract_positions_from_pairwise_matrix,
    compute_dynamic_threshold,
    UncertaintyAwareAdaptiveThreshold
)

# 获取logger实例（使用root logger，这样可以使用在主文件中配置的handlers）
logger = logging.getLogger()


def _prepare_subset_features_with_grad(cav_content, cav_idx, model, device):
    """Build voxel feature dict for a subset of agents and return their spatial features with gradients."""
    indices_tensor = torch.tensor(cav_idx, device=device)
    voxel_feature_dict = {
        'voxel_features': cav_content['processed_lidar']['voxel_features'],
        'voxel_coords': cav_content['processed_lidar']['voxel_coords'],
        'voxel_num_points': cav_content['processed_lidar']['voxel_num_points'],
        'record_len': torch.tensor([len(cav_idx)], device=device),
        'pairwise_t_matrix': torch.index_select(
            torch.index_select(cav_content['pairwise_t_matrix'][0], dim=0, index=indices_tensor),
            dim=1,
            index=indices_tensor,
        ).unsqueeze(0),
    }
    model.pillar_vfe(voxel_feature_dict)
    model.scatter(voxel_feature_dict)
    voxel_feature_dict['spatial_features'] = torch.index_select(
        voxel_feature_dict['spatial_features'], dim=0, index=indices_tensor
    )
    return voxel_feature_dict


def compute_similarity_score(feature1, feature2):
    """
    Compute cosine similarity between two features. S_L in the paper.
    Returns a value in [-1, 1], where 1 means identical.
    """
    score = F.cosine_similarity(
        feature1.view(1, -1),
        feature2.view(1, -1),
        dim=1,
        eps=1e-6,
    ).item()
    return score


def compute_gradient_consistency(feature_ego, feature_k, loss_fn, model, batch_data):
    """
    【梯度一致性方法】Compute Causal Gradient Consistency G_L.
    
    G_L = <G_0^L, G_k^L> / (||G_0^L|| * ||G_k^L||)
    
    where G_0^L = ∂L/∂A_0^L and G_k^L = ∂L/∂A_k^L
    
    Args:
        feature_ego: ego feature (C, H, W)
        feature_k: either a single feature (C, H, W) or multiple features (N, C, H, W)
                   If multiple features, they will be fused via mean before computing gradient
        batch_data: original batch data containing pairwise_t_matrix (needed for where2comm)
    """
    device = feature_ego.device
    
    # Create identity pairwise_t_matrix for single agent forward pass
    # Shape: (B=1, L=1, L=1, 4, 4) - identity transformation
    identity_t_matrix = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).unsqueeze(0)
    
    # Compute gradient for ego feature
    feature_ego_var = feature_ego.clone().detach().requires_grad_(True)
    
    # Forward pass with ego feature
    # Model requires record_len when spatial_features is provided
    # Some models (e.g., where2comm) also require pairwise_t_matrix
    output_dict_ego = OrderedDict()
    voxel_dict_ego = {
        'spatial_features': feature_ego_var.unsqueeze(0),
        'record_len': torch.tensor([1], device=device),
        'pairwise_t_matrix': identity_t_matrix  # Identity transform for single agent
    }
    output_dict_ego['ego'] = model(voxel_dict_ego)
    
    # Compute loss (using a dummy loss for gradient computation)
    # In practice, you would use the actual task loss
    loss_ego = output_dict_ego['ego']['psm'].mean()
    
    # Compute gradient
    grad_ego = torch.autograd.grad(loss_ego, feature_ego_var, create_graph=False)[0]
    
    # Compute gradient for candidate feature(s)
    # If feature_k has multiple features (N, C, H, W), we need to fuse them with gradient tracking
    if feature_k.dim() == 4:  # Multiple features (N, C, H, W)
        # Clone and enable gradient tracking for all features
        feature_k_var = feature_k.clone().detach().requires_grad_(True)
        # Fuse features with gradient tracking
        fused_feature_k = feature_k_var.mean(dim=0)  # (C, H, W)
    else:  # Single feature (C, H, W)
        fused_feature_k = feature_k.clone().detach().requires_grad_(True)
        feature_k_var = fused_feature_k
    
    output_dict_k = OrderedDict()
    voxel_dict_k = {
        'spatial_features': fused_feature_k.unsqueeze(0),
        'record_len': torch.tensor([1], device=device),
        'pairwise_t_matrix': identity_t_matrix  # Identity transform for single agent
    }
    output_dict_k['ego'] = model(voxel_dict_k)
    
    loss_k = output_dict_k['ego']['psm'].mean()
    # Compute gradient with respect to the original features (before fusion)
    grad_k = torch.autograd.grad(loss_k, feature_k_var, create_graph=False)[0]
    
    # If we had multiple features, grad_k is (N, C, H, W), take mean to get (C, H, W)
    if grad_k.dim() == 4:
        grad_k = grad_k.mean(dim=0)
    
    # Compute cosine similarity between gradients
    grad_similarity = F.cosine_similarity(
        grad_ego.view(1, -1),
        grad_k.view(1, -1),
        dim=1,
        eps=1e-6,
    ).item()
    
    return grad_similarity


def compute_ssim(feature_ego, feature_k):
    """
    【SSIM方法】Compute Structural Similarity Index (SSIM) between features.
    
    SSIM measures the structural similarity between two images/features.
    Returns a value in [-1, 1], where 1 means identical structure.
    
    Args:
        feature_ego: ego feature (C, H, W)
        feature_k: either a single feature (C, H, W) or multiple features (N, C, H, W)
                   If multiple features, they will be fused via mean before computing SSIM
    
    Returns:
        ssim_score: SSIM similarity score in [-1, 1]
    """
    # If feature_k has multiple features, fuse them first
    if feature_k.dim() == 4:  # Multiple features (N, C, H, W)
        fused_feature_k = feature_k.mean(dim=0)  # (C, H, W)
    else:  # Single feature (C, H, W)
        fused_feature_k = feature_k
    
    # SSIM parameters
    C1 = (0.01) ** 2  # Constant for stability
    C2 = (0.03) ** 2
    
    # Reshape features to (1, C, H, W) for processing
    f_ego = feature_ego.unsqueeze(0)  # (1, C, H, W)
    f_k = fused_feature_k.unsqueeze(0)  # (1, C, H, W)
    
    # Compute means using average pooling
    kernel_size = 11
    padding = kernel_size // 2
    
    # Create Gaussian kernel for weighted average (optional, using simple average here)
    mu_ego = F.avg_pool2d(f_ego, kernel_size, stride=1, padding=padding)
    mu_k = F.avg_pool2d(f_k, kernel_size, stride=1, padding=padding)
    
    mu_ego_sq = mu_ego ** 2
    mu_k_sq = mu_k ** 2
    mu_ego_k = mu_ego * mu_k
    
    # Compute variances and covariance
    sigma_ego_sq = F.avg_pool2d(f_ego ** 2, kernel_size, stride=1, padding=padding) - mu_ego_sq
    sigma_k_sq = F.avg_pool2d(f_k ** 2, kernel_size, stride=1, padding=padding) - mu_k_sq
    sigma_ego_k = F.avg_pool2d(f_ego * f_k, kernel_size, stride=1, padding=padding) - mu_ego_k
    
    # Compute SSIM
    numerator = (2 * mu_ego_k + C1) * (2 * sigma_ego_k + C2)
    denominator = (mu_ego_sq + mu_k_sq + C1) * (sigma_ego_sq + sigma_k_sq + C2)
    
    ssim_map = numerator / (denominator + 1e-8)
    
    # Average over all channels and spatial dimensions
    ssim_score = ssim_map.mean().item()
    
    # Clamp to [-1, 1] range for consistency
    ssim_score = max(-1.0, min(1.0, ssim_score))
    
    return ssim_score

def compute_activation_energy_shift(feature_ego, feature_k):
    """
    Compute Activation Energy Shift (Distance Metric).
    Range: [0, 1] (0 = Same, 1 = Very Different)
    """
    # 1. Compute Raw Distance (L1 or L2)
    # L2 usually works better for energy/magnitude shifts
    energy_dist = torch.norm(feature_k - feature_ego, p=2).item()
    
    # 2. Normalize using a temperature/scale factor (sigma)
    # sigma 控制敏感度。如果特征维度很大，距离通常很大，sigma 需要设大一点。
    # 经验法则：sigma ≈ sqrt(num_channels * H * W) 或者通过统计得出
    sigma = np.sqrt(feature_ego.numel()) # 这是一个自适应的 sigma 示例
    
    # 3. Use Tanh or 1-Exp to map to [0, 1] non-linearly
    # 这种方式不需要假设特征的最大值是 2.0
    energy_shift_normalized = 1.0 - np.exp(- energy_dist / sigma)
    
    # Debug output
    # logger.info(...)
    logger.info(f"[DEBUG energy_shift] raw={energy_dist:.6f}, sigma={sigma:.6f}, normalized={energy_shift_normalized:.6f}")
    
    return energy_shift_normalized

# def compute_activation_energy_shift(feature_ego, feature_k):
#     """
#     Compute Activation Energy Shift E_L.
    
#     E_L = ||A_k^L - A_0^L||_1
    
#     Normalized to [0, 1] range.
#     """
#     # Compute L1 distance
#     energy_shift = torch.norm(feature_k - feature_ego, p=1).item()
    
#     # 直接除以理论最大值
#     # 假设特征值范围在 [-1, 1] 或已标准化，最大差异为 2
#     max_possible_shift = 2.0 * feature_ego.numel()
#     energy_shift_normalized = min(1.0, energy_shift / max_possible_shift)
    
#     # Debug output
#     logger.info(f"[DEBUG energy_shift] raw={energy_shift:.6f}, "
#                 f"max_possible={max_possible_shift:.0f}, "
#                 f"normalized={energy_shift_normalized:.6f}")
    
#     return energy_shift_normalized



def compute_group_cps(cav_content, group_indices, reference_feature, model, device,
                      attacker_idx, perturbation, lambda1, lambda2, lambda3, compute_gradients, batch_data,
                      use_ssim=False):
    """
    计算一个组的 CPS 分数（组内不包含 ego）
    
    Args:
        group_indices: 组内车辆索引列表（不包含 ego）
        reference_feature: ego-only 的参考特征
        use_ssim: 是否使用SSIM代替梯度一致性 (True=使用SSIM, False=使用梯度一致性)
        
    Returns:
        cps_score: 该组的 CPS 分数
        fused_feature: 融合后的特征（包含 ego）
    """
    if len(group_indices) == 0:
        return float('inf'), None
    
    # 只准备被检测车辆的特征（不包含 ego）
    voxel_feature_dict = _prepare_subset_features_with_grad(cav_content, group_indices, model, device)
    
    # Apply perturbation if attacker is in the subset
    if attacker_idx in group_indices and perturbation is not None:
        local_idx = group_indices.index(attacker_idx)
        voxel_feature_dict['spatial_features'][local_idx] = perturbation
    
    # 融合被检测车辆的特征（不包括 ego）
    candidate_features = voxel_feature_dict['spatial_features']  # Shape: (num_candidates, C, H, W)
    fused_feature = candidate_features.mean(dim=0)  # Shape: (C, H, W)
    
    # 计算 CPS 分数：比较被检测车辆融合后的特征和 ego 特征
    similarity_score = compute_similarity_score(fused_feature, reference_feature)
    
    if compute_gradients:
        try:
            if use_ssim:
                # 【使用SSIM方法】计算结构相似性
                logger.info(f"[CPS] 使用SSIM方法计算特征一致性")
                gradient_consistency = compute_ssim(reference_feature, candidate_features)
            else:
                # 【使用梯度一致性方法】计算梯度一致性
                logger.info(f"[CPS] 使用梯度一致性方法")
                gradient_consistency = compute_gradient_consistency(
                    reference_feature, candidate_features, None, model, batch_data
                )
        except Exception as e:
            logger.warning(f"[WARNING] Gradient/SSIM computation failed: {e}")
            gradient_consistency = 0.0
    else:
        gradient_consistency = 0.0
    
    # 计算能量偏移：比较被检测车辆融合后的特征和 ego 特征
    energy_shift_normalized = compute_activation_energy_shift(reference_feature, fused_feature)
    
    logger.info(f"[DEBUG CPS] before normalization: similarity_score={similarity_score:.6f}, gradient_consistency={gradient_consistency:.6f}, energy_shift_normalized={energy_shift_normalized:.6f}")
      # 2. 归一化到[0,1]
    similarity_score = (similarity_score + 1.0) / 2.0
    gradient_score = (gradient_consistency + 1.0) / 2.0
    logger.info(f"[DEBUG CPS] after normalization: similarity_score={similarity_score:.6f}, gradient_score={gradient_score:.6f}")
    # 3. 确保数值安全
    similarity_score = max(0.0, min(1.0, similarity_score))
    gradient_score = max(0.0, min(1.0, gradient_score))
    logger.info(f"[DEBUG CPS] after clamping: similarity_score={similarity_score:.6f}, gradient_score={gradient_score:.6f}")
    # Compute CPS components
    sim_component = lambda1 * (1.0 - similarity_score)
    grad_component = lambda2 * (1.0 - gradient_score)
    energy_component = lambda3 * energy_shift_normalized
    
    cps_score = (sim_component + grad_component + energy_component) / (lambda1+lambda2+lambda3)
    
    # Debug output (can be removed later)
    logger.info(f"[DEBUG CPS] sim_score={similarity_score:.6f}, grad_cons={gradient_score:.6f}, "
          f"energy_norm={energy_shift_normalized:.6f}")
    logger.info(f"[DEBUG CPS] components: sim={sim_component:.6f}, grad={grad_component:.6f}, "
          f"energy={energy_component:.6f}, total={cps_score:.6f}")
    
    return cps_score, fused_feature


def recursive_group_screening(cav_content, candidate_agents, reference_feature, model, device,
                              attacker_idx, perturbation, lambda1, lambda2, lambda3, tau,
                              compute_gradients, batch_data, diversity_matrix, max_depth=4, depth=0,
                              use_ssim=False):
    """
    递归分组筛查：对候选车辆进行分组，如果组可疑则继续分组，直到找出所有良性车辆
    
    Args:
        candidate_agents: 待筛查的车辆索引列表（不包含 ego）
        reference_feature: ego-only 的参考特征
        max_depth: 最大递归深度
        depth: 当前递归深度
        use_ssim: 是否使用SSIM代替梯度一致性
        
    Returns:
        benign_agents: 筛查出的良性车辆列表
    """
    if len(candidate_agents) == 0:
        return []
    
    if len(candidate_agents) == 1:
        # 只剩一个车辆，直接判断
        cps_score, _ = compute_group_cps(
            cav_content, candidate_agents, reference_feature, model, device,
            attacker_idx, perturbation, lambda1, lambda2, lambda3, compute_gradients, batch_data,
            use_ssim=use_ssim
        )
        if cps_score < tau:
            logger.info(f"[MDAG] Depth {depth}: Agent {candidate_agents[0]} is benign (CPS: {cps_score:.4f})")
            return candidate_agents
        else:
            logger.info(f"[MDAG] Depth {depth}: Agent {candidate_agents[0]} is suspicious (CPS: {cps_score:.4f})")
            return []
    
    if depth >= max_depth:
        # 达到最大深度，不再递归，直接判断整个组
        cps_score, _ = compute_group_cps(
            cav_content, candidate_agents, reference_feature, model, device,
            attacker_idx, perturbation, lambda1, lambda2, lambda3, compute_gradients, batch_data,
            use_ssim=use_ssim
        )
        if cps_score < tau:
            logger.info(f"[MDAG] Depth {depth}: Group {candidate_agents} is benign (CPS: {cps_score:.4f})")
            return candidate_agents
        else:
            logger.info(f"[MDAG] Depth {depth}: Group {candidate_agents} is suspicious (CPS: {cps_score:.4f})")
            return []
    
    # 对候选车辆进行分组（不包含 ego）
    from defense.mdag_grouping import isomorphic_viewpoint_grouping
    
    # 创建子矩阵（只包含候选车辆和 ego）
    sub_indices = [0] + candidate_agents  # 包含 ego 用于计算视角差异
    sub_diversity = diversity_matrix[sub_indices, :][:, sub_indices]
    
    # 分组（排除 ego，只对候选车辆分组）
    # 注意：isomorphic_viewpoint_grouping 返回的索引是相对于 sub_indices 的
    group1_sub, group2_sub = isomorphic_viewpoint_grouping(
        sub_diversity, len(sub_indices), ego_idx=0, exclude_ego=True
    )
    
    # 映射回原始索引：sub_indices[1:] 是 candidate_agents
    # group1_sub 和 group2_sub 中的索引是相对于 sub_indices 的（0=ego, 1+=candidate_agents）
    group1 = [sub_indices[i] for i in group1_sub if i > 0 and i < len(sub_indices)]
    group2 = [sub_indices[i] for i in group2_sub if i > 0 and i < len(sub_indices)]
    
    logger.info(f"[MDAG] Depth {depth}: Grouping {candidate_agents} -> Group1: {group1}, Group2: {group2}")
    
    benign_agents = []
    
    # 测试 Group 1
    if len(group1) > 0:
        cps_score, _ = compute_group_cps(
            cav_content, group1, reference_feature, model, device,
            attacker_idx, perturbation, lambda1, lambda2, lambda3, compute_gradients, batch_data,
            use_ssim=use_ssim
        )
        logger.info(f"[MDAG] Depth {depth}: Group1 {group1} CPS: {cps_score:.4f} (threshold: {tau:.4f})")
        
        if cps_score < tau:
            # Group 1 是良性的
            logger.info(f"[MDAG] Depth {depth}: Group1 {group1} is benign, adding to benign list")
            benign_agents.extend(group1)
        else:
            # Group 1 可疑，继续递归分组
            logger.info(f"[MDAG] Depth {depth}: Group1 {group1} is suspicious, recursing...")
            if len(group1) > 1:
                benign_agents.extend(recursive_group_screening(
                    cav_content, group1, reference_feature, model, device,
                    attacker_idx, perturbation, lambda1, lambda2, lambda3, tau,
                    compute_gradients, batch_data, diversity_matrix, max_depth, depth + 1,
                    use_ssim=use_ssim
                ))
            else:
                logger.info(f"[MDAG] Depth {depth}: Group1 {group1} is suspicious, but only one vehicle, skipping...")
    
    # 测试 Group 2
    if len(group2) > 0:
        cps_score, _ = compute_group_cps(
            cav_content, group2, reference_feature, model, device,
            attacker_idx, perturbation, lambda1, lambda2, lambda3, compute_gradients, batch_data,
            use_ssim=use_ssim
        )
        logger.info(f"[MDAG] Depth {depth}: Group2 {group2} CPS: {cps_score:.4f} (threshold: {tau:.4f})")
        
        if cps_score < tau:
            # Group 2 是良性的
            logger.info(f"[MDAG] Depth {depth}: Group2 {group2} is benign, adding to benign list")
            benign_agents.extend(group2)
        else:
            # Group 2 可疑，继续递归分组
            logger.info(f"[MDAG] Depth {depth}: Group2 {group2} is suspicious, recursing...")
            if len(group2) > 1:
                benign_agents.extend(recursive_group_screening(
                    cav_content, group2, reference_feature, model, device,
                    attacker_idx, perturbation, lambda1, lambda2, lambda3, tau,
                    compute_gradients, batch_data, diversity_matrix, max_depth, depth + 1,
                    use_ssim=use_ssim
                ))
            else:
                logger.info(f"[MDAG] Depth {depth}: Group2 {group2} is suspicious, but only one vehicle, skipping...")
    
    return benign_agents


def cps_defense(batch_data, model, dataset, perturbation, attacker_idx=1, sampling_budget=10,
                lambda1=1.0, lambda2=1.0, lambda3=1.0, tau=0.68, compute_gradients=False,
                use_mdag=False, use_dynamic_threshold=False, threshold_sensitivity=1.0,
                threshold_calculator=None, use_ssim=False):
    """
    Comprehensive Protection Score (CPS) defense combining three metrics:
    1. Feature similarity (S_L): cosine similarity between features
    2. Gradient consistency (G_L): causal gradient consistency OR SSIM structural similarity
    3. Activation energy shift (E_L): L1 distance between features
    
    CPS(k) = Σ_L α_L (λ1(1 - S_L) + λ2(1 - G_L) + λ3 E_L)
    
    Args:
        batch_data: Input batch data
        model: The cooperative perception model
        dataset: Dataset for post-processing
        perturbation: Adversarial perturbation
        attacker_idx: Index of the attacker
        sampling_budget: Number of sampling iterations
        lambda1: Weight for similarity score
        lambda2: Weight for gradient consistency (or SSIM)
        lambda3: Weight for activation energy shift
        tau: Threshold for CPS score (used if use_dynamic_threshold=False)
        compute_gradients: Whether to compute gradient consistency or SSIM (expensive)
        use_mdag: Whether to use MDAG grouping instead of RoboSAC sampling
        use_dynamic_threshold: Whether to use dynamic adaptive threshold
        threshold_sensitivity: Sensitivity factor for dynamic threshold computation
        threshold_calculator: AdaptiveThresholdCalculator instance for temporal tracking
        use_ssim: 是否使用SSIM代替梯度一致性 (True=使用SSIM, False=使用梯度一致性)
    
    Returns:
        pred_box_tensor, pred_score, gt_box_tensor, best_cps_score, defense_info
        where defense_info is a dict containing:
            - 'threshold': dynamic threshold used
            - 'benign_agents': list of benign agent indices
            - 'malicious_agents': list of malicious agent indices
    """
    cav_content = batch_data['ego']
    agent_num = cav_content['cav_num']
    device = perturbation.device if perturbation is not None else cav_content['processed_lidar']['voxel_features'].device

    # Reference feature using only the ego vehicle (A0)
    base_voxel_dict = _prepare_subset_features_with_grad(cav_content, [0], model, device)
    reference_feature = base_voxel_dict['spatial_features'][0]
    
    # Initialize defense info dictionary
    defense_info = {
        'threshold': tau,  # Will be updated if dynamic threshold is used
        'benign_agents': [],
        'malicious_agents': [],
        'mu': None,  # Mean of CPS scores in the window
        'sigma': None  # Standard deviation of CPS scores in the window
    }
    is_ego_only = False

    # Determine grouping strategy: MDAG or RoboSAC
    if use_mdag:
        # MDAG grouping strategy with recursive screening
        # First, get all features for position and feature extraction
        all_cav_idx = list(range(agent_num))
        all_voxel_dict = _prepare_subset_features_with_grad(cav_content, all_cav_idx, model, device)
        all_features = all_voxel_dict['spatial_features']
        
        # Extract positions from pairwise transformation matrix
        # Only use the first agent_num positions (actual number of agents)
        pairwise_t_matrix = cav_content['pairwise_t_matrix'][0][:agent_num, :agent_num]
        positions = extract_positions_from_pairwise_matrix(pairwise_t_matrix)
        
        # Compute viewpoint diversity matrix
        from defense.mdag_grouping import compute_viewpoint_diversity
        diversity_matrix = compute_viewpoint_diversity(positions)
        
        # Get current threshold for this frame
        # If using dynamic threshold and calculator exists, use its current threshold
        # Otherwise use base threshold
        if use_dynamic_threshold and threshold_calculator is not None:
            # Check if using UA-AT or original calculator
            if isinstance(threshold_calculator, UncertaintyAwareAdaptiveThreshold):
                # UA-AT: use T_final if available, otherwise base_threshold
                tau = threshold_calculator.T_final if threshold_calculator.T_final is not None else threshold_calculator.base_threshold
                stats = threshold_calculator.get_statistics()
                defense_info['mu'] = stats.get('mu_t')
                defense_info['sigma'] = stats.get('sigma_t')
                defense_info['T_stat'] = stats.get('T_stat')
                defense_info['U_ego'] = stats.get('U_ego')
                logger.info(f"[MDAG UA-AT] Using threshold: {tau:.4f}, U_ego={stats.get('U_ego')}")
            else:
                # Original calculator
                tau = threshold_calculator.current_threshold
                stats = threshold_calculator.get_statistics()
                defense_info['mu'] = stats['mu']
                defense_info['sigma'] = stats['sigma']
                logger.info(f"[MDAG] Using dynamic threshold from calculator: {tau:.4f}, mu={stats['mu']}, sigma={stats['sigma']}")
        else:
            # Use base threshold for first frame or when dynamic threshold is disabled
            logger.info(f"[MDAG] Using base threshold: {tau:.4f}")
            # When not using dynamic threshold, mu and sigma are None
            defense_info['mu'] = None
            defense_info['sigma'] = None
        
        # Update defense info with the threshold used
        defense_info['threshold'] = tau
        
        logger.info(f"[MDAG] Agent num: {agent_num}, Positions shape: {positions.shape}, Features shape: {all_features.shape}")
        
        # 获取所有非 ego 车辆作为初始候选
        candidate_agents = list(range(1, agent_num))  # 不包含 ego (index 0)
        
        logger.info(f"[MDAG] Starting recursive screening for agents: {candidate_agents}")
        if use_ssim:
            logger.info(f"[MDAG] 【使用SSIM方法】进行特征一致性计算")
        else:
            logger.info(f"[MDAG] 【使用梯度一致性方法】进行特征一致性计算")
        
        # 递归分组筛查，找出所有良性车辆
        benign_agents = recursive_group_screening(
            cav_content, candidate_agents, reference_feature, model, device,
            attacker_idx, perturbation, lambda1, lambda2, lambda3, tau,
            compute_gradients, batch_data, diversity_matrix, max_depth=3, depth=0,
            use_ssim=use_ssim
        )
        
        logger.info(f"[MDAG] Final benign agents: {benign_agents}")
        
        # 记录良性和恶意车辆
        defense_info['benign_agents'] = benign_agents
        # 恶意车辆 = 所有候选车辆 - 良性车辆
        defense_info['malicious_agents'] = [agent for agent in candidate_agents if agent not in benign_agents]
        logger.info(f"[MDAG] Malicious agents: {defense_info['malicious_agents']}")
        
        # 最终融合：ego + 所有筛查出的良性车辆（用于最终检测）
        # 但是计算 CPS score 时，只比较良性车辆和 ego
        if len(benign_agents) > 0:
            final_group = [0] + benign_agents  # ego + 良性车辆（用于最终融合检测）
            logger.info(f"[MDAG] Final fusion group for detection: {final_group}")
            
            # 准备良性车辆的特征（不包含 ego）用于 CPS 计算
            benign_voxel_dict = _prepare_subset_features_with_grad(cav_content, benign_agents, model, device)
            
            # Apply perturbation if attacker is in benign agents
            if attacker_idx in benign_agents and perturbation is not None:
                local_idx = benign_agents.index(attacker_idx)
                benign_voxel_dict['spatial_features'][local_idx] = perturbation
            
            # Compute final CPS score: 比较良性车辆融合后的特征和 ego 特征
            benign_features = benign_voxel_dict['spatial_features']  # Shape: (num_benign, C, H, W)
            fused_benign_feature = benign_features.mean(dim=0)  # Shape: (C, H, W)
            similarity_score = compute_similarity_score(fused_benign_feature, reference_feature)
            
            if compute_gradients:
                try:
                    if use_ssim:
                        # 【使用SSIM方法】计算结构相似性
                        logger.info(f"[MDAG Final] 使用SSIM方法计算特征一致性")
                        gradient_consistency = compute_ssim(reference_feature, benign_features)
                    else:
                        # 【使用梯度一致性方法】计算梯度一致性
                        logger.info(f"[MDAG Final] 使用梯度一致性方法")
                        gradient_consistency = compute_gradient_consistency(
                            reference_feature, benign_features, None, model, batch_data
                        )
                except Exception as e:
                    logger.warning(f"[WARNING] Gradient/SSIM computation failed: {e}")
                    gradient_consistency = 0.0
            else:
                gradient_consistency = 0.0
            
            # 计算能量偏移：比较良性车辆融合后的特征和 ego 特征
            energy_shift_normalized = compute_activation_energy_shift(reference_feature, fused_benign_feature)
            
            # 准备最终融合用的特征（ego + 良性车辆）
            final_voxel_dict = _prepare_subset_features_with_grad(cav_content, final_group, model, device)
            if attacker_idx in final_group and perturbation is not None:
                local_idx = final_group.index(attacker_idx)
                final_voxel_dict['spatial_features'][local_idx] = perturbation
            best_subset = final_voxel_dict
            

            similarity_score = (similarity_score + 1.0) / 2.0
            gradient_score = (gradient_consistency + 1.0) / 2.0
            # print(f"[DEBUG CPS] after normalization: similarity_score={similarity_score:.6f}, gradient_score={gradient_score:.6f}")
            # 3. 确保数值安全
            similarity_score = max(0.0, min(1.0, similarity_score))
            gradient_score = max(0.0, min(1.0, gradient_score))
            # print(f"[DEBUG CPS] after clamping: similarity_score={similarity_score:.6f}, gradient_score={gradient_score:.6f}")
            # Compute CPS components
            sim_component = lambda1 * (1.0 - similarity_score)
            grad_component = lambda2 * (1.0 - gradient_score)
            energy_component = lambda3 * energy_shift_normalized
            
            best_cps_score = (sim_component + grad_component + energy_component) / (lambda1+lambda2+lambda3)
            
            logger.info(f"[MDAG] Final fusion - sim={similarity_score:.4f}, grad={gradient_score:.4f}, energy={energy_shift_normalized:.4f}")
            logger.info(f"[MDAG] Final fusion - components: sim={sim_component:.4f}, grad={grad_component:.4f}, energy={energy_component:.4f}")
            logger.info(f"[MDAG] Final fusion CPS score: {best_cps_score:.4f}")
        else:
            # 没有找到良性车辆，使用 ego-only
            logger.info(f"[MDAG] No benign agents found, using ego-only")
            best_subset = base_voxel_dict
            best_cps_score = 0.0
            is_ego_only = True
        
        # New recursive screening logic is complete, skip to output
        pass
    else:
        # Original RoboSAC sampling strategy
        # Get current threshold for this frame
        if use_dynamic_threshold and threshold_calculator is not None:
            # Check if using UA-AT or original calculator
            if isinstance(threshold_calculator, UncertaintyAwareAdaptiveThreshold):
                # UA-AT: use T_final if available, otherwise base_threshold
                tau = threshold_calculator.T_final if threshold_calculator.T_final is not None else threshold_calculator.base_threshold
                stats = threshold_calculator.get_statistics()
                defense_info['mu'] = stats.get('mu_t')
                defense_info['sigma'] = stats.get('sigma_t')
                defense_info['T_stat'] = stats.get('T_stat')
                defense_info['U_ego'] = stats.get('U_ego')
                logger.info(f"[RoboSAC UA-AT] Using threshold: {tau:.4f}, U_ego={stats.get('U_ego')}")
            else:
                # Original calculator
                tau = threshold_calculator.current_threshold
                stats = threshold_calculator.get_statistics()
                defense_info['mu'] = stats['mu']
                defense_info['sigma'] = stats['sigma']
                logger.info(f"[RoboSAC] Using dynamic threshold from calculator: {tau:.4f}, mu={stats['mu']}, sigma={stats['sigma']}")
        else:
            # Use base threshold for first frame or when dynamic threshold is disabled
            logger.info(f"[RoboSAC] Using base threshold: {tau:.4f}")
            # When not using dynamic threshold, mu and sigma are None
            defense_info['mu'] = None
            defense_info['sigma'] = None
        
        # Update defense info with the threshold used
        defense_info['threshold'] = tau
        
        if use_ssim:
            logger.info(f"[RoboSAC] 【使用SSIM方法】进行特征一致性计算")
        else:
            logger.info(f"[RoboSAC] 【使用梯度一致性方法】进行特征一致性计算")
        
        s = cal_robosac_consensus(agent_num, sampling_budget, num_attackers=1)
        s = max(1, min(s, agent_num - 1))
        
        # Initialize with ego-only as baseline (safe fallback)
        best_subset = base_voxel_dict
        best_cps_score = 0.0  # Ego compared to itself has CPS = 0
        best_cav_idx = [0]  # Track the best agent combination

        iterations = 0
        while iterations < sampling_budget and agent_num > 1:
            iterations += 1
            sampled_agents = sample_agents(agent_num, s)
            cav_idx = sorted(set([0] + sampled_agents))

            voxel_feature_dict = _prepare_subset_features_with_grad(cav_content, cav_idx, model, device)

            # Apply perturbation if attacker is in the subset
            if attacker_idx in cav_idx and perturbation is not None:
                local_idx = cav_idx.index(attacker_idx)
                voxel_feature_dict['spatial_features'][local_idx] = perturbation

            # Compute CPS score for each layer (here we use single layer for simplicity)
            # In multi-layer case, you would sum over all layers with α_L weights
            
            fused_feature = voxel_feature_dict['spatial_features'].mean(dim=0)
            
            # 1. Similarity score (S_L): higher is more similar
            similarity_score = compute_similarity_score(fused_feature, reference_feature)
            
            # 2. Gradient consistency (G_L) or SSIM: higher is more consistent
            if compute_gradients:
                try:
                    if use_ssim:
                        # 【使用SSIM方法】计算结构相似性
                        gradient_consistency = compute_ssim(reference_feature, fused_feature)
                    else:
                        # 【使用梯度一致性方法】
                        gradient_consistency = compute_gradient_consistency(
                            reference_feature, fused_feature, None, model, batch_data
                        )
                except:
                    # Fallback if gradient/SSIM computation fails
                    gradient_consistency = 0.0
            else:
                # Skip gradient/SSIM computation for efficiency
                gradient_consistency = 0.0
            
            # 3. Activation energy shift (E_L): lower is more similar
            energy_shift_normalized = compute_activation_energy_shift(reference_feature, fused_feature)
            
    
            logger.info(f"[DEBUG CPS] before normalization: similarity_score={similarity_score:.6f}, gradient_consistency={gradient_consistency:.6f}, energy_shift_normalized={energy_shift_normalized:.6f}")
            # 2. 归一化到[0,1]
            similarity_score = (similarity_score + 1.0) / 2.0
            gradient_score = (gradient_consistency + 1.0) / 2.0
            logger.info(f"[DEBUG CPS] after normalization: similarity_score={similarity_score:.6f}, gradient_score={gradient_score:.6f}")
            # 3. 确保数值安全
            similarity_score = max(0.0, min(1.0, similarity_score))
            gradient_score = max(0.0, min(1.0, gradient_score))
            logger.info(f"[DEBUG CPS] after clamping: similarity_score={similarity_score:.6f}, gradient_score={gradient_score:.6f}")
        
            # Compute CPS score (lower is better)
            # CPS(k) = λ1(1 - S_L) + λ2(1 - G_L) + λ3 E_L
            # Note: We want LOW CPS for trustworthy agents
            sim_component = lambda1 * (1.0 - similarity_score)
            grad_component = lambda2 * (1.0 - gradient_score)
            energy_component = lambda3 * energy_shift_normalized
            
            cps_score = (sim_component + grad_component + energy_component) / (lambda1+lambda2+lambda3)
    
            logger.info(f"[RoboSAC] Iteration {iterations} - agents {cav_idx}")
            logger.info(f"[RoboSAC] sim={similarity_score:.4f}, grad={gradient_score:.4f}, energy={energy_shift_normalized:.4f}")
            logger.info(f"[RoboSAC] components: sim={sim_component:.4f}, grad={grad_component:.4f}, energy={energy_component:.4f}, total={cps_score:.4f}")
            
            # Select subset with lowest CPS score (most trustworthy)
            # Only replace ego-only baseline if this subset is significantly better
            if cps_score < best_cps_score:
                best_cps_score = cps_score
                best_subset = voxel_feature_dict
                best_cav_idx = cav_idx  # Track the best agent combination
            
            # If CPS score exceeds threshold tau, this subset is suspicious
            # We keep the ego-only baseline in that case
        
        # Update defense_info for RoboSAC
        # best_cav_idx contains the selected agents (including ego at index 0)
        # benign agents are all non-ego agents in best_cav_idx
        defense_info['benign_agents'] = [idx for idx in best_cav_idx if idx != 0]
        # malicious agents are all other non-ego agents
        all_agents = list(range(1, agent_num))
        defense_info['malicious_agents'] = [idx for idx in all_agents if idx not in best_cav_idx]
        
        # 检查是否为ego-only
        if len(best_cav_idx) == 1 and best_cav_idx[0] == 0:
            is_ego_only = True
            best_cps_score = 0.0
            logger.info(f"[RoboSAC] Fallback to ego-only (CPS will be skipped in statistics)")
        else:
            is_ego_only = False
        
        logger.info(f"[RoboSAC] Final selected agents: {best_cav_idx}")
        logger.info(f"[RoboSAC] Benign agents: {defense_info['benign_agents']}")
        logger.info(f"[RoboSAC] Malicious agents: {defense_info['malicious_agents']}")

    # Check if best CPS score exceeds threshold
    # If CPS(k) > tau, the k-th agent is suspicious
    # Fallback to ego-only if no good subset found or all subsets are suspicious
    if best_subset is None or best_cps_score > tau:
        best_subset = base_voxel_dict
        best_cps_score = 0.0  # Ego-only baseline
        is_ego_only = True  # Mark as ego-only fallback
    
    # Update threshold calculator with final fusion CPS score for next frame
    if use_dynamic_threshold and threshold_calculator is not None:
        try:
            # Check if using UA-AT or original calculator
            is_ua_at = isinstance(threshold_calculator, UncertaintyAwareAdaptiveThreshold)
            
            if is_ua_at:
                # UA-AT requires ego_features parameter
                # Use reference_feature which is the ego feature (already computed at the beginning)
                # reference_feature shape: [C, H, W]
                if reference_feature is None:
                    logger.error("[CPS Defense UA-AT] ERROR: reference_feature is None, cannot update UA-AT")
                else:
                    ego_features = reference_feature
                    
                    # Check if malicious agents were detected
                    detected_malicious = len(defense_info.get('malicious_agents', [])) > 0
                    
                    if is_ego_only:
                        # ego-only case: skip CPS=0 from statistics
                        threshold_calculator.update(
                            current_cps_scores=[],
                            ego_features=ego_features,
                            is_ego_only=True,
                            detected_malicious=False,
                            current_threshold=defense_info.get('threshold', tau)
                        )
                        logger.info(f"[CPS Defense UA-AT] Skipped updating (ego-only)")
                    else:
                        # Normal case: update with CPS score and ego features
                        # 传入当前阈值，用于判断CPS是否应该加入良性窗口（CPS越大越恶意）
                        threshold_calculator.update(
                            current_cps_scores=[best_cps_score],
                            ego_features=ego_features,
                            is_ego_only=False,
                            detected_malicious=detected_malicious,
                            current_threshold=defense_info.get('threshold', tau)
                        )
                        # Update defense_info with latest statistics
                        stats = threshold_calculator.get_statistics()
                        defense_info['mu'] = stats.get('mu_t')
                        defense_info['sigma'] = stats.get('sigma_t')
                        defense_info['T_stat'] = stats.get('T_stat')
                        defense_info['U_ego'] = stats.get('U_ego')
                        defense_info['T_final'] = stats.get('T_final')
                        logger.info(f"[CPS Defense UA-AT] Updated: CPS={best_cps_score:.4f}, "
                                   f"U_ego={stats.get('U_ego'):.4f}, T_final={stats.get('T_final'):.4f}")
            else:
                # Original AdaptiveThresholdCalculator
                if is_ego_only:
                    threshold_calculator.update([], is_ego_only=True)
                    logger.info(f"[CPS Defense] Skipped updating threshold calculator (ego-only)")
                else:
                    threshold_calculator.update([best_cps_score], is_ego_only=False)
                    # Update defense_info with latest statistics after update
                    stats = threshold_calculator.get_statistics()
                    defense_info['mu'] = stats['mu']
                    defense_info['sigma'] = stats['sigma']
                    logger.info(f"[CPS Defense] Updated threshold calculator with final CPS score: {best_cps_score:.4f}, mu={stats['mu']}, sigma={stats['sigma']}")
        except Exception as e:
            logger.error(f"[CPS Defense] ERROR updating threshold calculator: {e}")
            logger.error(f"[CPS Defense] Threshold calculator type: {type(threshold_calculator)}")
            logger.error(f"[CPS Defense] Reference feature type: {type(reference_feature) if 'reference_feature' in locals() else 'Not defined'}")
            import traceback
            logger.error(traceback.format_exc())
    
    output_dict = OrderedDict()
    output_dict['ego'] = model(best_subset)

    pred_box_tensor, pred_score, gt_box_tensor = dataset.post_process(batch_data, output_dict)

    return pred_box_tensor, pred_score, gt_box_tensor, best_cps_score, defense_info


def cps_defense_multilayer(batch_data, model, dataset, perturbation, attacker_idx=1, 
                           sampling_budget=10, lambda1=1.0, lambda2=1.0, lambda3=1.0, 
                           tau=0.5, layer_weights=None):
    """
    Multi-layer version of CPS defense.
    
    CPS(k) = Σ_L α_L (λ1(1 - S_L) + λ2(1 - G_L) + λ3 E_L)
    
    Args:
        layer_weights: Dict mapping layer names to weights α_L
    """
    # TODO: Implement multi-layer version if needed
    # This would require extracting features from multiple layers
    # and computing CPS score for each layer with corresponding weights
    pass

