import numpy as np
from scipy.spatial.distance import cdist
import torch
from dataclasses import dataclass
from typing import Dict, Optional
import argparse


def compute_accuracy_recall(ref_indices1, ref_indices2, ind_nonref):
    """Compute accuracy and recall for reference indices."""
    correct = sum(1 for a, b in zip(ref_indices1, ref_indices2) if a == b)
    accuracy = correct / len(ref_indices1) if len(ref_indices1) > 0 else 0
    recall = correct / len(ind_nonref) if len(ind_nonref) > 0 else 0
    return accuracy, recall, correct


def analyze_distance_based_accuracy(
    ref_indices1: np.ndarray,
    ref_indices2: np.ndarray,
    emb1: np.ndarray,
    emb2: np.ndarray,
    anchor_indices: np.ndarray,
    distance_metric: str = 'cosine',
    use_gpu: bool = False,
    device: torch.device = None,
    percentile_ranges: list = None
) -> dict:
    """
    Analyze accuracy breakdown by distance to supervised anchor points.

    Computes per-pair correctness and correlates it with distances to anchors.
    Returns accuracy statistics for different distance percentile ranges.

    Args:
        ref_indices1: Discovered pair indices from embedding 1 (shape: [n_pairs])
        ref_indices2: Discovered pair indices from embedding 2 (shape: [n_pairs])
        emb1: Full embedding matrix 1 (shape: [n_emb1, d])
        emb2: Full embedding matrix 2 (shape: [n_emb2, d])
        anchor_indices: Indices of supervised anchors (shape: [n_anchors])
                       These are the ORIGINAL supervised reference points (ref_ind)
        distance_metric: 'cosine' or 'euclidean'
        use_gpu: Whether to use GPU for computation
        device: PyTorch device for GPU computation
        percentile_ranges: List of (min%, max%) tuples. Default: [(0,20), (20,40), (40,60), (60,80), (80,100)]

    Returns:
        dict with keys:
            - 'avg_distance_correlation': Pearson correlation for average distance
            - 'min_distance_correlation': Pearson correlation for minimum distance
            - 'avg_distance_p_value': p-value for average distance correlation
            - 'min_distance_p_value': p-value for minimum distance correlation
            - 'percentile_breakdown': dict mapping range tuple to {
                'accuracy': float,
                'n_pairs': int,
                'n_correct': int
              }
            - 'avg_distances': array of average distances for each pair
            - 'min_distances': array of minimum distances for each pair
            - 'correctness': array of 1/0 for each pair
    """
    from scipy.stats import pearsonr
    from utils.graph_util import get_dists
    from loguru import logger

    # Default percentile ranges
    if percentile_ranges is None:
        percentile_ranges = [(0, 20), (20, 40), (40, 60), (60, 80), (80, 100)]

    # Handle empty pairs case
    if len(ref_indices1) == 0 or len(anchor_indices) == 0:
        return {
            'avg_distance_correlation': np.nan,
            'min_distance_correlation': np.nan,
            'avg_distance_p_value': np.nan,
            'min_distance_p_value': np.nan,
            'percentile_breakdown': {range_tuple: {'accuracy': 0.0, 'n_pairs': 0, 'n_correct': 0}
                                    for range_tuple in percentile_ranges},
            'avg_distances': np.array([]),
            'min_distances': np.array([]),
            'correctness': np.array([])
        }

    n_pairs = len(ref_indices1)
    n_anchors = len(anchor_indices)

    # Compute correctness
    correctness = (ref_indices1 == ref_indices2).astype(np.float32)

    # Extract anchor embeddings
    anchor_emb1 = emb1[anchor_indices]
    anchor_emb2 = emb2[anchor_indices]

    # Compute distances with chunking
    avg_distances_list = []
    min_distances_list = []

    chunk_size = min(1000, n_pairs)  # Adaptive chunk size

    for chunk_start in range(0, n_pairs, chunk_size):
        chunk_end = min(chunk_start + chunk_size, n_pairs)

        # Get pair embeddings for this chunk
        chunk_indices1 = ref_indices1[chunk_start:chunk_end]
        chunk_indices2 = ref_indices2[chunk_start:chunk_end]
        chunk_emb1 = emb1[chunk_indices1]
        chunk_emb2 = emb2[chunk_indices2]

        try:
            # Distances in space 1: [chunk_size, n_anchors]
            dist_matrix1 = get_dists(
                chunk_emb1, anchor_emb1,
                metric=distance_metric,
                use_gpu=use_gpu,
                device=device
            )

            # Distances in space 2: [chunk_size, n_anchors]
            dist_matrix2 = get_dists(
                chunk_emb2, anchor_emb2,
                metric=distance_metric,
                use_gpu=use_gpu,
                device=device
            )

        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                logger.warning("GPU OOM in distance analysis, falling back to CPU")
                if use_gpu:
                    torch.cuda.empty_cache()
                dist_matrix1 = get_dists(chunk_emb1, anchor_emb1,
                                        metric=distance_metric, use_gpu=False, device=None)
                dist_matrix2 = get_dists(chunk_emb2, anchor_emb2,
                                        metric=distance_metric, use_gpu=False, device=None)
            else:
                raise

        # Convert to numpy if needed
        if torch.is_tensor(dist_matrix1):
            dist_matrix1 = dist_matrix1.cpu().numpy()
        if torch.is_tensor(dist_matrix2):
            dist_matrix2 = dist_matrix2.cpu().numpy()

        # Combine distances from both spaces: [chunk_size, 2*n_anchors]
        combined_distances = np.concatenate([dist_matrix1, dist_matrix2], axis=1)

        # Compute statistics per pair
        chunk_avg = np.mean(combined_distances, axis=1)
        chunk_min = np.min(combined_distances, axis=1)

        avg_distances_list.append(chunk_avg)
        min_distances_list.append(chunk_min)

    # Concatenate all chunks
    avg_distances = np.concatenate(avg_distances_list)
    min_distances = np.concatenate(min_distances_list)

    # Compute correlations
    if len(np.unique(correctness)) < 2:
        # All same correctness
        avg_corr, avg_p_value = np.nan, np.nan
        min_corr, min_p_value = np.nan, np.nan
    elif len(correctness) < 3:
        # Too few samples
        avg_corr, avg_p_value = np.nan, np.nan
        min_corr, min_p_value = np.nan, np.nan
    else:
        try:
            avg_corr, avg_p_value = pearsonr(correctness, avg_distances)
            min_corr, min_p_value = pearsonr(correctness, min_distances)
        except Exception as e:
            logger.warning(f"Pearson correlation computation failed: {e}")
            avg_corr, avg_p_value = np.nan, np.nan
            min_corr, min_p_value = np.nan, np.nan

    # Bucket by percentiles (using average distance)
    percentile_breakdown = {}
    for min_pct, max_pct in percentile_ranges:
        min_threshold = np.percentile(avg_distances, min_pct)
        max_threshold = np.percentile(avg_distances, max_pct)

        if max_pct == 100:
            mask = (avg_distances >= min_threshold) & (avg_distances <= max_threshold)
        else:
            mask = (avg_distances >= min_threshold) & (avg_distances < max_threshold)

        n_pairs_in_range = np.sum(mask)
        if n_pairs_in_range > 0:
            n_correct = np.sum(correctness[mask])
            accuracy = n_correct / n_pairs_in_range
        else:
            n_correct = 0
            accuracy = 0.0

        percentile_breakdown[(min_pct, max_pct)] = {
            'accuracy': float(accuracy),
            'n_pairs': int(n_pairs_in_range),
            'n_correct': int(n_correct)
        }

    return {
        'avg_distance_correlation': float(avg_corr) if not np.isnan(avg_corr) else np.nan,
        'min_distance_correlation': float(min_corr) if not np.isnan(min_corr) else np.nan,
        'avg_distance_p_value': float(avg_p_value) if not np.isnan(avg_p_value) else np.nan,
        'min_distance_p_value': float(min_p_value) if not np.isnan(min_p_value) else np.nan,
        'percentile_breakdown': percentile_breakdown,
        'avg_distances': avg_distances,
        'min_distances': min_distances,
        'correctness': correctness
    }


