"""
MDAG (Maximum Geometric Tension) Grouping Strategy
基于视角差异度的正交对抗分组策略

Implementation based on the paper's methodology:
1. Spatial Grid & Candidate Pool
2. Viewpoint Diversity Matrix
3. Orthogonal Adversarial Grouping
4. Fusion & Discrepancy Amplification
"""

import torch
import numpy as np
import torch.nn.functional as F
import logging
from typing import List, Tuple, Dict

# 获取logger实例（使用root logger，这样可以使用在主文件中配置的handlers）
logger = logging.getLogger()


def compute_viewpoint_diversity(positions: torch.Tensor, center_point: torch.Tensor = None) -> torch.Tensor:
    """
    计算视角差异度矩阵 (Viewpoint Diversity Matrix)
    
    对于列表 L_i 中的任意两辆车 u 和 v，计算它们对网格中心点 P_i 的视角差异度 θ_u,v:
    θ_u,v = arccos((V_u→P · V_v→P) / (|V_u→P||V_v→P|))
    
    Args:
        positions: Tensor of shape (n, 2) containing (x, y) positions of vehicles
        center_point: Tensor of shape (2,) containing the center point P. 
                     If None, uses the centroid of all vehicle positions.
        
    Returns:
        diversity_matrix: Tensor of shape (n, n) containing pairwise viewpoint diversity angles
    """
    n = positions.shape[0]
    diversity_matrix = torch.zeros(n, n, device=positions.device)
    
    # If no center point specified, use the centroid of all positions
    if center_point is None:
        center_point = positions.mean(dim=0)
    
    # Compute vectors from each vehicle TO the center point
    vectors_to_center = center_point.unsqueeze(0) - positions  # Shape: (n, 2)
    
    # Normalize vectors
    norms = torch.norm(vectors_to_center, dim=1, keepdim=True)  # Shape: (n, 1)
    vectors_to_center_normalized = vectors_to_center / (norms + 1e-6)  # Shape: (n, 2)
    
    # Compute pairwise angles between viewing vectors
    for i in range(n):
        for j in range(i + 1, n):
            # Get normalized vectors from vehicle i and j to center point
            vec_i = vectors_to_center_normalized[i]  # Shape: (2,)
            vec_j = vectors_to_center_normalized[j]  # Shape: (2,)
            
            # Check if either vehicle is too close to center (degenerate case)
            if norms[i] > 1e-6 and norms[j] > 1e-6:
                # Compute cosine of angle between the two viewing directions
                cos_theta = torch.dot(vec_i, vec_j)
                cos_theta = torch.clamp(cos_theta, -1.0, 1.0)
                
                # Compute angle in radians
                theta = torch.acos(cos_theta)
                
                diversity_matrix[i, j] = theta
                diversity_matrix[j, i] = theta
    
    return diversity_matrix


def isomorphic_viewpoint_grouping(diversity_matrix: torch.Tensor, agent_num: int, 
                                ego_idx: int = 0, exclude_ego: bool = True) -> Tuple[List[int], List[int]]:
    """
    同构视野分组 (Isomorphic Viewpoint Grouping, IVG)
    [ICML/CVPR Recommended Strategy for Defense]

    Paper Motivation:
    "To distinguish between 'true blind spots' and 'malicious attacks', we must ensure 
    that the two verification groups possess statistically equivalent Fields of View (FoV). 
    Random or Maximum-Diversity grouping creates coverage imbalance, leading to false positives."

    Logic:
    1. Sort all neighbors based on their angular position relative to the Ego.
    2. Perform Interleaved Assignment (Stratified Sampling).
    
    Args:
        diversity_matrix: (N, N) Tensor, diversity_matrix[i, j] is the angle diff.
        agent_num: Total number of agents in the scene.
        ego_idx: Index of the ego vehicle.
        exclude_ego: If True, Ego acts as the verifier and is not in any group.
        
    Returns:
        group1: List[int], Indices for Verification Group A.
        group2: List[int], Indices for Verification Group B.
    """
    
    # 1. 筛选候选人 (Candidates)
    if exclude_ego:
        candidates = [i for i in range(agent_num) if i != ego_idx]
    else:
        # 在防御任务中，通常 Ego 必须作为独立的裁判，不建议 Ego 入组
        # 但为了兼容性保留此分支
        candidates = list(range(agent_num))
        
    num_candidates = len(candidates)
    
    # 边界情况处理 (Edge Cases)
    if num_candidates == 0:
        return [], []
    if num_candidates == 1:
        # 只有一个人，无法进行双盲对比。
        # 策略：把它放到 Group 1，Group 2 为空。上层逻辑检测到 Group 2 为空时，
        # 只能降级为单车检查 (Trust or Verify with History)。
        return [candidates[0]], []

    # 2. 几何排序 (Geometric Sorting)
    # 获取所有候选人相对于 Ego 的角度偏差
    # diversity_matrix[ego_idx, k] 代表 k 号车和 Ego 的视角夹角
    # 我们不仅需要大小，最好是有符号的角度（如果矩阵只有绝对值也没关系，按差异大小排也可以）
    # 假设 diversity_matrix 存的是 0~pi 的差异值
    angle_diffs = diversity_matrix[ego_idx, candidates]
    
    # argsort 返回的是排序后的索引
    # 这一步保证了相邻的 indices 代表空间位置接近的车辆
    sorted_indices = torch.argsort(angle_diffs)
    
    group1 = []
    group2 = []
    
    # 3. 交错分配 (Interleaved Assignment)
    # 就像发牌一样：G1 一张，G2 一张，G1 一张...
    # 这样能保证 Group 1 和 Group 2 在“近处、远处、左边、右边”都有分布
    for i in range(num_candidates):
        # 找到原始的 agent ID
        idx_in_candidates = sorted_indices[i]
        agent_id = candidates[idx_in_candidates]
        
        if i % 2 == 0:
            group1.append(agent_id)
        else:
            group2.append(agent_id)
            
    return group1, group2

