"""
PASAC (Prior-Art Sampling and Consensus) Defense

PASAC uses a recursive binary splitting approach to identify benign agents:
1. Split agent list into two halves (G1 and G2)
2. Calculate CCLoss (Cross-Correlation Loss) for each group with ego
3. If CCLoss > threshold, the group likely contains benign agents -> recurse
4. If CCLoss <= threshold, the group likely contains attackers -> skip
5. Continue until all benign agents are identified

Reference: CP-Guard (https://github.com/CP-Security/CP-Guard)

Key differences from ROBOSAC:
- PASAC uses recursive binary splitting
- PASAC uses CCLoss (cross-correlation) instead of IoU
- PASAC is deterministic (no random sampling)
"""

import torch
import numpy as np
import logging
from collections import OrderedDict
from defense.robosac_util import associate_2_detections

# Get logger instance
logger = logging.getLogger()


def get_ego_pred(batch_data, model, dataset):
    """Get ego-only prediction (baseline) and spatial features."""
    output_dict = OrderedDict()
    cav_content = batch_data['ego']
    device = cav_content['processed_lidar']['voxel_features'].device
    
    voxel_feature_dict = {
        'voxel_features': cav_content['processed_lidar']['voxel_features'],
        'voxel_coords': cav_content['processed_lidar']['voxel_coords'],
        'voxel_num_points': cav_content['processed_lidar']['voxel_num_points'],
        'record_len': torch.tensor([1]).to(device),
        'pairwise_t_matrix': cav_content['pairwise_t_matrix'][0, :2, :2, ...].unsqueeze(0)
    }
    model.pillar_vfe(voxel_feature_dict)
    model.scatter(voxel_feature_dict)
    
    # Save spatial feature before reshaping for CCLoss calculation
    ego_spatial_feature = voxel_feature_dict['spatial_features'][0].clone()
    
    voxel_feature_dict['spatial_features'] = voxel_feature_dict['spatial_features'][0].unsqueeze(0)
    output_dict['ego'] = model(voxel_feature_dict)

    pred_box_tensor, pred_score, gt_box_tensor = dataset.post_process(batch_data, output_dict)
    return pred_box_tensor, pred_score, gt_box_tensor, ego_spatial_feature


def get_fused_pred(batch_data, model, dataset, agent_list, perturbation, attacker_idx):
    """Get fused prediction for a given agent list (including ego) and fused spatial features.
    
    This follows the same logic as ROBOSAC:
    1. Build voxel features for all agents
    2. Apply perturbation at attacker_idx BEFORE index_select
    3. Then index_select to get the subset
    """
    cav_content = batch_data['ego']
    device = cav_content['processed_lidar']['voxel_features'].device
    
    # agent_list should include ego (index 0)
    if 0 not in agent_list:
        agent_list = [0] + agent_list
    
    indices_tensor = torch.tensor(agent_list).to(device)
    
    voxel_feature_dict = {
        'voxel_features': cav_content['processed_lidar']['voxel_features'],
        'voxel_coords': cav_content['processed_lidar']['voxel_coords'],
        'voxel_num_points': cav_content['processed_lidar']['voxel_num_points'],
        'record_len': torch.tensor([len(agent_list)]).to(device),
        'pairwise_t_matrix': torch.index_select(
            torch.index_select(cav_content['pairwise_t_matrix'][0], dim=0, index=indices_tensor),
            dim=1, index=indices_tensor
        ).unsqueeze(0)
    }
    
    model.pillar_vfe(voxel_feature_dict)
    model.scatter(voxel_feature_dict)
    
    # Apply perturbation BEFORE index_select (same as ROBOSAC line 64)
    # attacker_idx is the index in the FULL spatial_features array
    if perturbation is not None:
        voxel_feature_dict['spatial_features'][attacker_idx] = perturbation
    
    # Then select only the agents in agent_list (same as ROBOSAC line 66)
    voxel_feature_dict['spatial_features'] = torch.index_select(
        voxel_feature_dict['spatial_features'], dim=0, index=indices_tensor
    )
    
    # Calculate fused spatial feature (mean of all agents' features) for CCLoss
    fused_spatial_feature = voxel_feature_dict['spatial_features'].mean(dim=0)
    
    output_dict = OrderedDict()
    output_dict['ego'] = model(voxel_feature_dict)
    pred_box_tensor, pred_score, gt_box_tensor = dataset.post_process(batch_data, output_dict)
    
    return pred_box_tensor, pred_score, gt_box_tensor, fused_spatial_feature