def cal_recall(ans1, ans2):
    """Calculate the recall given the nearest neighbors results."""
    # ans1: (distances, indices) from feature1
    # ans2: (distances, indices) from feature2
    # Assuming that `ans1` and `ans2` are tuples of (distances, indices)
    if isinstance(ans1, tuple):
        ans1 = ans1[1].flatten()
    if isinstance(ans2, tuple):
        ans2 = ans2[1].flatten()
    if isinstance(ans1, np.ndarray) and isinstance(ans2, np.ndarray):
        indices1 = set(ans1.flatten())
        indices2 = set(ans2.flatten())
        
        intersection = len(indices1.intersection(indices2))
        union = len(indices1.union(indices2))
        
        if union == 0:
            return 0.0
        return intersection / union
    indices1 = set(ans1)
    indices2 = set(ans2)
    
    intersection = len(indices1.intersection(indices2))
    union = len(indices1.union(indices2))
    
    if union == 0:
        return 0.0
    return intersection / union

def compute_completeness(R: np.ndarray, S: np.ndarray, distance_metric: str = 'euclidean', threshold: float = 0.1):
    """
    计算 completeness，并根据阈值返回满足条件的点的比例。

    completeness = max_{v ∈ S} min_{u ∈ R} dist(u, v)

    参数:
    - R: 集合 R，形状为 (n_r, d)
    - S: 集合 S，形状为 (n_s, d)
    - distance_metric: 距离度量方式，默认为 'euclidean'
    - threshold: 距离阈值

    返回:
    - completeness 的值
    - 满足 min_dist <= threshold 的点的比例 (0 到 1 之间)
    """
    # 计算 S 到 R 的距离矩阵，形状为 (|S|, |R|)
    distances = cdist(S, R, metric=distance_metric)
    # 对每个 v ∈ S，找到最小距离
    min_distances = distances.min(axis=1)
    # 取这些最小距离的最大值
    completeness = min_distances.max()
    # 计算满足 min_dist <= threshold 的点的数量
    num_within_threshold = np.sum(min_distances <= threshold)
    # 计算比例
    proportion_within_threshold = num_within_threshold / len(S)
    return completeness, proportion_within_threshold

def compute_soundness(R: np.ndarray, S: np.ndarray, distance_metric: str = 'euclidean', threshold: np.ndarray = None):
    """
    计算 soundness，并根据阈值返回满足条件的点的比例。

    soundness = max_{u ∈ R} min_{v ∈ S} dist(u, v)

    参数:
    - R: 集合 R，形状为 (n_r, d)
    - S: 集合 S，形状为 (n_s, d)
    - distance_metric: 距离度量方式，默认为 'euclidean'
    - threshold: 距离阈值

    返回:
    - soundness 的值
    - 满足 min_dist <= threshold 的点的比例 (0 到 1 之间)
    """
    # 计算 R 到 S 的距离矩阵，形状为 (|R|, |S|)
    distances = cdist(R, S, metric=distance_metric)
    # 对每个 u ∈ R，找到最小距离
    min_distances = distances.min(axis=1)
    # 取这些最小距离的最大值
    soundness = min_distances.max()
    # 计算满足 min_dist <= threshold 的点的数量
    num_within_threshold = np.sum(min_distances <= threshold)
    # 计算比例
    proportion_within_threshold = num_within_threshold / len(R)
    return soundness, proportion_within_threshold

def topk_mean(m, k, inplace=False):
    """
    Compute the mean of the top k values for each row in matrix m
    """
    if isinstance(m, torch.Tensor):
        device = m.device
        n = m.shape[0]
        ans = torch.zeros(n, dtype=m.dtype, device=device)
        if k <= 0:
            return ans
        if not inplace:
            m = m.clone()
        minimum = m.min()
        for _ in range(k):
            ind1 = m.argmax(dim=1)
            ans += m[torch.arange(n, device=device), ind1]
            m[torch.arange(n, device=device), ind1] = minimum
        return ans / k
    else:
        # Handle numpy arrays
        n = m.shape[0]
        ans = np.zeros(n, dtype=m.dtype)
        if k <= 0:
            return ans
        if not inplace:
            m = np.array(m)
        ind0 = np.arange(n)
        ind1 = np.empty(n, dtype=int)
        minimum = m.min()
        for _ in range(k):
            np.argmax(m, axis=1, out=ind1)
            ans += m[ind0, ind1]
            m[ind0, ind1] = minimum
        return ans / k