def orthogonal_grouping(diversity_matrix: torch.Tensor, agent_num: int, 
                       ego_idx: int = 0, exclude_ego: bool = False) -> Tuple[List[int], List[int]]:
    """
    正交对抗分组 (Orthogonal Adversarial Grouping)
    
    贪心策略：在每个网格的候选人中，优先选取 θ 接近 90° (正交) 或 180° (对向) 的车辆配对成组。
    
    Group 1 (强对抗组): 包含视角差异最大的车辆 (例如: Ego + 侧车)。这组最难欺骗。
    Group 2 (弱对抗组): 剩余的、视角较为平行的车辆 (例如: Ego + 前车)。作为补充。
    
    Args:
        diversity_matrix: Tensor of shape (n, n) containing viewpoint diversity angles
        agent_num: Total number of agents
        ego_idx: Index of ego vehicle (default: 0)
        exclude_ego: If True, groups will not include ego (for recursive grouping)
        
    Returns:
        group1: List of agent indices with maximum diversity (strong adversarial group)
        group2: List of agent indices with remaining agents (weak adversarial group)
    """
    n = diversity_matrix.shape[0]
    
    # Target angles for orthogonal grouping: 90° (π/2) or 180° (π)
    target_orthogonal = np.pi / 2  # 90 degrees
    target_opposite = np.pi  # 180 degrees
    
    # Compute how close each angle is to orthogonal or opposite
    # Score: higher is better (closer to 90° or 180°)
    scores = torch.zeros_like(diversity_matrix)
    for i in range(n):
        for j in range(n):
            if i != j:
                angle = diversity_matrix[i, j]
                # Distance to 90° or 180°
                dist_to_orthogonal = torch.abs(angle - target_orthogonal)
                dist_to_opposite = torch.abs(angle - target_opposite)
                # Use minimum distance (closer to either 90° or 180°)
                min_dist = torch.min(dist_to_orthogonal, dist_to_opposite)
                # Convert to score (smaller distance = higher score)
                scores[i, j] = 1.0 / (min_dist + 1e-6)
    
    if exclude_ego:
        # 新策略：分组时不包含 ego，只对其他车辆分组
        non_ego_agents = [i for i in range(agent_num) if i != ego_idx]
        
        if len(non_ego_agents) == 0:
            return [], []
        elif len(non_ego_agents) == 1:
            return [non_ego_agents[0]], []
        elif len(non_ego_agents) == 2:
            return [non_ego_agents[0]], [non_ego_agents[1]]
        else:
            # 计算非 ego 车辆之间的视角差异度
            # 选择与 ego 视角差异最大的车辆作为 Group 1
            ego_scores = scores[ego_idx, :].clone()
            ego_scores[ego_idx] = -float('inf')  # Exclude ego itself
            
            best_agent = torch.argmax(ego_scores).item()
            group1 = [best_agent]
            
            # 剩余车辆进入 Group 2
            group2 = [i for i in non_ego_agents if i != best_agent]
            
            return group1, group2
    else:
        # 原策略：分组包含 ego
        group1 = [ego_idx]
        group2 = [ego_idx]
        
        # Get scores for all agents relative to ego
        ego_scores = scores[ego_idx, :].clone()
        ego_scores[ego_idx] = -float('inf')  # Exclude ego itself
        
        # Select agent with highest score for group 1
        if agent_num > 1:
            best_agent = torch.argmax(ego_scores).item()
            group1.append(best_agent)
            
            # Remaining agents go to group 2
            for i in range(agent_num):
                if i != ego_idx and i != best_agent:
                    group2.append(i)
        
        # If group2 only has ego, add at least one more agent
        if len(group2) == 1 and agent_num > 1:
            # Find agent with lowest score (most parallel)
            ego_scores[best_agent] = -float('inf')  # Exclude already selected
            if torch.max(ego_scores) > -float('inf'):
                second_agent = torch.argmax(ego_scores).item()
                group2.append(second_agent)
        
        return group1, group2



# def compute_dynamic_threshold(features: torch.Tensor, base_threshold: float = 0.6,
#                               sensitivity: float = 1.0) -> float:
#     """
#     计算动态自适应阈值
    
#     基于特征的统计特性动态调整阈值：
#     - 计算特征之间的相似度分布
#     - 根据分布的均值和标准差自适应调整阈值
    
#     Args:
#         features: Tensor of shape (n, C, H, W) containing spatial features
#         base_threshold: Base threshold value (default: 0.5)
#         sensitivity: Sensitivity factor for threshold adjustment (default: 1.0)
        
#     Returns:
#         dynamic_threshold: Adaptively computed threshold
#     """
#     n = features.shape[0]
    
#     if n <= 1:
#         return base_threshold
    
#     # Compute pairwise cosine similarities
#     features_flat = features.view(n, -1)
#     features_normalized = F.normalize(features_flat, p=2, dim=1)
    
#     similarity_matrix = torch.mm(features_normalized, features_normalized.t())
    
#     # Get upper triangular part (excluding diagonal)
#     mask = torch.triu(torch.ones_like(similarity_matrix), diagonal=1).bool()
#     similarities = similarity_matrix[mask]
    
