"""
SSIM功能使用示例

展示如何在CPS防御中使用梯度一致性或SSIM方法
"""

from defense.cps_defense import cps_defense

# ============================================================================
# 示例1：使用梯度一致性方法（原方法，默认）
# ============================================================================
def example_gradient_consistency(batch_data, model, dataset, perturbation):
    """
    使用梯度一致性方法进行CPS防御
    
    特点：
    - 计算时间较长
    - GPU内存占用较高
    - 考虑特征对模型输出的因果影响
    """
    print("=" * 80)
    print("示例1：使用【梯度一致性方法】")
    print("=" * 80)
    
    pred_box, pred_score, gt_box, cps_score, defense_info = cps_defense(
        batch_data=batch_data,
        model=model,
        dataset=dataset,
        perturbation=perturbation,
        attacker_idx=1,
        sampling_budget=10,
        lambda1=1.0,
        lambda2=1.0,  # 梯度一致性权重
        lambda3=1.0,
        tau=0.68,
        compute_gradients=True,   # ✅ 开启一致性计算
        use_ssim=False,           # ✅ 使用梯度一致性（默认）
        use_mdag=True,
        use_dynamic_threshold=False
    )
    
    print(f"CPS分数: {cps_score:.4f}")
    print(f"良性车辆: {defense_info['benign_agents']}")
    print(f"恶意车辆: {defense_info['malicious_agents']}")
    print(f"使用阈值: {defense_info['threshold']:.4f}")
    
    return pred_box, pred_score, gt_box, cps_score, defense_info


# ============================================================================
# 示例2：使用SSIM方法（新方法）
# ============================================================================
def example_ssim(batch_data, model, dataset, perturbation):
    """
    使用SSIM方法进行CPS防御
    
    特点：
    - 计算速度快
    - GPU内存占用低
    - 专注于结构相似性
    """
    print("=" * 80)
    print("示例2：使用【SSIM方法】")
    print("=" * 80)
    
    pred_box, pred_score, gt_box, cps_score, defense_info = cps_defense(
        batch_data=batch_data,
        model=model,
        dataset=dataset,
        perturbation=perturbation,
        attacker_idx=1,
        sampling_budget=10,
        lambda1=1.0,
        lambda2=1.0,  # SSIM权重
        lambda3=1.0,
        tau=0.68,
        compute_gradients=True,   # ✅ 开启一致性计算
        use_ssim=True,            # ✅ 使用SSIM方法
        use_mdag=True,
        use_dynamic_threshold=False
    )
    
    print(f"CPS分数: {cps_score:.4f}")
    print(f"良性车辆: {defense_info['benign_agents']}")
    print(f"恶意车辆: {defense_info['malicious_agents']}")
    print(f"使用阈值: {defense_info['threshold']:.4f}")
    
    return pred_box, pred_score, gt_box, cps_score, defense_info


# ============================================================================
# 示例3：配合动态阈值使用SSIM（推荐）
# ============================================================================
def example_ssim_with_dynamic_threshold(batch_data, model, dataset, perturbation, 
                                       threshold_calculator):
    """
    使用SSIM + 动态阈值进行CPS防御（推荐配置）
    
    特点：
    - 计算速度快
    - 自适应阈值，提高鲁棒性
    - 适合实时场景
    """
    print("=" * 80)
    print("示例3：使用【SSIM方法 + 动态阈值】（推荐）")
    print("=" * 80)
    
    pred_box, pred_score, gt_box, cps_score, defense_info = cps_defense(
        batch_data=batch_data,
        model=model,
        dataset=dataset,
        perturbation=perturbation,
        attacker_idx=1,
        sampling_budget=10,
        lambda1=1.0,
        lambda2=1.0,
        lambda3=1.0,
        tau=0.68,  # 基础阈值
        compute_gradients=True,        # ✅ 开启一致性计算
        use_ssim=True,                 # ✅ 使用SSIM方法
        use_mdag=True,
        use_dynamic_threshold=True,    # ✅ 开启动态阈值
        threshold_calculator=threshold_calculator  # 传入阈值计算器
    )
    
    print(f"CPS分数: {cps_score:.4f}")
    print(f"良性车辆: {defense_info['benign_agents']}")
    print(f"恶意车辆: {defense_info['malicious_agents']}")
    print(f"使用阈值: {defense_info['threshold']:.4f}")
    if defense_info.get('mu') is not None:
        print(f"均值 μ: {defense_info['mu']:.4f}")
        print(f"标准差 σ: {defense_info['sigma']:.4f}")
    
    return pred_box, pred_score, gt_box, cps_score, defense_info