def get_topk(dist_vec1, dist_vec2=None, k=5, metric='euclidean', return_dist=False, csls_neighborhood=0, use_faiss=True):
    """
    Get the top k nearest neighbors for each row in dist_vec2 in the space of dist_vec1.
    
    Automatically detects GPU availability and uses appropriate computation backend:
    - GPU with CUDA: Uses PyTorch tensors on GPU + FAISS GPU (if available)
    - CPU: Uses NumPy arrays + FAISS CPU or PyTorch CPU
    
    Args:
        dist_vec1: Reference vectors (database)
        dist_vec2: Query vectors (if None, uses dist_vec1)
        k: Number of nearest neighbors
        metric: Distance metric ('euclidean', 'cosine')
        return_dist: Whether to return distances along with indices
        csls_neighborhood: CSLS correction neighborhood size (0 to disable)
        use_faiss: Whether to use FAISS for k-NN search when possible
    """
    import faiss
    from utils.graph_util import get_dists
    
    if dist_vec2 is None:
        dist_vec2 = dist_vec1
    
    # Auto-detect GPU availability
    use_gpu = torch.cuda.is_available() and torch.cuda.device_count() > 0
    device = torch.device('cuda' if use_gpu else 'cpu')
    
    # Convert to appropriate format based on GPU availability
    if use_gpu:
        # Use GPU with PyTorch tensors
        if isinstance(dist_vec1, np.ndarray):
            dist_vec1 = torch.from_numpy(dist_vec1).float().to(device)
        elif isinstance(dist_vec1, torch.Tensor):
            dist_vec1 = dist_vec1.float().to(device)
            
        if isinstance(dist_vec2, np.ndarray):
            dist_vec2 = torch.from_numpy(dist_vec2).float().to(device)
        elif isinstance(dist_vec2, torch.Tensor):
            dist_vec2 = dist_vec2.float().to(device)
    else:
        # Use CPU with NumPy arrays
        if isinstance(dist_vec1, torch.Tensor):
            dist_vec1 = dist_vec1.cpu().numpy()
        if isinstance(dist_vec2, torch.Tensor):
            dist_vec2 = dist_vec2.cpu().numpy()
    
    # Apply CSLS correction if needed
    if csls_neighborhood > 0:
        if use_gpu:
            dists = get_dists(dist_vec2, dist_vec1, metric, use_gpu=True, device=device)
            sim = -dists
            
            # CSLS correction with PyTorch
            knn_sim_fwd = torch.topk(sim, k=csls_neighborhood, dim=1)[0].mean(dim=1)
            knn_sim_bwd = torch.topk(sim.T, k=csls_neighborhood, dim=1)[0].mean(dim=1)
            sim = sim - knn_sim_fwd.unsqueeze(1)/2 - knn_sim_bwd.unsqueeze(0)/2
            
            # Convert back to distances
            dists = -sim
            knn_dists, knn_indices = torch.topk(dists, k, largest=False)
        else:
            dists = get_dists(dist_vec2, dist_vec1, metric, use_gpu=False)
            sim = -dists
            
            # CSLS correction with NumPy
            knn_sim_fwd = topk_mean(sim, k=csls_neighborhood)
            knn_sim_bwd = topk_mean(sim.T, k=csls_neighborhood)
            sim = sim - knn_sim_fwd[:, np.newaxis]/2 - knn_sim_bwd[np.newaxis, :]/2
            
            # Convert back to distances  
            dists = -sim
            knn_indices = np.argpartition(dists, k, axis=1)[:, :k]
            knn_dists = np.take_along_axis(dists, knn_indices, axis=1)
            
            # Sort within each row
            sort_indices = np.argsort(knn_dists, axis=1)
            knn_indices = np.take_along_axis(knn_indices, sort_indices, axis=1)
            knn_dists = np.take_along_axis(knn_dists, sort_indices, axis=1)
    
    else:
        # No CSLS correction - use FAISS for efficient k-NN search
        if use_faiss and metric in ['euclidean', 'cosine']:
            gpu_success = False
            
            if use_gpu:
                # Try FAISS GPU
                try:
                    # Convert to numpy for FAISS
                    vec1_np = dist_vec1.cpu().numpy().astype(np.float32)
                    vec2_np = dist_vec2.cpu().numpy().astype(np.float32)
                    
                    if metric == 'cosine':
                        # Normalize for cosine similarity
                        faiss.normalize_L2(vec1_np)
                        faiss.normalize_L2(vec2_np)
                        index = faiss.IndexFlatIP(vec1_np.shape[1])  # Inner product for cosine
                    else:
                        index = faiss.IndexFlatL2(vec1_np.shape[1])  # L2 for euclidean
                    
                    # Try to move to GPU
                    gpu_res = faiss.StandardGpuResources()
                    index = faiss.index_cpu_to_gpu(gpu_res, 0, index)
                    
                    index.add(vec1_np)
                    knn_dists, knn_indices = index.search(vec2_np, k)
                    
                    # Convert results back to PyTorch tensors
                    knn_indices = torch.from_numpy(knn_indices).to(device)
                    knn_dists = torch.from_numpy(knn_dists).to(device)
                    gpu_success = True
                    
                except Exception:
                    # Fallback to CPU FAISS
                    pass
            
            # CPU FAISS (either initially CPU or GPU fallback)
            if not gpu_success:
                vec1_np = dist_vec1.astype(np.float32) if isinstance(dist_vec1, np.ndarray) else dist_vec1
                vec2_np = dist_vec2.astype(np.float32) if isinstance(dist_vec2, np.ndarray) else dist_vec2
                
                if metric == 'cosine':
                    # Normalize for cosine similarity
                    faiss.normalize_L2(vec1_np)
                    faiss.normalize_L2(vec2_np)
                    index = faiss.IndexFlatIP(vec1_np.shape[1])  # Inner product for cosine
                else:
                    index = faiss.IndexFlatL2(vec1_np.shape[1])  # L2 for euclidean
                
                index.add(vec1_np)
                knn_dists, knn_indices = index.search(vec2_np, k)
        
        else:
            # Fallback to manual computation without FAISS
            if use_gpu:
                dists = get_dists(dist_vec2, dist_vec1, metric, use_gpu=True, device=device)
                knn_dists, knn_indices = torch.topk(dists, k, largest=False)
            else:
                dists = get_dists(dist_vec2, dist_vec1, metric, use_gpu=False)
                knn_indices = np.argpartition(dists, k, axis=1)[:, :k]
                knn_dists = np.take_along_axis(dists, knn_indices, axis=1)
                
                # Sort within each row
                sort_indices = np.argsort(knn_dists, axis=1)
                knn_indices = np.take_along_axis(knn_indices, sort_indices, axis=1)
                knn_dists = np.take_along_axis(knn_dists, sort_indices, axis=1)
    
    if return_dist:
        return knn_indices, knn_dists
    else:
        return knn_indices

def knn_hit_ratio(dist_vec1, dist_vec2, k=5, metric='euclidean'):
    N = dist_vec1.shape[0]
    knn_indices = get_topk(dist_vec1, dist_vec2, k, metric)
    # Check if the correct index is among the k nearest for each row
    correct = torch.arange(N, device=dist_vec1.device).unsqueeze(1)
    hits = (knn_indices == correct).any(dim=1).float().mean().item()
    return hits

def deduplicate_pairs(pairs):
    # Deduplicate pairs - keep only the shortest distance for each element in each position
    first_pos_best = {}  # first_position_element -> (second_element, distance, full_pair)
    second_pos_best = {}  # second_position_element -> (first_element, distance, full_pair)
    
    for i, j, dist in pairs:
        # Check if element i (first position) already has a better pair
        if i not in first_pos_best or dist < first_pos_best[i][1]:
            first_pos_best[i] = (j, dist, (i, j, dist))
        
        # Check if element j (second position) already has a better pair  
        if j not in second_pos_best or dist < second_pos_best[j][1]:
            second_pos_best[j] = (i, dist, (i, j, dist))
    
    # Find pairs that are optimal for both positions
    deduplicated_pairs = []
    for i, (j, dist, pair) in first_pos_best.items():
        if second_pos_best.get(j, (None, None, None))[2] == pair:
            deduplicated_pairs.append(pair)
    return deduplicated_pairs
        