def calculate_ccloss(ego_pred_boxes, fused_pred_boxes, ego_feature=None, fused_feature=None, use_features=False):
    """
    Calculate Cross-Correlation Loss (CCLoss) between ego and fused predictions.
    
    Uses the same approach as ROBOSAC: associate_2_detections (Jaccard Index).
    
    Args:
        ego_pred_boxes: Ego-only prediction bounding boxes
        fused_pred_boxes: Fused prediction bounding boxes  
        ego_feature: Optional (not used, kept for API compatibility)
        fused_feature: Optional (not used, kept for API compatibility)
        use_features: Whether to use feature-level CCLoss (default: False, use box-level like ROBOSAC)
    
    Returns:
        ccloss: Jaccard Index value in [0, 1], same as ROBOSAC
    """
    # Get lengths for debugging
    len_ego = 0 if ego_pred_boxes is None else len(ego_pred_boxes)
    len_fused = 0 if fused_pred_boxes is None else len(fused_pred_boxes)
    
    logger.info(f"[CCLoss] ego_boxes: {len_ego}, fused_boxes: {len_fused}")
    
    # Use the same function as ROBOSAC but with lower threshold
    # Default 0.5 is too strict for defense similarity check
    ccloss = associate_2_detections(ego_pred_boxes, fused_pred_boxes, iou_threshold=0.3)
    
    logger.info(f"[CCLoss] Result: {ccloss:.4f}")
    
    return ccloss


def pasac_recursive(batch_data, model, dataset, agent_list, ego_pred_boxes, ego_feature,
                    perturbation, attacker_idx, threshold, depth=0, max_depth=20):
    """
    Recursively execute PASAC algorithm using binary splitting.
    
    Algorithm:
    1. If only one agent remains, directly calculate its CCLoss
    2. Split agent list into two halves (G1 and G2)
    3. Calculate CCLoss for each group (ego + group)
    4. If CCLoss > threshold, the group is consistent -> recurse
    5. If CCLoss <= threshold, the group likely contains attackers -> discard
    
    Args:
        batch_data: Input batch data
        model: The cooperative perception model
        dataset: Dataset for post-processing
        agent_list: List of agent indices to check (excluding ego)
        ego_pred_boxes: Ego-only prediction boxes (reference)
        ego_feature: Ego spatial feature for feature-level CCLoss
        perturbation: Adversarial perturbation
        attacker_idx: Index of the attacker
        threshold: CCLoss threshold for accepting a group
        depth: Current recursion depth
        max_depth: Maximum recursion depth to prevent infinite recursion
    
    Returns:
        benign_agents: List of identified benign agent indices
    """
    indent = "  " * depth
    
    if len(agent_list) == 0:
        return []
    
    if depth >= max_depth:
        logger.warning(f"{indent}[PASAC] Max recursion depth reached, returning empty")
        return []
    
    if len(agent_list) == 1:
        # Only one agent remains, directly calculate its CCLoss
        agent_id = agent_list[0]
        fused_pred, _, _, fused_feature = get_fused_pred(
            batch_data, model, dataset, [0, agent_id], perturbation, attacker_idx
        )
        ccloss = calculate_ccloss(ego_pred_boxes, fused_pred, ego_feature, fused_feature)
        
        logger.info(f"{indent}[PASAC] Testing single agent {agent_id}: CCLoss = {ccloss:.4f} (threshold: {threshold})")
        
        if ccloss > threshold:
            logger.info(f"{indent}[PASAC] Agent {agent_id} is BENIGN")
            return [agent_id]
        else:
            logger.info(f"{indent}[PASAC] Agent {agent_id} is MALICIOUS")
            return []
    
    # Split agent list into two halves
    mid = len(agent_list) // 2
    G1 = agent_list[:mid]
    G2 = agent_list[mid:]
    
    logger.info(f"{indent}[PASAC] Depth {depth}: Splitting {agent_list} -> G1: {G1}, G2: {G2}")
    
    benign_agents = []
    
    # Test Group 1
    if len(G1) > 0:
        fused_pred_G1, _, _, fused_feature_G1 = get_fused_pred(
            batch_data, model, dataset, [0] + G1, perturbation, attacker_idx
        )
        ccloss_G1 = calculate_ccloss(ego_pred_boxes, fused_pred_G1, ego_feature, fused_feature_G1)
        
        logger.info(f"{indent}[PASAC] G1 {G1}: CCLoss = {ccloss_G1:.4f} (threshold: {threshold})")
        
        if ccloss_G1 > threshold:
            # G1 is consistent, recurse to find benign agents inside
            logger.info(f"{indent}[PASAC] G1 {G1} is CONSISTENT, recursing...")
            benign_agents.extend(
                pasac_recursive(batch_data, model, dataset, G1, ego_pred_boxes, ego_feature,
                               perturbation, attacker_idx, threshold, depth + 1, max_depth)
            )
        else:
            # G1 is inconsistent, discard
            logger.info(f"{indent}[PASAC] G1 {G1} is INCONSISTENT, discarding")
    
    # Test Group 2
    if len(G2) > 0:
        fused_pred_G2, _, _, fused_feature_G2 = get_fused_pred(
            batch_data, model, dataset, [0] + G2, perturbation, attacker_idx
        )
        ccloss_G2 = calculate_ccloss(ego_pred_boxes, fused_pred_G2, ego_feature, fused_feature_G2)
        
        logger.info(f"{indent}[PASAC] G2 {G2}: CCLoss = {ccloss_G2:.4f} (threshold: {threshold})")
        
        if ccloss_G2 > threshold:
            # G2 is consistent, recurse to find benign agents inside
            logger.info(f"{indent}[PASAC] G2 {G2} is CONSISTENT, recursing...")
            benign_agents.extend(
                pasac_recursive(batch_data, model, dataset, G2, ego_pred_boxes, ego_feature,
                               perturbation, attacker_idx, threshold, depth + 1, max_depth)
            )
        else:
            # G2 is inconsistent, discard
            logger.info(f"{indent}[PASAC] G2 {G2} is INCONSISTENT, discarding")
    
    return benign_agents