# ============================================================================
# 示例4：关闭一致性计算（最快，但精度较低）
# ============================================================================
def example_no_consistency(batch_data, model, dataset, perturbation):
    """
    不使用一致性计算（梯度或SSIM都不计算）
    
    特点：
    - 最快速度
    - 最低GPU内存
    - 仅使用特征相似度和能量偏移
    """
    print("=" * 80)
    print("示例4：不使用一致性计算（最快）")
    print("=" * 80)
    
    pred_box, pred_score, gt_box, cps_score, defense_info = cps_defense(
        batch_data=batch_data,
        model=model,
        dataset=dataset,
        perturbation=perturbation,
        attacker_idx=1,
        sampling_budget=10,
        lambda1=1.0,
        lambda2=0.0,  # ⚠️ lambda2设为0，因为不计算一致性
        lambda3=1.0,
        tau=0.68,
        compute_gradients=False,  # ❌ 关闭一致性计算
        # use_ssim参数此时无效
        use_mdag=True,
        use_dynamic_threshold=False
    )
    
    print(f"CPS分数: {cps_score:.4f}")
    print(f"良性车辆: {defense_info['benign_agents']}")
    print(f"恶意车辆: {defense_info['malicious_agents']}")
    print(f"使用阈值: {defense_info['threshold']:.4f}")
    
    return pred_box, pred_score, gt_box, cps_score, defense_info


# ============================================================================
# 主函数：对比两种方法
# ============================================================================
def main():
    """
    主函数：对比梯度一致性和SSIM两种方法
    """
    import torch
    from defense.mdag_grouping import compute_dynamic_threshold
    
    # 假设你已经有了这些对象（实际使用时需要根据你的代码进行初始化）
    # batch_data = ...  # 你的批次数据
    # model = ...       # 你的融合模型
    # dataset = ...     # 你的数据集
    # perturbation = ... # 对抗扰动
    
    print("\n" + "=" * 80)
    print("CPS防御：梯度一致性 vs SSIM 方法对比")
    print("=" * 80 + "\n")
    
    # 场景1：GPU资源充足，对精度要求高
    print("\n【场景1】GPU资源充足，对精度要求高")
    print("推荐：使用梯度一致性方法")
    # results_grad = example_gradient_consistency(batch_data, model, dataset, perturbation)
    
    # 场景2：GPU资源有限，需要快速推理
    print("\n【场景2】GPU资源有限，需要快速推理")
    print("推荐：使用SSIM方法")
    # results_ssim = example_ssim(batch_data, model, dataset, perturbation)
    
    # 场景3：实时应用，需要自适应阈值
    print("\n【场景3】实时应用，需要自适应阈值")
    print("推荐：SSIM + 动态阈值")
    # threshold_calc = compute_dynamic_threshold(...)
    # results_ssim_dynamic = example_ssim_with_dynamic_threshold(
    #     batch_data, model, dataset, perturbation, threshold_calc
    # )
    
    # 场景4：极端速度需求
    print("\n【场景4】极端速度需求")
    print("推荐：关闭一致性计算")
    # results_fast = example_no_consistency(batch_data, model, dataset, perturbation)
    
    print("\n" + "=" * 80)
    print("对比总结")
    print("=" * 80)
    print("| 方法              | 速度 | GPU内存 | 精度  | 适用场景           |")
    print("|------------------|------|---------|-------|--------------------|")
    print("| 梯度一致性        | ★☆☆  | ★★★     | ★★★   | 高精度需求         |")
    print("| SSIM             | ★★★  | ★☆☆     | ★★☆   | 快速推理           |")
    print("| SSIM+动态阈值    | ★★★  | ★☆☆     | ★★★   | 实时应用（推荐）   |")
    print("| 无一致性计算      | ★★★  | ☆☆☆     | ★☆☆   | 极端速度需求       |")
    print("=" * 80)
    

# ============================================================================
# 在命令行脚本中使用的示例
# ============================================================================
def parse_args_example():
    """
    命令行参数解析示例
    """
    import argparse
    
    parser = argparse.ArgumentParser(description='CPS Defense with SSIM support')
    
    # ... 其他参数 ...
    
    # 添加SSIM相关参数
    parser.add_argument('--use_ssim', 
                       action='store_true',
                       default=False,
                       help='使用SSIM代替梯度一致性（更快，内存占用更低）')
    
    parser.add_argument('--compute_gradients',
                       action='store_true', 
                       default=False,
                       help='是否计算特征一致性（梯度或SSIM）')
    
    args = parser.parse_args()
    
    # 在代码中使用
    # results = cps_defense(
    #     ...,
    #     compute_gradients=args.compute_gradients,
    #     use_ssim=args.use_ssim,
    #     ...
    # )
    
    return args


if __name__ == '__main__':
    # 运行主函数查看对比
    main()
    
    # 命令行使用示例：
    print("\n" + "=" * 80)
    print("命令行使用示例")
    print("=" * 80)
    print("\n# 使用梯度一致性方法：")
    print("python your_script.py --compute_gradients")
    print("\n# 使用SSIM方法：")
    print("python your_script.py --compute_gradients --use_ssim")
    print("\n# 不使用一致性计算：")
    print("python your_script.py")
    print("=" * 80)

