"""
Ensemble-based reference selection methods for cross-lingual alignment.

This module provides different ensemble selection strategies including:
- Standard voting-based ensemble selection
- Bernoulli trial-based ensemble selection with posterior distributions
- Voting matrices are maintained in-memory as scipy.sparse matrices (no save/load)
"""

import numpy as np
import torch
import argparse
import time
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
from loguru import logger
from scipy.stats import beta as beta_dist
from scipy.sparse import lil_matrix, csr_matrix

from graph_utils.distance_encoder import compute_distance_encoding
from utils.retrieval_util import find_mutual_pairs, deduplicate_pairs
from utils.clustering import Clusterer
from utils.graph_util import get_dists
from utils.memory_util import (
    estimate_matrix_memory_gb,
    get_available_memory_gb,
    log_memory_usage
)
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.mixture import GaussianMixture


# =============================================================================
# ADAPTIVE OVERLAP INFERENCE METHODS
# =============================================================================
# These methods infer the number of true overlapping pairs from the posterior
# probability distribution, instead of using a fixed threshold.

def estimate_overlap_otsu(posterior_means: np.ndarray) -> tuple:
    """
    Use Otsu's method to find optimal threshold that maximizes inter-class variance.

    This is a classic thresholding method from image processing that finds the
    threshold that best separates a bimodal distribution into two classes.

    Args:
        posterior_means: Array of posterior mean probabilities for each pair

    Returns:
        threshold: Optimal threshold value
        n_selected: Number of pairs above threshold
        method_info: Dict with diagnostic information
    """
    if len(posterior_means) == 0:
        return 0.5, 0, {'method': 'otsu', 'status': 'empty_input'}

    # Discretize to histogram (100 bins between 0 and 1)
    n_bins = 100
    hist, bin_edges = np.histogram(posterior_means, bins=n_bins, range=(0, 1))
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Normalize histogram
    hist = hist.astype(float)
    total = hist.sum()
    if total == 0:
        return 0.5, 0, {'method': 'otsu', 'status': 'zero_histogram'}
    hist = hist / total

    # Compute cumulative sums and means
    cumsum = np.cumsum(hist)
    cumsum_mean = np.cumsum(hist * bin_centers)

    global_mean = cumsum_mean[-1]

    # Compute inter-class variance for each threshold
    inter_class_var = np.zeros(n_bins)
    for t in range(n_bins):
        w0 = cumsum[t]  # Weight of class 0
        w1 = 1 - w0      # Weight of class 1

        if w0 < 1e-10 or w1 < 1e-10:
            continue

        mu0 = cumsum_mean[t] / w0  # Mean of class 0
        mu1 = (global_mean - cumsum_mean[t]) / w1  # Mean of class 1

        inter_class_var[t] = w0 * w1 * (mu0 - mu1) ** 2

    # Find threshold with maximum inter-class variance
    optimal_idx = np.argmax(inter_class_var)
    threshold = bin_centers[optimal_idx]

    n_selected = np.sum(posterior_means > threshold)

    return threshold, n_selected, {
        'method': 'otsu',
        'inter_class_variance': inter_class_var[optimal_idx],
        'status': 'success'
    }


def estimate_overlap_gmm(posterior_means: np.ndarray, n_components: int = 2) -> tuple:
    """
    Use Gaussian Mixture Model to separate true pairs from false pairs.

    Fits a 2-component GMM to the posterior means. The component with higher
    mean represents true pairs. Pairs are assigned to the "true" component
    based on posterior probability.

    Args:
        posterior_means: Array of posterior mean probabilities for each pair
        n_components: Number of mixture components (default: 2 for true/false)

    Returns:
        threshold: Estimated threshold (intersection of two Gaussians)
        n_selected: Number of pairs classified as "true"
        method_info: Dict with diagnostic information
    """
    if len(posterior_means) < n_components * 2:
        return 0.5, 0, {'method': 'gmm', 'status': 'insufficient_data'}

    X = posterior_means.reshape(-1, 1)

    try:
        gmm = GaussianMixture(n_components=n_components, random_state=42, max_iter=200)
        gmm.fit(X)

        # Get component parameters
        means = gmm.means_.flatten()
        stds = np.sqrt(gmm.covariances_.flatten())
        weights = gmm.weights_

        # Identify the "true" component (higher mean)
        true_component = np.argmax(means)
        false_component = 1 - true_component

        # Predict component membership
        labels = gmm.predict(X)
        probs = gmm.predict_proba(X)

        # Count pairs assigned to true component
        n_selected = np.sum(labels == true_component)

        # Estimate threshold as intersection point of two Gaussians
        # Approximate: use the point where posterior prob of true component > 0.5
        if means[true_component] > means[false_component]:
            # Find threshold where P(true|x) = 0.5
            # This is approximately the midpoint weighted by variances
            mu_t, mu_f = means[true_component], means[false_component]
            std_t, std_f = stds[true_component], stds[false_component]

            # Quadratic formula for intersection
            a = 1/(2*std_f**2) - 1/(2*std_t**2)
            b = mu_t/std_t**2 - mu_f/std_f**2
            c = mu_f**2/(2*std_f**2) - mu_t**2/(2*std_t**2) + np.log(std_t/std_f)

            if abs(a) > 1e-10:
                discriminant = b**2 - 4*a*c
                if discriminant >= 0:
                    x1 = (-b + np.sqrt(discriminant)) / (2*a)
                    x2 = (-b - np.sqrt(discriminant)) / (2*a)
                    # Choose intersection between the two means
                    for x in [x1, x2]:
                        if mu_f < x < mu_t:
                            threshold = x
                            break
                    else:
                        threshold = (mu_t + mu_f) / 2
                else:
                    threshold = (mu_t + mu_f) / 2
            else:
                threshold = (mu_t + mu_f) / 2
        else:
            threshold = 0.5

        return threshold, n_selected, {
            'method': 'gmm',
            'means': means.tolist(),
            'stds': stds.tolist(),
            'weights': weights.tolist(),
            'true_component': int(true_component),
            'bic': gmm.bic(X),
            'status': 'success'
        }

    except Exception as e:
        logger.warning(f"GMM fitting failed: {e}")
        return 0.5, 0, {'method': 'gmm', 'status': f'failed: {e}'}