def find_mutual_pairs(dist_vec1, dist_vec2, ind_emb1_unique, ind_emb2_unique, 
                             args, device, use_gpu=True):
    """
    Unified function to find mutual pairs using either GPU-optimized or CPU fallback method.
    
    Args:
        dist_vec1_subset, dist_vec2_subset: Distance vectors
        ind_emb1_unique, ind_emb2_unique: Original indices
        args: Arguments containing topk, distance_metric, etc.
        device: GPU device (unused for CPU path)
        use_gpu: Whether to use GPU acceleration
        
    Returns:
        mutual_pairs: List of (i, nearest_i, distance) tuples
        mutual_nn: Number of mutual pairs
        correct: Number of correct pairs
    """
    from utils.graph_util import get_dists
    if use_gpu and torch.cuda.is_available():
        if not torch.is_tensor(dist_vec1):
            dist_vec1 = torch.from_numpy(dist_vec1).to(device).float()
        else:
            dist_vec1 = dist_vec1.to(device)
            
        if not torch.is_tensor(dist_vec2):
            dist_vec2 = torch.from_numpy(dist_vec2).to(device).float()
        else:
            dist_vec2 = dist_vec2.to(device)
        
        ind_emb1_tensor = torch.from_numpy(ind_emb1_unique).to(device).long()
        ind_emb2_tensor = torch.from_numpy(ind_emb2_unique).to(device).long()
        
        # Memory-efficient chunked distance computation
        batch_size = min(1000, dist_vec2.shape[0])  # Adjust based on available memory
        
        def compute_distance_for_pairs(query_vecs, ref_vecs, query_indices, ref_indices, distance_metric):
            """Compute distances for specific pairs efficiently"""
            if distance_metric == 'cosine':
                # Normalize vectors
                query_norm = torch.nn.functional.normalize(query_vecs[query_indices], dim=1)
                ref_norm = torch.nn.functional.normalize(ref_vecs[ref_indices], dim=1)
                # Cosine similarity for pairs
                sim_pairs = torch.sum(query_norm * ref_norm, dim=1)
                # Convert to distance
                distances = 1.0 - sim_pairs
            elif distance_metric == 'euclidean':
                # Euclidean distance for pairs
                distances = torch.norm(query_vecs[query_indices] - ref_vecs[ref_indices], dim=1)
            
            return distances

        # OPTIMIZATION 1.3: Check if vectors are already normalized to skip redundant normalization
        # Vectors are pre-normalized in ensemble workers (ensemble_selection.py:413-421)
        vec1_norms = torch.norm(dist_vec1, p=2, dim=1)
        vec2_norms = torch.norm(dist_vec2, p=2, dim=1)
        # If mean norm is close to 1.0, vectors are already normalized
        if not (torch.abs(vec1_norms.mean() - 1.0) < 0.01 and torch.abs(vec2_norms.mean() - 1.0) < 0.01):
            # Not normalized yet, normalize now
            dist_vec1 = torch.nn.functional.normalize(dist_vec1, p=2, dim=1)
            dist_vec2 = torch.nn.functional.normalize(dist_vec2, p=2, dim=1)

        nearest_ind1, nearest_dist1 = get_topk(dist_vec1, dist_vec2, k=args.topk,
                          return_dist=True,
                          csls_neighborhood=args.csls_neighborhood, use_faiss=True)
        nearest_ind2, nearest_dist2 = get_topk(dist_vec2, dist_vec1, k=args.topk,
                                return_dist=True,
                                csls_neighborhood=args.csls_neighborhood, use_faiss=True)
        # nearest_ind1, _ = get_topk(dist_vec1, dist_vec2, k=args.topk,
        #                   metric=args.distance_metric, return_dist=True,
        #                   csls_neighborhood=args.csls_neighborhood, use_faiss=True)
        # nearest_ind2, _ = get_topk(dist_vec2, dist_vec1, k=args.topk,
        #                         metric=args.distance_metric, return_dist=True,
        #                         csls_neighborhood=args.csls_neighborhood, use_faiss=True)
        
        # Find mutual pairs using fully vectorized GPU operations
        n_points = dist_vec2.shape[0]

        # Vectorized mutual NN detection on GPU
        # Ensure nearest_ind1 and nearest_ind2 are tensors on the same device
        if not torch.is_tensor(nearest_ind1):
            nearest_ind1 = torch.from_numpy(nearest_ind1).to(device)
        if not torch.is_tensor(nearest_ind2):
            nearest_ind2 = torch.from_numpy(nearest_ind2).to(device)
        if not torch.is_tensor(nearest_dist1):
            nearest_dist1 = torch.from_numpy(nearest_dist1).to(device)

        # Detect device from actual tensors (important for multi-GPU setups)
        actual_device = nearest_ind1.device

        # Move to same device if needed
        if nearest_ind2.device != actual_device:
            nearest_ind2 = nearest_ind2.to(actual_device)
        if nearest_dist1.device != actual_device:
            nearest_dist1 = nearest_dist1.to(actual_device)

        # Create index tensor for all points on the correct device
        idx = torch.arange(n_points, device=actual_device)

        # For efficiency, we'll check mutual relationships in a vectorized way
        # nearest_ind1[i, k] gives the k-th nearest neighbor of point i
        # We need to check if i is in nearest_ind2[nearest_ind1[i, k]]

        # Expand idx to match nearest_ind1 shape for broadcasting
        idx_expanded = idx.unsqueeze(1).expand(-1, args.topk)  # (n_points, topk)

        # Gather the neighbor lists for all neighbors at once
        # nearest_ind1: (n_points, topk) - for each point, its k neighbors
        # We need to check if idx_expanded[i, k] (which is i) is in nearest_ind2[nearest_ind1[i, k]]

        # Flatten to gather all neighbor lists efficiently
        neighbors_flat = nearest_ind1.flatten()  # (n_points * topk,)
        neighbor_lists = nearest_ind2[neighbors_flat]  # (n_points * topk, topk)
        neighbor_lists = neighbor_lists.reshape(n_points, args.topk, args.topk)  # (n_points, topk, topk)

        # Check if i is in the neighbor list of its k-th neighbor
        # idx_expanded: (n_points, topk, 1) vs neighbor_lists: (n_points, topk, topk)
        is_mutual = (idx_expanded.unsqueeze(2) == neighbor_lists).any(dim=2)  # (n_points, topk)

        # For each point, find the first mutual neighbor (to match original "break" behavior)
        has_mutual = is_mutual.any(dim=1)  # (n_points,)
        first_mutual_k = torch.where(is_mutual,
                                      torch.arange(args.topk, device=actual_device).unsqueeze(0).expand(n_points, -1),
                                      torch.full((n_points, args.topk), args.topk, device=actual_device)).min(dim=1)[0]

        # Extract mutual pairs: only points that have mutual neighbors
        valid_points = torch.where(has_mutual)[0]  # Indices of points with mutual neighbors
        valid_k = first_mutual_k[has_mutual]  # Their first mutual neighbor index

        # Get the actual neighbor indices and distances
        i_indices = valid_points  # First element of pair
        j_indices = nearest_ind1[valid_points, valid_k]  # Second element of pair
        distances = nearest_dist1[valid_points, valid_k]  # Distance between them

        # Convert to list of tuples (matching original output format)
        # Single GPU->CPU transfer here instead of thousands
        mutual_pairs = list(zip(i_indices.cpu().numpy(),
                               j_indices.cpu().numpy(),
                               distances.cpu().numpy()))

        deduplicated_pairs = deduplicate_pairs(mutual_pairs)
        
        correct = 0
        for i, j, dist in deduplicated_pairs:
            orig_j = ind_emb1_tensor[j].item()
            orig_i = ind_emb2_tensor[i].item()
            if orig_i == orig_j:
                correct += 1
        
        return deduplicated_pairs, len(deduplicated_pairs), correct
    else:
        # CPU fallback
        nearest_ind1, nearest_dist1 = get_topk(dist_vec1, dist_vec2, k=args.topk,
                                   metric=args.distance_metric, return_dist=True,
                                   csls_neighborhood=args.csls_neighborhood)
        nearest_ind2, nearest_dist2 = get_topk(dist_vec2, dist_vec1, k=args.topk,
                                   metric=args.distance_metric, return_dist=True,
                                   csls_neighborhood=args.csls_neighborhood)

        # OPTIMIZED: Convert tensors to numpy ONCE before loop to avoid per-iteration .item() overhead
        if torch.is_tensor(nearest_ind1):
            nearest_ind1 = nearest_ind1.cpu().numpy()
        if torch.is_tensor(nearest_ind2):
            nearest_ind2 = nearest_ind2.cpu().numpy()
        if torch.is_tensor(nearest_dist1):
            nearest_dist1 = nearest_dist1.cpu().numpy()

        subset_mutual_pairs = []
        correct = 0
        mutual_nn = 0

        for i, neighbors_of_i in enumerate(nearest_ind1):
            for k_idx in range(args.topk):
                nearest_i = int(neighbors_of_i[k_idx])  # Direct numpy indexing, no .item()
                neighbors_of_nearest_i = nearest_ind2[nearest_i]
                if i in neighbors_of_nearest_i:
                    mutual_nn += 1
                    if ind_emb2_unique[i] == ind_emb1_unique[nearest_i]:
                        correct += 1

                    # Direct numpy indexing, no .item() needed
                    dist_between_pair = float(nearest_dist1[i, k_idx])
                    # (ind2, ind1, dist)
                    subset_mutual_pairs.append((i, nearest_i, dist_between_pair))
                    break

        deduplicated_pairs = deduplicate_pairs(subset_mutual_pairs)

        # Recalculate correct count for deduplicated pairs
        correct = 0
        for i, j, dist in deduplicated_pairs:
            if ind_emb2_unique[i] == ind_emb1_unique[j]:
                correct += 1

        return deduplicated_pairs, len(deduplicated_pairs), correct