#     if similarities.numel() == 0:
#         return base_threshold
    
#     # Compute statistics
#     mean_sim = similarities.mean().item()
#     # Use unbiased=False to get population std (divide by n instead of n-1)
#     # This is more stable when we have few samples
#     std_sim = similarities.std(unbiased=False).item()
    
#     # Debug: Print intermediate values
#     print(f"[DEBUG compute_dynamic_threshold] n={n}, num_pairs={similarities.numel()}")
#     print(f"[DEBUG compute_dynamic_threshold] mean_sim={mean_sim:.6f}, std_sim={std_sim:.6f}")
#     print(f"[DEBUG compute_dynamic_threshold] similarity range: [{similarities.min().item():.6f}, {similarities.max().item():.6f}]")
    
#     # Dynamic threshold calculation
#     # Original formula: base + sensitivity * (mean - std)
#     # 
#     # Problem: If mean_sim and std_sim are both small and close, 
#     # then (mean - std) ≈ 0, and threshold stays at base_threshold (0.5)
#     #
#     # Improved formula: Use mean_sim directly as a scaling factor
#     # - High mean_sim (similar features) -> increase threshold
#     # - Low mean_sim (diverse features) -> decrease threshold
#     # - Use std_sim as a penalty term for uncertainty
#     #
#     # New formula: base + sensitivity * (mean_sim - 0.5) - sensitivity * std_sim * 0.5
#     # This ensures the threshold changes even when mean_sim and std_sim are both small
    
#     # Calculate adjustment based on mean similarity (centered around 0.5)
#     # mean_sim ranges from -1 to 1, but typically 0 to 1 for normalized features
#     mean_adjustment = sensitivity * (mean_sim - 0.5) * 0.4  # Scale to [-0.2, 0.2] range
#     std_penalty = sensitivity * std_sim * 0.2  # Penalty for high variance
#     adjustment = mean_adjustment - std_penalty
    
#     dynamic_threshold = base_threshold + adjustment
    
#     print(f"[DEBUG compute_dynamic_threshold] base_threshold={base_threshold:.6f}, sensitivity={sensitivity:.6f}")
#     print(f"[DEBUG compute_dynamic_threshold] mean_adjustment={mean_adjustment:.6f}, std_penalty={std_penalty:.6f}")
#     print(f"[DEBUG compute_dynamic_threshold] adjustment={adjustment:.6f} (mean_sim - std_sim = {mean_sim - std_sim:.6f})")
#     print(f"[DEBUG compute_dynamic_threshold] before clamp: {dynamic_threshold:.6f}")
    
#     # Clamp to reasonable range [0.1, 0.9]
#     dynamic_threshold = max(0.1, min(0.9, dynamic_threshold))
    
#     print(f"[DEBUG compute_dynamic_threshold] final threshold: {dynamic_threshold:.6f}")
    
#     return dynamic_threshold

# class AdaptiveThresholdCalculator:
#     """
#     基于时间窗口的自适应阈值计算器
    
#     使用滑动窗口维护历史 CPS 分数值，动态调整可疑度阈值
    
#     注意：这个计算器是用于 CPS 分数的阈值，而不是特征相似度的阈值
#     """
#     def __init__(self, window_size: int = 10, k: float = 1.0, 
#                  base_threshold: float = 0.68, max_sigma_ratio: float = 2.5,
#                  min_threshold: float = 0.4, max_threshold: float = 0.9):
#         """
#         Args:
#             window_size: 时间窗口大小
#             k: 标准差系数 (建议 0.8-1.5，用于 CPS 阈值)
#             base_threshold: 初始基础阈值 (建议 0.5-0.7，针对 CPS 分数)
#             max_sigma_ratio: 方差与均值的最大比例 (sigma <= mu / max_sigma_ratio)
#             min_threshold: 阈值下限 (CPS 分数下限)
#             max_threshold: 阈值上限 (CPS 分数上限)
#         """
#         self.window_size = window_size
#         self.k = k
#         self.base_threshold = base_threshold
#         self.max_sigma_ratio = max_sigma_ratio
#         self.min_threshold = min_threshold
#         self.max_threshold = max_threshold
        
#         # 滑动窗口存储历史 CPS 分数
#         self.window = []
#         # 存储每一帧的阈值（用于分析）
#         self.threshold_history = []
        
#         # 当前统计量
#         self.mu_0 = None
#         self.sigma_0 = None
#         self.current_threshold = base_threshold
    
#     def update(self, current_cps_scores: list) -> float:
#         """
#         更新滑动窗口并计算新的阈值
        
#         计算逻辑：
#         - CPS 分数通常在 [0.5, 0.9] 范围，值越大表示越可疑
#         - 阈值用于判断：CPS >= threshold 的车辆被标记为可疑
#         - 公式: threshold = mu + k * sigma
#           - 高于 (均值 + k倍标准差) 的车辆被认为是异常可疑的
#           - 这是标准的统计异常检测方法
        
#         Args:
#             current_cps_scores: 当前帧所有车辆的 CPS 分数列表
            
#         Returns:
#             threshold: 更新后的阈值
#         """
#         # 如果当前帧有 CPS 分数，计算平均值并加入窗口
#         if len(current_cps_scores) > 0:
#             avg_cps = np.mean(current_cps_scores)
#             self.window.append(avg_cps)
        
#         # Remove oldest value if window exceeds size
#         if len(self.window) > self.window_size:
#             self.window.pop(0)
        
#         # Update mu_0, sigma_0, and threshold only when window is full
#         if len(self.window) == self.window_size:
#             self.mu_0 = np.mean(self.window)
#             self.sigma_0 = np.std(self.window)
            
