import argparse, os, time, torch, random
import numpy as np
import pandas as pd
import open3d as o3d
import logging
from datetime import datetime
import attack_utils.inference_attack_utils as attack_utils
import cv2

from tqdm import tqdm
from torch.utils.data import DataLoader
from opencood.tools import train_utils
from opencood.tools import inference_utils
from opencood.data_utils.datasets import build_dataset
from opencood.utils import eval_utils
from opencood.utils.common_utils import torch_tensor_to_numpy
from opencood.visualization import vis_utils
from opencood.hypes_yaml import yaml_utils
from attack_utils import get_adv_loss
from attack_utils.adv_loss import focalLoss, PALoss
from defense.defender import cp_defense_lucia
from defense.robosac import robosac
from defense.cps_defense import cps_defense
from defense.feature_guard import feature_guard
from defense.pasac import pasac, pasac_simple

random.seed(0)

def setup_logging(model_dir, attack_mode, loss, defense, use_dynamic_threshold, use_ssim):
    """
    设置日志系统，同时输出到控制台和文件
    返回logger实例，可以在其他模块中使用
    """
    # 创建日志文件名
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_filename = (
        f'attack_{attack_mode}_{loss}_defense_{defense}'
        f'_threshold_{use_dynamic_threshold}_use_ssim={use_ssim}_{timestamp}.log'
    )
    log_filepath = os.path.join(model_dir, log_filename)
    
    # 配置日志格式
    log_format = '%(asctime)s - %(levelname)s - %(message)s'
    date_format = '%Y-%m-%d %H:%M:%S'
    
    # 清除已有的handlers（避免重复添加）
    root_logger = logging.getLogger()
    root_logger.handlers = []
    
    # 配置日志级别
    root_logger.setLevel(logging.INFO)
    
    # 创建formatter
    formatter = logging.Formatter(log_format, datefmt=date_format)
    
    # 文件处理器
    file_handler = logging.FileHandler(log_filepath, encoding='utf-8')
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    
    # 控制台处理器
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(formatter)
    
    # 添加handlers到root logger
    root_logger.addHandler(file_handler)
    root_logger.addHandler(console_handler)
    
    # 获取logger实例
    logger = logging.getLogger(__name__)
    logger.info(f"Logging initialized. Log file: {log_filepath}")
    
    return logger

def test_parser():
    parser = argparse.ArgumentParser(description="Adversarial attacks and defenses on OPV2V")
    parser.add_argument('--model_dir', type=str, required=True,
                        help='directory of the model')
    parser.add_argument('--show_vis', action='store_true',
                        help='specify to show image visualization result')
    parser.add_argument('--show_sequence', action='store_true',
                        help='specify to show video visualization result.'
                             'it can note be set true with show_vis together ')
    parser.add_argument('--save_video', action='store_true',
                        help='specify to save video visualization result')
    parser.add_argument('--save_vis', action='store_true',
                        help='specify to save visualization result')
    parser.add_argument('--save_npy', action='store_true',
                        help='specify to save prediction and gt result'
                             'in npy_test file')
    parser.add_argument('--save_perturb', action='store_true',
                        help='specify to save perturbed feature for each frame')
    parser.add_argument('--save_feature_maps', action='store_true',
                        help='specify to save per-agent feature response maps')
    parser.add_argument('--feature_map_dir', type=str, default='',
                        help='custom directory for saved feature maps (default under model_dir)')
    parser.add_argument('--feature_map_reduction', type=str, default='mean_abs',
                        choices=['mean_abs', 'max_abs', 'sum_squares', 'mean', 'sum'],
                        help='channel reduction method when creating response maps')
    parser.add_argument('--feature_map_cmap', type=str, default='plasma',
                        help='matplotlib colormap used for response map visualization')
    parser.add_argument('--feature_map_scale', type=int, default=4,
                        help='upsampling factor for saved response maps (>=1)')
    parser.add_argument('--save_tsne', action='store_true',
                        help='save t-SNE visualization of feature distribution (benign vs attacked)')
    parser.add_argument('--iter', type=int, default=10, help='number of iterations for the attack, default 10')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate for the attack, default 0.1')
    parser.add_argument('--model', type=str, default='AttentiveFusion', help='choose from AttentiveFusion, CoAlign, Where2comm, V2VAM-needed for fetching attention default: AttentiveFusion')
    parser.add_argument('--attack_mode', type=str, default='tor', help='choose from tor, mor, (targeted or mass removal) default: tor')
    parser.add_argument('--skip', type=int, help="Skip the first x number of frames, default: 0", default=0)
    parser.add_argument('--data_dir', type=str, \
                        help="overwrite the test dataset directory specified in the model hypes", default="")
    parser.add_argument('--target_id', help='Specify the ID for which the attacker attempt to suppress its bbox, \
                        choose from: <TARGET ID>, random, in, out', default=-1)
    parser.add_argument('--loss', type=str, default='sombra', help='choose from sombra, pa, bim, pgd, cw, default: sombra')
    parser.add_argument('--pgd_epsilon', type=float, default=0.3, help='epsilon for PGD attack (L_inf bound), default: 0.3')
    parser.add_argument('--cw_confidence', type=float, default=0.0, help='confidence (kappa) for C&W attack, default: 0.0')
    parser.add_argument('--cw_c', type=float, default=1.0, help='weight c for C&W attack loss, default: 1.0')
    parser.add_argument('--defense', action='store_true', help='specify if use Lucia defense')
    parser.add_argument('--robosac', action='store_true', help='specify if use ROBOSAC as defense')
    parser.add_argument('--pasac', action='store_true', help='specify if use PASAC as defense (recursive binary splitting)')
    parser.add_argument('--pasac_threshold', type=float, default=0.3,
                        help='CCLoss threshold for PASAC defense (default: 0.3)')
    parser.add_argument('--pasac_simple', action='store_true',
                        help='use simple PASAC (individual testing) instead of recursive')
    parser.add_argument('--feature_guard', action='store_true',
                        help='specify if use feature-map consensus defense')
    parser.add_argument('--cps_defense', action='store_true',
                        help='specify if use CPS (Comprehensive Protection Score) defense')
    parser.add_argument('--cps_lambda1', type=float, default=0.3,
                        help='weight for similarity score in CPS')
    parser.add_argument('--cps_lambda2', type=float, default=0.3,
                        help='weight for gradient consistency / SSIM in CPS')
    parser.add_argument('--cps_lambda3', type=float, default=0.4,
                        help='weight for activation energy shift in CPS')
    parser.add_argument('--cps_compute_grad', action='store_true', default=True,
                        help='compute gradient consistency (expensive)')
    parser.add_argument('--use_ssim', action='store_true', default=True,
                        help='use SSIM instead of gradient consistency in CPS')
    parser.add_argument('--use_mdag', action='store_true', default=True
                        help='use MDAG grouping instead of RoboSAC sampling')
    parser.add_argument('--use_dynamic_threshold', action='store_true', default=True,
                        help='use dynamic adaptive threshold')
    parser.add_argument('--threshold_sensitivity', type=float, default=1.0,
                        help='sensitivity for dynamic threshold')
    parser.add_argument('--use_ua_at', action='store_true', default=True,
                        help='use UA-AT (Uncertainty-Aware Adaptive Thresholding) instead of basic adaptive threshold')
    parser.add_argument('--ua_at_gamma', type=float, default=2.5,
                        help='gamma for UA-AT 3-Sigma rule (default: 3.0)')
    parser.add_argument('--ua_at_beta', type=float, default=0.15,
                        help='beta for UA-AT uncertainty weight (default: 0.5)')
    parser.add_argument('--ua_at_method', type=str, default='entropy',
                        choices=['entropy', 'variance'],
                        help='uncertainty calculation method for UA-AT (default: entropy)')
    parser.add_argument('--async_mode', action='store_true', help='specify if use asynchronous communication')
    parser.add_argument('--exclude_attn', action='store_true', help='specify if exclude attention in the SOMBRA attack, default: False')
    opt = parser.parse_args()
    return opt