def weighted_pearson_correlation(x, y, weights):
    """
    Compute weighted Pearson correlation between two vectors.

    Uses the weighted correlation formula where each dimension (anchor)
    contributes according to its weight. This is useful when some dimensions
    are more reliable than others (e.g., shorter distances are better preserved).

    Args:
        x, y: Arrays of same length (distance profiles)
        weights: Array of weights (one per dimension/anchor), non-negative

    Returns:
        Weighted correlation coefficient in [-1, 1], or 0.0 if degenerate
    """
    x = np.asarray(x)
    y = np.asarray(y)
    weights = np.asarray(weights)

    # Normalize weights to sum to 1
    w_sum = np.sum(weights)
    if w_sum < 1e-10:
        return 0.0
    w = weights / w_sum

    # Weighted means
    mu_x = np.sum(w * x)
    mu_y = np.sum(w * y)

    # Centered values
    x_centered = x - mu_x
    y_centered = y - mu_y

    # Weighted covariance and variances
    cov_xy = np.sum(w * x_centered * y_centered)
    var_x = np.sum(w * x_centered ** 2)
    var_y = np.sum(w * y_centered ** 2)

    # Correlation
    denom = np.sqrt(var_x * var_y)
    if denom < 1e-10:
        return 0.0
    return cov_xy / denom


def compute_anchor_weights(dist_profiles1, dist_profiles2, method="rbf", params=None):
    """
    Compute per-anchor weights based on distance preservation principle.

    Closer anchors (shorter average distances) get higher weights because
    short distances are typically better preserved across embedding spaces.

    Args:
        dist_profiles1: Distance profiles from space 1 (n1, k) - raw, unnormalized
        dist_profiles2: Distance profiles from space 2 (n2, k) - raw, unnormalized
        method: Weighting method:
            - "rbf": exp(-d / sigma) - exponential decay (default)
            - "inverse": 1 / (1 + scale * d) - inverse distance
            - "sigmoid": 1 / (1 + exp(scale * (d - midpoint))) - soft threshold
            - "uniform": equal weights (no distance weighting)
        params: Dict with method-specific parameters:
            - For "rbf": {"sigma": float} (default: auto-compute from median)
            - For "inverse": {"scale": float} (default: 1.0)
            - For "sigmoid": {"scale": float, "midpoint": float} (defaults: 1.0, auto-median)

    Returns:
        weights: Array of shape (k,) with per-anchor weights
    """
    dist_profiles1 = np.asarray(dist_profiles1)
    dist_profiles2 = np.asarray(dist_profiles2)

    # Compute average distance to each anchor across both spaces
    avg_dist1 = np.mean(dist_profiles1, axis=0)  # (k,)
    avg_dist2 = np.mean(dist_profiles2, axis=0)  # (k,)
    avg_dist = (avg_dist1 + avg_dist2) / 2  # Combined average per anchor

    params = params or {}

    if method == "rbf":
        # RBF: exp(-d / sigma) - short distances get weight ~1, long distances ~0
        sigma = params.get("sigma")
        if sigma is None:
            # Auto-compute using median heuristic
            positive_dists = avg_dist[avg_dist > 0]
            sigma = np.median(positive_dists) if len(positive_dists) > 0 else 1.0
        sigma = max(sigma, 1e-8)
        weights = np.exp(-avg_dist / sigma)

    elif method == "inverse":
        # Inverse: 1 / (1 + scale * d)
        scale = params.get("scale", 1.0)
        weights = 1.0 / (1.0 + scale * avg_dist)

    elif method == "sigmoid":
        # Sigmoid: smooth threshold at midpoint
        scale = params.get("scale", 1.0)
        midpoint = params.get("midpoint")
        if midpoint is None:
            midpoint = np.median(avg_dist[avg_dist > 0]) if np.any(avg_dist > 0) else 0.5
        weights = 1.0 / (1.0 + np.exp(scale * (avg_dist - midpoint)))

    else:  # "uniform" or unknown
        # Equal weights for all anchors
        weights = np.ones(len(avg_dist))

    return weights