#             # 方差不能超过 mu 的一定比例
#             if self.sigma_0 > self.mu_0 / self.max_sigma_ratio:
#                 self.sigma_0 = self.mu_0 / self.max_sigma_ratio
            
#             # 更新阈值: threshold = mu + k * sigma
#             # 注意：这里是加法！高于均值+k倍标准差的被认为是异常
#             self.current_threshold = self.mu_0 + self.k * self.sigma_0
            
#             # Clamp to reasonable range
#             self.current_threshold = max(self.min_threshold, min(self.max_threshold, self.current_threshold))
        
#         # Store the current threshold for this frame
#         self.threshold_history.append(float(self.current_threshold))
        
#         return self.current_threshold
    
#     def reset(self):
#         """重置计算器状态"""
#         self.window = []
#         self.threshold_history = []
#         self.mu_0 = None
#         self.sigma_0 = None
#         self.current_threshold = self.base_threshold


class AdaptiveThresholdCalculator:
    """
    改进的基于时间窗口的自适应阈值计算器
    
    使用滑动窗口维护历史 CPS 分数值，动态调整可疑度阈值
    """
    def __init__(self, window_size: int = 10, k: float = 1.0, 
                 base_threshold: float = 0.68, max_sigma_ratio: float = 2.5,
                 min_threshold: float = 0.5, max_threshold: float = 0.9,
                 warmup_frames: int = 3):
        """
        Args:
            window_size: 时间窗口大小
            k: 标准差系数 (建议 0.8-1.5，用于 CPS 阈值)
            base_threshold: 初始基础阈值 (建议 0.6-0.75，针对 CPS 分数)
            max_sigma_ratio: 方差与均值的最大比例 (sigma <= mu / max_sigma_ratio)
            min_threshold: 阈值下限 (CPS 分数下限)
            max_threshold: 阈值上限 (CPS 分数上限)
            warmup_frames: 预热帧数，在此之前使用 base_threshold
        """
        self.window_size = window_size
        self.k = k
        self.base_threshold = base_threshold
        self.max_sigma_ratio = max_sigma_ratio
        self.min_threshold = min_threshold
        self.max_threshold = max_threshold
        self.warmup_frames = warmup_frames
        
        # 滑动窗口存储历史 CPS 分数
        self.window = []
        # 存储每一帧的阈值（用于分析）
        self.threshold_history = []
        
        # 当前统计量
        self.mu_0 = None
        self.sigma_0 = None
        self.current_threshold = base_threshold
        
        # 帧计数器
        self.frame_count = 0
    
    def update(self, current_cps_scores: list, is_ego_only: bool = False) -> float:
        """
        更新滑动窗口并计算新的阈值
        
        改进点:
        1. 添加预热阶段,避免初期数据不足导致阈值异常
        2. 即使窗口未满也更新阈值(使用可用数据)
        3. 确保阈值有合理的下界,避免出现0
        4. 使用更稳健的标准差限制
        5. 跳过ego-only情况,避免CPS=0污染统计
        
        Args:
            current_cps_scores: 当前帧所有车辆的 CPS 分数列表
            is_ego_only: 是否为ego-only情况(回退到仅使用ego)
            
        Returns:
            threshold: 更新后的阈值
        """
        self.frame_count += 1
        
        # 如果是ego-only情况,不加入统计窗口(CPS=0会污染统计)
        if is_ego_only:
            logger.info(f"[Threshold] Frame {self.frame_count}: Ego-only fallback detected, skipping CPS=0 from statistics")
            self.threshold_history.append(float(self.current_threshold))
            return self.current_threshold
        
        # 如果当前帧有 CPS 分数，计算平均值并加入窗口
        if len(current_cps_scores) > 0:
            avg_cps = np.mean(current_cps_scores)
            # 额外保护: 如果平均CPS接近0,可能是异常情况,不加入窗口
            if avg_cps < 0.01:
                logger.info(f"[Threshold] Frame {self.frame_count}: CPS too low ({avg_cps:.6f}), skipping")
                self.threshold_history.append(float(self.current_threshold))
                return self.current_threshold
            self.window.append(avg_cps)
        
        # Remove oldest value if window exceeds size
        if len(self.window) > self.window_size:
            self.window.pop(0)
        
        # 预热阶段：使用 base_threshold
        if self.frame_count <= self.warmup_frames:
            logger.info(f"[Threshold] Warmup phase ({self.frame_count}/{self.warmup_frames}), using base_threshold={self.base_threshold:.4f}")
            self.threshold_history.append(float(self.base_threshold))
            return self.base_threshold
        
        # 至少需要2个数据点才能计算标准差
        if len(self.window) < 2:
            logger.info(f"[Threshold] Insufficient data (n={len(self.window)}), using base_threshold={self.base_threshold:.4f}")
            self.threshold_history.append(float(self.base_threshold))
            return self.base_threshold
        
        # 计算统计量
        self.mu_0 = np.mean(self.window)
        self.sigma_0 = np.std(self.window)
        
        # 改进1: 确保标准差有最小值，避免阈值过低
        min_sigma = self.mu_0 * 0.05  # 标准差至少是均值的5%
        self.sigma_0 = max(self.sigma_0, min_sigma)
        
        # 改进2: 更合理的标准差上限
        max_allowed_sigma = self.mu_0 / self.max_sigma_ratio
        if self.sigma_0 > max_allowed_sigma:
            self.sigma_0 = max_allowed_sigma
        
        # 更新阈值: threshold = mu + k * sigma
        self.current_threshold = self.mu_0 + self.k * self.sigma_0
        
        # 改进3: 确保阈值不低于一个安全下界
        # 即使历史CPS很低，阈值也不应该太低
        safe_lower_bound = max(self.min_threshold, self.base_threshold * 0.8)
        self.current_threshold = max(safe_lower_bound, self.current_threshold)
        
        # Clamp to reasonable range
        self.current_threshold = max(self.min_threshold, min(self.max_threshold, self.current_threshold))
        
        # Debug output
        logger.info(f"[Threshold] Frame {self.frame_count}: window_size={len(self.window)}/{self.window_size}")
        logger.info(f"[Threshold] mu={self.mu_0:.4f}, sigma={self.sigma_0:.4f}, k={self.k:.2f}")
        logger.info(f"[Threshold] threshold before safe_lower_bound = {self.mu_0:.4f} + {self.k:.2f} * {self.sigma_0:.4f} = {self.mu_0 + self.k * self.sigma_0:.4f}")
        logger.info(f"[Threshold] threshold = {self.mu_0:.4f} + {self.k:.2f} * {self.sigma_0:.4f} = {self.current_threshold:.4f}")
        logger.info(f"[Threshold] safe_lower_bound={safe_lower_bound:.4f}")
        
        # Store the current threshold for this frame
        self.threshold_history.append(float(self.current_threshold))
        
        return self.current_threshold
    
    def reset(self):
        """重置计算器状态"""
        self.window = []
        self.threshold_history = []
        self.mu_0 = None
        self.sigma_0 = None
        self.current_threshold = self.base_threshold
        self.frame_count = 0
    
    def get_statistics(self):
        """获取当前统计信息"""
        return {
            'frame_count': self.frame_count,
            'window_size': len(self.window),
            'mu': self.mu_0,
            'sigma': self.sigma_0,
            'current_threshold': self.current_threshold,
            'threshold_history': self.threshold_history,
            'window_values': self.window.copy()
        }