def estimate_overlap_elbow(posterior_means: np.ndarray, sensitivity: float = 1.0) -> tuple:
    """
    Find the "elbow" in sorted posterior means to determine cutoff.

    Sorts posterior means in descending order and finds the point where
    the rate of decrease changes most dramatically (the elbow/knee point).

    Args:
        posterior_means: Array of posterior mean probabilities
        sensitivity: Higher values = more sensitive to changes (default: 1.0)

    Returns:
        threshold: Posterior mean value at the elbow
        n_selected: Number of pairs above the elbow
        method_info: Dict with diagnostic information
    """
    if len(posterior_means) < 3:
        return 0.5, 0, {'method': 'elbow', 'status': 'insufficient_data'}

    # Sort in descending order
    sorted_means = np.sort(posterior_means)[::-1]
    n = len(sorted_means)

    # Compute curvature using finite differences
    # First derivative (slope)
    dx = 1.0 / n  # Normalized x-axis
    dy = np.diff(sorted_means)
    slope = dy / dx

    # Second derivative (curvature approximation)
    d2y = np.diff(slope)
    curvature = d2y / dx

    # Find the point of maximum curvature (most negative second derivative)
    # This is where the curve bends most sharply
    if len(curvature) > 0:
        # Weight by position to prefer earlier elbows
        position_weight = np.exp(-np.arange(len(curvature)) / (n * sensitivity))
        weighted_curvature = curvature * position_weight

        elbow_idx = np.argmin(weighted_curvature)  # Most negative = sharpest bend

        # The threshold is the value at the elbow point
        threshold = sorted_means[elbow_idx + 1]  # +1 due to diff offset
        n_selected = elbow_idx + 2  # Number of points before/at elbow

        return threshold, n_selected, {
            'method': 'elbow',
            'elbow_index': int(elbow_idx),
            'max_curvature': float(curvature[elbow_idx]),
            'status': 'success'
        }

    return sorted_means[n // 2], n // 2, {'method': 'elbow', 'status': 'fallback'}


def estimate_overlap_expected_count(posterior_means: np.ndarray) -> tuple:
    """
    Estimate the expected number of true pairs from posterior means.

    The posterior mean represents P(true pair | data). The sum of all posterior
    means gives the expected number of true pairs. Select top-k pairs where
    k = round(sum(posterior_means)).

    This is a principled Bayesian approach: E[# true pairs] = sum(P(true|data))

    Args:
        posterior_means: Array of posterior mean probabilities

    Returns:
        threshold: The posterior mean value of the k-th pair
        n_selected: Expected number of true pairs
        method_info: Dict with diagnostic information
    """
    if len(posterior_means) == 0:
        return 0.5, 0, {'method': 'expected_count', 'status': 'empty_input'}

    # Expected number of true pairs
    expected_count = np.sum(posterior_means)
    n_selected = int(np.round(expected_count))

    # Clamp to valid range
    n_selected = max(0, min(n_selected, len(posterior_means)))

    # Find the threshold: sort and take the n_selected-th value
    sorted_means = np.sort(posterior_means)[::-1]

    if n_selected > 0 and n_selected <= len(sorted_means):
        threshold = sorted_means[n_selected - 1]
    else:
        threshold = 0.5

    return threshold, n_selected, {
        'method': 'expected_count',
        'expected_count_raw': float(expected_count),
        'status': 'success'
    }


def estimate_overlap_gap_statistic(posterior_means: np.ndarray, min_gap_ratio: float = 0.1) -> tuple:
    """
    Find natural gaps in the sorted posterior distribution.

    Looks for large gaps (jumps) in the sorted posterior means that might
    indicate a natural separation between true and false pairs.

    Args:
        posterior_means: Array of posterior mean probabilities
        min_gap_ratio: Minimum gap size as ratio of max gap to consider

    Returns:
        threshold: Posterior mean value at the largest gap
        n_selected: Number of pairs above the gap
        method_info: Dict with diagnostic information
    """
    if len(posterior_means) < 3:
        return 0.5, 0, {'method': 'gap', 'status': 'insufficient_data'}

    sorted_means = np.sort(posterior_means)[::-1]
    gaps = -np.diff(sorted_means)  # Negative because sorted descending

    # Find significant gaps
    max_gap = np.max(gaps)
    significant_gaps = gaps > (max_gap * min_gap_ratio)

    if np.any(significant_gaps):
        # Find the first significant gap (prefer earlier cutoffs)
        first_sig_gap = np.argmax(significant_gaps)
        threshold = sorted_means[first_sig_gap + 1]
        n_selected = first_sig_gap + 1
    else:
        # No significant gap found, use median
        n_selected = len(sorted_means) // 2
        threshold = sorted_means[n_selected] if n_selected < len(sorted_means) else 0.5

    return threshold, n_selected, {
        'method': 'gap',
        'max_gap': float(max_gap),
        'gap_index': int(first_sig_gap) if np.any(significant_gaps) else -1,
        'status': 'success' if np.any(significant_gaps) else 'fallback'
    }


def infer_overlap_adaptive(
    posterior_stats: dict,
    method: str = 'ensemble',
    fallback_threshold: float = 0.5,
    min_pairs: int = 1,
    max_pairs_ratio: float = 1.0
) -> tuple:
    """
    Adaptively infer the overlapping pairs from posterior distribution.

    This is the main entry point for adaptive overlap inference. It combines
    multiple methods and uses ensemble voting to determine the final selection.

    Args:
        posterior_stats: Dict mapping pair_key -> {'posterior_mean': float, ...}
        method: Selection method:
            - 'otsu': Otsu's thresholding
            - 'gmm': Gaussian Mixture Model
            - 'elbow': Elbow/knee detection
            - 'expected': Expected count from posterior sum
            - 'gap': Gap statistic
            - 'ensemble': Combine multiple methods (recommended)
        fallback_threshold: Threshold to use if method fails
        min_pairs: Minimum number of pairs to select
        max_pairs_ratio: Maximum ratio of pairs to select (1.0 = all)

    Returns:
        selected_pair_keys: List of pair keys selected as true overlaps
        threshold_used: The threshold value used for selection
        method_info: Dict with diagnostic information from all methods
    """
    if not posterior_stats:
        return [], fallback_threshold, {'status': 'empty_input'}

    # Extract posterior means
    pair_keys = list(posterior_stats.keys())
    posterior_means = np.array([
        posterior_stats[k].get('posterior_mean', 0.0) for k in pair_keys
    ])

    if len(posterior_means) == 0:
        return [], fallback_threshold, {'status': 'no_posterior_means'}

    max_pairs = int(len(posterior_means) * max_pairs_ratio)

    all_methods_info = {}

    if method == 'ensemble':
        # Run all methods and combine results
        results = {}

        # Otsu
        thresh_otsu, n_otsu, info_otsu = estimate_overlap_otsu(posterior_means)
        results['otsu'] = (thresh_otsu, n_otsu)
        all_methods_info['otsu'] = info_otsu

        # GMM
        thresh_gmm, n_gmm, info_gmm = estimate_overlap_gmm(posterior_means)
        results['gmm'] = (thresh_gmm, n_gmm)
        all_methods_info['gmm'] = info_gmm

        # Elbow
        thresh_elbow, n_elbow, info_elbow = estimate_overlap_elbow(posterior_means)
        results['elbow'] = (thresh_elbow, n_elbow)
        all_methods_info['elbow'] = info_elbow

        # Expected count
        thresh_exp, n_exp, info_exp = estimate_overlap_expected_count(posterior_means)
        results['expected'] = (thresh_exp, n_exp)
        all_methods_info['expected'] = info_exp

        # Gap statistic
        thresh_gap, n_gap, info_gap = estimate_overlap_gap_statistic(posterior_means)
        results['gap'] = (thresh_gap, n_gap)
        all_methods_info['gap'] = info_gap

        # Ensemble: use median of n_selected values (robust to outliers)
        n_values = [n for _, n in results.values() if n > 0]
        if n_values:
            n_selected = int(np.median(n_values))
        else:
            n_selected = len(posterior_means) // 2

        # Compute corresponding threshold
        sorted_indices = np.argsort(posterior_means)[::-1]
        n_selected = max(min_pairs, min(n_selected, max_pairs))

        if n_selected > 0:
            threshold_used = posterior_means[sorted_indices[n_selected - 1]]
        else:
            threshold_used = fallback_threshold

        all_methods_info['ensemble'] = {
            'n_from_methods': {k: v[1] for k, v in results.items()},
            'final_n_selected': n_selected,
            'method': 'median_ensemble'
        }

    elif method == 'otsu':
        threshold_used, n_selected, info = estimate_overlap_otsu(posterior_means)
        all_methods_info['otsu'] = info

    elif method == 'gmm':
        threshold_used, n_selected, info = estimate_overlap_gmm(posterior_means)
        all_methods_info['gmm'] = info

    elif method == 'elbow':
        threshold_used, n_selected, info = estimate_overlap_elbow(posterior_means)
        all_methods_info['elbow'] = info

    elif method == 'expected':
        threshold_used, n_selected, info = estimate_overlap_expected_count(posterior_means)
        all_methods_info['expected'] = info

    elif method == 'gap':
        threshold_used, n_selected, info = estimate_overlap_gap_statistic(posterior_means)
        all_methods_info['gap'] = info

    else:
        # Fallback to fixed threshold
        threshold_used = fallback_threshold
        n_selected = np.sum(posterior_means > threshold_used)
        all_methods_info['fallback'] = {'threshold': threshold_used}

    # Enforce bounds
    n_selected = max(min_pairs, min(n_selected, max_pairs))

    # Select top-n pairs by posterior mean
    sorted_indices = np.argsort(posterior_means)[::-1][:n_selected]
    selected_pair_keys = [pair_keys[i] for i in sorted_indices]

    logger.debug(f"Adaptive overlap inference ({method}): selected {len(selected_pair_keys)} pairs "
                f"from {len(pair_keys)} candidates, threshold={threshold_used:.4f}")

    return selected_pair_keys, threshold_used, all_methods_info


def precompute_full_distance_matrices(emb1_unique, emb2_unique, ref_emb1, ref_emb2,
                                       ori_ref_emb1, ori_ref_emb2, args, device, use_gpu,
                                       is_normalized=False):
    """
    OPTIMIZATION: Precompute full distance matrices once before ensemble loop.

    Instead of computing distances to subset_ref_emb in each ensemble iteration
    (which is redundant since subset_ref_emb = ref_emb[subset_indices]), we compute
    distances to the full ref_emb once and extract subsets by column indexing.

    This reduces distance computation from O(n_ensembles * n_unique * n_subset)
    to O(n_unique * n_ref), a significant speedup for multiple ensembles.

    Args:
        emb1_unique, emb2_unique: Unique embeddings to compute distances for
        ref_emb1, ref_emb2: Full reference embeddings
        ori_ref_emb1, ori_ref_emb2: Original reference embeddings (for concat_seed_pairs)
        args: Arguments with distance_metric, transformation, etc.
        device: Computation device
        use_gpu: Whether to use GPU
        is_normalized: If True, skip normalization for cosine distance

    Returns:
        full_dist_vec1: Distance matrix (n_unique1, n_ref)
        full_dist_vec2: Distance matrix (n_unique2, n_ref)
        ori_dist_vec1: Distance to ori_ref (n_unique1, n_ori) or None
        ori_dist_vec2: Distance to ori_ref (n_unique2, n_ori) or None
    """
    # Get transformation parameters
    transformation = getattr(args, 'transformation', None)
    transformation_params = getattr(args, 'transformation_params', None)
    multi_gpu_config = getattr(args, 'multi_gpu_config', None)

    # Backward compatibility
    if transformation is None and getattr(args, 'use_rbf_distance_encoding', False):
        transformation = 'rbf'
        rbf_sigma_val = getattr(args, 'rbf_sigma', None)
        if rbf_sigma_val is not None:
            transformation_params = {'sigma': rbf_sigma_val}

    logger.debug(f"Precomputing full distance matrices: ({len(emb1_unique)}, {len(ref_emb1)}) and ({len(emb2_unique)}, {len(ref_emb2)})")

    # Compute full distance matrices (ONCE instead of n_ensembles times)
    full_dist_vec1 = compute_distance_encoding(
        emb=emb1_unique, ref_embeddings=ref_emb1, distance_metric=args.distance_metric,
        use_gpu=use_gpu, device=device, multi_gpu_config=multi_gpu_config,
        transformation=transformation, transformation_params=transformation_params,
        is_normalized=is_normalized)

    full_dist_vec2 = compute_distance_encoding(
        emb=emb2_unique, ref_embeddings=ref_emb2, distance_metric=args.distance_metric,
        use_gpu=use_gpu, device=device, multi_gpu_config=multi_gpu_config,
        transformation=transformation, transformation_params=transformation_params,
        is_normalized=is_normalized)

    # Also precompute distances to original refs if needed for concat_seed_pairs
    ori_dist_vec1 = None
    ori_dist_vec2 = None
    if ori_ref_emb1 is not None and ori_ref_emb2 is not None:
        logger.debug(f"Precomputing ori_ref distance matrices: ({len(emb1_unique)}, {len(ori_ref_emb1)})")
        ori_dist_vec1 = compute_distance_encoding(
            emb=emb1_unique, ref_embeddings=ori_ref_emb1, distance_metric=args.distance_metric,
            use_gpu=use_gpu, device=device, multi_gpu_config=multi_gpu_config,
            transformation=transformation, transformation_params=transformation_params,
            is_normalized=is_normalized)
        ori_dist_vec2 = compute_distance_encoding(
            emb=emb2_unique, ref_embeddings=ori_ref_emb2, distance_metric=args.distance_metric,
            use_gpu=use_gpu, device=device, multi_gpu_config=multi_gpu_config,
            transformation=transformation, transformation_params=transformation_params,
            is_normalized=is_normalized)

    return full_dist_vec1, full_dist_vec2, ori_dist_vec1, ori_dist_vec2


def process_ensemble_with_precomputed_distances(
    ensemble_idx, full_dist_vec1, full_dist_vec2, ori_dist_vec1, ori_dist_vec2,
    ind_emb1_unique, ind_emb2_unique, ref_subset_indices, args, device, use_gpu,
    concat_seed_pairs, anchor_mode):
    """
    OPTIMIZED: Process a single ensemble using precomputed full distance matrices.

    Instead of computing distances from emb_unique to subset_ref_emb, this function
    extracts the relevant columns from precomputed full distance matrices.

    This is mathematically equivalent because:
        dist(emb[i], ref[subset_indices][j]) == dist(emb[i], ref[subset_indices[j]])

    So: compute_distances(emb, ref[subset_indices]) == compute_distances(emb, ref)[:, subset_indices]

    Args:
        ensemble_idx: Index of this ensemble
        full_dist_vec1: Precomputed (n_unique1, n_ref) distance matrix
        full_dist_vec2: Precomputed (n_unique2, n_ref) distance matrix
        ori_dist_vec1/2: Precomputed distances to ori_ref (or None)
        ind_emb1_unique, ind_emb2_unique: Original indices
        ref_subset_indices: Which columns to extract for this ensemble
        args, device, use_gpu: Standard arguments
        concat_seed_pairs, anchor_mode: For supervised/ood modes

    Returns:
        (ensemble_idx, subset_mutual_pairs, subset_accuracy, mutual_nn, dist_vec1_cpu, dist_vec2_cpu)
    """
    # OPTIMIZATION: Extract subset by column indexing (very fast!)
    # This replaces the expensive compute_distance_encoding call that was done per-ensemble
    if isinstance(full_dist_vec1, torch.Tensor):
        ref_subset_indices_t = torch.tensor(ref_subset_indices, dtype=torch.long, device=full_dist_vec1.device)
        dist_vec1_subset = full_dist_vec1[:, ref_subset_indices_t].clone()
        dist_vec2_subset = full_dist_vec2[:, ref_subset_indices_t].clone()
    else:
        dist_vec1_subset = full_dist_vec1[:, ref_subset_indices].copy()
        dist_vec2_subset = full_dist_vec2[:, ref_subset_indices].copy()

    # Handle concat_seed_pairs case (concatenate ori_ref distances)
    if concat_seed_pairs and anchor_mode in ("supervised", "ood") and ori_dist_vec1 is not None:
        if isinstance(dist_vec1_subset, torch.Tensor):
            ori_dist_vec1_t = torch.as_tensor(ori_dist_vec1, device=dist_vec1_subset.device, dtype=dist_vec1_subset.dtype)
            ori_dist_vec2_t = torch.as_tensor(ori_dist_vec2, device=dist_vec2_subset.device, dtype=dist_vec2_subset.dtype)
            dist_vec1_subset = torch.cat((ori_dist_vec1_t, dist_vec1_subset), dim=1)
            dist_vec2_subset = torch.cat((ori_dist_vec2_t, dist_vec2_subset), dim=1)
        else:
            dist_vec1_subset = np.concatenate((ori_dist_vec1, dist_vec1_subset), axis=1)
            dist_vec2_subset = np.concatenate((ori_dist_vec2, dist_vec2_subset), axis=1)

    # Normalize distance vectors (L2 normalize each row)
    if use_gpu and isinstance(dist_vec1_subset, torch.Tensor):
        dist_vec1_subset = torch.nn.functional.normalize(dist_vec1_subset, p=2, dim=1)
        dist_vec2_subset = torch.nn.functional.normalize(dist_vec2_subset, p=2, dim=1)
    else:
        dist_vec1_subset = dist_vec1_subset / (np.linalg.norm(dist_vec1_subset, axis=1, keepdims=True) + 1e-8)
        dist_vec2_subset = dist_vec2_subset / (np.linalg.norm(dist_vec2_subset, axis=1, keepdims=True) + 1e-8)

    # Find mutual pairs
    subset_mutual_pairs, mutual_nn, correct = find_mutual_pairs(
        dist_vec1_subset, dist_vec2_subset, ind_emb1_unique, ind_emb2_unique, args, device, use_gpu)

    subset_accuracy = correct / mutual_nn if mutual_nn > 0 else 0.0

    # Return distance vectors on CPU as numpy
    if use_gpu and isinstance(dist_vec1_subset, torch.Tensor):
        dist_vec1_cpu = dist_vec1_subset.cpu().numpy()
        dist_vec2_cpu = dist_vec2_subset.cpu().numpy()
    else:
        dist_vec1_cpu = np.asarray(dist_vec1_subset)
        dist_vec2_cpu = np.asarray(dist_vec2_subset)

    return ensemble_idx, subset_mutual_pairs, subset_accuracy, mutual_nn, dist_vec1_cpu, dist_vec2_cpu


def calculate_safe_max_workers(n_ensembles, n_unique_samples, n_ref_subset, embedding_dim_avg, use_gpu, device, safety_factor=0.25, max_parallel_workers=None):
    """
    Calculate safe max_workers based on available memory to prevent OOM.

    Args:
        max_parallel_workers: Optional hard limit on parallel workers (useful for large datasets
                             where multiprocessing overhead causes OOM even with memory-based limits)

    Returns:
        tuple: (max_workers, should_use_sequential)
            - max_workers: Number of safe parallel workers (1 to n_ensembles)
            - should_use_sequential: True if memory too low for parallel execution
    """
    # Memory per worker: ref embeddings + unique embeddings + temporary computation buffers
    ref_memory_gb = 2 * n_ref_subset * embedding_dim_avg * 4 / (1024**3)
    unique_memory_gb = 2 * n_unique_samples * embedding_dim_avg * 4 / (1024**3)

    # For GPU: distance matrices are stored on CPU, but we need temporary GPU buffers for computation
    # Each worker needs space for input tensors + intermediate computation (~30% of distance matrix size)
    if use_gpu:
        # Be MUCH more conservative for GPU - each worker needs significant memory
        # Account for: input tensors, output tensors, intermediate computations, and PyTorch overhead
        dist_matrix_memory_gb = 2 * estimate_matrix_memory_gb(n_unique_samples, n_ref_subset) * 0.5
    else:
        # Full distance matrix in memory for CPU
        dist_matrix_memory_gb = 2 * estimate_matrix_memory_gb(n_unique_samples, n_ref_subset)

    serialization_overhead_gb = (ref_memory_gb + unique_memory_gb) * 0.5  # Increase overhead estimate
    memory_per_worker_gb = ref_memory_gb + unique_memory_gb + dist_matrix_memory_gb + serialization_overhead_gb

    available_memory_gb = get_available_memory_gb(use_gpu=use_gpu, device=device)
    memory_type = "GPU" if (use_gpu and device is not None and device.type == 'cuda') else "RAM"

    # Check if we have enough memory for even 1 worker with safety margin
    min_required_memory_gb = memory_per_worker_gb * 1.5  # Need 50% headroom (increased from 20%)

    if available_memory_gb < min_required_memory_gb:
        # Critical memory shortage - force sequential execution
        logger.warning(f"CRITICAL MEMORY SHORTAGE: {available_memory_gb:.2f} GB {memory_type} available, "
                      f"but need {min_required_memory_gb:.2f} GB for safe parallel execution")
        logger.warning(f"Automatically switching to SEQUENTIAL mode (no multiprocessing)")
        logger.warning(f"This will be slower but won't cause OOM errors")

        # Check for competing processes
        if use_gpu and device is not None and device.type == 'cuda':
            try:
                import subprocess
                result = subprocess.run(['nvidia-smi', '--query-compute-apps=pid,used_memory',
                                       '--format=csv,noheader,nounits'],
                                      capture_output=True, text=True, timeout=5)
                if result.returncode == 0 and result.stdout.strip():
                    processes = result.stdout.strip().split('\n')
                    if len(processes) > 1:
                        logger.warning(f"Detected {len(processes)} competing GPU processes:")
                        for proc in processes[:5]:  # Show first 5
                            logger.warning(f"  - PID using GPU memory: {proc}")
            except Exception:
                pass  # Ignore errors in process detection

        return 1, True  # Use sequential mode

    # For GPU, use balanced safety factor that allows reasonable parallelism
    if use_gpu and device is not None and device.type == 'cuda':
        safety_factor = 0.2  # Balanced for GPU with parallel workers (allows 4-5 workers)
        logger.debug(f"GPU mode: using balanced safety_factor={safety_factor} for performance")
    else:
        # Adaptive safety factor for CPU
        if available_memory_gb < memory_per_worker_gb * 3:
            # Low memory: be more aggressive (use up to 40% of available)
            adjusted_safety_factor = 0.4
            logger.debug(f"Low memory detected, increasing safety_factor: {safety_factor:.2f} -> {adjusted_safety_factor:.2f}")
            safety_factor = adjusted_safety_factor

    usable_memory_gb = available_memory_gb * safety_factor
    max_workers = max(1, int(usable_memory_gb / memory_per_worker_gb))
    max_workers = min(max_workers, n_ensembles)

    # For GPU, cap at a reasonable number regardless of memory calculation
    if use_gpu and device is not None and device.type == 'cuda':
        gpu_max_workers = 12  # Maximum 6 parallel GPU workers for good parallelism without excessive contention
        if max_workers > gpu_max_workers:
            logger.debug(f"GPU mode: capping workers at {gpu_max_workers} for optimal performance (calculated {max_workers})")
            max_workers = gpu_max_workers

    # Apply user-specified limit for large datasets (reduces multiprocessing overhead)
    if max_parallel_workers is not None and max_parallel_workers > 0:
        if max_workers > max_parallel_workers:
            logger.debug(f"Applying max_parallel_workers limit: {max_workers} -> {max_parallel_workers}")
            max_workers = max_parallel_workers

    logger.debug(f"Memory-aware worker calculation: {available_memory_gb:.2f} GB {memory_type} available, "
                f"{memory_per_worker_gb:.2f} GB per worker, max_workers={max_workers}/{n_ensembles}")
    if max_workers < n_ensembles:
        logger.warning(f"Reducing workers from {n_ensembles} to {max_workers} due to memory constraints")

    return max_workers, False  # Use parallel mode with limited workers


def run_ensembles_in_batches(ensemble_args, max_workers, n_ensembles, ctx):
    """Run ensembles in batches to avoid OOM errors with better GPU load balancing."""
    import gc
    import time
    results_dict = {}

    # Detect number of GPUs from ensemble_args
    n_gpus = 1
    if ensemble_args:
        # gpu_id is at index 10 in the args tuple
        gpu_ids_in_use = set()
        for args in ensemble_args:
            if len(args) > 10 and args[10] is not None:
                gpu_ids_in_use.add(args[10])
        n_gpus = len(gpu_ids_in_use) if gpu_ids_in_use else 1

    # For multi-GPU, increase batch size to better utilize both GPUs
    if n_gpus > 1:
        # Use at least 2x the number of GPUs to ensure both are busy
        effective_max_workers = max(max_workers, n_gpus * 2)
        logger.debug(f"Multi-GPU detected ({n_gpus} GPUs): increasing batch size from {max_workers} to {effective_max_workers} for better utilization")
        max_workers = effective_max_workers

    n_batches = (n_ensembles + max_workers - 1) // max_workers
    logger.debug(f"Running {n_ensembles} ensembles in {n_batches} batch(es) with {max_workers} workers per batch")

    for batch_idx in range(n_batches):
        batch_start = batch_idx * max_workers
        batch_end = min(batch_start + max_workers, n_ensembles)
        batch_size = batch_end - batch_start
        batch_args = ensemble_args[batch_start:batch_end]

        logger.debug(f"Batch {batch_idx + 1}/{n_batches}: ensembles {batch_start}-{batch_end-1}")

        # Log GPU distribution in this batch
        if n_gpus > 1:
            gpu_counts = {}
            for args in batch_args:
                gpu_id = args[10] if len(args) > 10 else 0
                gpu_counts[gpu_id] = gpu_counts.get(gpu_id, 0) + 1
            logger.debug(f"  GPU distribution in batch: {gpu_counts}")

        # Clean up before starting new batch (only between batches, not first batch)
        if batch_idx > 0:
            gc.collect()
            if torch.cuda.is_available():
                for gpu_id in range(n_gpus):
                    with torch.cuda.device(gpu_id):
                        torch.cuda.empty_cache()
                time.sleep(0.3)

        with ProcessPoolExecutor(max_workers=batch_size, mp_context=ctx) as executor:
            future_to_idx = {}
            # Submit all jobs at once for better parallelism across GPUs
            for args in batch_args:
                future = executor.submit(run_single_ensemble_gpu, args)
                future_to_idx[future] = args[0]

            for future in as_completed(future_to_idx):
                ensemble_idx, subset_mutual_pairs, subset_accuracy, mutual_nn, dist_vec1, dist_vec2 = future.result()
                results_dict[ensemble_idx] = (subset_mutual_pairs, subset_accuracy, mutual_nn, dist_vec1, dist_vec2)

        logger.debug(f"Batch {batch_idx + 1}/{n_batches} completed")

    return results_dict


def generate_ensemble_subsets(ref_emb1, n_ensembles, subset_size, strategy='random', distance_metric='cosine', use_gpu=False, device=None):
    """
    Generate n_ensembles subsets of reference embeddings based on different strategies.

    Args:
        ref_emb1: Reference embeddings (n_ref, dim)
        n_ensembles: Number of ensembles (subsets) to generate
        subset_size: Size of each subset
        strategy: One of ['random', 'cluster', 'furthest', 'nearest']
        distance_metric: Distance metric for 'furthest' and 'nearest' strategies

    Returns:
        List of n_ensembles arrays, each containing indices of subset members
    """
    n_ref = len(ref_emb1)

    if strategy == 'random':
        # Random sampling: each ensemble randomly samples subset_size points
        subset_indices_list = []
        for _ in range(n_ensembles):
            indices = np.random.choice(n_ref, size=subset_size, replace=False)
            subset_indices_list.append(indices)
        return subset_indices_list

    elif strategy == 'cluster':
        # Cluster-based: cluster into n_ensembles clusters, each cluster is one subset
        # Each cluster will have varying size, aiming for average of subset_size
        from utils.clustering import Clusterer

        # Use n_ensembles clusters
        clusterer = Clusterer(method='kmeans', n_clusters=n_ensembles, use_gpu=False)
        cluster_labels = clusterer.fit(ref_emb1)

        # Group cluster members
        subset_indices_list = []
        for cluster_id in range(n_ensembles):
            cluster_mask = cluster_labels == cluster_id
            cluster_member_indices = np.where(cluster_mask)[0]
            if len(cluster_member_indices) > 0:
                subset_indices_list.append(cluster_member_indices)

        # If some clusters are empty, fill with remaining random samples
        while len(subset_indices_list) < n_ensembles:
            indices = np.random.choice(n_ref, size=subset_size, replace=False)
            subset_indices_list.append(indices)

        logger.debug(f"Cluster strategy: created {len(subset_indices_list)} clusters")
        for i, indices in enumerate(subset_indices_list):
            logger.debug(f"  Cluster {i}: {len(indices)} members")

        return subset_indices_list

    elif strategy == 'furthest':
        # Furthest points: each subset contains points that are maximally far apart
        # Use greedy farthest point sampling for each subset
        subset_indices_list = []

        # Compute pairwise distances once (optimized for GPU if available)
        if torch.cuda.is_available() and use_gpu:
            # GPU-accelerated distance computation
            ref_emb_tensor = torch.from_numpy(ref_emb1).float().to(device)

            if distance_metric == 'cosine':
                # Normalize for cosine distance
                ref_emb_normalized = torch.nn.functional.normalize(ref_emb_tensor, p=2, dim=1)
                dist_matrix = 1 - torch.mm(ref_emb_normalized, ref_emb_normalized.T)
            else:
                # Euclidean distance
                dist_matrix = torch.cdist(ref_emb_tensor, ref_emb_tensor, p=2)

            # Keep on GPU for fast indexing
            dist_matrix_gpu = dist_matrix
            use_gpu_indexing = True
        else:
            # CPU fallback
            if distance_metric == 'cosine':
                # Normalize for cosine distance
                ref_emb_normalized = ref_emb1 / (np.linalg.norm(ref_emb1, axis=1, keepdims=True) + 1e-8)
                dist_matrix = 1 - np.dot(ref_emb_normalized, ref_emb_normalized.T)
            else:
                # Euclidean distance
                dist_matrix = euclidean_distances(ref_emb1, ref_emb1)
            use_gpu_indexing = False

        # Keep track of which points have been used
        available_pool = set(range(n_ref))

        for ensemble_idx in range(n_ensembles):
            # Greedy farthest point sampling
            subset_indices = []

            # Start with a random point from available pool
            if len(available_pool) > 0:
                first_idx = np.random.choice(list(available_pool))
            else:
                # If we've used all points, reset the pool
                available_pool = set(range(n_ref))
                first_idx = np.random.choice(list(available_pool))

            subset_indices.append(first_idx)
            available_pool.discard(first_idx)

            # Iteratively add points that are farthest from current subset
            # OPTIMIZED: Use vectorized operations instead of np.ix_
            for _ in range(subset_size - 1):
                if len(available_pool) == 0:
                    # If no more points available, sample from all points
                    available_pool = set(range(n_ref))
                    # Remove already selected points
                    available_pool -= set(subset_indices)

                if len(available_pool) == 0:
                    # If subset_size > n_ref, allow repeats
                    candidate_idx = np.random.choice(n_ref)
                else:
                    # OPTIMIZED: Vectorized farthest point selection
                    available_array = np.array(list(available_pool), dtype=np.int64)

                    if use_gpu_indexing:
                        # GPU path: much faster indexing
                        subset_tensor = torch.tensor(subset_indices, dtype=torch.long, device=device)
                        available_tensor = torch.tensor(available_array, dtype=torch.long, device=device)

                        # Extract distances: (n_available, n_subset)
                        dists_subset = dist_matrix_gpu[available_tensor][:, subset_tensor]

                        # Find minimum distance to subset for each available point
                        min_dists = dists_subset.min(dim=1)[0]

                        # Find point with maximum minimum distance
                        farthest_idx_in_available = min_dists.argmax().item()
                    else:
                        # CPU path: optimized numpy indexing (avoid np.ix_)
                        # Extract distances: (n_available, n_subset)
                        dists_subset = dist_matrix[available_array[:, None], subset_indices]

                        # Find minimum distance to subset for each available point
                        min_dists = dists_subset.min(axis=1)

                        # Find point with maximum minimum distance
                        farthest_idx_in_available = np.argmax(min_dists)

                    candidate_idx = available_array[farthest_idx_in_available]

                subset_indices.append(candidate_idx)
                available_pool.discard(candidate_idx)

            subset_indices_list.append(np.array(subset_indices))

        logger.debug(f"Furthest points strategy: created {len(subset_indices_list)} subsets of size {subset_size}")

        return subset_indices_list

    elif strategy == 'nearest':
        # Nearest neighbors: randomly sample seed points, each with its nearest neighbors
        # Each ensemble consists of one seed point and its (subset_size - 1) nearest neighbors
        subset_indices_list = []

        # Compute pairwise distances once (optimized for GPU if available)
        if torch.cuda.is_available() and use_gpu:
            # GPU-accelerated distance computation
            ref_emb_tensor = torch.from_numpy(ref_emb1).float().to(device)

            if distance_metric == 'cosine':
                # Normalize for cosine distance
                ref_emb_normalized = torch.nn.functional.normalize(ref_emb_tensor, p=2, dim=1)
                dist_matrix = 1 - torch.mm(ref_emb_normalized, ref_emb_normalized.T)
            else:
                # Euclidean distance
                dist_matrix = torch.cdist(ref_emb_tensor, ref_emb_tensor, p=2)

            # Keep on GPU for fast indexing
            dist_matrix_gpu = dist_matrix
            use_gpu_indexing = True
        else:
            # CPU fallback
            if distance_metric == 'cosine':
                # Normalize for cosine distance
                ref_emb_normalized = ref_emb1 / (np.linalg.norm(ref_emb1, axis=1, keepdims=True) + 1e-8)
                dist_matrix = 1 - np.dot(ref_emb_normalized, ref_emb_normalized.T)
            else:
                # Euclidean distance
                dist_matrix = euclidean_distances(ref_emb1, ref_emb1)
            use_gpu_indexing = False

        # Randomly sample n_ensembles seed points
        seed_indices = np.random.choice(n_ref, size=n_ensembles, replace=False)

        for seed_idx in seed_indices:
            # Find k nearest neighbors (including the seed point itself)
            if use_gpu_indexing:
                # GPU path
                seed_dists = dist_matrix_gpu[seed_idx]
                # Sort and get top k indices (including self at distance 0)
                _, nearest_indices = torch.topk(seed_dists, k=min(subset_size, n_ref), largest=False)
                nearest_indices = nearest_indices.cpu().numpy()
            else:
                # CPU path
                seed_dists = dist_matrix[seed_idx]
                # Sort and get top k indices (including self at distance 0)
                nearest_indices = np.argsort(seed_dists)[:min(subset_size, n_ref)]

            subset_indices_list.append(nearest_indices)

        logger.debug(f"Nearest neighbors strategy: created {len(subset_indices_list)} subsets of size {subset_size}")
        for i, indices in enumerate(subset_indices_list):
            logger.debug(f"  Ensemble {i}: seed={seed_indices[i]}, {len(indices)} neighbors")

        return subset_indices_list

    else:
        raise ValueError(f"Unknown ensemble strategy: {strategy}")


def run_single_ensemble_gpu(args_tuple):
    """
    GPU-accelerated single ensemble iteration for multiprocessing.

    Supports both random sampling and pre-specified subset indices:
    - If ref_subset_indices is None in args_tuple, randomly sample subset_size points
    - If ref_subset_indices is provided, use those indices directly (for cluster-based selection)
    """
    # Support both old (18-element) and new (19-element with is_normalized) tuple formats
    if len(args_tuple) == 19:
        (ensemble_idx, ref_emb1, ref_emb2, emb1_unique, emb2_unique,
         ind_emb1_unique, ind_emb2_unique, subset_size, args_dict, use_gpu, gpu_id,
         ref_indices1, ref_indices2, ori_ref_emb1, ori_ref_emb2, anchor_mode, concat_seed_pairs,
         ref_subset_indices, is_normalized) = args_tuple
    else:
        # Old format with 18 elements (no is_normalized)
        (ensemble_idx, ref_emb1, ref_emb2, emb1_unique, emb2_unique,
         ind_emb1_unique, ind_emb2_unique, subset_size, args_dict, use_gpu, gpu_id,
         ref_indices1, ref_indices2, ori_ref_emb1, ori_ref_emb2, anchor_mode, concat_seed_pairs,
         ref_subset_indices) = args_tuple
        is_normalized = False  # Default for backward compatibility

    args = argparse.Namespace(**args_dict)

    if use_gpu and torch.cuda.is_available():
        if gpu_id is not None and gpu_id < torch.cuda.device_count():
            device = torch.device(f"cuda:{gpu_id}")
        else:
            device = torch.device("cuda:0")  # Default to GPU 0
    else:
        device = torch.device("cpu")

    # If ref_subset_indices not provided, do random sampling
    if ref_subset_indices is None:
        n_ref = len(ref_emb1)
        ref_subset_indices = np.random.choice(n_ref, size=subset_size, replace=False)

    subset_ref_emb1 = ref_emb1[ref_subset_indices]
    subset_ref_emb2 = ref_emb2[ref_subset_indices]
    
    if concat_seed_pairs and anchor_mode in ("supervised", "ood"):
        subset_ref_emb1 = np.concatenate((ori_ref_emb1, subset_ref_emb1))
        subset_ref_emb2 = np.concatenate((ori_ref_emb2, subset_ref_emb2))

    # Prepare transformation parameters
    transformation = getattr(args, 'transformation', None)
    transformation_params = getattr(args, 'transformation_params', None)
    multi_gpu_config = getattr(args, 'multi_gpu_config', None)

    # Backward compatibility: handle deprecated use_rbf_distance_encoding
    if transformation is None and getattr(args, 'use_rbf_distance_encoding', False):
        transformation = 'rbf'
        rbf_sigma_val = getattr(args, 'rbf_sigma', None)
        if rbf_sigma_val is not None:
            transformation_params = {'sigma': rbf_sigma_val}

    # Compute distance vectors (skip normalization if already pre-normalized for cosine distance)
    dist_vec1_subset = compute_distance_encoding(
        emb=emb1_unique, ref_embeddings=subset_ref_emb1, distance_metric=args.distance_metric,
        use_gpu=use_gpu, device=device, multi_gpu_config=multi_gpu_config,
        transformation=transformation,
        transformation_params=transformation_params,
        is_normalized=is_normalized)
    dist_vec2_subset = compute_distance_encoding(
        emb=emb2_unique, ref_embeddings=subset_ref_emb2, distance_metric=args.distance_metric,
        use_gpu=use_gpu, device=device, multi_gpu_config=multi_gpu_config,
        transformation=transformation,
        transformation_params=transformation_params,
        is_normalized=is_normalized)

    # OPTIMIZATION 1.3: Pre-normalize distance vectors here to avoid redundant normalization
    # in find_mutual_pairs (called n_ensembles times)
    if use_gpu and isinstance(dist_vec1_subset, torch.Tensor):
        dist_vec1_subset = torch.nn.functional.normalize(dist_vec1_subset, p=2, dim=1)
        dist_vec2_subset = torch.nn.functional.normalize(dist_vec2_subset, p=2, dim=1)
    else:
        # CPU path - normalize using numpy
        dist_vec1_subset = dist_vec1_subset / (np.linalg.norm(dist_vec1_subset, axis=1, keepdims=True) + 1e-8)
        dist_vec2_subset = dist_vec2_subset / (np.linalg.norm(dist_vec2_subset, axis=1, keepdims=True) + 1e-8)

    # Find mutual pairs using unified function
    subset_mutual_pairs, mutual_nn, correct = find_mutual_pairs(
        dist_vec1_subset, dist_vec2_subset, ind_emb1_unique, ind_emb2_unique, args, device,
        use_gpu)

    subset_accuracy = correct / mutual_nn if mutual_nn > 0 else 0.0

    # Store distance vectors on CPU as numpy for returning
    if use_gpu and isinstance(dist_vec1_subset, torch.Tensor):
        dist_vec1_cpu = dist_vec1_subset.cpu().numpy()
        dist_vec2_cpu = dist_vec2_subset.cpu().numpy()
    else:
        dist_vec1_cpu = dist_vec1_subset
        dist_vec2_cpu = dist_vec2_subset

    # Clean up memory after processing
    import gc
    del dist_vec1_subset, dist_vec2_subset, subset_ref_emb1, subset_ref_emb2
    gc.collect()
    if use_gpu and torch.cuda.is_available():
        torch.cuda.empty_cache()

    return ensemble_idx, subset_mutual_pairs, subset_accuracy, mutual_nn, dist_vec1_cpu, dist_vec2_cpu


def ensemble_reference_selection_voting(ref_emb1, ref_emb2, emb1_unique, emb2_unique, ind_emb1_unique, ind_emb2_unique,
                                        args, device, ind_nonref, n_ensembles=10, subset_ratio=0.3,
                                        ref_indices1=None, ref_indices2=None,
                                        vote_threshold=0.6, ori_ref_emb1=None, ori_ref_emb2=None, return_vote_matrix=False,
                                        ensemble_strategy='random', skip_adaptive_scaling=False):
    """
    Ensemble-based reference selection method using voting.

    For each ensemble:
    1. Sample/select a subset of points from ref_emb1/ref_emb2 based on strategy
    2. Compute distance vectors using the subset
    3. Find mutual pairs and compute accuracy
    4. Track which mutual pairs appear most frequently across ensembles
    5. Return the final ensemble accuracy/recall

    Args:
        ref_emb1, ref_emb2: Reference embeddings
        emb1_unique, emb2_unique: Unique embeddings to compute distance vectors for
        ind_emb1_unique, ind_emb2_unique: Original indices of unique embeddings
        args: Arguments containing topk, distance_metric, etc.
        device: Device for computations
        n_ensembles: Number of ensemble runs (also number of clusters for 'cluster' strategy)
        subset_ratio: Percentage of reference points to use in each ensemble (for 'random' and 'furthest')
        ensemble_strategy: Strategy for selecting reference subsets:
            - 'random': Random sampling (default)
            - 'cluster': Cluster into n_ensembles clusters, each cluster votes
            - 'furthest': Each subset contains maximally dispersed points
        return_vote_matrix: If True, return vote matrix along with mutual pairs

    Returns:
        mutual_pair: Final mutual pairs selected by ensemble (list of (idx2, idx1, dist))
        vote_matrix (optional): Sparse CSR matrix (n2, n1) if return_vote_matrix=True
    """

    if len(ref_emb1) == 0:
        if return_vote_matrix:
            # Return empty sparse matrix
            n1 = len(emb1_unique)
            n2 = len(emb2_unique)
            return [], csr_matrix((n2, n1), dtype=np.int32)
        else:
            return []

    n_ref = len(ref_emb1)

    use_gpu = device.type == 'cuda' if hasattr(device, 'type') else False

    # For cluster strategy, subset_size is determined by clustering
    # For random and furthest, use subset_ratio
    if ensemble_strategy == 'cluster':
        subset_size = None  # Will be determined by cluster sizes
    else:
        subset_size = max(1, int(n_ref * subset_ratio))

    # Generate subsets based on strategy
    logger.debug(f"Generating {n_ensembles} subsets using '{ensemble_strategy}' strategy")
    if ensemble_strategy == 'cluster':
        # For cluster strategy, we don't need to specify subset_size
        # Each cluster will be its own subset
        subset_indices_list = generate_ensemble_subsets(
            ref_emb1, n_ensembles, subset_size=0, strategy=ensemble_strategy,
            distance_metric=args.distance_metric, 
            use_gpu=use_gpu,
            device=device
        )
    else:
        subset_indices_list = generate_ensemble_subsets(
            ref_emb1, n_ensembles, subset_size, strategy=ensemble_strategy,
            distance_metric=args.distance_metric,
            use_gpu=use_gpu,
            device=device
        )

    # Initialize sparse vote matrix: vote_matrix[i, j] = number of votes for pair (i, j)
    # Use lil_matrix for efficient incremental construction
    n1 = len(emb1_unique)
    n2 = len(emb2_unique)
    vote_matrix = lil_matrix((n2, n1), dtype=np.int32)  # Sparse matrix: (emb2_unique, emb1_unique)

    mutual_pair_dist = dict()
    subset_accuracies = []
    all_subset_pairs = []

    use_gpu = device.type == 'cuda' if hasattr(device, 'type') else False
    enable_parallel = getattr(args, 'enable_parallel_ensemble', True)
    n_gpus = torch.cuda.device_count() if use_gpu else 1

    logger.debug(f"Running ensemble reference selection: {n_ensembles} ensembles using '{ensemble_strategy}' strategy")
    logger.debug(f"Parameters: n_ensembles={n_ensembles}, subset_ratio={subset_ratio:.2f}, vote_threshold={vote_threshold:.2f}")
    logger.debug(f"Sparse vote matrix shape: {vote_matrix.shape} (emb2_unique x emb1_unique)")
    logger.debug(f"Using {'GPU' if use_gpu else 'CPU'} acceleration with {'parallel' if enable_parallel else 'sequential'} execution")

    start_time = time.time()

    # OPTIMIZATION: Pre-normalize embeddings once for cosine distance to avoid redundant normalization
    # in each ensemble worker. This is a major speedup for large n_ensembles.
    is_normalized = False
    if args.distance_metric == 'cosine':
        logger.debug("Pre-normalizing embeddings for cosine distance (avoids redundant normalization in workers)")
        # Normalize unique embeddings
        emb1_unique = emb1_unique / (np.linalg.norm(emb1_unique, axis=1, keepdims=True) + 1e-8)
        emb2_unique = emb2_unique / (np.linalg.norm(emb2_unique, axis=1, keepdims=True) + 1e-8)
        # Normalize reference embeddings
        ref_emb1 = ref_emb1 / (np.linalg.norm(ref_emb1, axis=1, keepdims=True) + 1e-8)
        ref_emb2 = ref_emb2 / (np.linalg.norm(ref_emb2, axis=1, keepdims=True) + 1e-8)
        # Normalize original reference embeddings if present
        if ori_ref_emb1 is not None:
            ori_ref_emb1 = ori_ref_emb1 / (np.linalg.norm(ori_ref_emb1, axis=1, keepdims=True) + 1e-8)
        if ori_ref_emb2 is not None:
            ori_ref_emb2 = ori_ref_emb2 / (np.linalg.norm(ori_ref_emb2, axis=1, keepdims=True) + 1e-8)
        is_normalized = True

    # OPTIMIZATION: Precompute full distance matrices ONCE instead of per-ensemble
    # This is the key optimization: computing distances once and extracting subsets by indexing
    # Complexity reduction: O(n_ensembles * n_unique * n_subset) -> O(n_unique * n_ref)
    use_precomputed_distances = getattr(args, 'use_precomputed_distances', True)  # Enable by default
    full_dist_vec1, full_dist_vec2, ori_dist_vec1, ori_dist_vec2 = None, None, None, None

    if use_precomputed_distances:
        logger.debug("OPTIMIZATION: Using precomputed distance matrices (compute once, index per-ensemble)")
        full_dist_vec1, full_dist_vec2, ori_dist_vec1, ori_dist_vec2 = precompute_full_distance_matrices(
            emb1_unique, emb2_unique, ref_emb1, ref_emb2,
            ori_ref_emb1, ori_ref_emb2, args, device, use_gpu, is_normalized)

    # With precomputed distances, sequential execution is often faster than parallel
    # because the per-ensemble work (column indexing + normalization + find_mutual_pairs)
    # is lightweight compared to multiprocessing overhead
    if use_precomputed_distances and full_dist_vec1 is not None:
        logger.debug("Using sequential execution with precomputed distances (optimal for this case)")
        enable_parallel = False

    if enable_parallel and n_ensembles > 1:
        # Calculate safe max_workers based on memory and user limits
        emb_dim1 = emb1_unique.shape[1] if len(emb1_unique.shape) > 1 else 1
        emb_dim2 = emb2_unique.shape[1] if len(emb2_unique.shape) > 1 else 1
        embedding_dim_avg = (emb_dim1 + emb_dim2) // 2

        # Get the average subset size
        avg_subset_size = int(np.mean([len(s) for s in subset_indices_list]))

        max_parallel_workers = getattr(args, 'max_parallel_workers', None)
        max_workers, should_use_sequential = calculate_safe_max_workers(
            n_ensembles=n_ensembles,
            n_unique_samples=len(emb1_unique),
            n_ref_subset=avg_subset_size,
            embedding_dim_avg=embedding_dim_avg,
            use_gpu=use_gpu,
            device=device,
            safety_factor=0.25,
            max_parallel_workers=max_parallel_workers
        )

        # Force sequential execution if memory too low
        if should_use_sequential:
            logger.warning(f"Memory too low for parallel execution, using sequential mode")
            enable_parallel = False

        # Use multiprocessing with spawn for both CPU and GPU
        try:
            if use_gpu and torch.cuda.is_available():
                # For GPU, use ProcessPoolExecutor with spawn method for true parallelism
                logger.debug(f"Using ProcessPoolExecutor for GPU-based parallel ensembles across {n_gpus} GPU(s)")
                logger.debug(f"max_workers={max_workers} (limited from {n_ensembles} ensembles)")

                # Set spawn method for multiprocessing
                ctx = mp.get_context('spawn')

                args_dict = vars(args)  # Convert args to dictionary for pickling

                # OPTIMIZATION 2.1: Ensure arrays are float32 to avoid dtype conversion overhead
                # in each worker process
                if not isinstance(ref_emb1, torch.Tensor):
                    ref_emb1 = np.asarray(ref_emb1, dtype=np.float32)
                    ref_emb2 = np.asarray(ref_emb2, dtype=np.float32)
                    emb1_unique = np.asarray(emb1_unique, dtype=np.float32)
                    emb2_unique = np.asarray(emb2_unique, dtype=np.float32)

                # Prepare arguments for each ensemble
                ensemble_args = []
                for ensemble_idx in range(n_ensembles):
                    gpu_id = ensemble_idx % n_gpus if n_gpus > 1 else 0
                    # Use pre-generated subset indices
                    ref_subset_indices = subset_indices_list[ensemble_idx]
                    ensemble_args.append((
                        ensemble_idx, ref_emb1, ref_emb2, emb1_unique, emb2_unique,
                        ind_emb1_unique, ind_emb2_unique, len(ref_subset_indices), args_dict, True, gpu_id,
                        ref_indices1, ref_indices2, ori_ref_emb1, ori_ref_emb2, args.anchor_mode, args.concat_seed_pairs,
                        ref_subset_indices, is_normalized  # Pass pre-normalization flag
                    ))

                with ProcessPoolExecutor(max_workers=max_workers, mp_context=ctx) as executor:
                    future_to_idx = {
                        executor.submit(run_single_ensemble_gpu, args): args[0]
                        for args in ensemble_args
                    }

                    results_dict = {}
                    for future in as_completed(future_to_idx):
                        ensemble_idx, subset_mutual_pairs, subset_accuracy, mutual_nn, _, _ = future.result()
                        results_dict[ensemble_idx] = (subset_mutual_pairs, subset_accuracy, mutual_nn)

                    # Process results in order
                    for ensemble_idx in range(n_ensembles):
                        if ensemble_idx in results_dict:
                            subset_mutual_pairs, subset_accuracy, mutual_nn = results_dict[ensemble_idx]

                            # Update vote matrix
                            for i, nearest_i, dist_between_pair in subset_mutual_pairs:
                                vote_matrix[i, nearest_i] += 1
                                pair_key = (i, nearest_i)
                                mutual_pair_dist[pair_key] = dist_between_pair

                            subset_accuracies.append(subset_accuracy)
                            all_subset_pairs.append(subset_mutual_pairs)

                            logger.info(f"Ensemble {ensemble_idx+1}: {mutual_nn} mutual pairs, accuracy: {subset_accuracy:.3f}")

            else:
                # For CPU, use ProcessPoolExecutor with spawn method
                logger.debug("Using ProcessPoolExecutor for CPU-based parallel ensembles")
                logger.debug(f"max_workers={max_workers} (limited from {n_ensembles} ensembles)")

                # Set spawn method for multiprocessing
                ctx = mp.get_context('spawn')

                args_dict = vars(args)  # Convert args to dictionary for pickling

                # Prepare arguments for each ensemble
                ensemble_args = []
                for ensemble_idx in range(n_ensembles):
                    # Use pre-generated subset indices
                    ref_subset_indices = subset_indices_list[ensemble_idx]
                    ensemble_args.append((
                        ensemble_idx, ref_emb1, ref_emb2, emb1_unique, emb2_unique,
                        ind_emb1_unique, ind_emb2_unique, len(ref_subset_indices), args_dict, False, None,
                        ref_indices1, ref_indices2, ori_ref_emb1, ori_ref_emb2, args.anchor_mode, args.concat_seed_pairs,
                        ref_subset_indices, is_normalized  # Pass pre-normalization flag
                    ))

                with ProcessPoolExecutor(max_workers=max_workers, mp_context=ctx) as executor:
                    future_to_idx = {
                        executor.submit(run_single_ensemble_gpu, args): args[0]
                        for args in ensemble_args
                    }

                    results_dict = {}
                    for future in as_completed(future_to_idx):
                        ensemble_idx, subset_mutual_pairs, subset_accuracy, mutual_nn, _, _ = future.result()
                        results_dict[ensemble_idx] = (subset_mutual_pairs, subset_accuracy, mutual_nn)

                    # Process results in order
                    for ensemble_idx in range(n_ensembles):
                        if ensemble_idx in results_dict:
                            subset_mutual_pairs, subset_accuracy, mutual_nn = results_dict[ensemble_idx]

                            # Update vote matrix
                            for i, nearest_i, dist_between_pair in subset_mutual_pairs:
                                vote_matrix[i, nearest_i] += 1
                                pair_key = (i, nearest_i)
                                mutual_pair_dist[pair_key] = dist_between_pair

                            subset_accuracies.append(subset_accuracy)
                            all_subset_pairs.append(subset_mutual_pairs)

                            logger.info(f"Ensemble {ensemble_idx+1}: {mutual_nn} mutual pairs, accuracy: {subset_accuracy:.3f}")

        except Exception as e:
            logger.warning(f"Parallel execution failed: {e}. Falling back to sequential execution.")
            enable_parallel = False

    if not enable_parallel or n_ensembles <= 1:
        # Sequential execution with GPU optimization
        aggressive_memory_clear = getattr(args, 'aggressive_memory_clear', False)
        for ensemble_idx in range(n_ensembles):
            # Use pre-generated subset indices
            ref_subset_indices = subset_indices_list[ensemble_idx]

            # OPTIMIZATION: Use precomputed distances if available (much faster!)
            if use_precomputed_distances and full_dist_vec1 is not None:
                _, subset_mutual_pairs, subset_accuracy, mutual_nn, _, _ = process_ensemble_with_precomputed_distances(
                    ensemble_idx, full_dist_vec1, full_dist_vec2, ori_dist_vec1, ori_dist_vec2,
                    ind_emb1_unique, ind_emb2_unique, ref_subset_indices, args, device, use_gpu,
                    args.concat_seed_pairs, args.anchor_mode)
            else:
                # Fallback to legacy method (computes distances per-ensemble)
                args_tuple = (
                    ensemble_idx, ref_emb1, ref_emb2, emb1_unique, emb2_unique,
                    ind_emb1_unique, ind_emb2_unique, len(ref_subset_indices), vars(args), use_gpu, 0,
                    ref_indices1, ref_indices2, ori_ref_emb1, ori_ref_emb2, args.anchor_mode, args.concat_seed_pairs,
                    ref_subset_indices, is_normalized  # Pass pre-normalization flag
                )
                _, subset_mutual_pairs, subset_accuracy, mutual_nn, _, _ = run_single_ensemble_gpu(args_tuple)

            # Update vote matrix
            for i, nearest_i, dist_between_pair in subset_mutual_pairs:
                vote_matrix[i, nearest_i] += 1
                pair_key = (i, nearest_i)
                mutual_pair_dist[pair_key] = dist_between_pair

            subset_accuracies.append(subset_accuracy)
            all_subset_pairs.append(subset_mutual_pairs)

            logger.info(f"Ensemble {ensemble_idx+1}: {mutual_nn} mutual pairs, accuracy: {subset_accuracy:.3f}")

            # Aggressive memory clearing for large datasets
            if aggressive_memory_clear:
                import gc
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

    elapsed_time = time.time() - start_time
    logger.debug(f"Ensemble computation completed in {elapsed_time:.2f} seconds")

    # Convert to CSR format for efficient operations
    vote_matrix = vote_matrix.tocsr()

    # Select pairs using vote threshold
    min_votes = max(1, int(n_ensembles * vote_threshold))

    # Count total number of pairs with votes > 0
    total_pairs_with_votes = vote_matrix.nnz
    min_threshold_pairs = max(5, total_pairs_with_votes // 20)

    # Start with threshold, then fall back if needed
    frequent_pairs = []
    final_min_votes = min_votes

    # OPTIMIZED: Convert to COO once, then filter with vectorized operations
    cx = vote_matrix.tocoo()
    coo_rows, coo_cols, coo_data = cx.row, cx.col, cx.data

    for min_votes_loop in range(min_votes, 0, -1):
        # Vectorized filtering instead of Python loop
        mask = coo_data >= min_votes_loop
        n_candidates = mask.sum()

        logger.debug(f"Threshold {min_votes_loop}: Found {n_candidates} pairs with >= {min_votes_loop} votes")

        # Use this threshold if we have enough pairs or if we're at minimum threshold
        if n_candidates >= min_threshold_pairs or min_votes_loop == 1:
            # Build candidate pairs only when we find the right threshold
            filtered_rows = coo_rows[mask]
            filtered_cols = coo_cols[mask]
            filtered_data = coo_data[mask]
            # Sort by vote count (descending)
            sort_idx = np.argsort(-filtered_data)
            candidate_pairs = [((filtered_rows[i], filtered_cols[i]), filtered_data[i]) for i in sort_idx]
            frequent_pairs = candidate_pairs
            final_min_votes = min_votes_loop
            break

    logger.debug(f"Selected {len(frequent_pairs)} pairs using min_votes={final_min_votes} (initial_min_votes={min_votes}, out of {total_pairs_with_votes} total pairs)")

    # Log sparse vote matrix statistics
    if total_pairs_with_votes > 0:
        vote_data = vote_matrix.data
        max_votes = vote_data.max()
        mean_votes = vote_data.mean()
        sparsity = 1 - (total_pairs_with_votes / (n2 * n1))
        logger.debug(f"Sparse vote matrix statistics: non-zero entries={total_pairs_with_votes}/{n2*n1}, sparsity={sparsity:.4f}, max votes={max_votes}, mean votes (non-zero)={mean_votes:.2f}")
    else:
        logger.debug(f"Sparse vote matrix statistics: no pairs received any votes")

    # Extract indices from frequent_pairs: list of ((idx2, idx1), votes)
    pairs_only = [pair for pair, _ in frequent_pairs]
    mutual_pair_dist = [(pair[0], pair[1], mutual_pair_dist[pair]) for pair in pairs_only]
    mutual_pair = deduplicate_pairs(mutual_pair_dist)

    if return_vote_matrix:
        return mutual_pair, vote_matrix
    else:
        return mutual_pair


def compute_pair_weight_from_anchor_distance(dist_vec1, dist_vec2, pair_idx1, pair_idx2):
    """
    Compute continuous weight for a pair based on proximity to nearest shared anchor.

    Args:
        dist_vec1: Distance matrix for space1, shape (n_unique1, n_anchors)
        dist_vec2: Distance matrix for space2, shape (n_unique2, n_anchors)
        pair_idx1: Index in space1 (corresponds to dist_vec1 row)
        pair_idx2: Index in space2 (corresponds to dist_vec2 row)

    Returns:
        weight: Float in range [0, 1], higher means closer to anchors
    """
    # Get distance vectors for both points
    dists_1 = dist_vec1[pair_idx1]  # shape: (n_anchors,)
    dists_2 = dist_vec2[pair_idx2]  # shape: (n_anchors,)

    # Sum element-wise to find combined distance to each anchor
    combined_distances = dists_1 + dists_2  # shape: (n_anchors,)

    # Find minimum combined distance (nearest shared anchor)
    min_combined_dist = combined_distances.min()

    # Convert to weight using inverse formula
    weight = 1.0 / (1.0 + min_combined_dist)

    return float(weight)


def ensemble_reference_selection_bernoulli(ref_emb1, ref_emb2, emb1_unique, emb2_unique, ind_emb1_unique, ind_emb2_unique,
                                           args, device, n_ensembles=10, subset_ratio=0.3,
                                           ref_indices1=None, ref_indices2=None, ori_ref_emb1=None, ori_ref_emb2=None,
                                           pair_history=None, posterior_threshold=0.1, ensemble_strategy='random',
                                           use_distance_weighting=True, distance_filter_percentile=0.2,
                                           posterior_strategy='iteration_based', current_iteration=1, max_iterations=100,
                                           total_ensembles_run=None, overlap_inference_method='threshold'):
    """
    Ensemble-based reference selection method using Bernoulli trials with posterior distributions.

    For each ensemble:
    1. Select a subset of points from ref_emb1/ref_emb2 based on strategy (random/cluster/furthest)
    2. Compute distance vectors using the subset
    3. Find mutual pairs and treat as Bernoulli trials (success/failure for each candidate pair)
    4. Update Beta posterior distributions for each pair's success probability
    5. Sample from posterior distributions to select pairs

    Args:
        ref_emb1, ref_emb2: Reference embeddings
        emb1_unique, emb2_unique: Unique embeddings to compute distance vectors for
        ind_emb1_unique, ind_emb2_unique: Original indices of unique embeddings
        args: Arguments containing topk, distance_metric, etc.
        device: Device for computations
        n_ensembles: Number of ensemble runs (also number of clusters for 'cluster' strategy)
        subset_ratio: Percentage of reference points to use in each ensemble (for 'random' and 'furthest')
        ensemble_strategy: Strategy for selecting reference subsets:
            - 'random': Random sampling (default)
            - 'cluster': Cluster into n_ensembles clusters, each cluster votes
            - 'furthest': Each subset contains maximally dispersed points
        pair_history: Dictionary tracking Beta parameters (alpha, beta) for each pair
        posterior_threshold: Base threshold for posterior sampling
        use_distance_weighting: If True, filter pairs by distance percentile and use binary weighting.
            Only pairs in top distance_filter_percentile (closest to anchors) are considered.
            If False, use all pairs with binary weighting (1.0 for success, 0.0 for failure).
        distance_filter_percentile: Float in (0, 1]. When use_distance_weighting=True,
            only keep pairs in top X percentile by average distance to anchors (default: 0.2 = top 20%).
        posterior_strategy: Strategy for posterior selection:
            - 'standard': Use fixed posterior_threshold
            - 'iteration_based': Adjust threshold based on iteration progress
        current_iteration: Current iteration number (1-based)
        max_iterations: Maximum number of iterations for normalization
        total_ensembles_run: Total number of ensembles run across all iterations (for intrinsic strategy)
        overlap_inference_method: Method for inferring which pairs are true overlaps:
            - 'threshold': Use fixed/iteration-based threshold (default, original behavior)
            - 'adaptive': Use adaptive inference combining multiple methods
            - 'otsu': Use Otsu's thresholding
            - 'gmm': Use Gaussian Mixture Model
            - 'elbow': Use elbow/knee detection
            - 'expected': Use expected count from posterior sum
            - 'gap': Use gap statistic

    Returns:
        mutual_pair: Final mutual pairs selected by posterior sampling
        pair_history: Updated Beta parameters for next iteration
        posterior_stats: Dictionary with credibility metrics for each selected pair
    """

    if len(ref_emb1) == 0:
        return [], {}, {}

    n_ref = len(ref_emb1)

    # Initialize pair history if not provided (first iteration)
    if pair_history is None:
        pair_history = {}
    use_gpu = device.type == 'cuda' if hasattr(device, 'type') else False
    # For cluster strategy, subset_size is determined by clustering
    # For random and furthest, use subset_ratio
    if ensemble_strategy == 'cluster':
        subset_size = None  # Will be determined by cluster sizes
    else:
        subset_size = max(5, int(n_ref * subset_ratio))

    # Generate subsets based on strategy
    logger.debug(f"Generating {n_ensembles} subsets using '{ensemble_strategy}' strategy for Bernoulli trials")
    if ensemble_strategy == 'cluster':
        # For cluster strategy, we don't need to specify subset_size
        # Each cluster will be its own subset
        subset_indices_list = generate_ensemble_subsets(
            ref_emb1, n_ensembles, subset_size=0, strategy=ensemble_strategy,
            distance_metric=args.distance_metric,
            use_gpu=use_gpu,
            device=device
        )
        # Calculate average subset size from generated clusters for memory estimation
        subset_size = int(np.mean([len(indices) for indices in subset_indices_list]))
        logger.debug(f"Cluster strategy: average subset size = {subset_size}")
    else:
        subset_indices_list = generate_ensemble_subsets(
            ref_emb1, n_ensembles, subset_size, strategy=ensemble_strategy,
            distance_metric=args.distance_metric,
            use_gpu=use_gpu,
            device=device
        )

    enable_parallel = getattr(args, 'enable_parallel_ensemble', True)
    n_gpus = torch.cuda.device_count() if use_gpu else 1

    logger.debug(f"Running Bernoulli trial ensemble selection: {n_ensembles} ensembles using '{ensemble_strategy}' strategy")
    logger.debug(f"Parameters: n_ensembles={n_ensembles}, subset_ratio={subset_ratio:.2f}, current pairs tracked: {len(pair_history)}")
    logger.debug(f"Using {'GPU' if use_gpu else 'CPU'} acceleration with {'parallel' if enable_parallel else 'sequential'} execution")

    start_time = time.time()

    # OPTIMIZATION: Pre-normalize embeddings once for cosine distance to avoid redundant normalization
    # in each ensemble worker. This is a major speedup for large n_ensembles.
    is_normalized = False
    if args.distance_metric == 'cosine':
        logger.debug("Pre-normalizing embeddings for cosine distance (avoids redundant normalization in workers)")
        # Normalize unique embeddings
        emb1_unique = emb1_unique / (np.linalg.norm(emb1_unique, axis=1, keepdims=True) + 1e-8)
        emb2_unique = emb2_unique / (np.linalg.norm(emb2_unique, axis=1, keepdims=True) + 1e-8)
        # Normalize reference embeddings
        ref_emb1 = ref_emb1 / (np.linalg.norm(ref_emb1, axis=1, keepdims=True) + 1e-8)
        ref_emb2 = ref_emb2 / (np.linalg.norm(ref_emb2, axis=1, keepdims=True) + 1e-8)
        # Normalize original reference embeddings if present
        if ori_ref_emb1 is not None:
            ori_ref_emb1 = ori_ref_emb1 / (np.linalg.norm(ori_ref_emb1, axis=1, keepdims=True) + 1e-8)
        if ori_ref_emb2 is not None:
            ori_ref_emb2 = ori_ref_emb2 / (np.linalg.norm(ori_ref_emb2, axis=1, keepdims=True) + 1e-8)
        is_normalized = True

    # OPTIMIZATION: Precompute full distance matrices ONCE instead of per-ensemble
    use_precomputed_distances = getattr(args, 'use_precomputed_distances', True)
    full_dist_vec1, full_dist_vec2, ori_dist_vec1, ori_dist_vec2 = None, None, None, None

    if use_precomputed_distances:
        logger.debug("OPTIMIZATION: Using precomputed distance matrices (compute once, index per-ensemble)")
        full_dist_vec1, full_dist_vec2, ori_dist_vec1, ori_dist_vec2 = precompute_full_distance_matrices(
            emb1_unique, emb2_unique, ref_emb1, ref_emb2,
            ori_ref_emb1, ori_ref_emb2, args, device, use_gpu, is_normalized)
        # With precomputed distances, sequential is faster than parallel
        logger.debug("Using sequential execution with precomputed distances (optimal for this case)")
        enable_parallel = False

    # Collect all candidate pairs across ensembles and their distances
    all_candidate_pairs = set()
    pair_distances = {}
    pair_discovery_map = {}  # Maps pair_key -> ensemble_idx that first discovered it
    pair_ensemble_votes = {}  # Maps pair_key -> list of ensemble indices that found this pair
    ensemble_distance_vecs = {}  # Store distance vectors for each ensemble
    ensemble_mutual_pairs = {}  # Store mutual pairs for reuse in Phase 2 (avoid re-running ensembles)

    # Run ensembles to collect all possible candidate pairs (Phase 1)
    if enable_parallel and n_ensembles > 1 and use_gpu and torch.cuda.is_available():
        # Check memory first before attempting parallel execution
        ctx = mp.get_context('spawn')
        args_dict = vars(args)

        # Calculate embedding dimensions for memory estimation
        emb_dim1 = ref_emb1.shape[1] if len(ref_emb1.shape) > 1 else ref_emb1.shape[0]
        emb_dim2 = ref_emb2.shape[1] if len(ref_emb2.shape) > 1 else ref_emb2.shape[0]
        embedding_dim_avg = (emb_dim1 + emb_dim2) // 2

        # Calculate safe max_workers and check if should use sequential
        max_parallel_workers = getattr(args, 'max_parallel_workers', None)
        max_workers, should_use_sequential = calculate_safe_max_workers(
            n_ensembles=n_ensembles,
            n_unique_samples=len(emb1_unique),
            n_ref_subset=subset_size,
            embedding_dim_avg=embedding_dim_avg,
            use_gpu=use_gpu,
            device=device,
            safety_factor=0.25,
            max_parallel_workers=max_parallel_workers
        )

        # Force sequential execution if memory too low
        if should_use_sequential:
            logger.warning(f"Phase 1: Memory too low for parallel execution, using sequential mode")
            enable_parallel = False
        else:
            # Parallel GPU execution
            logger.debug(f"Phase 1: Collecting candidate pairs in parallel across {n_gpus} GPU(s)")

            ensemble_args = []
            for ensemble_idx in range(n_ensembles):
                gpu_id = ensemble_idx % n_gpus if n_gpus > 1 else 0
                # Use pre-generated subset indices
                ref_subset_indices = subset_indices_list[ensemble_idx]
                ensemble_args.append((
                    ensemble_idx, ref_emb1, ref_emb2, emb1_unique, emb2_unique,
                    ind_emb1_unique, ind_emb2_unique, len(ref_subset_indices), args_dict, True, gpu_id,
                    ref_indices1, ref_indices2, ori_ref_emb1, ori_ref_emb2, args.anchor_mode, args.concat_seed_pairs,
                    ref_subset_indices, is_normalized  # Pass pre-normalization flag
                ))

            # Run in batches if memory-constrained
            results_dict = run_ensembles_in_batches(ensemble_args, max_workers, n_ensembles, ctx)

            # Collect results
            for ensemble_idx in range(n_ensembles):
                if ensemble_idx in results_dict:
                    subset_mutual_pairs, subset_accuracy, mutual_nn, dist_vec1, dist_vec2 = results_dict[ensemble_idx]

                    # Store distance vectors for use in Phase 2
                    ensemble_distance_vecs[ensemble_idx] = (dist_vec1, dist_vec2)

                    # Store mutual pairs for reuse in Phase 2 (avoid re-running ensembles)
                    ensemble_mutual_pairs[ensemble_idx] = subset_mutual_pairs

                    # Collect all pairs that appear in this ensemble
                    for i, nearest_i, dist_between_pair in subset_mutual_pairs:
                        pair_key = (i, nearest_i)
                        all_candidate_pairs.add(pair_key)
                        pair_distances[pair_key] = dist_between_pair

                        # Track which ensemble discovered this pair (first discovery)
                        if pair_key not in pair_discovery_map:
                            pair_discovery_map[pair_key] = ensemble_idx

                    logger.info(f"Ensemble {ensemble_idx+1}: {mutual_nn} mutual pairs, accuracy: {subset_accuracy:.3f}")
    else:
        # Sequential execution
        logger.debug("Phase 1: Collecting candidate pairs sequentially")
        aggressive_memory_clear = getattr(args, 'aggressive_memory_clear', False)
        for ensemble_idx in range(n_ensembles):
            # Use pre-generated subset indices
            ref_subset_indices = subset_indices_list[ensemble_idx]

            # OPTIMIZATION: Use precomputed distances if available (much faster!)
            if use_precomputed_distances and full_dist_vec1 is not None:
                _, subset_mutual_pairs, subset_accuracy, mutual_nn, dist_vec1, dist_vec2 = process_ensemble_with_precomputed_distances(
                    ensemble_idx, full_dist_vec1, full_dist_vec2, ori_dist_vec1, ori_dist_vec2,
                    ind_emb1_unique, ind_emb2_unique, ref_subset_indices, args, device, use_gpu,
                    args.concat_seed_pairs, args.anchor_mode)
            else:
                # Fallback to legacy method
                args_tuple = (
                    ensemble_idx, ref_emb1, ref_emb2, emb1_unique, emb2_unique,
                    ind_emb1_unique, ind_emb2_unique, len(ref_subset_indices), vars(args), use_gpu, 0,
                    ref_indices1, ref_indices2, ori_ref_emb1, ori_ref_emb2, args.anchor_mode, args.concat_seed_pairs,
                    ref_subset_indices, is_normalized  # Pass pre-normalization flag
                )
                _, subset_mutual_pairs, subset_accuracy, mutual_nn, dist_vec1, dist_vec2 = run_single_ensemble_gpu(args_tuple)

            # Store distance vectors for use in Phase 2
            ensemble_distance_vecs[ensemble_idx] = (dist_vec1, dist_vec2)

            # Store mutual pairs for reuse in Phase 2 (avoid re-running ensembles)
            ensemble_mutual_pairs[ensemble_idx] = subset_mutual_pairs

            # Collect all pairs that appear in this ensemble
            for i, nearest_i, dist_between_pair in subset_mutual_pairs:
                pair_key = (i, nearest_i)
                all_candidate_pairs.add(pair_key)
                pair_distances[pair_key] = dist_between_pair

                # Track which ensemble discovered this pair (first discovery)
                if pair_key not in pair_discovery_map:
                    pair_discovery_map[pair_key] = ensemble_idx

            logger.info(f"Ensemble {ensemble_idx+1}: {mutual_nn} mutual pairs, accuracy: {subset_accuracy:.3f}")

            # Aggressive memory clearing for large datasets
            if aggressive_memory_clear:
                import gc
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

    logger.debug(f"Total unique candidate pairs found: {len(all_candidate_pairs)}")

    # Phase 1.5: Compute average distances to DISCOVERING SUBSETS for filtering (if enabled)
    if use_distance_weighting and len(all_candidate_pairs) > 0:
        logger.debug(f"Computing average distances to discovering subsets for {len(all_candidate_pairs)} candidate pairs...")

        full_anchor_emb1 = ref_emb1
        full_anchor_emb2 = ref_emb2

        # Group pairs by discovering ensemble for efficient batch processing
        ensemble_to_pairs = {}
        for pair_key in all_candidate_pairs:
            discovering_ensemble = pair_discovery_map.get(pair_key, 0)  # Default to 0 if missing
            if discovering_ensemble not in ensemble_to_pairs:
                ensemble_to_pairs[discovering_ensemble] = []
            ensemble_to_pairs[discovering_ensemble].append(pair_key)

        pair_avg_distances = {}

        # Process each ensemble's pairs in batch
        for discovering_ensemble, pairs_in_group in ensemble_to_pairs.items():
            # Get the discovering subset indices
            discovering_subset_indices = subset_indices_list[discovering_ensemble]

            # Extract subset anchor embeddings
            subset_anchor_emb1 = full_anchor_emb1[discovering_subset_indices]  # (subset_size, d)
            subset_anchor_emb2 = full_anchor_emb2[discovering_subset_indices]  # (subset_size, d)

            # Convert pairs to arrays for batch processing
            indices_i = np.array([i for i, _ in pairs_in_group])
            indices_nearest_i = np.array([nearest_i for _, nearest_i in pairs_in_group])

            # Get batch embeddings
            batch_emb2 = emb2_unique[indices_i]  # (n_pairs, d)
            batch_emb1 = emb1_unique[indices_nearest_i]  # (n_pairs, d)

            # Compute distances to discovering subset only (use is_normalized since pre-normalized earlier)
            if use_gpu and torch.cuda.is_available():
                try:
                    batch_emb1_t = torch.from_numpy(batch_emb1).to(device).float()
                    batch_emb2_t = torch.from_numpy(batch_emb2).to(device).float()
                    subset_anchor_emb1_t = torch.from_numpy(subset_anchor_emb1).to(device).float()
                    subset_anchor_emb2_t = torch.from_numpy(subset_anchor_emb2).to(device).float()

                    dist_matrix1 = get_dists(batch_emb1_t, subset_anchor_emb1_t,
                                            metric=args.distance_metric, use_gpu=True, device=device,
                                            is_normalized=is_normalized)
                    dist_matrix2 = get_dists(batch_emb2_t, subset_anchor_emb2_t,
                                            metric=args.distance_metric, use_gpu=True, device=device,
                                            is_normalized=is_normalized)

                    dist_matrix1 = dist_matrix1.cpu().numpy()
                    dist_matrix2 = dist_matrix2.cpu().numpy()
                except Exception as e:
                    logger.warning(f"GPU distance computation failed for ensemble {discovering_ensemble}, falling back to CPU: {e}")
                    dist_matrix1 = get_dists(batch_emb1, subset_anchor_emb1, metric=args.distance_metric, use_gpu=False, is_normalized=is_normalized)
                    dist_matrix2 = get_dists(batch_emb2, subset_anchor_emb2, metric=args.distance_metric, use_gpu=False, is_normalized=is_normalized)
            else:
                dist_matrix1 = get_dists(batch_emb1, subset_anchor_emb1, metric=args.distance_metric, use_gpu=False, is_normalized=is_normalized)
                dist_matrix2 = get_dists(batch_emb2, subset_anchor_emb2, metric=args.distance_metric, use_gpu=False, is_normalized=is_normalized)

            # Compute average distances per pair
            for local_idx, pair_key in enumerate(pairs_in_group):
                # Combine distances from both spaces
                all_distances = np.concatenate([dist_matrix1[local_idx], dist_matrix2[local_idx]])
                avg_distance = np.mean(all_distances)
                pair_avg_distances[pair_key] = avg_distance

        # Filter pairs: keep only top distance_filter_percentile (closest to discovering subsets)
        all_distances_array = np.array(list(pair_avg_distances.values()))
        threshold_distance = np.percentile(all_distances_array, distance_filter_percentile * 100)

        filtered_candidate_pairs = {
            pair_key for pair_key, dist in pair_avg_distances.items()
            if dist <= threshold_distance
        }

        logger.debug(f"Distance filtering (using discovering subsets): "
                   f"kept {len(filtered_candidate_pairs)}/{len(all_candidate_pairs)} pairs "
                   f"within top {distance_filter_percentile*100:.1f}% (threshold distance: {threshold_distance:.4f})")
        logger.debug(f"Average distance range: [{all_distances_array.min():.4f}, {all_distances_array.max():.4f}], "
                   f"mean: {all_distances_array.mean():.4f}")

        # Update all_candidate_pairs to filtered set
        all_candidate_pairs = filtered_candidate_pairs

    # Phase 2: Use stored results from Phase 1 (OPTIMIZED - no re-running ensembles)
    # This reuses ensemble_mutual_pairs computed in Phase 1, avoiding redundant computation
    logger.debug("Phase 2: Running Bernoulli trials using stored Phase 1 results (optimized)")

    # Initialize pair history for all candidate pairs
    for pair_key in all_candidate_pairs:
        if pair_key not in pair_history:
            pair_history[pair_key] = {'alpha': 1.0, 'beta': 1.0}

    # Iterate over stored results from Phase 1 (no GPU computation needed)
    for ensemble_idx in range(n_ensembles):
        if ensemble_idx in ensemble_mutual_pairs:
            subset_mutual_pairs = ensemble_mutual_pairs[ensemble_idx]

            # Convert to set for fast lookup
            found_pairs = {(i, nearest_i) for i, nearest_i, _ in subset_mutual_pairs}

            # Update Beta parameters for each candidate pair
            for pair_key in all_candidate_pairs:
                # Bernoulli trial: success if pair found in this ensemble, failure otherwise
                # Note: Distance filtering already happened in Phase 1.5 if enabled
                if pair_key in found_pairs:
                    # Binary weighting: success
                    pair_history[pair_key]['alpha'] += 1.0

                    # Track which ensemble voted for this pair
                    if pair_key not in pair_ensemble_votes:
                        pair_ensemble_votes[pair_key] = []
                    pair_ensemble_votes[pair_key].append(ensemble_idx)
                else:
                    pair_history[pair_key]['beta'] += 1.0  # Failure

    # Sample from posterior distributions to select final pairs
    selected_pairs = []
    posterior_stats = {}

    # First, compute posterior stats for ALL pairs in pair_history
    for pair_key, params in pair_history.items():
        alpha, beta = params['alpha'], params['beta']
        posterior_mean = alpha / (alpha + beta)
        posterior_var = (alpha * beta) / ((alpha + beta)**2 * (alpha + beta + 1))
        posterior_std = np.sqrt(posterior_var)
        credible_interval = beta_dist.interval(0.95, alpha, beta)
        sample = np.random.beta(alpha, beta)

        posterior_stats[pair_key] = {
            'posterior_mean': posterior_mean,
            'posterior_std': posterior_std,
            'credible_interval_95': credible_interval,
            'n_successes': alpha - 1,
            'n_trials': alpha + beta - 2,
            'sample': sample
        }

    # Check if using adaptive overlap inference methods
    if overlap_inference_method != 'threshold':
        # Map method names for adaptive inference
        method_map = {
            'adaptive': 'ensemble',
            'otsu': 'otsu',
            'gmm': 'gmm',
            'elbow': 'elbow',
            'expected': 'expected',
            'gap': 'gap'
        }
        inference_method = method_map.get(overlap_inference_method, 'ensemble')

        # Use adaptive inference to select pairs
        selected_pair_keys, adaptive_threshold, method_info = infer_overlap_adaptive(
            posterior_stats,
            method=inference_method,
            fallback_threshold=posterior_threshold,
            min_pairs=1,
            max_pairs_ratio=1.0
        )

        # Build selected_pairs list from selected keys
        for pair_key in selected_pair_keys:
            stats = posterior_stats[pair_key]
            selected_pairs.append((pair_key, stats['posterior_mean'], stats['posterior_std']**2, stats['sample']))

        logger.debug(f"Adaptive overlap inference ({overlap_inference_method}): selected {len(selected_pairs)} pairs, threshold={adaptive_threshold:.4f}")
        if 'ensemble' in method_info:
            logger.debug(f"  Method estimates: {method_info['ensemble'].get('n_from_methods', {})}")

        # Skip the threshold-based selection below
        effective_threshold = adaptive_threshold

    else:
        # Original threshold-based selection (preserved exactly as before)
        # Compute effective threshold using iteration-based strategy
        if total_ensembles_run is not None:
            total_ensembles_with_current = total_ensembles_run
        else:
            total_ensembles_with_current = current_iteration * n_ensembles

        effective_threshold = (2 * current_iteration + 1) / (2 + total_ensembles_with_current)

        logger.debug(f"Iteration-based posterior: iteration {current_iteration}, "
                   f"total_ensembles={total_ensembles_with_current}, "
                   f"effective_threshold={effective_threshold:.6f} (base={posterior_threshold:.3f})")

        # Select pairs based on threshold using precomputed posterior_stats
        for pair_key, stats in posterior_stats.items():
            if stats['sample'] > effective_threshold:
                selected_pairs.append((pair_key, stats['posterior_mean'], stats['posterior_std']**2, stats['sample']))

        # Sort by posterior sample (or could use posterior mean)
        selected_pairs.sort(key=lambda x: x[3], reverse=True)

        # Limit to reasonable number of pairs
        max_pairs = min(len(selected_pairs), int(len(all_candidate_pairs) * 0.5))
        selected_pairs = selected_pairs[:max_pairs]

        logger.debug(f"Selected {len(selected_pairs)} pairs using Bernoulli trial posterior sampling (iteration_based strategy)")
        logger.debug(f"Effective threshold: {effective_threshold:.6f} (base: {posterior_threshold}), max pairs: {max_pairs}")

    # Convert back to the expected format
    mutual_pair = []
    final_posterior_stats = {}
    pair_voting_refs = {}  # Track which references voted for each pair

    for (i, nearest_i), post_mean, post_var, sample in selected_pairs:
        if (i, nearest_i) in pair_distances:
            dist = pair_distances[(i, nearest_i)]
            mutual_pair.append((i, nearest_i, dist))
            final_posterior_stats[(i, nearest_i)] = posterior_stats[(i, nearest_i)]

            # Track which reference subsets voted for this pair
            # Get the ensemble indices that voted for this pair
            voting_ensembles = pair_ensemble_votes.get((i, nearest_i), [])

            # Collect all reference indices from the voting ensembles
            voting_ref_indices = set()
            for ensemble_idx in voting_ensembles:
                # Get the reference subset indices used in this ensemble
                ref_subset_indices = subset_indices_list[ensemble_idx]
                voting_ref_indices.update(ref_subset_indices)

            # Store the reference indices that voted
            pair_voting_refs[(i, nearest_i)] = sorted(list(voting_ref_indices))

    mutual_pair = deduplicate_pairs(mutual_pair)

    elapsed_time = time.time() - start_time
    logger.debug(f"Bernoulli trial ensemble computation completed in {elapsed_time:.2f} seconds")

    return mutual_pair, pair_history, final_posterior_stats, pair_voting_refs


def ensemble_reference_selection_cluster_voting(ref_emb1, ref_emb2, emb1_unique, emb2_unique, ind_emb1_unique, ind_emb2_unique,
                                                 args, device, ind_nonref, n_clusters=None,
                                                 ref_indices1=None, ref_indices2=None,
                                                 vote_threshold=0.6, ori_ref_emb1=None, ori_ref_emb2=None, return_vote_matrix=False,
                                                 use_clustering_gpu=False):
    """
    Ensemble-based reference selection using cluster-based voting.

    Instead of random sampling, this method:
    1. Clusters reference embeddings into k clusters
    2. Each cluster votes by running ensemble with all its members
    3. Aggregates votes across all clusters using consensus
    4. Runs cluster votes in parallel for efficiency

    This approach localizes reference points to improve matching quality.

    Args:
        ref_emb1, ref_emb2: Reference embeddings
        emb1_unique, emb2_unique: Unique embeddings to compute distance vectors for
        ind_emb1_unique, ind_emb2_unique: Original indices of unique embeddings
        args: Arguments containing topk, distance_metric, etc.
        device: Device for computations
        n_clusters: Number of clusters (if None, auto-determined from subset_ratio)
        vote_threshold: Minimum vote ratio to accept a pair
        ori_ref_emb1, ori_ref_emb2: Original reference embeddings (for supervised mode)
        return_vote_matrix: If True, return vote matrix along with mutual pairs
        use_clustering_gpu: Use GPU for clustering

    Returns:
        mutual_pair: Final mutual pairs selected by cluster ensemble
        vote_matrix (optional): Sparse CSR matrix if return_vote_matrix=True
    """

    if len(ref_emb1) == 0:
        if return_vote_matrix:
            n1 = len(emb1_unique)
            n2 = len(emb2_unique)
            return [], csr_matrix((n2, n1), dtype=np.int32)
        else:
            return []

    n_ref = len(ref_emb1)

    # Auto-determine number of clusters based on subset_ratio if not provided
    # Each cluster will act as one "ensemble" with all its members
    subset_ratio = getattr(args, 'ensemble_subset_ratio', 0.3)
    if n_clusters is None:
        desired_subset_size = max(1, int(n_ref * subset_ratio))
        # Create clusters such that average cluster size ≈ desired_subset_size
        n_clusters = max(2, min(n_ref // desired_subset_size, 20))

    # Ensure n_clusters is valid
    n_clusters = min(n_clusters, n_ref)

    logger.debug(f"Cluster-based ensemble: clustering {n_ref} reference points into {n_clusters} clusters")
    logger.debug(f"Target cluster size: ~{n_ref // n_clusters} (subset_ratio={subset_ratio:.2f})")

    # Cluster reference embeddings using emb1 only
    ref_emb_combined = ref_emb1

    clusterer = Clusterer(method='kmeans', n_clusters=n_clusters, use_gpu=use_clustering_gpu)
    cluster_labels = clusterer.fit(ref_emb_combined)

    # Group cluster members
    cluster_members = {}
    for cluster_id in range(n_clusters):
        cluster_mask = cluster_labels == cluster_id
        cluster_member_indices = np.where(cluster_mask)[0]
        if len(cluster_member_indices) > 0:
            cluster_members[cluster_id] = cluster_member_indices

    logger.debug(f"Created {len(cluster_members)} non-empty clusters")
    for cluster_id, members in cluster_members.items():
        logger.debug(f"  Cluster {cluster_id}: {len(members)} members")

    # Initialize sparse vote matrix
    n1 = len(emb1_unique)
    n2 = len(emb2_unique)
    vote_matrix = lil_matrix((n2, n1), dtype=np.int32)

    mutual_pair_dist = dict()
    cluster_accuracies = []

    use_gpu = device.type == 'cuda' if hasattr(device, 'type') else False
    enable_parallel = getattr(args, 'enable_parallel_ensemble', True)
    n_gpus = torch.cuda.device_count() if use_gpu else 1

    logger.debug(f"Running cluster-based ensemble: {len(cluster_members)} clusters")
    logger.debug(f"Parameters: vote_threshold={vote_threshold:.2f}")
    logger.debug(f"Using {'GPU' if use_gpu else 'CPU'} acceleration with {'parallel' if enable_parallel else 'sequential'} execution")

    start_time = time.time()

    # OPTIMIZATION: Pre-normalize embeddings once for cosine distance to avoid redundant normalization
    # in each cluster worker. This is a major speedup for large number of clusters.
    is_normalized = False
    if args.distance_metric == 'cosine':
        logger.debug("Pre-normalizing embeddings for cosine distance (avoids redundant normalization in workers)")
        # Normalize unique embeddings
        emb1_unique = emb1_unique / (np.linalg.norm(emb1_unique, axis=1, keepdims=True) + 1e-8)
        emb2_unique = emb2_unique / (np.linalg.norm(emb2_unique, axis=1, keepdims=True) + 1e-8)
        # Normalize reference embeddings
        ref_emb1 = ref_emb1 / (np.linalg.norm(ref_emb1, axis=1, keepdims=True) + 1e-8)
        ref_emb2 = ref_emb2 / (np.linalg.norm(ref_emb2, axis=1, keepdims=True) + 1e-8)
        # Normalize original reference embeddings if present
        if ori_ref_emb1 is not None:
            ori_ref_emb1 = ori_ref_emb1 / (np.linalg.norm(ori_ref_emb1, axis=1, keepdims=True) + 1e-8)
        if ori_ref_emb2 is not None:
            ori_ref_emb2 = ori_ref_emb2 / (np.linalg.norm(ori_ref_emb2, axis=1, keepdims=True) + 1e-8)
        is_normalized = True

    if enable_parallel and len(cluster_members) > 1:
        # Parallel execution across clusters
        try:
            if use_gpu and torch.cuda.is_available():
                logger.debug(f"Using ProcessPoolExecutor for GPU-based parallel cluster voting across {n_gpus} GPU(s)")

                ctx = mp.get_context('spawn')
                args_dict = vars(args)

                # Prepare arguments for each cluster
                ensemble_args = []
                for cluster_idx, (cluster_id, member_indices) in enumerate(cluster_members.items()):
                    gpu_id = cluster_idx % n_gpus if n_gpus > 1 else 0
                    ensemble_args.append((
                        cluster_id, ref_emb1, ref_emb2, emb1_unique, emb2_unique,
                        ind_emb1_unique, ind_emb2_unique, len(member_indices), args_dict, True, gpu_id,
                        ref_indices1, ref_indices2, ori_ref_emb1, ori_ref_emb2, args.anchor_mode, args.concat_seed_pairs,
                        member_indices, is_normalized  # Pass pre-normalization flag
                    ))

                max_workers = len(cluster_members)

                with ProcessPoolExecutor(max_workers=max_workers, mp_context=ctx) as executor:
                    future_to_idx = {
                        executor.submit(run_single_ensemble_gpu, args): args[0]
                        for args in ensemble_args
                    }

                    results_dict = {}
                    for future in as_completed(future_to_idx):
                        cluster_id, cluster_mutual_pairs, cluster_accuracy, mutual_nn, _, _ = future.result()
                        results_dict[cluster_id] = (cluster_mutual_pairs, cluster_accuracy, mutual_nn)

                    # Process results in order
                    for cluster_id in cluster_members.keys():
                        if cluster_id in results_dict:
                            cluster_mutual_pairs, cluster_accuracy, mutual_nn, _, _ = results_dict[cluster_id]

                            # Update vote matrix
                            for i, nearest_i, dist_between_pair in cluster_mutual_pairs:
                                vote_matrix[i, nearest_i] += 1
                                pair_key = (i, nearest_i)
                                mutual_pair_dist[pair_key] = dist_between_pair

                            cluster_accuracies.append(cluster_accuracy)

                            logger.debug(f"Cluster {cluster_id}: {mutual_nn} mutual pairs, accuracy: {cluster_accuracy:.3f}")

            else:
                # CPU parallel execution
                logger.debug("Using ProcessPoolExecutor for CPU-based parallel cluster voting")

                ctx = mp.get_context('spawn')
                args_dict = vars(args)

                # Prepare arguments for each cluster
                ensemble_args = []
                for cluster_id, member_indices in cluster_members.items():
                    ensemble_args.append((
                        cluster_id, ref_emb1, ref_emb2, emb1_unique, emb2_unique,
                        ind_emb1_unique, ind_emb2_unique, len(member_indices), args_dict, False, None,
                        ref_indices1, ref_indices2, ori_ref_emb1, ori_ref_emb2, args.anchor_mode, args.concat_seed_pairs,
                        member_indices, is_normalized  # Pass pre-normalization flag
                    ))

                max_workers = min(len(cluster_members), mp.cpu_count())

                with ProcessPoolExecutor(max_workers=max_workers, mp_context=ctx) as executor:
                    future_to_idx = {
                        executor.submit(run_single_ensemble_gpu, args): args[0]
                        for args in ensemble_args
                    }

                    results_dict = {}
                    for future in as_completed(future_to_idx):
                        cluster_id, cluster_mutual_pairs, cluster_accuracy, mutual_nn, _, _ = future.result()
                        results_dict[cluster_id] = (cluster_mutual_pairs, cluster_accuracy, mutual_nn)

                    # Process results in order
                    for cluster_id in cluster_members.keys():
                        if cluster_id in results_dict:
                            cluster_mutual_pairs, cluster_accuracy, mutual_nn, _, _ = results_dict[cluster_id]

                            # Update vote matrix
                            for i, nearest_i, dist_between_pair in cluster_mutual_pairs:
                                vote_matrix[i, nearest_i] += 1
                                pair_key = (i, nearest_i)
                                mutual_pair_dist[pair_key] = dist_between_pair

                            cluster_accuracies.append(cluster_accuracy)

                            logger.debug(f"Cluster {cluster_id}: {mutual_nn} mutual pairs, accuracy: {cluster_accuracy:.3f}")

        except Exception as e:
            logger.warning(f"Parallel execution failed: {e}. Falling back to sequential execution.")
            enable_parallel = False

    if not enable_parallel or len(cluster_members) <= 1:
        # Sequential execution
        logger.debug("Running cluster voting sequentially")
        for cluster_id, member_indices in cluster_members.items():
            args_tuple = (
                cluster_id, ref_emb1, ref_emb2, emb1_unique, emb2_unique,
                ind_emb1_unique, ind_emb2_unique, len(member_indices), vars(args), use_gpu, 0,
                ref_indices1, ref_indices2, ori_ref_emb1, ori_ref_emb2, args.anchor_mode, args.concat_seed_pairs,
                member_indices, is_normalized  # Pass pre-normalization flag
            )

            _, cluster_mutual_pairs, cluster_accuracy, mutual_nn, _, _ = run_single_ensemble_gpu(args_tuple)

            # Update vote matrix
            for i, nearest_i, dist_between_pair in cluster_mutual_pairs:
                vote_matrix[i, nearest_i] += 1
                pair_key = (i, nearest_i)
                mutual_pair_dist[pair_key] = dist_between_pair

            cluster_accuracies.append(cluster_accuracy)

            logger.debug(f"Cluster {cluster_id}: {mutual_nn} mutual pairs, accuracy: {cluster_accuracy:.3f}")

    elapsed_time = time.time() - start_time
    logger.debug(f"Cluster-based ensemble computation completed in {elapsed_time:.2f} seconds")

    # Convert to CSR format for efficient operations
    vote_matrix = vote_matrix.tocsr()

    # Select pairs using vote threshold
    min_votes = max(1, int(len(cluster_members) * vote_threshold))

    # Count total number of pairs with votes > 0
    total_pairs_with_votes = vote_matrix.nnz
    min_threshold_pairs = max(5, total_pairs_with_votes // 20)

    # Start with threshold, then fall back if needed
    frequent_pairs = []
    final_min_votes = min_votes

    for min_votes_loop in range(min_votes, 0, -1):
        # Extract pairs from sparse vote matrix with votes >= min_votes_loop
        candidate_pairs = []

        # Iterate over non-zero elements in sparse matrix
        cx = vote_matrix.tocoo()
        for i, j, votes in zip(cx.row, cx.col, cx.data):
            if votes >= min_votes_loop:
                candidate_pairs.append(((i, j), votes))

        candidate_pairs.sort(key=lambda x: x[1], reverse=True)

        logger.debug(f"Threshold {min_votes_loop}: Found {len(candidate_pairs)} pairs with >= {min_votes_loop} votes")

        # Use this threshold if we have enough pairs or if we're at minimum threshold
        if len(candidate_pairs) >= min_threshold_pairs or min_votes_loop == 1:
            frequent_pairs = candidate_pairs
            final_min_votes = min_votes_loop
            break

    logger.debug(f"Selected {len(frequent_pairs)} pairs using min_votes={final_min_votes} (initial_min_votes={min_votes}, out of {total_pairs_with_votes} total pairs)")

    # Log statistics
    if total_pairs_with_votes > 0:
        vote_data = vote_matrix.data
        max_votes = vote_data.max()
        mean_votes = vote_data.mean()
        sparsity = 1 - (total_pairs_with_votes / (n2 * n1))
        logger.debug(f"Sparse vote matrix statistics: non-zero entries={total_pairs_with_votes}/{n2*n1}, sparsity={sparsity:.4f}, max votes={max_votes}, mean votes (non-zero)={mean_votes:.2f}")
    else:
        logger.debug(f"Sparse vote matrix statistics: no pairs received any votes")

    # Log average cluster accuracy
    if cluster_accuracies:
        avg_cluster_accuracy = np.mean(cluster_accuracies)
        logger.debug(f"Average cluster accuracy: {avg_cluster_accuracy:.3f}")

    # Extract indices from frequent_pairs
    pairs_only = [pair for pair, _ in frequent_pairs]
    mutual_pair_dist = [(pair[0], pair[1], mutual_pair_dist[pair]) for pair in pairs_only]
    mutual_pair = deduplicate_pairs(mutual_pair_dist)

    if return_vote_matrix:
        return mutual_pair, vote_matrix
    else:
        return mutual_pair


def ensemble_reference_selection_all_points_mnn(ref_emb1, ref_emb2, emb1_unique, emb2_unique, ind_emb1_unique, ind_emb2_unique,
                                                  args, device, ref_indices1=None, ref_indices2=None,
                                                  ori_ref_emb1=None, ori_ref_emb2=None):
    """
    Baseline method that uses ALL points as references to find mutual nearest neighbors.

    This is the simplest baseline approach:
    1. Use all reference points (no sampling or ensembles)
    2. Compute distance vectors for all unique points
    3. Find mutual nearest neighbors directly
    4. Return all MNN pairs found

    This method does NOT use:
    - Bernoulli trials or posterior distributions
    - Voting or threshold-based selection
    - Ensemble subsampling

    Instead, it simply uses all available reference points to compute distance encodings
    and identifies mutual nearest neighbor pairs.

    Args:
        ref_emb1, ref_emb2: Reference embeddings (all points will be used)
        emb1_unique, emb2_unique: Unique embeddings to compute distance vectors for
        ind_emb1_unique, ind_emb2_unique: Original indices of unique embeddings
        args: Arguments containing topk, distance_metric, etc.
        device: Device for computations
        ref_indices1, ref_indices2: Original indices of reference points (for supervised mode)
        ori_ref_emb1, ori_ref_emb2: Original reference embeddings (for anchor mode)

    Returns:
        mutual_pairs: List of (idx2, idx1, dist) tuples representing MNN pairs
    """

    if len(ref_emb1) == 0:
        logger.warning("No reference points available for all-points MNN baseline")
        return []

    n_ref = len(ref_emb1)
    logger.debug(f"Running all-points MNN baseline with {n_ref} reference points")
    logger.debug(f"Computing distance encodings for {len(emb1_unique)} and {len(emb2_unique)} unique points")

    use_gpu = device.type == 'cuda' if hasattr(device, 'type') else False

    start_time = time.time()

    # Compute distance vectors using ALL reference points
    from graph_utils.distance_encoder import compute_distance_encoding

    logger.debug("Computing distance encodings using all reference points...")

    # Use anchor mode if specified
    if args.anchor_mode == 'cluster_representative' and ori_ref_emb1 is not None and ori_ref_emb2 is not None:
        # Use original reference embeddings for anchors
        anchor_ref_emb1 = ori_ref_emb1
        anchor_ref_emb2 = ori_ref_emb2
        logger.debug(f"Using cluster representative anchor mode with {len(anchor_ref_emb1)} anchors")
    else:
        # Use all reference embeddings directly
        anchor_ref_emb1 = ref_emb1
        anchor_ref_emb2 = ref_emb2

    # Compute distance vectors for all unique points
    dist_vec1 = compute_distance_encoding(
        emb1_unique, ref_embeddings=anchor_ref_emb1, distance_metric=args.distance_metric,
        use_gpu=use_gpu, device=device
    )

    dist_vec2 = compute_distance_encoding(
        emb2_unique, ref_embeddings=anchor_ref_emb2, distance_metric=args.distance_metric,
        use_gpu=use_gpu, device=device
    )

    logger.debug(f"Distance encoding computed: dist_vec1 shape={dist_vec1.shape}, dist_vec2 shape={dist_vec2.shape}")

    # Find mutual nearest neighbors
    logger.debug("Finding mutual nearest neighbors...")
    mutual_pairs, mutual_nn, correct = find_mutual_pairs(
        dist_vec1, dist_vec2, ind_emb1_unique, ind_emb2_unique,
        args, device, use_gpu=use_gpu
    )

    elapsed_time = time.time() - start_time

    logger.debug(f"All-points MNN baseline completed in {elapsed_time:.2f} seconds")
    logger.debug(f"Found {len(mutual_pairs)} mutual nearest neighbor pairs")
    logger.debug(f"Mutual NN count: {mutual_nn}, Correct: {correct}")

    if args.anchor_mode == "supervised" and correct > 0:
        accuracy = correct / mutual_nn if mutual_nn > 0 else 0.0
        logger.debug(f"Supervised accuracy: {accuracy:.3f} ({correct}/{mutual_nn})")

    return mutual_pairs