def compute_overlap_confidence(dist_vec1, dist_vec2, mutual_pairs, return_per_pair=False,
                               n_random_trials=10, seed=42,
                               use_distance_weighting=False,
                               weighting_method="rbf",
                               weighting_params=None,
                               raw_dist_profiles=None):
    """
    Compute confidence that mutual pairs represent true structural overlap.

    For truly corresponding pairs, their distance profiles to reference points
    should be correlated across the two embedding spaces. Compares against
    random baseline to detect spurious correlations from similar geometry.

    For each mutual pair (i, j):
    - i indexes into dist_vec2 (embedding space 2)
    - j indexes into dist_vec1 (embedding space 1)
    - Compute Pearson correlation between dist_vec1[j] and dist_vec2[i]

    Optionally uses distance-weighted correlation, where anchors with shorter
    average distances contribute more to the correlation. This leverages the
    observation that short distances are typically better preserved across
    embedding spaces.

    Args:
        dist_vec1: Distance vectors from emb1 to references (shape: n_emb1, n_refs)
        dist_vec2: Distance vectors from emb2 to references (shape: n_emb2, n_refs)
        mutual_pairs: List of (i, j, dist) tuples from find_mutual_pairs
        return_per_pair: If True, also return per-pair correlations
        n_random_trials: Number of random shuffles for baseline
        seed: Random seed for reproducibility
        use_distance_weighting: If True, weight correlation by anchor distances.
            Anchors with shorter average distances get higher weights.
        weighting_method: Method for computing weights from distances:
            - "rbf": exp(-d / sigma) - exponential decay (default)
            - "inverse": 1 / (1 + scale * d)
            - "sigmoid": 1 / (1 + exp(scale * (d - midpoint)))
            - "uniform": equal weights
        weighting_params: Dict of method-specific parameters
        raw_dist_profiles: Tuple (raw_dist1, raw_dist2) of unnormalized distance
            profiles for computing weights. Required if use_distance_weighting=True.

    Returns:
        mean_correlation: Average correlation across pairs
        std_correlation: Standard deviation of correlations
        correlation_lift: Difference from random baseline (key metric!)
        random_baseline: Mean correlation of random pairings
        (optional) per_pair_correlations: List of correlation values
    """
    if len(mutual_pairs) == 0:
        if return_per_pair:
            return 0.0, 0.0, 0.0, 0.0, []
        return 0.0, 0.0, 0.0, 0.0

    # Convert to numpy if needed
    if torch.is_tensor(dist_vec1):
        dist_vec1 = dist_vec1.cpu().numpy()
    if torch.is_tensor(dist_vec2):
        dist_vec2 = dist_vec2.cpu().numpy()

    # Compute anchor weights if distance weighting is enabled
    anchor_weights = None
    if use_distance_weighting and raw_dist_profiles is not None:
        raw_dist1, raw_dist2 = raw_dist_profiles
        # Convert to numpy if needed
        if torch.is_tensor(raw_dist1):
            raw_dist1 = raw_dist1.cpu().numpy()
        if torch.is_tensor(raw_dist2):
            raw_dist2 = raw_dist2.cpu().numpy()
        anchor_weights = compute_anchor_weights(
            raw_dist1, raw_dist2,
            method=weighting_method,
            params=weighting_params
        )

    # Compute actual correlations for matched pairs
    correlations = []
    indices_i = []
    indices_j = []
    for i, j, _ in mutual_pairs:
        # i is index in emb2, j is index in emb1
        vec1 = dist_vec1[j]  # Distance profile of point j in space 1
        vec2 = dist_vec2[i]  # Distance profile of point i in space 2
        indices_i.append(i)
        indices_j.append(j)

        # Compute correlation (weighted or standard)
        if anchor_weights is not None:
            corr = weighted_pearson_correlation(vec1, vec2, anchor_weights)
        else:
            corr = np.corrcoef(vec1, vec2)[0, 1]

        if not np.isnan(corr):
            correlations.append(corr)

    if len(correlations) == 0:
        if return_per_pair:
            return 0.0, 0.0, 0.0, 0.0, []
        return 0.0, 0.0, 0.0, 0.0

    mean_corr = np.mean(correlations)
    std_corr = np.std(correlations)

    # Compute random baseline: shuffle indices_i and recompute correlations
    # Uses the same weighting scheme for fair comparison
    rng = np.random.RandomState(seed)
    random_correlations = []

    for _ in range(n_random_trials):
        shuffled_i = rng.permutation(indices_i)
        trial_corrs = []
        for idx, (si, j) in enumerate(zip(shuffled_i, indices_j)):
            vec1 = dist_vec1[j]
            vec2 = dist_vec2[si]

            # Use same weighting for baseline
            if anchor_weights is not None:
                corr = weighted_pearson_correlation(vec1, vec2, anchor_weights)
            else:
                corr = np.corrcoef(vec1, vec2)[0, 1]

            if not np.isnan(corr):
                trial_corrs.append(corr)
        if trial_corrs:
            random_correlations.append(np.mean(trial_corrs))

    random_baseline = np.mean(random_correlations) if random_correlations else 0.0
    correlation_lift = mean_corr - random_baseline

    if return_per_pair:
        return mean_corr, std_corr, correlation_lift, random_baseline, correlations
    return mean_corr, std_corr, correlation_lift, random_baseline


def interpret_overlap_confidence(mean_corr, std_corr, n_pairs, correlation_lift=None, random_baseline=None):
    """
    Interpret the overlap confidence score using correlation lift over random baseline.

    The key metric is correlation_lift (actual - random), which indicates how much
    better the matched pairs correlate compared to random pairings.

    Threshold interpretation for correlation_lift:
    - lift > 0.15: Strong overlap evidence
    - 0.02 < lift <= 0.15: Moderate overlap (detects even small overlaps)
    - 0.01 < lift <= 0.02: Weak overlap (pairs may be spurious)
    - lift <= 0.01: Likely no meaningful overlap (matches random)

    Args:
        mean_corr: Mean correlation from compute_overlap_confidence
        std_corr: Standard deviation of correlations
        n_pairs: Number of pairs analyzed
        correlation_lift: Difference from random baseline (primary metric)
        random_baseline: Mean correlation of random pairings

    Returns:
        has_overlap: bool - whether meaningful overlap is likely
        confidence_level: str - "high", "medium", "low", "none"
        message: str - human-readable interpretation
    """
    # Use correlation_lift if provided, otherwise fall back to raw correlation
    if correlation_lift is not None and random_baseline is not None:
        lift = correlation_lift
        baseline_info = f", lift={lift:.3f}, baseline={random_baseline:.3f}"

        if lift > 0.15:
            return True, "high", f"Strong overlap signal (corr={mean_corr:.3f}±{std_corr:.3f}{baseline_info}, n={n_pairs})"
        elif lift > 0.02:
            return True, "medium", f"Moderate overlap signal (corr={mean_corr:.3f}±{std_corr:.3f}{baseline_info}, n={n_pairs})"
        elif lift > 0.01:
            return False, "low", f"Weak overlap signal (corr={mean_corr:.3f}±{std_corr:.3f}{baseline_info}, n={n_pairs}), pairs may be spurious"
        else:
            return False, "none", f"No significant overlap detected (corr={mean_corr:.3f}±{std_corr:.3f}{baseline_info}, n={n_pairs})"
    else:
        # Legacy behavior for backwards compatibility
        if mean_corr > 0.6:
            return True, "high", f"Strong overlap signal (corr={mean_corr:.3f}±{std_corr:.3f}, n={n_pairs})"
        elif mean_corr > 0.4:
            return True, "medium", f"Moderate overlap signal (corr={mean_corr:.3f}±{std_corr:.3f}, n={n_pairs})"
        elif mean_corr > 0.2:
            return False, "low", f"Weak overlap signal (corr={mean_corr:.3f}±{std_corr:.3f}, n={n_pairs}), pairs may be spurious"
        else:
            return False, "none", f"No significant overlap detected (corr={mean_corr:.3f}±{std_corr:.3f}, n={n_pairs})"