class UncertaintyAwareAdaptiveThreshold:
    """
    UA-AT (Uncertainty-Aware Adaptive Thresholding)
    
    动态阈值 T_t 结合两部分:
    1. 历史统计分布 (时间维度): 基于3-Sigma准则的历史正常分数
    2. 当前场景不确定性 (空间维度): Ego特征的不确定性
    
    最终公式: T_final = T_stat * (1 + β * Norm(U_ego))
    其中: T_stat = μ_t + γ * σ_t
    """
    
    def __init__(self, 
                 window_size: int = 10,
                 gamma: float = 3.0,
                 beta: float = 0.5,
                 base_threshold: float = 0.54,
                 min_threshold: float = 0.35,
                 max_threshold: float = 0.95,
                 warmup_frames: int = 3,
                 uncertainty_method: str = 'entropy'):
        """
        Args:
            window_size: 滑动窗口大小 (维护N帧历史正常分数)
            gamma: 3-Sigma准则的敏感度系数 (通常取3, 覆盖99.7%正常分布)
            beta: 不确定性调节系数 (控制U_ego对阈值的影响程度)
            base_threshold: 预热阶段的基础阈值
            min_threshold: 阈值下限
            max_threshold: 阈值上限
            warmup_frames: 预热帧数
            uncertainty_method: 不确定性计算方法 ('entropy' 或 'variance')
        """
        self.window_size = window_size
        self.gamma = gamma
        self.beta = beta
        self.base_threshold = base_threshold
        self.min_threshold = min_threshold
        self.max_threshold = max_threshold
        self.warmup_frames = warmup_frames
        self.uncertainty_method = uncertainty_method
        
        # 滑动窗口: 存储历史良性(Benign)的因果差异分数
        self.benign_window = []
        
        # 统计量
        self.mu_t = None  # 历史均值
        self.sigma_t = None  # 历史标准差
        self.T_stat = None  # 统计阈值
        self.U_ego = None  # Ego不确定性
        self.T_final = None  # 最终阈值
        
        # 历史记录
        self.threshold_history = []
        self.uncertainty_history = []
        self.frame_count = 0
    
    def compute_ego_uncertainty(self, ego_features: torch.Tensor) -> float:
        """
        计算 Ego 特征的不确定性 U_ego
        
        两种方法:
        1. Entropy (熵): 衡量特征分布的混乱程度
           U_ego = -1/(H*W) * Σ p(f_ij) * log(p(f_ij))
        
        2. Variance (方差): 衡量特征的波动程度
           U_ego = Var(F_ego)
        
        Args:
            ego_features: Ego特征图, shape: [C, H, W] 或 [B, C, H, W]
            
        Returns:
            uncertainty: 归一化的不确定性值 [0, 1]
        """
        # 确保是4D tensor [B, C, H, W]
        if ego_features.dim() == 3:
            ego_features = ego_features.unsqueeze(0)
        
        B, C, H, W = ego_features.shape
        
        if self.uncertainty_method == 'entropy':
            # 方法1: 基于熵的不确定性
            # 将特征展平为 [B, C, H*W]
            features_flat = ego_features.view(B, C, -1)
            
            # 对每个空间位置计算概率分布 (softmax across channels)
            probs = F.softmax(features_flat, dim=1)  # [B, C, H*W]
            
            # 计算熵: H = -Σ p*log(p)
            log_probs = torch.log(probs + 1e-10)  # 避免log(0)
            entropy = -(probs * log_probs).sum(dim=1)  # [B, H*W]
            
            # 平均熵作为不确定性
            U_ego = entropy.mean().item()
            
            # 归一化: 熵的理论最大值是 log(C)
            U_ego_norm = U_ego / np.log(C)
            
        elif self.uncertainty_method == 'variance':
            # 方法2: 基于方差的不确定性 (更简单直接)
            # 计算每个通道的方差，然后平均
            variance = ego_features.var(dim=[2, 3])  # [B, C]
            U_ego = variance.mean().item()
            
            # 归一化: 使用tanh将方差映射到[0,1]
            # 假设方差在0-10范围内比较常见
            U_ego_norm = np.tanh(U_ego / 5.0)
        
        else:
            raise ValueError(f"Unknown uncertainty method: {self.uncertainty_method}")
        
        # 确保在[0, 1]范围内
        U_ego_norm = np.clip(U_ego_norm, 0.0, 1.0)
        
        return U_ego_norm
    
    def update(self, 
               current_cps_scores: List[float],
               ego_features: torch.Tensor,
               is_ego_only: bool = False,
               detected_malicious: bool = False,
               current_threshold: float = None) -> float:
        """
        更新阈值
        
        改进逻辑:
        1. 如果当前帧检测到恶意车辆 (detected_malicious=True), 不加入良性窗口
        2. 如果是ego-only情况, 跳过 (CPS=0会污染统计)
        3. 只有 CPS < current_threshold 的分数才加入良性窗口（CPS越大越恶意）
        4. 计算统计阈值 T_stat = μ_t + γ * σ_t
        5. 计算Ego不确定性 U_ego
        6. 最终阈值 T_final = T_stat * (1 + β * U_ego)
        
        Args:
            current_cps_scores: 当前帧所有车辆的CPS分数（CPS越大越恶意）
            ego_features: Ego特征图 (用于计算不确定性)
            is_ego_only: 是否为ego-only回退情况
            detected_malicious: 是否检测到恶意车辆
            current_threshold: 当前使用的阈值（用于判断CPS是否应该加入良性窗口）
            
        Returns:
            T_final: 最终自适应阈值
        """
        self.frame_count += 1
        
        # === 步骤1: 更新良性窗口 (只有正常帧才加入) ===
        # 重要：CPS 越大越恶意，所以只有 CPS < threshold 的才应该加入良性窗口
        # 
        # 关键需求：只有ego车辆是良性时，不更新cps得分窗口
        # 原因：窗口是检测出来的所有良性车辆（含ego）的融合CPS分
        #      如果只有ego是良性，CPS分估计就是0，这样弄进去反而会污染窗口，导致得不到正确的窗口
        # 
        # 判断逻辑：
        # 1. is_ego_only=True: 只有ego是良性，跳过窗口更新（CPS=0会污染统计）
        # 2. detected_malicious=True: 检测到恶意车辆，跳过窗口更新
        # 3. len(current_cps_scores)==0: 没有CPS分数，跳过
        # 4. 其他情况：有多个良性车辆（包括ego），可以更新窗口
        if is_ego_only:
            # ego-only情况：只有ego是良性，不更新窗口（CPS=0会污染统计）
            logger.info(f"[UA-AT] Frame {self.frame_count}: Ego-only detected (only ego is benign), skipping window update to avoid CPS=0 pollution")
        elif len(current_cps_scores) == 0:
            # 没有CPS分数，跳过
            logger.info(f"[UA-AT] Frame {self.frame_count}: No CPS scores provided, skipping window update")
        else:
            avg_cps = np.mean(current_cps_scores)
            
            # 获取当前阈值（用于判断是否应该加入良性窗口）
            # 如果没有提供 current_threshold，使用 T_final（如果已计算）或 base_threshold
            threshold_for_check = current_threshold
            if threshold_for_check is None:
                threshold_for_check = self.T_final if self.T_final is not None else self.base_threshold
            
            # 只有 CPS < threshold 的才加入良性窗口（CPS越大越恶意）
            # 同时过滤异常低的CPS (可能是异常情况，如CPS=0)
            if 0.01 <= avg_cps < threshold_for_check:
                self.benign_window.append(avg_cps)
                logger.info(f"[UA-AT] Frame {self.frame_count}: Added benign score {avg_cps:.4f} to window (threshold: {threshold_for_check:.4f})")
            else:
                logger.info(f"[UA-AT] Frame {self.frame_count}: Skipped CPS {avg_cps:.4f} (not benign: avg_cps >= {threshold_for_check:.4f} or < 0.01)")
        
        # 维护窗口大小
        if len(self.benign_window) > self.window_size:
            self.benign_window.pop(0)
        
        # === 步骤2: 预热阶段 ===
        if self.frame_count <= self.warmup_frames:
            logger.info(f"[UA-AT] Warmup phase ({self.frame_count}/{self.warmup_frames}), "
                       f"using base_threshold={self.base_threshold:.4f}")
            self.T_final = self.base_threshold
            self.threshold_history.append(self.T_final)
            self.uncertainty_history.append(0.0)
            self.U_ego = 0.0
            self.T_stat = self.base_threshold
            return self.T_final
        
        # === 步骤3: 计算统计阈值 T_stat ===
        if len(self.benign_window) < 2:
            # 数据不足, 使用基础阈值
            self.T_stat = self.base_threshold
            logger.info(f"[UA-AT] Insufficient benign data (n={len(self.benign_window)}), "
                       f"using base T_stat={self.T_stat:.4f}")
        else:
            # 计算历史均值和标准差
            self.mu_t = np.mean(self.benign_window)
            self.sigma_t = np.std(self.benign_window)
            
            # 添加最小标准差保护 (避免σ过小导致阈值过低)
            min_sigma = self.mu_t * 0.05
            if self.sigma_t < min_sigma:
                logger.info(f"[UA-AT] σ_t too small ({self.sigma_t:.4f}), using min_sigma={min_sigma:.4f}")
                self.sigma_t = min_sigma
            
            # 3-Sigma准则: T_stat = μ_t + γ * σ_t
            # 重要：CPS分数在[0,1]范围内，所以阈值也应该在[0,1]范围内
            # 先计算原始值，然后clip到合理范围（不超过max_threshold）
            self.T_stat = self.mu_t + self.gamma * self.sigma_t
            
            # 确保T_stat不超过max_threshold（因为CPS最大是1）
            # 但也不要低于min_threshold
            self.T_stat = np.clip(self.T_stat, self.min_threshold, self.max_threshold)
            
            logger.info(f"[UA-AT] T_stat = μ_t + γ*σ_t = {self.mu_t:.4f} + {self.gamma:.2f}*{self.sigma_t:.4f} = {self.T_stat:.4f} (clamped to [{self.min_threshold}, {self.max_threshold}])")
        
        # === 步骤4: 计算Ego不确定性 U_ego ===
        self.U_ego = self.compute_ego_uncertainty(ego_features)
        logger.info(f"[UA-AT] U_ego ({self.uncertainty_method}) = {self.U_ego:.4f}")
        
        # === 步骤5: 计算最终阈值 T_final ===
        # 改进公式：使用相对调整而不是绝对调整，确保阈值永远不会超过max_threshold
        # T_final = T_stat + β * U_ego * (max_threshold - T_stat)
        # 这样当U_ego=0时，T_final = T_stat
        # 当U_ego=1时，T_final = T_stat + β * (max_threshold - T_stat) ≤ max_threshold
        # 
        # 或者使用乘法形式但限制上限：
        # T_final = min(T_stat * (1 + β * U_ego), max_threshold)
        # 
        # 我们使用改进的相对调整公式，更稳定且保证不超过上限
        uncertainty_adjustment = self.beta * self.U_ego * (self.max_threshold - self.T_stat)
        self.T_final = self.T_stat + uncertainty_adjustment
        
        logger.info(f"[UA-AT] Uncertainty adjustment = β * U_ego * (max_threshold - T_stat)")
        logger.info(f"[UA-AT] Uncertainty adjustment = {self.beta:.2f} * {self.U_ego:.4f} * ({self.max_threshold:.4f} - {self.T_stat:.4f}) = {uncertainty_adjustment:.4f}")
        logger.info(f"[UA-AT] T_final = T_stat + adjustment = {self.T_stat:.4f} + {uncertainty_adjustment:.4f} = {self.T_final:.4f}")
        
        # === 步骤6: 阈值边界约束 ===
        # 双重保险：确保阈值在[0,1]范围内（CPS分数最大是1，所以阈值最大也应该是1）
        # 使用max_threshold作为上限（默认0.95，留一些安全边际）
        self.T_final = np.clip(self.T_final, self.min_threshold, self.max_threshold)
        
        logger.info(f"[UA-AT] T_final (clamped) = {self.T_final:.4f} (range: [{self.min_threshold}, {self.max_threshold}])")
        
        # === 记录历史 ===
        self.threshold_history.append(self.T_final)
        self.uncertainty_history.append(self.U_ego)
        
        return self.T_final
    
    def reset(self):
        """重置计算器状态"""
        self.benign_window = []
        self.mu_t = None
        self.sigma_t = None
        self.T_stat = None
        self.U_ego = None
        self.T_final = None
        self.threshold_history = []
        self.uncertainty_history = []
        self.frame_count = 0
    
    def get_statistics(self) -> Dict:
        """获取当前统计信息"""
        return {
            'frame_count': self.frame_count,
            'window_size': len(self.benign_window),
            'mu_t': self.mu_t,
            'sigma_t': self.sigma_t,
            'T_stat': self.T_stat,
            'U_ego': self.U_ego,
            'T_final': self.T_final,
            'threshold_history': self.threshold_history.copy(),
            'uncertainty_history': self.uncertainty_history.copy(),
            'benign_window': self.benign_window.copy(),
            'hyperparameters': {
                'gamma': self.gamma,
                'beta': self.beta,
                'window_size': self.window_size,
                'uncertainty_method': self.uncertainty_method
            }
        }