def pasac(batch_data, model, dataset, perturbation, attacker_idx=1, 
          threshold=0.3, max_depth=20):
    """
    PASAC (Prior-Art Sampling and Consensus) Defense
    
    Uses recursive binary splitting to identify benign agents:
    1. Get ego-only prediction as baseline
    2. Recursively split agent groups and check CCLoss
    3. If group CCLoss > threshold, recurse to verify subgroup consistency
    4. If group CCLoss <= threshold, discard the group (assumed to contain attackers)
    5. Use ego + all benign agents for final prediction
    
    Args:
        batch_data: Input batch data containing CAV information
        model: The cooperative perception model
        dataset: Dataset for post-processing
        perturbation: Adversarial perturbation tensor
        attacker_idx: Index of the attacker (for applying perturbation)
        threshold: CCLoss threshold for accepting a group (default: 0.3)
        max_depth: Maximum recursion depth
    
    Returns:
        pred_box_tensor: Predicted bounding boxes
        pred_score: Prediction confidence scores
        gt_box_tensor: Ground truth bounding boxes
        defense_info: Dictionary containing defense statistics
    """
    agent_num = batch_data['ego']['cav_num']
    cav_content = batch_data['ego']
    device = cav_content['processed_lidar']['voxel_features'].device
    
    # Initialize defense info
    defense_info = {
        'threshold': threshold,
        'benign_agents': [],
        'malicious_agents': [],
        'recursion_depth': 0
    }
    
    logger.info(f"[PASAC] Starting PASAC defense with {agent_num} agents, threshold: {threshold}")
    
    # Step 1: Get ego-only prediction as baseline (including spatial feature)
    ego_pred_box_tensor, ego_pred_score, gt_box_tensor, ego_feature = get_ego_pred(batch_data, model, dataset)
    
    ego_box_count = len(ego_pred_box_tensor) if ego_pred_box_tensor is not None else 0
    logger.info(f"[PASAC] Ego-only prediction: {ego_box_count} boxes, feature shape: {ego_feature.shape}")
    
    # If only ego, return ego prediction
    if agent_num <= 1:
        logger.info(f"[PASAC] Only ego agent, returning ego-only prediction")
        return ego_pred_box_tensor, ego_pred_score, gt_box_tensor, defense_info
    
    # Step 2: Get all non-ego agents
    all_agent_list = list(range(1, agent_num))  # Exclude ego (index 0)
    logger.info(f"[PASAC] Non-ego agents to check: {all_agent_list}")
    
    # Step 3: Recursively identify benign agents (with feature-level CCLoss)
    benign_agents = pasac_recursive(
        batch_data, model, dataset, all_agent_list, ego_pred_box_tensor, ego_feature,
        perturbation, attacker_idx, threshold, depth=0, max_depth=max_depth
    )
    
    # Update defense info
    defense_info['benign_agents'] = benign_agents
    defense_info['malicious_agents'] = [
        agent_id for agent_id in all_agent_list 
        if agent_id not in benign_agents
    ]
    
    logger.info(f"[PASAC] Final benign agents: {defense_info['benign_agents']}")
    logger.info(f"[PASAC] Final malicious agents: {defense_info['malicious_agents']}")
    
    # Step 4: Final prediction using ego + all benign agents
    if len(benign_agents) == 0:
        # No benign agents found, use ego-only
        logger.info(f"[PASAC] No benign agents found, returning ego-only prediction")
        return ego_pred_box_tensor, ego_pred_score, gt_box_tensor, defense_info
    
    # Use ego + all benign agents for final prediction
    final_agents = [0] + benign_agents
    logger.info(f"[PASAC] Final fusion with agents: {final_agents}")
    
    pred_box_tensor, pred_score, gt_box_tensor, _ = get_fused_pred(
        batch_data, model, dataset, final_agents, perturbation, attacker_idx
    )
    
    return pred_box_tensor, pred_score, gt_box_tensor, defense_info