def _find_mutual_pairs_cpu(dist_profiles1, dist_profiles2, indices1, indices2, topk=5):
    """
    CPU-only mutual pair finding using scipy (no faiss dependency).

    Uses cosine distance on normalized distance profiles.
    """
    # Compute pairwise distances between profiles
    # Since profiles are normalized, use euclidean which is equivalent to cosine
    pairwise_dists = cdist(dist_profiles1, dist_profiles2, metric='euclidean')

    n1, n2 = pairwise_dists.shape

    # Find top-k nearest neighbors from 1 -> 2
    nearest_from_1 = np.argsort(pairwise_dists, axis=1)[:, :topk]  # (n1, topk)

    # Find top-k nearest neighbors from 2 -> 1
    nearest_from_2 = np.argsort(pairwise_dists.T, axis=1)[:, :topk]  # (n2, topk)

    # Find mutual pairs
    mutual_pairs = []
    for i in range(n1):
        for rank, j in enumerate(nearest_from_1[i]):
            # Check if i is in j's top-k neighbors
            if i in nearest_from_2[j]:
                dist = pairwise_dists[i, j]
                mutual_pairs.append((i, j, dist))
                break  # Only take best mutual match per point

    return mutual_pairs


@dataclass
class OverlapDetectionResult:
    """Result of overlap detection between two embedding sets."""
    has_overlap: bool
    confidence_level: str  # "high", "medium", "low", "none"
    correlation_lift: float
    random_baseline: float
    mean_correlation: float
    std_correlation: float
    n_mutual_pairs: int
    mutual_pair_density: float  # n_mutual_pairs / n_non_anchor_points
    estimated_overlap_pct: float  # Estimated overlap percentage (0-100)
    message: str  # Human-readable interpretation


def estimate_overlap_from_lift(lift: float, calibration: str = "real") -> float:
    """
    Estimate overlap percentage from measured correlation lift.

    Two calibration options are available:

    1. "real" (default): Calibrated on real embedding model pairs (6 models, 5 datasets)
        lift = 0.00534 * overlap_pct + 0.0100
        overlap_pct = (lift - 0.0100) / 0.00534
        R² = 0.85, MAE = 8.7%

    2. "synthetic": Original calibration from synthetic experiments
        lift = 0.0102 * overlap_pct - 0.0178
        overlap_pct = (lift + 0.0178) / 0.0102
        R² = 0.998 (on synthetic data only)

    Args:
        lift: Measured correlation lift
        calibration: "real" for real embedding calibration (recommended),
                    "synthetic" for original synthetic calibration

    Returns:
        Estimated overlap percentage (0-100), clamped to valid range
    """
    if calibration == "synthetic":
        # Coefficients from synthetic overlap experiments (R² = 0.998)
        a = 0.0102  # slope
        b = -0.0178  # intercept
    else:  # "real" - default
        # Coefficients from real embedding experiments (R² = 0.85)
        # Calibrated on: kalm, mistral, e5, openai, gritlm, gte models
        # Datasets: arguana, scifact, fiqa, nfcorpus, scidocs
        a = 0.00534  # slope
        b = 0.0100   # intercept

    estimated_pct = (lift - b) / a
    return max(0.0, min(100.0, estimated_pct))