def compute_dynamic_threshold(cps_scores: List[float], 
                              threshold_calculator: AdaptiveThresholdCalculator = None,
                              base_threshold: float = 0.6,
                              k: float = 1.0,
                              window_size: int = 10) -> float:
    """
    计算动态自适应阈值（基于时间窗口的 CPS 分数）
    
    基于历史 CPS 分数动态调整阈值：
    - 收集当前帧所有车辆的 CPS 分数
    - 使用滑动窗口维护历史 CPS 分数的平均值
    - 根据窗口内的均值和标准差自适应调整阈值
    
    阈值计算说明：
    - CPS 分数通常在 [0.5, 0.9] 范围
    - 阈值公式: threshold = mu + k * sigma (标准异常检测)
    - 高于阈值的车辆被标记为可疑
    - 建议参数：
      - base_threshold: 0.6-0.7 (初始阈值)
      - k: 0.8-1.5 (标准差系数，越大越严格)

    该函数暂未使用，使用的是 AdaptiveThresholdCalculator 类
    
    Args:
        cps_scores: 当前帧所有车辆的 CPS 分数列表
        threshold_calculator: AdaptiveThresholdCalculator instance
        base_threshold: Base threshold value (default: 0.6)
        k: 标准差系数 (default: 1.0)
        window_size: 时间窗口大小 (default: 10)
        
    Returns:
        dynamic_threshold: Adaptively computed threshold for CPS scores
    """
    # 如果没有提供计算器，创建一个临时的
    if threshold_calculator is None:
        logger.warning("[WARNING] No threshold_calculator provided. Creating temporary instance.")
        logger.warning("[WARNING] For proper temporal tracking, pass a persistent calculator instance.")
        threshold_calculator = AdaptiveThresholdCalculator(
            window_size=window_size, 
            k=k, 
            base_threshold=base_threshold,
            min_threshold=0.5,  # CPS 分数下限
            max_threshold=0.9   # CPS 分数上限
        )
    
    if len(cps_scores) == 0:
        return threshold_calculator.current_threshold
    
    # Update threshold using sliding window of CPS scores
    dynamic_threshold = threshold_calculator.update(cps_scores)
    
    # Debug output
    logger.debug(f"[DEBUG compute_dynamic_threshold] num_agents={len(cps_scores)}")
    logger.debug(f"[DEBUG compute_dynamic_threshold] current_cps_scores={[f'{s:.4f}' for s in cps_scores]}")
    logger.debug(f"[DEBUG compute_dynamic_threshold] avg_cps={np.mean(cps_scores):.6f}")
    logger.debug(f"[DEBUG compute_dynamic_threshold] window_size={len(threshold_calculator.window)}/{threshold_calculator.window_size}")
    if threshold_calculator.mu_0 is not None:
        logger.debug(f"[DEBUG compute_dynamic_threshold] mu_0={threshold_calculator.mu_0:.6f}, sigma_0={threshold_calculator.sigma_0:.6f}")
        logger.debug(f"[DEBUG compute_dynamic_threshold] k={threshold_calculator.k:.2f}")
        logger.debug(f"[DEBUG compute_dynamic_threshold] threshold = {threshold_calculator.mu_0:.6f} + {threshold_calculator.k:.2f} * {threshold_calculator.sigma_0:.6f} = {dynamic_threshold:.6f}")
    else:
        logger.debug(f"[DEBUG compute_dynamic_threshold] Window not full yet, using base_threshold={dynamic_threshold:.6f}")
    
    return dynamic_threshold