def main():
    opt = test_parser()
    assert not (opt.show_vis and opt.show_sequence), 'you can only visualize ' \
                                                'the results in single ' \
                                                'image mode or video mode'
    
    # 确定防御方法名称（用于日志文件名）
    if opt.defense:
        defense = 'lucia'
    elif opt.robosac:
        defense = 'robosac'
    elif opt.pasac:
        defense = 'pasac'
    elif opt.feature_guard:
        defense = 'feature_guard'
    elif opt.cps_defense:
        defense = 'cps_defense' # Cerberus
    else:
        defense = 'no_defense'
    
    # 初始化日志系统（需要在其他操作之前初始化，以便后续可以使用logger）
    logger = setup_logging(
        opt.model_dir, opt.attack_mode, opt.loss, defense,
        opt.use_dynamic_threshold, opt.use_ssim
    )
    attacker_agent_index = 1
    feature_map_root = None
    if opt.save_feature_maps:
        feature_dir_parts = [
            f'feature_response_iter{opt.iter}',
            f'lr{opt.lr}',
            opt.attack_mode,
            opt.loss,
            f'defense_{defense}'
        ]
        if opt.cps_defense or opt.defense:
            feature_dir_parts.append(f'threshold_{opt.use_dynamic_threshold}')
            if opt.use_dynamic_threshold:
                feature_dir_parts.append(f'ua_at_method={opt.ua_at_method}')
            feature_dir_parts.append(f'use_ssim={opt.use_ssim}')
        elif opt.pasac:
            feature_dir_parts.append(f'threshold_{opt.pasac_threshold}')
            if opt.pasac_simple:
                feature_dir_parts.append('simple')
        if opt.use_mdag:
            feature_dir_parts.append('mdag')
        feature_dir_name = '_'.join(feature_dir_parts)
        feature_map_root = opt.feature_map_dir if opt.feature_map_dir else os.path.join(opt.model_dir, feature_dir_name)
        os.makedirs(feature_map_root, exist_ok=True)
        logger.info(f"Feature response maps will be saved to: {feature_map_root}")
    
    hypes = yaml_utils.load_yaml(None, opt)
    logger.info('Dataset Building')
    number_stat = {'pred':[], 'gt':[], 'tp':[], 'fp':[], 'threshold':[], 'benign_agents':[], 'malicious_agents':[]}
    
    # Add mu and sigma fields only for CPS defense
    if opt.cps_defense:
        number_stat['mu'] = []
        number_stat['sigma'] = []

    # Specify attack to be targeted or mass object removal
    if opt.attack_mode == 'tor':
        try:
            opt.target_id = int(opt.target_id) # If specified integer ID
        except:
            pass
        hypes['remove_id'] = opt.target_id
        targeted = True
        dataset_worker = 0 # Ensure each frame has a randomized target without error
        remove_success = 0
        number_stat['target_detected'] = []
    elif opt.attack_mode == 'mor':
        targeted = False
        dataset_worker = 16 # Speed up data fetching
        orr_total = []

    # Building specifided loss function
    logger.info('Building adversarial loss function')
    attn_loss = None
    attack_method = 'bim'  # Default attack method
    if opt.loss == 'sombra':
        loss = focalLoss(hypes['loss']['args'], targeted)
        attn_loss = get_adv_loss(opt.model) if not opt.exclude_attn else None
    elif opt.loss == 'pa':
        loss = PALoss(targeted)
    elif opt.loss == 'bim':
        from attack_utils.untargeted_loss import untargetedAttack
        loss = untargetedAttack(hypes['loss']['args'])
        attack_method = 'bim'
        opt.attack_mode = 'bim'
    elif opt.loss == 'pgd':
        from attack_utils.untargeted_loss import PGDAttack
        loss = PGDAttack(hypes['loss']['args'], epsilon=opt.pgd_epsilon, random_start=True)
        attack_method = 'pgd'
        opt.attack_mode = 'pgd'
        logger.info(f'PGD Attack: epsilon={opt.pgd_epsilon}')
    elif opt.loss == 'cw':
        from attack_utils.untargeted_loss import CWAttack
        loss = CWAttack(hypes['loss']['args'], confidence=opt.cw_confidence, c=opt.cw_c)
        attack_method = 'cw'
        opt.attack_mode = 'cw'
        logger.info(f'C&W Attack: confidence={opt.cw_confidence}, c={opt.cw_c}')
    else:
        logger.error('Invalid loss function')
        exit()
    
    defender = cp_defense_lucia if opt.defense else None

    if defense != 'no_defense':
        number_stat['defense_time'] = []
        number_stat['trust_score'] = []
    
    # Initialize threshold calculator for CPS defense with dynamic threshold
    threshold_calculator = None
    if opt.cps_defense and opt.use_dynamic_threshold:
        if opt.use_ua_at:
            # Use UA-AT (Uncertainty-Aware Adaptive Thresholding)
            from defense.mdag_grouping import UncertaintyAwareAdaptiveThreshold
            threshold_calculator = UncertaintyAwareAdaptiveThreshold(
                window_size=10,
                gamma=opt.ua_at_gamma,
                beta=opt.ua_at_beta,
                base_threshold=0.30,
                min_threshold=0.25,
                max_threshold=0.75,
                warmup_frames=3,
                uncertainty_method=opt.ua_at_method
            )
            logger.info(f"[CPS Defense] Initialized UncertaintyAwareAdaptiveThreshold (UA-AT)")
            logger.info(f"[CPS Defense] UA-AT params: gamma={opt.ua_at_gamma}, beta={opt.ua_at_beta}, method={opt.ua_at_method}")
        else:
            # Use original AdaptiveThresholdCalculator
            from defense.mdag_grouping import AdaptiveThresholdCalculator
            threshold_calculator = AdaptiveThresholdCalculator(
                window_size=10, 
                k=2.5, 
                base_threshold=0.68,
                max_sigma_ratio=2.3
            )
            logger.info(f"[CPS Defense] Initialized AdaptiveThresholdCalculator with window_size=10")
        

    # Load dataset
    if opt.data_dir != "":
        hypes['validate_dir'] = opt.data_dir
    opencood_dataset = build_dataset(hypes, visualize=True, train=False)
    data_len = len(opencood_dataset)
    logger.info(f"{data_len} samples found.")
    data_loader = DataLoader(opencood_dataset,
                            batch_size=1,
                            num_workers=dataset_worker,
                            collate_fn=opencood_dataset.collate_batch_test,
                            shuffle=False,
                            pin_memory=False,
                            drop_last=False)
    
    # Simulate asynchronous communication
    if opt.async_mode:
        hypes.update({'wild_setting':{
            'async': True,
            'async_mode': 'sim',
            'async_overhead': 100,
            'seed': 2025,
            'backbone_delay': 10,
            'data_size': 1.06,
            'loc_err': False,
            'ryp_std': 0.0,
            'transmission_speed': 27,
            'xyz_std': 0.0,
        }})
        opencood_dataset_async = build_dataset(hypes, visualize=True, train=False)
        data_loader_async = DataLoader(opencood_dataset_async,
                                batch_size=1,
                                num_workers=dataset_worker,
                                collate_fn=opencood_dataset.collate_batch_test,
                                shuffle=False,
                                pin_memory=False,
                                drop_last=False)
        async_dat_iter = iter(data_loader_async)

    logger.info('Creating Model')
    model = train_utils.create_model(hypes)
    # we assume gpu is necessary
    if torch.cuda.is_available():
        model.cuda()
        logger.info(f'✓ Using GPU: {torch.cuda.get_device_name(0)}')
        logger.info(f'  GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB')
    else:
        logger.warning('⚠ WARNING: CUDA not available, running on CPU (this will be very slow!)')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f'Device: {device}')

    logger.info('Loading Model from checkpoint')
    saved_path = opt.model_dir
    _, model = train_utils.load_saved_model(saved_path, model)
    model.eval()

    # Create the dictionary for evaluation
    result_stat = {0.3: {'tp': [], 'fp': [], 'gt': 0},
                   0.5: {'tp': [], 'fp': [], 'gt': 0},
                   0.7: {'tp': [], 'fp': [], 'gt': 0}}

    # For visualization purposes
    video_writer_defense = None  # Video after defense (with real-time display)
    video_writer_attack = None   # Video after attack (offscreen rendering only)
    video_filename_attack = None
    video_filename_defense = None
    temp_frame_path = None
    
    if opt.show_sequence or opt.save_video:
        # Only show real-time window if show_sequence is enabled
        # save_video will use offscreen rendering
        if opt.show_sequence:
            vis = o3d.visualization.Visualizer()
            # Set window size to match video output size for better framing
            vis.create_window(width=1920, height=1080)

            vis.get_render_option().background_color = [0.05, 0.05, 0.05]
            vis.get_render_option().point_size = 1.0
            vis.get_render_option().line_width = 50.0  # Much thicker lines for better visibility
            vis.get_render_option().show_coordinate_frame = True

            # used to visualize lidar points
            vis_pcd = o3d.geometry.PointCloud()
            # used to visualize object bounding box, maximum 50
            vis_aabbs_gt = []
            vis_aabbs_pred = []
            for _ in range(50):
                vis_aabbs_gt.append(o3d.geometry.LineSet())
                vis_aabbs_pred.append(o3d.geometry.LineSet())
        
        # Setup video writers if saving video
        if opt.save_video:
            # Video settings
            fps = 10  # Frames per second
            frame_width = 1920
            frame_height = 1080
            # Use mp4v codec which is more compatible for mp4 format on Linux/WSL than avc1
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            
            # 1. Video for AFTER ATTACK (offscreen rendering, no real-time display)
            video_name_attack_parts = [
                f'video_after_attack_iter{opt.iter}',
                f'lr{opt.lr}',
                f'{opt.attack_mode}',
                f'{opt.loss}',
                f'defense_{defense}'
            ]
            if opt.cps_defense or opt.defense:
                video_name_attack_parts.append(f'thresh_{opt.use_dynamic_threshold}')
                if opt.use_dynamic_threshold:
                    video_name_attack_parts.append(f'ua_{opt.ua_at_method}')
                video_name_attack_parts.append(f'ssim_{opt.use_ssim}')
            if opt.use_mdag:
                video_name_attack_parts.append('mdag')
            
            video_filename_attack = os.path.join(
                opt.model_dir,
                '_'.join(video_name_attack_parts) + '.mp4'
            )
            # Ensure output directory exists (it should, but just in case)
            if not os.path.exists(opt.model_dir):
                os.makedirs(opt.model_dir)
                
            video_writer_attack = cv2.VideoWriter(video_filename_attack, fourcc, fps, (frame_width, frame_height))
            if not video_writer_attack.isOpened():
                logger.error(f"Failed to open video writer for: {video_filename_attack}. Try changing codec or path.")
            else:
                logger.info(f"Video (after attack) will be saved to: {video_filename_attack}")
            
            # 2. Video for AFTER DEFENSE (with real-time display)
            video_name_defense_parts = [
                f'video_after_defense_iter{opt.iter}',
                f'lr{opt.lr}',
                f'{opt.attack_mode}',
                f'{opt.loss}',
                f'defense_{defense}'
            ]
            if opt.cps_defense or opt.defense:
                video_name_defense_parts.append(f'thresh_{opt.use_dynamic_threshold}')
                if opt.use_dynamic_threshold:
                    video_name_defense_parts.append(f'ua_{opt.ua_at_method}')
                video_name_defense_parts.append(f'ssim_{opt.use_ssim}')
            if opt.use_mdag:
                video_name_defense_parts.append('mdag')
            
            video_filename_defense = os.path.join(
                opt.model_dir,
                '_'.join(video_name_defense_parts) + '.mp4'
            )
            video_writer_defense = cv2.VideoWriter(video_filename_defense, fourcc, fps, (frame_width, frame_height))
            if not video_writer_defense.isOpened():
                logger.error(f"Failed to open video writer for: {video_filename_defense}")
            else:
                logger.info(f"Video (after defense) will be saved to: {video_filename_defense}")
            
            # Temporary path for capturing frames
            temp_frame_path = os.path.join(opt.model_dir, 'temp_frame.png')
    
    # Initialize t-SNE data collection lists
    tsne_benign_features = []
    tsne_attacked_features = []
    
    # Begin evaluation
    progress_bar = tqdm(enumerate(data_loader),
                        total=data_len,
                        desc='Frames',
                        dynamic_ncols=True,
                        smoothing=0.01)
    for i, batch_data in progress_bar:
        # Simulate asynchronous communication
        if opt.async_mode:
            async_data = next(async_dat_iter)
            
        else:
            async_data = batch_data
            async_data = train_utils.to_device(async_data, device)

        if batch_data['ego']['cav_num'] < 2 or i < opt.skip: # A very limited number of datapoints only have 1 CAV
            continue

        # 输出车辆数量
        cav_num = batch_data['ego']['cav_num']
        logger.info(f"Frame {i}: 车辆数量 (cav_num) = {cav_num}")
        
        # Initialize save path variables for this frame
        vis_save_path_after_attack = None
        vis_save_path = None

        batch_data = train_utils.to_device(batch_data, device)
        async_data = train_utils.to_device(async_data, device)
        gt_box_tensor = opencood_dataset.post_processor.generate_gt_bbx(batch_data)

        if opt.attack_mode == 'tor':
            target_bbox_tensor = opencood_dataset.post_processor.generate_gt_bbx(batch_data, \
                                                                                 selected_id=[opencood_dataset.actual_remove_id])
            batch_data['ego']['label_dict'] = batch_data['ego']['target_label_dict']
            batch_data['ego']['object_bbx_center'] = batch_data['ego']['target_object_bbox']
            if opt.async_mode:
                async_data['ego']['label_dict'] = async_data['ego']['target_label_dict']
                async_data['ego']['object_bbx_center'] = async_data['ego']['target_object_bbox']
        elif opt.attack_mode == 'mor':
            if opt.loss == 'sombra':
                attack_utils.get_empty_target(batch_data, opencood_dataset.post_processor)
                if opt.async_mode:
                    attack_utils.get_empty_target(async_data, opencood_dataset.post_processor)

        # =====================================================================
        # STEP 1: Save "No Attack" baseline visualization (before any attack)
        # =====================================================================
        if opt.save_vis:
            # Perform clean inference (no attack) to get baseline results
            clean_pred_box_tensor, clean_pred_score, clean_gt_box_tensor = \
                inference_utils.inference_intermediate_fusion(batch_data, model, opencood_dataset)
            
            # Build directory name for no-attack baseline
            vis_dir_name_no_attack = f'vis_no_attack_baseline_{opt.attack_mode}_{opt.loss}'
            vis_save_path_no_attack = os.path.join(opt.model_dir, vis_dir_name_no_attack)
            if not os.path.exists(vis_save_path_no_attack):
                os.makedirs(vis_save_path_no_attack)
            vis_save_path_no_attack = os.path.join(vis_save_path_no_attack, '%05d.png' % i)
            
            # Only save if the file doesn't already exist (avoid redundant saves across runs)
            if not os.path.exists(vis_save_path_no_attack):
                # Handle None tensors
                clean_pred_for_vis = clean_pred_box_tensor if clean_pred_box_tensor is not None else torch.empty((0, 8, 3))
                clean_gt_for_vis = clean_gt_box_tensor if clean_gt_box_tensor is not None else gt_box_tensor
                if clean_gt_for_vis is None:
                    clean_gt_for_vis = torch.empty((0, 8, 3))
                
                opencood_dataset.visualize_result(
                    clean_pred_for_vis,
                    clean_gt_for_vis,
                    batch_data['ego']['origin_lidar'],
                    False,  # Don't show, just save
                    vis_save_path_no_attack,
                    dataset=opencood_dataset,
                    target_tensor=None
                )
                logger.info(f"[Frame {i}] Saved no-attack baseline visualization")

        # =====================================================================
        # STEP 2: Perform attack
        # =====================================================================
        (pred_box_tensor,
         pred_score,
         _,
         attacker_feature_adv,
         trust_scores,
         defense_time,
         raw_agent_features) = attack_utils.inference_intermediate_fusion_attack(
            batch_data,
            model,
            opencood_dataset, attn_loss, opt.iter, opt.lr, criterion=loss,
            cav_max=2, defender=defender, attacker_index=attacker_agent_index,
            attack_method=attack_method
        ) #cav_max=2 to limit attacker knowledge to only victim and itself
        
        if opt.save_feature_maps and raw_agent_features is not None and feature_map_root is not None:
            upsample_factor = max(1, opt.feature_map_scale)
            frame_prefix = os.path.join(feature_map_root, f'{i:05d}')
            agent_count = raw_agent_features.shape[0]
            for agent_idx in range(agent_count):
                if agent_idx == 0:
                    agent_label = 'ego'
                elif agent_idx == attacker_agent_index:
                    agent_label = 'attacker'
                else:
                    agent_label = f'agent{agent_idx}'
                save_path = f'{frame_prefix}_{agent_label}_pre_attack.png'
                vis_utils.save_feature_response_map(
                    raw_agent_features[agent_idx],
                    save_path,
                    reduction=opt.feature_map_reduction,
                    cmap=opt.feature_map_cmap,
                    upsample_factor=upsample_factor,
                    title=f'{agent_label} pre-attack'
                )
            if attacker_feature_adv is not None and attacker_agent_index < agent_count:
                vis_utils.save_feature_response_map(
                    attacker_feature_adv,
                    f'{frame_prefix}_attacker_post_attack.png',
                    reduction=opt.feature_map_reduction,
                    cmap=opt.feature_map_cmap,
                    upsample_factor=upsample_factor,
                    title='attacker post-attack'
                )
                delta_tensor = attacker_feature_adv - raw_agent_features[attacker_agent_index]
                vis_utils.save_feature_response_map(
                    delta_tensor,
                    f'{frame_prefix}_attacker_delta.png',
                    reduction='mean',
                    cmap='seismic',
                    upsample_factor=upsample_factor,
                    title='attacker delta',
                    symmetric=True
                )
                
                # --- New: Visualize Fusion (Average) Effects ---
                # 1. Clean Fused: Mean of raw features
                clean_fused = raw_agent_features.mean(dim=0)
                vis_utils.save_feature_response_map(
                    clean_fused,
                    f'{frame_prefix}_fused_clean.png',
                    reduction=opt.feature_map_reduction,
                    cmap=opt.feature_map_cmap,
                    upsample_factor=upsample_factor,
                    title='Fused (Clean)'
                )
                
                # 2. Attacked Fused (No Defense): Mean of Ego + Attacker_Adv + Others
                # Replace attacker feature in the set with the adversarial one
                attacked_features = raw_agent_features.clone()
                attacked_features[attacker_agent_index] = attacker_feature_adv
                attacked_fused = attacked_features.mean(dim=0)
                vis_utils.save_feature_response_map(
                    attacked_fused,
                    f'{frame_prefix}_fused_attacked_no_defense.png',
                    reduction=opt.feature_map_reduction,
                    cmap=opt.feature_map_cmap,
                    upsample_factor=upsample_factor,
                    title='Fused (Attacked, No Defense)'
                )
                
                # 3. Defended Fused (Approximation using trust_scores if available)
                if trust_scores is not None and (opt.cps_defense or opt.defense):
                    try:
                        # Handle different formats of trust_scores
                        if isinstance(trust_scores, list):
                            weights = torch.tensor(trust_scores, device=attacked_features.device, dtype=torch.float32)
                        elif isinstance(trust_scores, torch.Tensor):
                            weights = trust_scores.to(attacked_features.device).float()
                        else:
                            weights = None
                        
                        if weights is not None:
                            # Reshape weights to [N, 1, 1] for broadcasting
                            if weights.dim() == 1:
                                weights = weights.view(-1, 1, 1)
                            
                            # Normalize weights to sum to 1 for visualization consistency?
                            # Or just weighted average. If sum is small, it might look dark, which is correct (suppression).
                            
                            defended_fused = (attacked_features * weights).sum(dim=0) / (weights.sum(dim=0) + 1e-6)
                            
                            vis_utils.save_feature_response_map(
                                defended_fused,
                                f'{frame_prefix}_fused_defended.png',
                                reduction=opt.feature_map_reduction,
                                cmap=opt.feature_map_cmap,
                                upsample_factor=upsample_factor,
                                title='Fused (Defended)'
                            )
                    except Exception as e:
                        print(f"Warning: Failed to visualize defended fusion: {e}")
                # -----------------------------------------------
        
        # Collect t-SNE data if requested
        if opt.save_tsne and raw_agent_features is not None:
            # Collect benign features (all agents except attacker)
            for agent_idx in range(raw_agent_features.shape[0]):
                if agent_idx != attacker_agent_index:
                    # Apply Global Average Pooling (GAP) to get a feature vector
                    feature_vector = raw_agent_features[agent_idx].mean(dim=(1, 2)).detach().cpu().numpy()  # (C,)
                    tsne_benign_features.append(feature_vector)
            
            # Collect attacked feature
            if attacker_feature_adv is not None:
                # Apply GAP to attacked feature
                attacked_vector = attacker_feature_adv.mean(dim=(1, 2)).detach().cpu().numpy()  # (C,)
                tsne_attacked_features.append(attacked_vector)
        
        # 保存攻击后、防御前的可视化（用于诊断）
        if opt.save_vis:
            # Build directory name with defense method - simplified and always include defense name
            vis_dir_name_attack = f'vis_after_attack_iter{opt.iter}_lr{opt.lr}_{opt.attack_mode}_{opt.loss}_defense_{defense}'
            
            # Add defense-specific details
            if opt.cps_defense or opt.defense:
                vis_dir_name_attack += f'_threshold_{opt.use_dynamic_threshold}'
                if opt.use_dynamic_threshold:
                    vis_dir_name_attack += f'_ua_at_method={opt.ua_at_method}'
                vis_dir_name_attack += f'_use_ssim={opt.use_ssim}'
            
            if opt.use_mdag:
                vis_dir_name_attack += '_mdag'
            
            vis_save_path_after_attack = os.path.join(opt.model_dir, vis_dir_name_attack)
            if not os.path.exists(vis_save_path_after_attack):
                os.makedirs(vis_save_path_after_attack)
            vis_save_path_after_attack = os.path.join(vis_save_path_after_attack, '%05d.png' % i)
            
            # 保存攻击后的结果 - 即使pred或gt为空也保存
            logger.info(f"[Frame {i}] Saving after attack - pred: {len(pred_box_tensor) if pred_box_tensor is not None else 0}, gt: {len(gt_box_tensor) if gt_box_tensor is not None else 0}")
            
            # 如果pred_box_tensor为None，创建一个空的tensor
            if pred_box_tensor is None:
                pred_box_tensor = torch.empty((0, 8, 3))
            if gt_box_tensor is None:
                gt_box_tensor = torch.empty((0, 8, 3))
            
            opencood_dataset.visualize_result(pred_box_tensor,
                                            gt_box_tensor,
                                            batch_data['ego']['origin_lidar'],
                                            False,  # 不显示，只保存
                                            vis_save_path_after_attack,
                                            dataset=opencood_dataset,
                                            target_tensor=None)
        
        # Capture frame for "after attack" video
        if opt.save_video and video_writer_attack is not None:
            # If save_vis is enabled, reuse the saved image instead of re-rendering
            if opt.save_vis and vis_save_path_after_attack and os.path.exists(vis_save_path_after_attack):
                # Reuse the image saved by save_vis
                frame_attack = cv2.imread(vis_save_path_after_attack)
                if frame_attack is not None:
                    if frame_attack.shape[1] != 1920 or frame_attack.shape[0] != 1080:
                        frame_attack = cv2.resize(frame_attack, (1920, 1080))
                    video_writer_attack.write(frame_attack)
                    if i % 50 == 0:
                        logger.info(f"[Frame {i}] Attack video frame written (reused save_vis), shape: {frame_attack.shape}")
            else:
                # Need to render separately for video
                temp_attack_frame = os.path.join(opt.model_dir, f'temp_attack_frame_{i}.png')
                
                pred_tensor_attack = pred_box_tensor if pred_box_tensor is not None else torch.empty((0, 8, 3))
                gt_tensor_attack = gt_box_tensor if gt_box_tensor is not None else torch.empty((0, 8, 3))
                
                opencood_dataset.visualize_result(pred_tensor_attack,
                                                gt_tensor_attack,
                                                batch_data['ego']['origin_lidar'],
                                                False,
                                                temp_attack_frame,
                                                dataset=opencood_dataset,
                                                target_tensor=None)
                
                frame_attack = cv2.imread(temp_attack_frame)
                if frame_attack is not None:
                    if frame_attack.shape[1] != 1920 or frame_attack.shape[0] != 1080:
                        frame_attack = cv2.resize(frame_attack, (1920, 1080))
                    video_writer_attack.write(frame_attack)
                    if i % 50 == 0:
                        logger.info(f"[Frame {i}] Attack video frame written, shape: {frame_attack.shape}")
                else:
                    logger.warning(f"[Frame {i}] Failed to read attack frame from: {temp_attack_frame}")
                    
                # Clean up temporary frame file
                if os.path.exists(temp_attack_frame):
                    os.remove(temp_attack_frame)
        
        # 处理防御方法（互斥的）
        if opt.defense:
            # LUCIA defense - 在 inference_intermediate_fusion_attack 中已经应用
            number_stat['trust_score'].append(str(trust_scores.tolist()))
            number_stat['defense_time'].append(defense_time)
            # LUCIA defense 没有阈值和车辆信息
            number_stat['threshold'].append(None)
            number_stat['benign_agents'].append(None)
            number_stat['malicious_agents'].append(None)
        elif opt.robosac:
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            start = time.time()
            pred_box_tensor, pred_score, gt_boxes_tensor = robosac(
                batch_data, model, opencood_dataset, attacker_feature_adv,
                attacker_idx=attacker_agent_index, sampling_budget=10
            )
            torch.cuda.synchronize()
            defense_time = time.time() - start
            number_stat['defense_time'].append(defense_time)
            number_stat['trust_score'].append(None)  # ROBOSAC doesn't provide trust score
            # ROBOSAC 没有阈值和车辆信息
            number_stat['threshold'].append(None)
            number_stat['benign_agents'].append(None)
            number_stat['malicious_agents'].append(None)
        elif opt.pasac:
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            start = time.time()
            if opt.pasac_simple:
                # Simple PASAC: individual agent testing
                pred_box_tensor, pred_score, gt_boxes_tensor, pasac_info = pasac_simple(
                    batch_data, model, opencood_dataset, attacker_feature_adv, 
                    attacker_idx=attacker_agent_index, consensus_threshold=opt.pasac_threshold
                )
            else:
                # Recursive PASAC: binary splitting (original algorithm from CP-Guard)
                pred_box_tensor, pred_score, gt_boxes_tensor, pasac_info = pasac(
                    batch_data, model, opencood_dataset, attacker_feature_adv, 
                    attacker_idx=attacker_agent_index, threshold=opt.pasac_threshold
                )
            torch.cuda.synchronize()
            defense_time = time.time() - start
            number_stat['defense_time'].append(defense_time)
            number_stat['trust_score'].append(str(pasac_info.get('agent_scores', {})))
            # PASAC 提供阈值和车辆信息
            number_stat['threshold'].append(pasac_info.get('threshold'))
            number_stat['benign_agents'].append(str(pasac_info.get('benign_agents', [])))
            number_stat['malicious_agents'].append(str(pasac_info.get('malicious_agents', [])))
        elif opt.feature_guard:
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            start = time.time()
            pred_box_tensor, pred_score, _, feature_score = feature_guard(
                batch_data, model, opencood_dataset, attacker_feature_adv,
                attacker_idx=attacker_agent_index, sampling_budget=10
            )
            torch.cuda.synchronize()
            defense_time = time.time() - start
            number_stat['defense_time'].append(defense_time)
            number_stat['trust_score'].append(feature_score)
            # Feature Guard 没有阈值和车辆信息
            number_stat['threshold'].append(None)
            number_stat['benign_agents'].append(None)
            number_stat['malicious_agents'].append(None)
        elif opt.cps_defense:
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            start = time.time()
            pred_box_tensor, pred_score, _, cps_score, defense_info = cps_defense(
                batch_data, model, opencood_dataset, attacker_feature_adv, 
                attacker_idx=attacker_agent_index, sampling_budget=10,
                lambda1=opt.cps_lambda1, lambda2=opt.cps_lambda2, lambda3=opt.cps_lambda3,
                compute_gradients=opt.cps_compute_grad,
                use_ssim=opt.use_ssim,
                use_mdag=opt.use_mdag,  # 新增
                use_dynamic_threshold=opt.use_dynamic_threshold,  # 新增
                threshold_sensitivity=opt.threshold_sensitivity,  # 新增
                threshold_calculator=threshold_calculator  # 新增
            )
            torch.cuda.synchronize()
            defense_time = time.time() - start
            number_stat['defense_time'].append(defense_time)
            number_stat['trust_score'].append(cps_score)
            # 保存阈值和车辆信息
            number_stat['threshold'].append(defense_info['threshold'])
            number_stat['benign_agents'].append(str(defense_info['benign_agents']))
            number_stat['malicious_agents'].append(str(defense_info['malicious_agents']))
            # 保存均值和标准差（mu和sigma在defense_info中，如果使用动态阈值则有值，否则为None）
            number_stat['mu'].append(defense_info.get('mu'))
            number_stat['sigma'].append(defense_info.get('sigma'))
        else:
            # 没有使用任何防御
            if 'threshold' in number_stat:
                number_stat['threshold'].append(None)
                number_stat['benign_agents'].append(None)
                number_stat['malicious_agents'].append(None)

        eval_utils.caluclate_tp_fp(pred_box_tensor,
                                    pred_score,
                                    gt_box_tensor,
                                    result_stat,
                                    0.3)
        tp, fp, gt = eval_utils.caluclate_tp_fp(pred_box_tensor,
                                    pred_score,
                                    gt_box_tensor,
                                    result_stat,
                                    0.5)
        eval_utils.caluclate_tp_fp(pred_box_tensor,
                                    pred_score,
                                    gt_box_tensor,
                                    result_stat,
                                    0.7)
        pred_num = 0 if pred_box_tensor is None else len(pred_box_tensor)
        gt_num = 0 if gt_box_tensor is None else len(gt_box_tensor)

        number_stat['pred'].append(int(pred_num))
        number_stat['gt'].append(int(gt_num))
        number_stat['tp'].append(int(tp))
        number_stat['fp'].append(int(fp))
        
        # 记录车辆数量
        if 'cav_num' not in number_stat:
            number_stat['cav_num'] = []
        number_stat['cav_num'].append(int(cav_num))

        if targeted:
            num_detected = eval_utils.caluclate_tp_fp(pred_box_tensor, pred_score, target_bbox_tensor, result_stat, 0, write=False)
            number_stat['target_detected'].append(num_detected)
            if num_detected < 1:
                remove_success += 1
            progress_bar.set_postfix({'TargetDetected': int(num_detected), 'FP': int(fp)})
        else:
            denom = gt_num if gt_num > 0 else 1
            orr = 1 - min(pred_num, gt_num) / denom
            orr_total.append(orr)
            progress_bar.set_postfix({'ORR': f'{orr:.3f}'})

        if opt.save_npy:
            npy_save_path = os.path.join(
                opt.model_dir,
                f'npy_iter{opt.iter}_lr{opt.lr}_{opt.attack_mode}_{defense}_{opt.loss}_use_ssim={opt.use_ssim}'
            )
            if not os.path.exists(npy_save_path):
                os.makedirs(npy_save_path)
            inference_utils.save_prediction_gt(pred_box_tensor,
                                                gt_box_tensor,
                                                batch_data['ego'][
                                                    'origin_lidar'][0],
                                                i,
                                                npy_save_path,
                                                pred_score)
        if opt.save_perturb:
            pertub_save_path = os.path.join(
                opt.model_dir,
                f'perturb_iter{opt.iter}_lr{opt.lr}_{opt.attack_mode}_{defense}_{opt.loss}_use_ssim={opt.use_ssim}'
            )
            if not os.path.exists(pertub_save_path):
                os.makedirs(pertub_save_path)
            if attacker_feature_adv is not None:
                perturbation_np = torch_tensor_to_numpy(attacker_feature_adv)
                np.save(os.path.join(pertub_save_path, '%04d_perturb.npy' % i), perturbation_np)

        if opt.show_vis or opt.save_vis:
            vis_save_path = ''
            if opt.save_vis:
                # Build directory name with defense method - simplified and always include defense name
                vis_dir_name = f'vis_after_defense_iter{opt.iter}_lr{opt.lr}_{opt.attack_mode}_{opt.loss}_defense_{defense}'
                
                # Add defense-specific details
                if opt.cps_defense or opt.defense:
                    vis_dir_name += f'_threshold_{opt.use_dynamic_threshold}'
                    if opt.use_dynamic_threshold:
                        vis_dir_name += f'_ua_at_method={opt.ua_at_method}'
                    vis_dir_name += f'_use_ssim={opt.use_ssim}'
                
                if opt.use_mdag:
                    vis_dir_name += '_mdag'
                
                vis_save_path = os.path.join(opt.model_dir, vis_dir_name)
                if not os.path.exists(vis_save_path):
                    os.makedirs(vis_save_path)
                vis_save_path = os.path.join(vis_save_path, '%05d.png' % i)

            # 添加调试信息
            pred_shape = pred_box_tensor.shape if pred_box_tensor is not None else None
            gt_shape = gt_box_tensor.shape if gt_box_tensor is not None else None
            logger.info(f"[Frame {i}] After defense - pred_shape: {pred_shape}, gt_shape: {gt_shape}, pred_num: {pred_num}, gt_num: {gt_num}")
            
            if pred_box_tensor is not None and gt_box_tensor is not None:
                # 检查是否为空
                if len(pred_box_tensor) == 0:
                    logger.info(f"[Frame {i}] pred_box_tensor is empty! No detections to visualize.")
                if len(gt_box_tensor) == 0:
                    logger.info(f"[Frame {i}] gt_box_tensor is empty! No ground truth to visualize.")
                
                opencood_dataset.visualize_result(pred_box_tensor,
                                                gt_box_tensor,
                                                batch_data['ego'][
                                                    'origin_lidar'],
                                                opt.show_vis,
                                                vis_save_path,
                                                dataset=opencood_dataset,
                                                target_tensor=None)
            else:
                logger.info(f"[Frame {i}] Skipping visualization - pred_box_tensor or gt_box_tensor is None")
        # Real-time visualization with show_sequence
        if opt.show_sequence:
            pcd, pred_o3d_box, gt_o3d_box = \
                vis_utils.visualize_inference_sample_dataloader(
                    pred_box_tensor,
                    gt_box_tensor,
                    batch_data['ego']['origin_lidar'],
                    vis_pcd,
                    mode='constant'
                    )
            if i == 0:
                vis.add_geometry(pcd)
                if pred_o3d_box is not None:
                    vis_utils.linset_assign_list(vis,
                                                vis_aabbs_pred,
                                                pred_o3d_box,
                                                update_mode='add')

                vis_utils.linset_assign_list(vis,
                                                vis_aabbs_gt,
                                                gt_o3d_box,
                                                update_mode='add')
                
                # Set camera view to fill the frame on first iteration
                vis.poll_events()
                vis.update_renderer()
                
                # Calculate center of all geometries
                all_points = []
                pcd_points = np.asarray(pcd.points)
                if len(pcd_points) > 0:
                    all_points.append(pcd_points)
                
                # Get points from all bounding boxes
                if pred_o3d_box is not None:
                    for box in pred_o3d_box:
                        box_points = np.asarray(box.points)
                        if len(box_points) > 0:
                            all_points.append(box_points)
                
                for box in gt_o3d_box:
                    box_points = np.asarray(box.points)
                    if len(box_points) > 0:
                        all_points.append(box_points)
                
                # Calculate center
                if len(all_points) > 0:
                    all_points_combined = np.vstack(all_points)
                    center = np.mean(all_points_combined, axis=0)
                else:
                    center = np.array([0, 0, 0])
                
                ctr = vis.get_view_control()
                # Set a bird's eye view with slight angle for better depth perception
                ctr.set_zoom(0.2)  # Much larger zoom (smaller value = more zoom in)
                ctr.set_lookat(center.tolist())  # Look at center of all geometries
                ctr.set_front([0.3, 0.3, -0.9])  # View from above with slight angle for depth
                ctr.set_up([0, 1, 0])  # Y-axis up

            if pred_o3d_box is not None:
                vis_utils.linset_assign_list(vis,
                                            vis_aabbs_pred,
                                            pred_o3d_box)
            vis_utils.linset_assign_list(vis,
                                            vis_aabbs_gt,
                                            gt_o3d_box)
            vis.update_geometry(pcd)
            vis.poll_events()
            vis.update_renderer()
            
            # Periodic memory cleanup to prevent segfault
            if i % 50 == 0:
                import gc
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
            
            time.sleep(0.001)
        
        # Save defense video
        if opt.save_video and video_writer_defense is not None:
            # If save_vis is enabled and image exists, reuse it instead of re-rendering
            if opt.save_vis and vis_save_path and os.path.exists(vis_save_path):
                # Reuse the image saved by save_vis
                frame_defense = cv2.imread(vis_save_path)
                if frame_defense is not None:
                    if frame_defense.shape[1] != 1920 or frame_defense.shape[0] != 1080:
                        frame_defense = cv2.resize(frame_defense, (1920, 1080))
                    video_writer_defense.write(frame_defense)
                    if i % 50 == 0:
                        logger.info(f"[Frame {i}] Defense video frame written (reused save_vis), shape: {frame_defense.shape}")
            else:
                # Need to render separately for video
                temp_defense_frame = os.path.join(opt.model_dir, f'temp_defense_frame_{i}.png')
                
                pred_tensor_defense = pred_box_tensor if pred_box_tensor is not None else torch.empty((0, 8, 3))
                gt_tensor_defense = gt_box_tensor if gt_box_tensor is not None else torch.empty((0, 8, 3))
                
                opencood_dataset.visualize_result(pred_tensor_defense,
                                                gt_tensor_defense,
                                                batch_data['ego']['origin_lidar'],
                                                False,
                                                temp_defense_frame,
                                                dataset=opencood_dataset,
                                                target_tensor=None)
                
                frame_defense = cv2.imread(temp_defense_frame)
                if frame_defense is not None:
                    if frame_defense.shape[1] != 1920 or frame_defense.shape[0] != 1080:
                        frame_defense = cv2.resize(frame_defense, (1920, 1080))
                    video_writer_defense.write(frame_defense)
                    if i % 50 == 0:
                        logger.info(f"[Frame {i}] Defense video frame written, shape: {frame_defense.shape}")
                else:
                    logger.warning(f"[Frame {i}] Failed to read defense frame from: {temp_defense_frame}")
                
                # Clean up temporary frame file
                if os.path.exists(temp_defense_frame):
                    os.remove(temp_defense_frame)
            
            # Force garbage collection periodically
            if i % 50 == 0:
                import gc
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

            
        
    # Construct evaluation filename suffix
    eval_suffix_parts = [opt.attack_mode]
    
    if opt.attack_mode == 'tor':
        eval_suffix_parts.append(f"_{opt.target_id}")
        
    # Add PGD/CW specific parameters
    if opt.loss == 'pgd':
        eval_suffix_parts.append(f"_eps{opt.pgd_epsilon}")
    elif opt.loss == 'cw':
        eval_suffix_parts.append(f"_conf{opt.cw_confidence}_c{opt.cw_c}")
        
    eval_suffix_parts.append(f"_defense_{defense}")
    
    # Add defense-specific parameters
    if opt.cps_defense or opt.defense:
        # CPS Defense and LUCIA use threshold, ua_at_method, use_ssim
        eval_suffix_parts.append(f"_threshold_{opt.use_dynamic_threshold}")
        eval_suffix_parts.append(f",ua_at_method={opt.ua_at_method},use_ssim={opt.use_ssim}")
    elif opt.pasac:
        # PASAC uses pasac_threshold and optionally pasac_simple flag
        eval_suffix_parts.append(f"_threshold_{opt.pasac_threshold}")
        if opt.pasac_simple:
            eval_suffix_parts.append(",simple=True")
    elif opt.robosac:
        # ROBOSAC doesn't have specific parameters in the current implementation
        pass
    elif opt.feature_guard:
        # Feature Guard doesn't have specific parameters in the current implementation
        pass
    
    eval_suffix = "".join(eval_suffix_parts)

    eval_utils.eval_final_results(
        result_stat,
        opt.model_dir,
        opt.loss,
        eval_suffix,
        iter=opt.iter,
        lr=opt.lr
    )

    # Ensure all lists have the same length before creating DataFrame
    base_length = len(number_stat['pred'])
    for key in number_stat:
        if len(number_stat[key]) != base_length:
            logger.info(f"Warning: {key} has length {len(number_stat[key])}, expected {base_length}. Padding with None.")
            # Pad with None to match base_length
            while len(number_stat[key]) < base_length:
                number_stat[key].append(None)
    
    stat_pd = pd.DataFrame(number_stat)
    total_rows = stat_pd.shape[0]

    # ------------------------------------------------------------------
    # 计算防御时间的平均值，并单独保存（区分是否使用SSIM）
    # ------------------------------------------------------------------
    if 'defense_time' in number_stat:
        valid_def_times = [t for t in number_stat['defense_time'] if t is not None]
        avg_defense_time = float(np.mean(valid_def_times)) if len(valid_def_times) > 0 else None

        # 保存到单独的文件，文件名包含 use_ssim 标记
        # 基础部分：defense_time_summary_{attack_mode}
        avg_time_filename_parts = [
            f"defense_time_summary_{opt.attack_mode}"
        ]
        
        # 如果是 TOR 模式，必须加入 target_id (random/in/out) 以防覆盖
        if opt.attack_mode == 'tor':
            avg_time_filename_parts.append(f"_visibility_{opt.target_id}")
            
        avg_time_filename_parts.append(f"_defense_{defense}")
        
        # Add PGD/CW specific parameters to filename
        if opt.loss == 'pgd':
            avg_time_filename_parts.append(f'_eps{opt.pgd_epsilon}')
        elif opt.loss == 'cw':
            avg_time_filename_parts.append(f'_conf{opt.cw_confidence}_c{opt.cw_c}')
        
        # Add defense-specific parameters
        if opt.cps_defense or opt.defense:
            avg_time_filename_parts.append(f"_threshold_{opt.use_dynamic_threshold}")
            avg_time_filename_parts.append(f"_ua_at_method={opt.ua_at_method}_use_ssim={opt.use_ssim}.txt")
        elif opt.pasac:
            avg_time_filename_parts.append(f"_threshold_{opt.pasac_threshold}")
            if opt.pasac_simple:
                avg_time_filename_parts.append("_simple.txt")
            else:
                avg_time_filename_parts.append(".txt")
        else:
            avg_time_filename_parts.append(".txt")
        
        avg_time_filename = os.path.join(opt.model_dir, "".join(avg_time_filename_parts))
        with open(avg_time_filename, "w") as f:
            # Write defense-specific parameters
            if opt.cps_defense or opt.defense:
                f.write(f"use_ssim: {opt.use_ssim}\n")
            elif opt.pasac:
                f.write(f"pasac_threshold: {opt.pasac_threshold}\n")
                if opt.pasac_simple:
                    f.write(f"pasac_simple: True\n")
            f.write(f"frames: {len(valid_def_times)}\n")
            f.write(f"avg_defense_time_per_frame: {avg_defense_time}\n")
        logger.info(f"Saved average defense time to {avg_time_filename}")
    if not targeted:
        result_name_parts = [
            f'{opt.attack_mode}_result_iter{opt.iter}_lr{opt.lr}_loss{opt.loss}'
        ]
        
        # Add PGD/CW specific parameters to filename
        if opt.loss == 'pgd':
            result_name_parts.append(f'eps{opt.pgd_epsilon}')
        elif opt.loss == 'cw':
            result_name_parts.append(f'conf{opt.cw_confidence}_c{opt.cw_c}')
            
        result_name_parts.append(f'defense_{defense}')
        
        # Add defense-specific parameters
        if opt.cps_defense or opt.defense:
            result_name_parts.append(f'threshold_{opt.use_dynamic_threshold}')
            result_name_parts.append(f'ua_at_method={opt.ua_at_method}_use_ssim={opt.use_ssim}.txt')
        elif opt.pasac:
            result_name_parts.append(f'threshold_{opt.pasac_threshold}')
            if opt.pasac_simple:
                result_name_parts.append('simple.txt')
            else:
                result_name_parts.append('.txt')
        else:
            result_name_parts.append('.txt')
        
        result_name = os.path.join(opt.model_dir, '_'.join(result_name_parts))
        
        f = open(result_name, "w")
        with pd.option_context('mode.use_inf_as_na', True):
            asr_0 = (stat_pd['pred'] <= 0).sum() / total_rows
            asr_1 = (stat_pd['pred'] <= 1).sum() / total_rows
            orr_final = sum(orr_total) / len(orr_total)
            result = f'ASR_0: {asr_0}, ASR_1: {asr_1}, ORR: {orr_final}'
            logger.info(result)
            f.write(result)
            f.close()
            
            # CSV filename construction
            csv_name_parts = [
                f'det_result_iter{opt.iter}_lr{opt.lr}_{opt.attack_mode}_{opt.loss}'
            ]
            if opt.loss == 'pgd':
                csv_name_parts.append(f'eps{opt.pgd_epsilon}')
            elif opt.loss == 'cw':
                csv_name_parts.append(f'conf{opt.cw_confidence}_c{opt.cw_c}')
                
            csv_name_parts.append(f'defense_{defense}')
            
            # Add defense-specific parameters
            if opt.cps_defense or opt.defense:
                csv_name_parts.append(f'threshold_{opt.use_dynamic_threshold}')
                csv_name_parts.append(f'ua_at_method={opt.ua_at_method},use_ssim={opt.use_ssim}.csv')
            elif opt.pasac:
                csv_name_parts.append(f'threshold_{opt.pasac_threshold}')
                if opt.pasac_simple:
                    csv_name_parts.append('simple.csv')
                else:
                    csv_name_parts.append('.csv')
            else:
                csv_name_parts.append('.csv')
            
            df_filename = os.path.join(opt.model_dir, '_'.join(csv_name_parts))

    else:
        result_name_parts = [
            f'{opt.attack_mode}_result_visibility_{opt.target_id}_iter{opt.iter}_lr{opt.lr}_loss{opt.loss}'
        ]
        
        # Add PGD/CW specific parameters to filename
        if opt.loss == 'pgd':
            result_name_parts.append(f'eps{opt.pgd_epsilon}')
        elif opt.loss == 'cw':
            result_name_parts.append(f'conf{opt.cw_confidence}_c{opt.cw_c}')
            
        result_name_parts.append(f'defense_{defense}')
        
        # Add defense-specific parameters
        if opt.cps_defense or opt.defense:
            result_name_parts.append(f'threshold_{opt.use_dynamic_threshold}')
            result_name_parts.append(f'ua_at_method={opt.ua_at_method},use_ssim={opt.use_ssim}.txt')
        elif opt.pasac:
            result_name_parts.append(f'threshold_{opt.pasac_threshold}')
            if opt.pasac_simple:
                result_name_parts.append('simple.txt')
            else:
                result_name_parts.append('.txt')
        else:
            result_name_parts.append('.txt')
        
        result_name = os.path.join(opt.model_dir, '_'.join(result_name_parts))
        
        f = open(result_name, "w")
        fp_threshold = 2 # We require the attack to not introduce above average FP (2) after attack to minimize suspicion
        asr = ((stat_pd['target_detected'] < 1) & (stat_pd['fp'] < fp_threshold)).sum() / total_rows
        logger.info(f'ASR: {asr}')
        f.write(f'ASR: {asr}')
        f.close()
        
        # CSV filename construction
        csv_name_parts = [
            f'det_result_iter{opt.iter}_lr{opt.lr}_{opt.attack_mode}_visibility_{opt.target_id}_{opt.loss}'
        ]
        if opt.loss == 'pgd':
            csv_name_parts.append(f'eps{opt.pgd_epsilon}')
        elif opt.loss == 'cw':
            csv_name_parts.append(f'conf{opt.cw_confidence}_c{opt.cw_c}')
            
        csv_name_parts.append(f'defense_{defense}')
        
        # Add defense-specific parameters
        if opt.cps_defense or opt.defense:
            csv_name_parts.append(f'threshold_{opt.use_dynamic_threshold}')
            csv_name_parts.append(f'ua_at_method={opt.ua_at_method},use_ssim={opt.use_ssim}.csv')
        elif opt.pasac:
            csv_name_parts.append(f'threshold_{opt.pasac_threshold}')
            if opt.pasac_simple:
                csv_name_parts.append('simple.csv')
            else:
                csv_name_parts.append('.csv')
        else:
            csv_name_parts.append('.csv')
        
        df_filename = os.path.join(opt.model_dir, '_'.join(csv_name_parts))

    
    stat_pd.to_csv(df_filename, index=False)

    logger.info(f"Results saved to {result_name} and {df_filename}")

    # Generate t-SNE visualization if requested
    if opt.save_tsne:
        if len(tsne_benign_features) > 0 and len(tsne_attacked_features) > 0:
            # Convert lists to numpy arrays
            benign_array = np.array(tsne_benign_features)
            attacked_array = np.array(tsne_attacked_features)
            
            # Create t-SNE output directory
            tsne_dir_parts = [
                f'tsne_vis_iter{opt.iter}',
                f'lr{opt.lr}',
                opt.attack_mode,
                opt.loss,
                f'defense_{defense}'
            ]
            if opt.cps_defense or opt.defense:
                tsne_dir_parts.append(f'threshold_{opt.use_dynamic_threshold}')
                if opt.use_dynamic_threshold:
                    tsne_dir_parts.append(f'ua_{opt.ua_at_method}')
                tsne_dir_parts.append(f'ssim_{opt.use_ssim}')
            if opt.use_mdag:
                tsne_dir_parts.append('mdag')
            
            tsne_dir_name = '_'.join(tsne_dir_parts)
            tsne_save_dir = os.path.join(opt.model_dir, tsne_dir_name)
            if not os.path.exists(tsne_save_dir):
                os.makedirs(tsne_save_dir)
            
            tsne_save_path = os.path.join(tsne_save_dir, 'tsne_feature_separation.png')
            
            logger.info(f"[t-SNE] Generating visualization with {len(benign_array)} benign and {len(attacked_array)} attacked features...")
            
            # Call the standalone t-SNE plotting function
            vis_utils.plot_tsne_separation(
                benign_features=benign_array,
                attacked_features=attacked_array,
                save_path=tsne_save_path
            )
            
            logger.info(f"✓ t-SNE visualization saved to: {tsne_save_path}")
        else:
            logger.warning("[t-SNE] Insufficient data collected. Skipping t-SNE visualization.")
            logger.warning(f"  Benign samples: {len(tsne_benign_features)}, Attacked samples: {len(tsne_attacked_features)}")

    if opt.show_sequence:
        vis.destroy_window()
    
    # Release video writers and clean up
    if opt.save_video:
        if video_writer_attack is not None:
            video_writer_attack.release()
            if video_filename_attack:
                logger.info(f"✓ Video (after attack) saved to: {video_filename_attack}")
        if video_writer_defense is not None:
            video_writer_defense.release()
            if video_filename_defense:
                logger.info(f"✓ Video (after defense) saved to: {video_filename_defense}")
        
        # Clean up temporary frame file
        if temp_frame_path and os.path.exists(temp_frame_path):
            os.remove(temp_frame_path)

    # Clean up global visualizer to release OpenGL resources
    if opt.save_vis or opt.save_video:
        # import opencood.visualization.vis_utils as vis_utils  # REMOVED: Causes UnboundLocalError
        if hasattr(vis_utils, '_global_vis') and vis_utils._global_vis is not None:
            try:
                vis_utils._global_vis.destroy_window()
                vis_utils._global_vis = None
                vis_utils._global_vis_geoms = []
            except:
                pass
        import gc
        gc.collect()
        logger.info("Global visualizer cleaned up successfully!")

if __name__ == '__main__':
    main()


        