def detect_overlap_with_anchors(
    emb1: np.ndarray,
    emb2: np.ndarray,
    anchor_indices1: np.ndarray,
    anchor_indices2: np.ndarray,
    distance_metric: str = 'cosine',
    topk: int = 5,
    min_mutual_pairs: int = 10,
    correlation_lift_threshold: float = 0.02,
    use_gpu: bool = False,
    device: torch.device = None,
    verbose: bool = True,
    use_distance_weighting: bool = False,
    weighting_method: str = "rbf",
    weighting_params: dict = None
) -> OverlapDetectionResult:
    """
    Detect whether meaningful overlap exists between two embedding sets,
    given only supervised anchor pairs.

    Algorithm:
    1. Use anchor pairs as reference points to compute distance profiles
    2. For non-anchor points, compute distance profiles to anchors
    3. Find mutual nearest neighbors in distance profile space
    4. Compute correlation lift to determine if matches are better than random
    5. Return binary decision based on correlation lift threshold

    Args:
        emb1: First embedding set (n1, d1)
        emb2: Second embedding set (n2, d2)
        anchor_indices1: Indices of anchors in emb1 (k,)
        anchor_indices2: Corresponding anchor indices in emb2 (k,)
        distance_metric: 'cosine' or 'euclidean'
        topk: k for mutual k-NN matching
        min_mutual_pairs: Minimum mutual pairs required for reliable detection
        correlation_lift_threshold: Threshold below which overlap is "none"
        use_gpu: Use GPU acceleration
        device: PyTorch device
        verbose: Print diagnostic information
        use_distance_weighting: If True, weight correlation by anchor distances.
            Closer anchors contribute more, leveraging the observation that
            short distances are better preserved across embedding spaces.
        weighting_method: Method for computing weights ("rbf", "inverse", "sigmoid", "uniform")
        weighting_params: Dict of method-specific parameters (e.g., {"sigma": 0.5})

    Returns:
        OverlapDetectionResult containing detection decision and metrics
    """
    # Validate inputs
    anchor_indices1 = np.asarray(anchor_indices1)
    anchor_indices2 = np.asarray(anchor_indices2)
    n_anchors = len(anchor_indices1)

    if n_anchors != len(anchor_indices2):
        raise ValueError("anchor_indices1 and anchor_indices2 must have same length")

    if n_anchors < 3:
        return OverlapDetectionResult(
            has_overlap=False,
            confidence_level="none",
            correlation_lift=0.0,
            random_baseline=0.0,
            mean_correlation=0.0,
            std_correlation=0.0,
            n_mutual_pairs=0,
            mutual_pair_density=0.0,
            estimated_overlap_pct=0.0,
            message=f"Insufficient anchors ({n_anchors} < 3)"
        )

    # Identify non-anchor points
    all_indices1 = np.arange(len(emb1))
    all_indices2 = np.arange(len(emb2))
    non_anchor_mask1 = ~np.isin(all_indices1, anchor_indices1)
    non_anchor_mask2 = ~np.isin(all_indices2, anchor_indices2)
    non_anchor_indices1 = all_indices1[non_anchor_mask1]
    non_anchor_indices2 = all_indices2[non_anchor_mask2]

    if len(non_anchor_indices1) == 0 or len(non_anchor_indices2) == 0:
        return OverlapDetectionResult(
            has_overlap=False,
            confidence_level="none",
            correlation_lift=0.0,
            random_baseline=0.0,
            mean_correlation=0.0,
            std_correlation=0.0,
            n_mutual_pairs=0,
            mutual_pair_density=0.0,
            estimated_overlap_pct=0.0,
            message="No non-anchor points to analyze"
        )

    # Extract anchor embeddings
    anchor_emb1 = emb1[anchor_indices1]  # (k, d1)
    anchor_emb2 = emb2[anchor_indices2]  # (k, d2)

    # Compute distance profiles for non-anchor points to their respective anchors
    non_anchor_emb1 = emb1[non_anchor_indices1]  # (n1-k, d1)
    non_anchor_emb2 = emb2[non_anchor_indices2]  # (n2-k, d2)

    if use_gpu and device is not None:
        from utils.graph_util import get_dists
        non_anchor_t1 = torch.tensor(non_anchor_emb1, device=device, dtype=torch.float32)
        anchor_t1 = torch.tensor(anchor_emb1, device=device, dtype=torch.float32)
        non_anchor_t2 = torch.tensor(non_anchor_emb2, device=device, dtype=torch.float32)
        anchor_t2 = torch.tensor(anchor_emb2, device=device, dtype=torch.float32)

        dist_profiles1 = get_dists(non_anchor_t1, anchor_t1, metric=distance_metric,
                                   use_gpu=True, device=device)
        dist_profiles2 = get_dists(non_anchor_t2, anchor_t2, metric=distance_metric,
                                   use_gpu=True, device=device)
    else:
        # Use scipy for CPU path (avoids faiss dependency)
        dist_profiles1 = cdist(non_anchor_emb1, anchor_emb1, metric=distance_metric)
        dist_profiles2 = cdist(non_anchor_emb2, anchor_emb2, metric=distance_metric)

    # Store raw distance profiles BEFORE normalization (for distance weighting)
    if use_distance_weighting:
        if torch.is_tensor(dist_profiles1):
            raw_dist_profiles1 = dist_profiles1.cpu().numpy().copy()
            raw_dist_profiles2 = dist_profiles2.cpu().numpy().copy()
        else:
            raw_dist_profiles1 = dist_profiles1.copy()
            raw_dist_profiles2 = dist_profiles2.copy()
        raw_dist_profiles = (raw_dist_profiles1, raw_dist_profiles2)
    else:
        raw_dist_profiles = None

    # Normalize distance profiles for comparison
    if torch.is_tensor(dist_profiles1):
        dist_profiles1 = torch.nn.functional.normalize(dist_profiles1, p=2, dim=1)
        dist_profiles2 = torch.nn.functional.normalize(dist_profiles2, p=2, dim=1)
    else:
        norm1 = np.linalg.norm(dist_profiles1, axis=1, keepdims=True) + 1e-8
        norm2 = np.linalg.norm(dist_profiles2, axis=1, keepdims=True) + 1e-8
        dist_profiles1 = dist_profiles1 / norm1
        dist_profiles2 = dist_profiles2 / norm2

    # Find mutual pairs in distance profile space
    if use_gpu and device is not None:
        # Use GPU-accelerated path with faiss
        args = argparse.Namespace(
            topk=topk,
            distance_metric='cosine',
            csls_neighborhood=0
        )
        mutual_pairs, mutual_nn, _ = find_mutual_pairs(
            dist_profiles1,
            dist_profiles2,
            non_anchor_indices1,
            non_anchor_indices2,
            args,
            device,
            use_gpu
        )
    else:
        # CPU-only path using scipy (no faiss dependency)
        mutual_pairs = _find_mutual_pairs_cpu(
            dist_profiles1, dist_profiles2,
            non_anchor_indices1, non_anchor_indices2,
            topk=topk
        )

    n_non_anchors = min(len(non_anchor_indices1), len(non_anchor_indices2))
    mutual_pair_density = len(mutual_pairs) / n_non_anchors if n_non_anchors > 0 else 0.0

    if verbose:
        print(f"[Overlap Detection] Found {len(mutual_pairs)} mutual pairs "
              f"({mutual_pair_density:.1%} of {n_non_anchors} non-anchor points)")

    # Check minimum mutual pairs threshold
    if len(mutual_pairs) < min_mutual_pairs:
        return OverlapDetectionResult(
            has_overlap=False,
            confidence_level="none",
            correlation_lift=0.0,
            random_baseline=0.0,
            mean_correlation=0.0,
            std_correlation=0.0,
            n_mutual_pairs=len(mutual_pairs),
            mutual_pair_density=mutual_pair_density,
            estimated_overlap_pct=0.0,
            message=f"Insufficient mutual pairs ({len(mutual_pairs)} < {min_mutual_pairs})"
        )

    # Compute correlation lift using existing function
    mean_corr, std_corr, correlation_lift, random_baseline = compute_overlap_confidence(
        dist_profiles1,
        dist_profiles2,
        mutual_pairs,
        return_per_pair=False,
        n_random_trials=10,
        use_distance_weighting=use_distance_weighting,
        weighting_method=weighting_method,
        weighting_params=weighting_params,
        raw_dist_profiles=raw_dist_profiles
    )

    if verbose:
        if use_distance_weighting:
            print(f"[Overlap Detection] Using distance-weighted correlation (method={weighting_method})")
        print(f"[Overlap Detection] Correlation: {mean_corr:.3f} +/- {std_corr:.3f}, "
              f"Lift: {correlation_lift:.3f}, Baseline: {random_baseline:.3f}")

    # Make decision based on correlation lift thresholds
    if correlation_lift > 0.15:
        has_overlap = True
        confidence_level = "high"
    elif correlation_lift > correlation_lift_threshold:
        has_overlap = True
        confidence_level = "medium"
    elif correlation_lift > 0.01:
        has_overlap = False
        confidence_level = "low"
    else:
        has_overlap = False
        confidence_level = "none"

    # Get human-readable message
    _, _, message = interpret_overlap_confidence(
        mean_corr, std_corr, len(mutual_pairs),
        correlation_lift=correlation_lift,
        random_baseline=random_baseline
    )

    # Estimate overlap percentage from lift
    estimated_overlap_pct = estimate_overlap_from_lift(correlation_lift)

    if verbose:
        print(f"[Overlap Detection] Estimated overlap: {estimated_overlap_pct:.1f}%")

    return OverlapDetectionResult(
        has_overlap=has_overlap,
        confidence_level=confidence_level,
        correlation_lift=correlation_lift,
        random_baseline=random_baseline,
        mean_correlation=mean_corr,
        std_correlation=std_corr,
        n_mutual_pairs=len(mutual_pairs),
        mutual_pair_density=mutual_pair_density,
        estimated_overlap_pct=estimated_overlap_pct,
        message=message
    )
