"""
Example script demonstrating how to use different defense methods:
1. Original RoboSAC sampling
2. MDAG grouping with static threshold
3. MDAG grouping with dynamic adaptive threshold
"""

import torch
from defense.cps_defense import cps_defense
from defense.robosac import robosac


def example_robosac_defense(batch_data, model, dataset, perturbation, attacker_idx=1):
    """
    Example 1: Using original RoboSAC sampling strategy
    
    This is the original method that randomly samples agent subsets
    and uses consensus-based detection.
    """
    print("=" * 60)
    print("Example 1: RoboSAC Defense (Original)")
    print("=" * 60)
    
    pred_box_tensor, pred_score, gt_box_tensor = robosac(
        batch_data=batch_data,
        model=model,
        dataset=dataset,
        perturbation=perturbation,
        attacker_idx=attacker_idx,
        sampling_budget=10
    )
    
    print(f"Prediction boxes: {pred_box_tensor.shape if pred_box_tensor is not None else None}")
    print(f"Prediction scores: {pred_score.shape if pred_score is not None else None}")
    print()
    
    return pred_box_tensor, pred_score, gt_box_tensor


def example_cps_defense_robosac(batch_data, model, dataset, perturbation, attacker_idx=1):
    """
    Example 2: Using CPS defense with RoboSAC sampling
    
    This uses the CPS (Comprehensive Protection Score) metrics
    with RoboSAC sampling strategy.
    """
    print("=" * 60)
    print("Example 2: CPS Defense with RoboSAC Sampling")
    print("=" * 60)
    
    pred_box_tensor, pred_score, gt_box_tensor, cps_score = cps_defense(
        batch_data=batch_data,
        model=model,
        dataset=dataset,
        perturbation=perturbation,
        attacker_idx=attacker_idx,
        sampling_budget=10,
        lambda1=1.0,
        lambda2=1.0,
        lambda3=1.0,
        tau=0.5,  # Static threshold
        compute_gradients=False,
        use_mdag=False,  # Use RoboSAC sampling
        use_dynamic_threshold=False
    )
    
    print(f"Best CPS score: {cps_score:.4f}")
    print(f"Prediction boxes: {pred_box_tensor.shape if pred_box_tensor is not None else None}")
    print(f"Prediction scores: {pred_score.shape if pred_score is not None else None}")
    print()
    
    return pred_box_tensor, pred_score, gt_box_tensor, cps_score


def example_cps_defense_mdag_static(batch_data, model, dataset, perturbation, attacker_idx=1):
    """
    Example 3: Using CPS defense with MDAG grouping and static threshold
    
    This uses the MDAG (Maximum Geometric Tension) grouping strategy
    based on viewpoint diversity, with a fixed threshold.
    """
    print("=" * 60)
    print("Example 3: CPS Defense with MDAG Grouping (Static Threshold)")
    print("=" * 60)
    
    pred_box_tensor, pred_score, gt_box_tensor, cps_score = cps_defense(
        batch_data=batch_data,
        model=model,
        dataset=dataset,
        perturbation=perturbation,
        attacker_idx=attacker_idx,
        sampling_budget=10,  # Not used in MDAG mode, but kept for compatibility
        lambda1=1.0,
        lambda2=1.0,
        lambda3=1.0,
        tau=0.5,  # Static threshold
        compute_gradients=False,
        use_mdag=True,  # Use MDAG grouping
        use_dynamic_threshold=False  # Use static threshold
    )
    
    print(f"Best CPS score: {cps_score:.4f}")
    print(f"Prediction boxes: {pred_box_tensor.shape if pred_box_tensor is not None else None}")
    print(f"Prediction scores: {pred_score.shape if pred_score is not None else None}")
    print()
    
    return pred_box_tensor, pred_score, gt_box_tensor, cps_score


def example_cps_defense_mdag_dynamic(batch_data, model, dataset, perturbation, attacker_idx=1):
    """
    Example 4: Using CPS defense with MDAG grouping and dynamic adaptive threshold
    
    This uses the MDAG grouping strategy with an adaptive threshold
    that adjusts based on feature similarity statistics.
    """
    print("=" * 60)
    print("Example 4: CPS Defense with MDAG Grouping (Dynamic Threshold)")
    print("=" * 60)
    
    pred_box_tensor, pred_score, gt_box_tensor, cps_score = cps_defense(
        batch_data=batch_data,
        model=model,
        dataset=dataset,
        perturbation=perturbation,
        attacker_idx=attacker_idx,
        sampling_budget=10,  # Not used in MDAG mode, but kept for compatibility
        lambda1=1.0,
        lambda2=1.0,
        lambda3=1.0,
        tau=0.5,  # Base threshold (will be adjusted dynamically)
        compute_gradients=False,
        use_mdag=True,  # Use MDAG grouping
        use_dynamic_threshold=True,  # Use dynamic adaptive threshold
        threshold_sensitivity=1.0  # Sensitivity factor for threshold adjustment
    )
    
    print(f"Best CPS score: {cps_score:.4f}")
    print(f"Prediction boxes: {pred_box_tensor.shape if pred_box_tensor is not None else None}")
    print(f"Prediction scores: {pred_score.shape if pred_score is not None else None}")
    print()
    
    return pred_box_tensor, pred_score, gt_box_tensor, cps_score


def run_all_examples(batch_data, model, dataset, perturbation, attacker_idx=1):
    """
    Run all defense examples and compare results
    """
    print("\n" + "=" * 60)
    print("Running All Defense Method Examples")
    print("=" * 60 + "\n")
    
    # Example 1: RoboSAC
    result1 = example_robosac_defense(batch_data, model, dataset, perturbation, attacker_idx)
    
    # Example 2: CPS + RoboSAC
    result2 = example_cps_defense_robosac(batch_data, model, dataset, perturbation, attacker_idx)
    
    # Example 3: CPS + MDAG (Static Threshold)
    result3 = example_cps_defense_mdag_static(batch_data, model, dataset, perturbation, attacker_idx)
    
    # Example 4: CPS + MDAG (Dynamic Threshold)
    result4 = example_cps_defense_mdag_dynamic(batch_data, model, dataset, perturbation, attacker_idx)
    
    print("=" * 60)
    print("Summary")
    print("=" * 60)
    print("1. RoboSAC: Random sampling with consensus")
    print("2. CPS + RoboSAC: CPS metrics with random sampling")
    print(f"   - CPS Score: {result2[3]:.4f}")
    print("3. CPS + MDAG (Static): Viewpoint-based grouping with fixed threshold")
    print(f"   - CPS Score: {result3[3]:.4f}")
    print("4. CPS + MDAG (Dynamic): Viewpoint-based grouping with adaptive threshold")
    print(f"   - CPS Score: {result4[3]:.4f}")
    print("=" * 60)


if __name__ == "__main__":
    # This is just an example structure
    # In practice, you would load your actual batch_data, model, dataset, and perturbation
    print("This is an example script demonstrating the API.")
    print("Please integrate these functions into your actual evaluation pipeline.")
    print("\nKey parameters:")
    print("  - use_mdag: True/False (MDAG grouping vs RoboSAC sampling)")
    print("  - use_dynamic_threshold: True/False (Dynamic vs static threshold)")
    print("  - threshold_sensitivity: float (Sensitivity for dynamic threshold)")