# Keep the old simple version for backward compatibility
def pasac_simple(batch_data, model, dataset, perturbation, attacker_idx=1, 
                 consensus_threshold=0.3, max_agents_to_add=None):
    """
    Simple PASAC - Tests each agent individually (non-recursive version).
    
    This is a simpler version that tests each agent one-by-one.
    For the correct recursive implementation, use pasac().
    """
    agent_num = batch_data['ego']['cav_num']
    cav_content = batch_data['ego']
    device = cav_content['processed_lidar']['voxel_features'].device
    
    defense_info = {
        'threshold': consensus_threshold,
        'benign_agents': [],
        'malicious_agents': [],
        'agent_scores': {}
    }
    
    ego_pred_box_tensor, ego_pred_score, gt_box_tensor, ego_feature = get_ego_pred(batch_data, model, dataset)
    
    if agent_num <= 1:
        return ego_pred_box_tensor, ego_pred_score, gt_box_tensor, defense_info
    
    accepted_agents = []
    
    for agent_id in range(1, agent_num):
        pred_box_tensor, _, _, _ = get_fused_pred(
            batch_data, model, dataset, [0, agent_id], perturbation, attacker_idx
        )
        
        iou_score = associate_2_detections(ego_pred_box_tensor, pred_box_tensor)
        defense_info['agent_scores'][agent_id] = iou_score
        
        logger.info(f"[PASAC-Simple] Agent {agent_id}: IoU = {iou_score:.4f}")
        
        if iou_score >= consensus_threshold:
            accepted_agents.append(agent_id)
            logger.info(f"[PASAC-Simple] Agent {agent_id}: ACCEPTED")
        else:
            logger.info(f"[PASAC-Simple] Agent {agent_id}: REJECTED")
    
    defense_info['benign_agents'] = accepted_agents
    defense_info['malicious_agents'] = [
        agent_id for agent_id in range(1, agent_num) 
        if agent_id not in accepted_agents
    ]
    
    if len(accepted_agents) == 0:
        return ego_pred_box_tensor, ego_pred_score, gt_box_tensor, defense_info
    
    final_agents = [0] + accepted_agents
    pred_box_tensor, pred_score, gt_box_tensor, _ = get_fused_pred(
        batch_data, model, dataset, final_agents, perturbation, attacker_idx
    )
    
    return pred_box_tensor, pred_score, gt_box_tensor, defense_info
