import torch
import torch.nn.functional as F
from collections import OrderedDict

from defense.robosac import cal_robosac_consensus, sample_agents


def _prepare_subset_features(cav_content, cav_idx, model, device):
    """Build voxel feature dict for a subset of agents and return their spatial features."""
    indices_tensor = torch.tensor(cav_idx, device=device)
    voxel_feature_dict = {
        'voxel_features': cav_content['processed_lidar']['voxel_features'],
        'voxel_coords': cav_content['processed_lidar']['voxel_coords'],
        'voxel_num_points': cav_content['processed_lidar']['voxel_num_points'],
        'record_len': torch.tensor([len(cav_idx)], device=device),
        'pairwise_t_matrix': torch.index_select(
            torch.index_select(cav_content['pairwise_t_matrix'][0], dim=0, index=indices_tensor),
            dim=1,
            index=indices_tensor,
        ).unsqueeze(0),
    }
    model.pillar_vfe(voxel_feature_dict)
    model.scatter(voxel_feature_dict)
    voxel_feature_dict['spatial_features'] = torch.index_select(
        voxel_feature_dict['spatial_features'], dim=0, index=indices_tensor
    )
    return voxel_feature_dict


def feature_guard(batch_data, model, dataset, perturbation, attacker_idx=1, sampling_budget=10):
    """
    Feature-map consensus defense inspired by ROBOSAC but operating on intermediate fusion features.

    The defender samples multiple collaboration subsets, fuses their intermediate features,
    and retains the subset whose fused feature stays closest to the ego-only reference.
    """
    cav_content = batch_data['ego']
    agent_num = cav_content['cav_num']
    device = perturbation.device if perturbation is not None else cav_content['processed_lidar']['voxel_features'].device

    # Reference feature using only the ego vehicle.
    base_voxel_dict = _prepare_subset_features(cav_content, [0], model, device)
    reference_feature = base_voxel_dict['spatial_features'][0].unsqueeze(0)

    s = cal_robosac_consensus(agent_num, sampling_budget, num_attackers=1)
    s = max(1, min(s, agent_num - 1))

    best_subset = base_voxel_dict
    best_score = 1.0

    iterations = 0
    while iterations < sampling_budget and agent_num > 1:
        iterations += 1
        sampled_agents = sample_agents(agent_num, s)
        cav_idx = sorted(set([0] + sampled_agents))

        voxel_feature_dict = _prepare_subset_features(cav_content, cav_idx, model, device)

        if attacker_idx in cav_idx and perturbation is not None:
            local_idx = cav_idx.index(attacker_idx)
            voxel_feature_dict['spatial_features'][local_idx] = perturbation

        fused_feature = voxel_feature_dict['spatial_features'].mean(dim=0, keepdim=True)
        score = F.cosine_similarity(
            fused_feature.view(1, -1),
            reference_feature.view(1, -1),
            dim=1,
            eps=1e-6,
        ).item()

        if score > best_score:
            best_score = score
            best_subset = voxel_feature_dict

    output_dict = OrderedDict()
    output_dict['ego'] = model(best_subset)

    pred_box_tensor, pred_score, gt_box_tensor = dataset.post_process(batch_data, output_dict)

    return pred_box_tensor, pred_score, gt_box_tensor, best_score