def mdag_grouping_strategy(positions: torch.Tensor, features: torch.Tensor,
                           ego_idx: int = 0, use_dynamic_threshold: bool = True,
                           base_threshold: float = 0.5,
                           sensitivity: float = 1.0,
                           cps_scores: List[float] = None,
                           threshold_calculator: AdaptiveThresholdCalculator = None) -> Dict:
    """
    MDAG 完整分组策略   未使用
    
    Args:
        positions: Tensor of shape (n, 2) containing vehicle positions
        features: Tensor of shape (n, C, H, W) containing spatial features
        ego_idx: Index of ego vehicle
        use_dynamic_threshold: Whether to use dynamic threshold
        base_threshold: Base threshold for filtering
        sensitivity: Sensitivity for dynamic threshold (used as k parameter)
        cps_scores: List of CPS scores for all agents (required if use_dynamic_threshold=True)
        threshold_calculator: AdaptiveThresholdCalculator instance for temporal tracking
        
    Returns:
        result: Dictionary containing:
            - group1: Strong adversarial group indices
            - group2: Weak adversarial group indices
            - diversity_matrix: Viewpoint diversity matrix
            - threshold: Used threshold value
    """
    agent_num = positions.shape[0]
    
    # Step 1: Compute viewpoint diversity matrix
    diversity_matrix = compute_viewpoint_diversity(positions)
    
    # Step 2: Isomorphic viewpoint grouping for balanced FoV coverage
    group1, group2 = isomorphic_viewpoint_grouping(
        diversity_matrix, agent_num, ego_idx=ego_idx, exclude_ego=True
    )
    
    # Step 3: Compute dynamic threshold if enabled
    if use_dynamic_threshold:
        if cps_scores is None or len(cps_scores) == 0:
            logger.warning("[WARNING] mdag_grouping_strategy: use_dynamic_threshold=True but CPS scores not provided.")
            logger.warning("[WARNING] Falling back to base_threshold.")
            threshold = base_threshold
        else:
            # Compute dynamic threshold using CPS scores
            threshold = compute_dynamic_threshold(
                cps_scores,
                threshold_calculator=threshold_calculator,
                base_threshold=base_threshold,
                k=sensitivity
            )
    else:
        threshold = base_threshold
    
    return {
        'group1': group1,
        'group2': group2,
        'diversity_matrix': diversity_matrix,
        'threshold': threshold
    }


def extract_positions_from_pairwise_matrix(pairwise_t_matrix: torch.Tensor) -> torch.Tensor:
    """
    从 pairwise transformation matrix 中提取车辆位置
    
    Args:
        pairwise_t_matrix: Tensor of shape (n, n, 4, 4) containing transformation matrices
        
    Returns:
        positions: Tensor of shape (n, 2) containing (x, y) positions
    """
    n = pairwise_t_matrix.shape[0]
    positions = torch.zeros(n, 2, device=pairwise_t_matrix.device)
    
    # Extract translation components from transformation matrices
    # Position of agent i relative to ego is in pairwise_t_matrix[0, i, :2, 3]
    for i in range(n):
        positions[i, 0] = pairwise_t_matrix[0, i, 0, 3]  # x
        positions[i, 1] = pairwise_t_matrix[0, i, 1, 3]  # y
    
    return positions

