import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='pkg_resources')

import os
import sys
import faulthandler
import traceback
import gc
import argparse
import random
from typing import Tuple

import numpy as np
import torch
from sklearn.preprocessing import LabelEncoder
from loguru import logger

from utils.load_data import load_npy, DataPartitioner, save_npy, load_beir_qrels, BEIR_AVAILABLE, convert_global_to_local_indices
from utils.clustering import Clusterer
from graph_utils.cluster import Cluster
from utils.graph_util import get_dists
from utils.retrieval_util import compute_accuracy_recall, analyze_distance_based_accuracy
from utils.sample_methods import sample_ref_points
from utils.ensemble_selection import (
    ensemble_reference_selection_voting,
    ensemble_reference_selection_bernoulli,
    ensemble_reference_selection_all_points_mnn
)
from utils.adaptive_ensemble import compute_principled_fixed_strategy_params
from utils.procrustes_util import cluster_wise_procrustes_refinement


def optimize_for_large_datasets(args):
    """
    Optimize execution strategy for large datasets to prevent CUDA OOM while maintaining speed.
    
    Only changes HOW ensembles are executed (parallelism, memory clearing), 
    NOT the algorithm parameters (n_ensembles, subset_ratio, etc.).
    """
    large_datasets = {"scidocs", "fiqa"}
    dataset_name = getattr(args, "dataset", None)
    ref_dataset_name = getattr(args, "ref_dataset", None)

    if dataset_name in large_datasets or ref_dataset_name in large_datasets:
        logger.debug(f"Large dataset detected ({dataset_name or ref_dataset_name}); optimizing execution strategy")
        
        # Enable aggressive memory clearing between ensembles
        args.aggressive_memory_clear = True
        logger.debug("  - Enabled aggressive GPU memory clearing between ensembles")


def compute_annealed_ref_filter_ratio(iteration, max_iter, initial_ratio, final_ratio, 
                                    annealing_type="linear", quality_history=None):
    """
    Compute annealed ref_filter_ratio based on iteration progress and quality metrics.
    
    Args:
        iteration: Current iteration (1-based)
        max_iter: Maximum number of iterations
        initial_ratio: Starting ref_filter_ratio
        final_ratio: Final ref_filter_ratio  
        annealing_type: Type of annealing ("linear", "exponential", "cosine", "quality_adaptive")
        quality_history: List of quality metrics (mean_quality, kept_quality) for adaptive annealing
        
    Returns:
        float: Annealed ref_filter_ratio
    """
    if annealing_type == "none" or max_iter <= 1:
        return initial_ratio
    
    # Normalize iteration progress to [0, 1]
    progress = min((iteration - 1) / (max_iter - 1), 1.0)
    
    if annealing_type == "linear":
        # Linear interpolation from initial to final
        ratio = initial_ratio + (final_ratio - initial_ratio) * progress
        
    elif annealing_type == "exponential":
        # Exponential decay: initial_ratio * (final_ratio/initial_ratio)^progress
        if final_ratio > 0 and initial_ratio > 0:
            ratio = initial_ratio * (final_ratio / initial_ratio) ** progress
        else:
            ratio = initial_ratio * (1 - progress) + final_ratio * progress
            
    elif annealing_type == "cosine":
        # Cosine annealing for smooth transitions - inverted to increase over time
        cosine_progress = 0.5 * (1 - np.cos(np.pi * progress))
        ratio = initial_ratio + (final_ratio - initial_ratio) * cosine_progress
        
    elif annealing_type == "quality_adaptive":
        # Adaptive annealing based on pairwise distance quality improvement
        if quality_history is None or len(quality_history) < 3:
            # Fall back to linear if not enough history
            ratio = initial_ratio + (final_ratio - initial_ratio) * progress
        else:
            # Check quality trends in recent iterations
            recent_qualities = quality_history[-3:]
            mean_qualities = [q[0] for q in recent_qualities]  # mean_quality
            kept_qualities = [q[1] for q in recent_qualities]  # kept_quality
            
            # Check if mean quality is improving
            quality_improving = len(mean_qualities) >= 2 and mean_qualities[-1] > mean_qualities[0]
            
            # Check if kept quality is above a threshold (good references available)
            high_quality = kept_qualities[-1] > 0.5 if kept_qualities else False
            
            if quality_improving and high_quality:
                # Quality is good and improving, be less aggressive with filtering
                ratio = initial_ratio + (final_ratio - initial_ratio) * (progress * 0.3)
            elif quality_improving:
                # Quality improving but not high, moderate filtering
                ratio = initial_ratio + (final_ratio - initial_ratio) * (progress * 0.7)
            elif high_quality:
                # High quality but not improving, standard annealing
                ratio = initial_ratio + (final_ratio - initial_ratio) * progress
            else:
                # Low quality and not improving, be more aggressive with filtering
                ratio = initial_ratio + (final_ratio - initial_ratio) * min(progress * 1.5, 1.0)
                
    else:
        raise ValueError(f"Unknown annealing type: {annealing_type}")
    
    # Ensure ratio stays within reasonable bounds
    ratio = max(0.1, min(1.0, ratio))
    return ratio

def filter_references_by_pairwise_distance_quality(ref_indices1, ref_indices2, emb1, emb2,
                                                   distance_metric="cosine", top_k_ratio=0.7, device=None, return_metrics=False,
                                                   previous_mutual_pairs=None, ind_emb1_unique=None, ind_emb2_unique=None,
                                                   use_multi_gpu=False, gpu_ids=None, multi_gpu_config=None,
                                                   cached_dist_matrices=None):
    """
    Filter reference pairs based on correlation quality and mutual nearest neighbor contribution.
    Uses only reference-to-reference distances (no data leakage) and prioritizes references
    that have good distance correlation and help find more mutual pairs.

    Args:
        ref_indices1, ref_indices2: Current reference indices
        emb1, emb2: Full embedding matrices
        distance_metric: Distance metric to use
        top_k_ratio: Keep top fraction of pairs with best combined score (0.7 = keep top 70%)
        device: Computing device
        return_metrics: If True, return quality metrics along with filtered indices
        previous_mutual_pairs: Previous iteration's mutual pairs for contribution tracking
        ind_emb1_unique, ind_emb2_unique: Unique indices for tracking mutual pair contributions
        use_multi_gpu: Whether to use multiple GPUs
        gpu_ids: List of GPU IDs to use
        multi_gpu_config: Optional dict configuring multi-GPU chunking for get_dists

    Returns:
        filtered_ref_indices1, filtered_ref_indices2: Filtered reference indices
        If return_metrics=True, also returns: (mean_quality, kept_quality, min_quality, max_quality)
    """
    # Ensure indices are NumPy arrays for advanced indexing
    ref_indices1 = np.asarray(ref_indices1, dtype=np.int32)
    ref_indices2 = np.asarray(ref_indices2, dtype=np.int32)

    if len(ref_indices1) < 5:  # Need minimum references for meaningful comparison
        if return_metrics:
            return ref_indices1, ref_indices2, (0.0, 0.0, 0.0, 0.0)
        return ref_indices1, ref_indices2

    use_gpu = device is not None and device.type == 'cuda'
    n_refs = len(ref_indices1)

    # Get reference embeddings only - no data leakage from non-reference points
    ref_emb1 = emb1[ref_indices1]
    ref_emb2 = emb2[ref_indices2]

    # Check if we can reuse cached distance matrices
    current_hash = hash((tuple(ref_indices1.tolist()), tuple(ref_indices2.tolist())))
    dist_matrix1 = None
    dist_matrix2 = None
    cache_hit = False

    if cached_dist_matrices is not None:
        prev_hash, cached_dist1, cached_dist2 = cached_dist_matrices
        if current_hash == prev_hash:
            # Reuse cached matrices - reference indices haven't changed
            logger.debug("Reusing cached reference distance matrices (no reference changes detected)")
            dist_matrix1 = cached_dist1
            dist_matrix2 = cached_dist2
            cache_hit = True

    # Configure multi-GPU settings for this call if not provided
    if multi_gpu_config is None and use_multi_gpu and gpu_ids:
        multi_gpu_config = {
            "enabled": True,
            "gpu_ids": gpu_ids
        }
    multi_gpu_enabled = bool(multi_gpu_config and multi_gpu_config.get("enabled") and multi_gpu_config.get("gpu_ids"))

    # Compute distance matrices only if not cached
    if not cache_hit:
        from utils.memory_util import estimate_matrix_memory_gb, get_available_memory_gb
        estimated_memory = estimate_matrix_memory_gb(n_refs, n_refs)
        available_memory = get_available_memory_gb(use_gpu=use_gpu, device=device)
        use_chunked = estimated_memory > (available_memory * 0.3)  # Use chunked if >30% of available memory

        if multi_gpu_enabled:
            dist_matrix1 = get_dists(
                ref_emb1,
                ref_emb1,
                metric=distance_metric,
                use_gpu=use_gpu,
                device=device,
                multi_gpu_config=multi_gpu_config
            )
            dist_matrix2 = get_dists(
                ref_emb2,
                ref_emb2,
                metric=distance_metric,
                use_gpu=use_gpu,
                device=device,
                multi_gpu_config=multi_gpu_config
            )
        elif use_chunked and n_refs > 1000:
            logger.debug(f"Using chunked distance computation for {n_refs} references (est. {estimated_memory:.2f} GB)")
            chunk_size = max(500, min(5000, n_refs // 5))
            dist_matrix1 = np.zeros((n_refs, n_refs), dtype=np.float32)
            dist_matrix2 = np.zeros((n_refs, n_refs), dtype=np.float32)

            if use_gpu:
                ref_t1 = torch.tensor(ref_emb1, device=device, dtype=torch.float32)
                ref_t2 = torch.tensor(ref_emb2, device=device, dtype=torch.float32)

            for chunk_start in range(0, n_refs, chunk_size):
                chunk_end = min(chunk_start + chunk_size, n_refs)

                if use_gpu:
                    chunk_t1 = torch.tensor(ref_emb1[chunk_start:chunk_end], device=device, dtype=torch.float32)
                    chunk_t2 = torch.tensor(ref_emb2[chunk_start:chunk_end], device=device, dtype=torch.float32)

                    dist_chunk1 = get_dists(chunk_t1, ref_t1, metric=distance_metric, use_gpu=True, device=device)
                    dist_chunk2 = get_dists(chunk_t2, ref_t2, metric=distance_metric, use_gpu=True, device=device)

                    dist_matrix1[chunk_start:chunk_end] = dist_chunk1.cpu().numpy()
                    dist_matrix2[chunk_start:chunk_end] = dist_chunk2.cpu().numpy()
                else:
                    chunk_emb1 = ref_emb1[chunk_start:chunk_end]
                    chunk_emb2 = ref_emb2[chunk_start:chunk_end]
                    dist_matrix1[chunk_start:chunk_end] = get_dists(chunk_emb1, ref_emb1, metric=distance_metric, use_gpu=False)
                    dist_matrix2[chunk_start:chunk_end] = get_dists(chunk_emb2, ref_emb2, metric=distance_metric, use_gpu=False)
        else:
            if use_gpu:
                ref_emb1_t = torch.tensor(ref_emb1, device=device, dtype=torch.float32)
                ref_emb2_t = torch.tensor(ref_emb2, device=device, dtype=torch.float32)
                dist_matrix1 = get_dists(ref_emb1_t, ref_emb1_t, metric=distance_metric, use_gpu=True, device=device)
                dist_matrix2 = get_dists(ref_emb2_t, ref_emb2_t, metric=distance_metric, use_gpu=True, device=device)
            else:
                dist_matrix1 = get_dists(ref_emb1, ref_emb1, metric=distance_metric, use_gpu=False)
                dist_matrix2 = get_dists(ref_emb2, ref_emb2, metric=distance_metric, use_gpu=False)
    
    # Score each reference based on correlation + mutual NN contribution
    # Use GPU-resident vectorized correlation computation for efficiency

    # Criterion 1: Distance vector correlation between embeddings (VECTORIZED ON GPU)
    if use_gpu and isinstance(dist_matrix1, torch.Tensor):
        # Keep distance matrices on GPU for vectorized correlation
        # dist_matrix1, dist_matrix2: (n_refs, n_refs) on GPU

        # Create mask to exclude diagonal (self-distances)
        mask = ~torch.eye(n_refs, dtype=torch.bool, device=device)

        # Apply mask to get filtered distance vectors for all references at once
        # Use advanced indexing to efficiently extract non-diagonal elements
        dist1_filtered = dist_matrix1[:, mask[0]].reshape(n_refs, n_refs - 1)  # (n_refs, n_refs-1)
        dist2_filtered = dist_matrix2[:, mask[0]].reshape(n_refs, n_refs - 1)  # (n_refs, n_refs-1)

        # Vectorized standardization (z-score normalization)
        mean1 = dist1_filtered.mean(dim=1, keepdim=True)
        mean2 = dist2_filtered.mean(dim=1, keepdim=True)
        std1 = dist1_filtered.std(dim=1, keepdim=True) + 1e-8  # Add epsilon for numerical stability
        std2 = dist2_filtered.std(dim=1, keepdim=True) + 1e-8

        dist1_standardized = (dist1_filtered - mean1) / std1
        dist2_standardized = (dist2_filtered - mean2) / std2

        # Vectorized correlation: Pearson correlation coefficient
        # corr = mean(z1 * z2) where z1, z2 are standardized
        correlations = (dist1_standardized * dist2_standardized).mean(dim=1)  # (n_refs,)

        # Handle NaN values (when std is too small despite epsilon)
        correlations = torch.nan_to_num(correlations, nan=0.0)

        # Convert to numpy for subsequent processing (single transfer)
        correlations_np = correlations.cpu().numpy()
    else:
        # OPTIMIZATION 2.3: Vectorized CPU correlation computation
        # CPU fallback: use vectorized numpy operations instead of sequential loop
        if isinstance(dist_matrix1, torch.Tensor):
            dist_matrix1 = dist_matrix1.cpu().numpy()
        if isinstance(dist_matrix2, torch.Tensor):
            dist_matrix2 = dist_matrix2.cpu().numpy()

        # Vectorized approach: extract all non-diagonal distances at once
        dist1_pairs = []
        dist2_pairs = []
        for i in range(n_refs):
            # Remove self-distance (diagonal element)
            dist1_pairs.append(np.delete(dist_matrix1[i], i))
            dist2_pairs.append(np.delete(dist_matrix2[i], i))

        dist1_filtered = np.array(dist1_pairs, dtype=np.float32)  # Shape: (n_refs, n_refs-1)
        dist2_filtered = np.array(dist2_pairs, dtype=np.float32)  # Shape: (n_refs, n_refs-1)

        # Vectorized standardization (z-score normalization)
        mean1 = dist1_filtered.mean(axis=1, keepdims=True)
        mean2 = dist2_filtered.mean(axis=1, keepdims=True)
        std1 = dist1_filtered.std(axis=1, keepdims=True) + 1e-8  # Add epsilon to avoid division by zero
        std2 = dist2_filtered.std(axis=1, keepdims=True) + 1e-8

        dist1_std = (dist1_filtered - mean1) / std1
        dist2_std = (dist2_filtered - mean2) / std2

        # Vectorized correlation: correlation = mean(z1 * z2)
        correlations_np = (dist1_std * dist2_std).mean(axis=1)
        # Handle NaN values (can occur if std is zero despite epsilon)
        correlations_np = np.nan_to_num(correlations_np, nan=0.0)

    # Criterion 2: Mutual pair contribution tracking
    # OPTIMIZED: O(n_refs + n_pairs) instead of O(n_refs × n_pairs)
    mutual_contributions = np.zeros(n_refs, dtype=np.float32)
    if previous_mutual_pairs is not None and ind_emb1_unique is not None and ind_emb2_unique is not None:
        # Pre-compute counts: how many mutual pairs reference each index
        ref1_counts = {}  # ref_idx -> count of appearances in emb1
        ref2_counts = {}  # ref_idx -> count of appearances in emb2
        for i_mutual, j_mutual, _ in previous_mutual_pairs:
            idx1 = ind_emb1_unique[j_mutual]
            idx2 = ind_emb2_unique[i_mutual]
            ref1_counts[idx1] = ref1_counts.get(idx1, 0) + 1
            ref2_counts[idx2] = ref2_counts.get(idx2, 0) + 1

        # Now O(1) lookup per reference instead of O(n_pairs)
        n_mutual_pairs = len(previous_mutual_pairs)
        for i in range(n_refs):
            # Count contributions from both embedding spaces
            mutual_contribution = ref1_counts.get(ref_indices1[i], 0) + ref2_counts.get(ref_indices2[i], 0)
            # Normalize by number of mutual pairs to get contribution rate
            if n_mutual_pairs > 0:
                mutual_contribution /= n_mutual_pairs
            mutual_contributions[i] = mutual_contribution

    # Combine criteria: correlation + mutual NN contribution
    reference_scores = correlations_np + mutual_contributions
    reference_scores = list(reference_scores)
    
    # Select top k% references based on combined score
    reference_scores = np.array(reference_scores)
    n_to_keep = max(2, int(top_k_ratio * n_refs))  # Keep at least 2 references
    
    # Get indices of top k references
    top_indices = np.argsort(reference_scores)[-n_to_keep:]
    
    filtered_indices1 = ref_indices1[top_indices]
    filtered_indices2 = ref_indices2[top_indices]
    
    # Compute quality metrics
    mean_quality = np.mean(reference_scores)
    kept_quality = np.mean(reference_scores[top_indices])
    min_quality = np.min(reference_scores)
    max_quality = np.max(reference_scores)
    
    n_filtered = len(ref_indices1) - len(filtered_indices1)
    if n_filtered > 0:
        logger.debug(f"Filtered out {n_filtered} references using correlation + mutual NN contribution")
        logger.debug(f"Score range: [{min_quality:.4f}, {max_quality:.4f}]")
        logger.debug(f"Mean score: {mean_quality:.4f}, Kept score: {kept_quality:.4f}")
    
    # Prepare cache for next call
    new_cache = (current_hash, dist_matrix1, dist_matrix2)

    if return_metrics:
        return filtered_indices1, filtered_indices2, (mean_quality, kept_quality, min_quality, max_quality), new_cache
    else:
        return filtered_indices1, filtered_indices2, new_cache

# Ensemble functions moved to utils/ensemble_selection.py


def save_iteration_pair_stats(
    iteration, ref_indices1, ref_indices2, ind_nonref,
    emb1, emb2, ref_emb1, ref_emb2,
    posterior_stats, pair_voting_refs, distance_metric, output_dir, args,
    ind_emb1_unique, ind_emb2_unique
):
    """
    Save detailed statistics for each pair discovered in this iteration.

    For each pair (idx1, idx2), save:
    - Correctness (whether idx1 == idx2, i.e., ground truth match)
    - Posterior probability (from Bernoulli trials if available)
    - Average distance from idx1 to references that voted for it
    - Minimum distance from idx1 to references that voted for it
    - Average distance from idx2 to references that voted for it
    - Minimum distance from idx2 to references that voted for it

    Args:
        ref_indices1, ref_indices2: GLOBAL indices of reference pairs
        ind_emb1_unique, ind_emb2_unique: Arrays mapping local indices to global indices
        posterior_stats, pair_voting_refs: Dictionaries with LOCAL indices as keys
    """
    import pandas as pd
    from utils.graph_util import get_dists

    os.makedirs(output_dir, exist_ok=True)

    # Create reverse mapping: global index -> local index
    # This is needed because posterior_stats and pair_voting_refs use local indices
    global_to_local_idx1 = {global_idx: local_idx for local_idx, global_idx in enumerate(ind_emb1_unique)}
    global_to_local_idx2 = {global_idx: local_idx for local_idx, global_idx in enumerate(ind_emb2_unique)}

    # Prepare data rows
    rows = []

    # OPTIMIZATION 1.2: Batch distance computations for all pairs at once
    # Extract all pair embeddings
    pair_indices1 = np.array(ref_indices1)
    pair_indices2 = np.array(ref_indices2)
    batch_emb1 = emb1[pair_indices1]  # Shape: (n_pairs, dim)
    batch_emb2 = emb2[pair_indices2]  # Shape: (n_pairs, dim)

    # Compute distances to ALL references for all pairs at once
    # Shape: (n_pairs, n_refs)
    batch_dists1_to_all_refs = get_dists(batch_emb1, ref_emb1, metric=distance_metric, device='cpu')
    batch_dists2_to_all_refs = get_dists(batch_emb2, ref_emb2, metric=distance_metric, device='cpu')

    for idx, (idx1, idx2) in enumerate(zip(ref_indices1, ref_indices2)):
        # Check correctness (ground truth)
        is_correct = (idx1 == idx2)

        # Convert global indices to local indices for dictionary lookup
        local_idx1 = global_to_local_idx1.get(idx1, None)
        local_idx2 = global_to_local_idx2.get(idx2, None)

        # Get posterior probability if available
        # Note: posterior_stats uses (local_idx2, local_idx1) as key (i, nearest_i format)
        posterior_mean = None
        posterior_std = None
        n_trials = None
        if posterior_stats is not None and local_idx1 is not None and local_idx2 is not None:
            pair_key = (local_idx2, local_idx1)  # (i, nearest_i) format from find_mutual_pairs
            if pair_key in posterior_stats:
                stats = posterior_stats[pair_key]
                posterior_mean = stats.get('posterior_mean', None)
                posterior_std = stats.get('posterior_std', None)
                n_trials = stats.get('n_trials', None)

        # Get voting references for this pair
        # Note: pair_voting_refs uses (local_idx2, local_idx1) as key
        voting_ref_idx = None
        n_voting_refs = 0
        if pair_voting_refs is not None and local_idx1 is not None and local_idx2 is not None:
            pair_key = (local_idx2, local_idx1)  # (i, nearest_i) format
            if pair_key in pair_voting_refs:
                voting_ref_idx = pair_voting_refs[pair_key]
                n_voting_refs = len(voting_ref_idx)

        # Use pre-computed distances and select appropriate subset
        # Get distances for this pair from batch computation
        dists1_to_all_refs = batch_dists1_to_all_refs[idx]  # Shape: (n_refs,)
        dists2_to_all_refs = batch_dists2_to_all_refs[idx]  # Shape: (n_refs,)

        if voting_ref_idx is not None and len(voting_ref_idx) > 0:
            # Select only voting reference distances
            dists1_to_refs = dists1_to_all_refs[voting_ref_idx]
            dists2_to_refs = dists2_to_all_refs[voting_ref_idx]
        else:
            # Use all reference distances
            dists1_to_refs = dists1_to_all_refs
            dists2_to_refs = dists2_to_all_refs

        # Compute statistics
        avg_dist1 = float(np.mean(dists1_to_refs))
        min_dist1 = float(np.min(dists1_to_refs))
        avg_dist2 = float(np.mean(dists2_to_refs))
        min_dist2 = float(np.min(dists2_to_refs))

        row = {
            'iteration': iteration,
            'idx1': idx1,
            'idx2': idx2,
            'is_correct': is_correct,
            'posterior_mean': posterior_mean if posterior_mean is not None else '',
            'posterior_std': posterior_std if posterior_std is not None else '',
            'n_trials': n_trials if n_trials is not None else '',
            'n_voting_refs': n_voting_refs,
            'avg_dist1_to_voting_refs': avg_dist1,
            'min_dist1_to_voting_refs': min_dist1,
            'avg_dist2_to_voting_refs': avg_dist2,
            'min_dist2_to_voting_refs': min_dist2,
        }
        rows.append(row)

    # Save to CSV
    df = pd.DataFrame(rows)

    # Create filename with experiment details
    filename = f"iter_{iteration:03d}_stats.csv"
    if hasattr(args, 'dataset'):
        filename = f"{args.dataset}_{filename}"

    output_path = os.path.join(output_dir, filename)
    df.to_csv(output_path, index=False)
    logger.debug(f"Saved iteration {iteration} pair statistics to {output_path} ({len(rows)} pairs)")


def test_clu(args, seed=None):
    # Set random seed if provided
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

    optimize_for_large_datasets(args)

    # Always use RBF distance transformation (sigma auto-computed downstream)
    args.transformation = "rbf"
    args.transformation_params = None

    # Multi-GPU setup
    n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
    use_multi_gpu = n_gpus > 1 and args.use_gpu

    if use_multi_gpu:
        logger.debug(f"Multi-GPU mode enabled: {n_gpus} GPUs detected")
        device = torch.device("cuda:0")  # Primary device
        gpu_ids = list(range(n_gpus))
        logger.debug(f"Using GPUs: {gpu_ids}")
        for i in range(n_gpus):
            logger.debug(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
    else:
        device = torch.device("cuda" if (torch.cuda.is_available() and args.use_gpu) else "cpu")
        gpu_ids = None

    logger.debug(f"Using device: {device}")
    logger.debug(f"args.use_gpu: {args.use_gpu}")
    logger.debug(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        logger.debug(f"CUDA device count: {n_gpus}")
        if not use_multi_gpu and n_gpus > 0:
            logger.debug(f"CUDA device name: {torch.cuda.get_device_name()}")
    if args.use_gpu and not torch.cuda.is_available():
        logger.warning("GPU requested but CUDA not available, using CPU")

    # Store multi-GPU config in args for downstream functions
    args.use_multi_gpu = use_multi_gpu
    args.gpu_ids = gpu_ids
    args.n_gpus = n_gpus
    multi_gpu_chunk_size = getattr(args, "multi_gpu_chunk_size", None)
    if use_multi_gpu:
        args.multi_gpu_config = {
            "enabled": True,
            "gpu_ids": gpu_ids,
            "chunk_size": multi_gpu_chunk_size
        }
    else:
        args.multi_gpu_config = None

    # Apply memory-efficient mode if enabled
    if getattr(args, 'memory_efficient', False):
        logger.debug("Memory-efficient mode enabled: adjusting parameters for large datasets")

    base_dir = args.base_dir

    # Load embeddings
    if args.dataset in ["scifact", "scidocs", "fiqa", "nfcorpus", "arguana"]:
        emb1 = load_npy(base_dir, f"corpus_embeddings_{args.emb1}_{args.dataset}.npy")
        emb2 = load_npy(base_dir, f"corpus_embeddings_{args.emb2}_{args.dataset}.npy")
    elif args.dataset in ["cifar10"]:
        emb1 = load_npy(base_dir, f"emb_{args.dataset}/{args.dataset}_embeddings_{args.emb1}_{args.emb_dim1}.npy")
        emb2 = load_npy(base_dir, f"emb_{args.dataset}/{args.dataset}_embeddings_{args.emb2}_{args.emb_dim2}.npy")
    elif args.dataset in ["Citeseer", "Cora", "PubMed"]:
        emb1 = load_npy(base_dir, f"emb_{args.dataset}/{args.emb1}.npy")
        emb2 = load_npy(base_dir, f"emb_{args.dataset}/{args.emb2}.npy")
    elif args.dataset in ["biorxiv", "alloprof", "big_patent", "arxivp2p", "plsc", "wikicities", "stack_exchange"]:
        emb1 = load_npy(base_dir, f"{args.emb1}_{args.dataset}/{args.emb1}_{args.dataset}_embeddings.npy")
        emb2 = load_npy(base_dir, f"{args.emb2}_{args.dataset}/{args.emb2}_{args.dataset}_embeddings.npy")
    elif args.dataset in ["coco"]:
        emb1 = load_npy(base_dir, f"{args.dataset}_image_embeddings_{args.emb1}.npy")
        emb2 = load_npy(base_dir, f"{args.dataset}_text_embeddings_{args.emb1}.npy")
    elif args.dataset in ["StackExchangeClustering", "StackExchangeClustering.v2", "TwentyNewsgroups", "TwentyNewsgroupsClustering", "TwentyNewsgroupsClustering.v2", "RedditClustering.v2"]:
        emb1 = load_npy(base_dir, f"texts_embeddings_{args.emb1}_{args.dataset}.npy")
        emb2 = load_npy(base_dir, f"texts_embeddings_{args.emb2}_{args.dataset}.npy")
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
    
    # Load indices
    # Add seed suffix to cache name if seed is provided
    seed_suffix = f"_seed{seed}" if seed is not None else ""

    if args.partition == "cluster_partial":
        ind_file_name = os.path.join(args.cache_dir, f"{args.dataset}_{args.emb1}_{args.partition}{args.n_clusters}_{args.nonref_clu_choices}{seed_suffix}")
    elif args.partition == "random":
        ind_file_name = os.path.join(args.cache_dir, f"{args.dataset}_{args.partition}_{args.overlap_ratio}{seed_suffix}")
    elif args.partition == "la2m":
        ind_file_name = os.path.join(args.cache_dir, f"{args.dataset}_la2m_{args.overlap_ratio}{seed_suffix}")

    if os.path.exists(os.path.join(ind_file_name, "ind1.npy")):
        ind_emb1_unique = load_npy(ind_file_name, "ind1")
        ind_emb2_unique = load_npy(ind_file_name, "ind2")
        ind_nonref = np.intersect1d(ind_emb1_unique, ind_emb2_unique)
        data_partitioner = DataPartitioner.from_indices(ind_emb1_unique, ind_emb2_unique, ind_nonref)

        # Compute cluster labels for Procrustes if enabled
        if args.use_procrustes:
            clusterer = Clusterer(method=args.cluster_method, n_clusters=args.n_clusters, use_gpu=args.use_gpu)
            emb1_cluster_labels = clusterer.fit(emb1)
            emb1_cluster_labels = LabelEncoder().fit_transform(emb1_cluster_labels)
        else:
            emb1_cluster_labels = None
    else:
        if args.partition == "la2m":
            # LA2M partition requires BEIR qrels
            if not BEIR_AVAILABLE:
                raise ImportError("BEIR is required for la2m partition. Install with: pip install beir")
            logger.debug(f"Creating LA2M partition for dataset {args.dataset}")
            qrels, dataset_index, dataset_obj = load_beir_qrels(args.dataset, data_path=args.base_dir)
            labels = np.zeros(len(emb1), dtype=np.int32)
            data_partitioner = DataPartitioner(
                labels=labels,
                total_ind=np.arange(len(emb1)),
                partition_type="la2m",
                overlap_ratio=args.overlap_ratio,
                qrels=qrels,
                dataset_index=dataset_index,
                dataset_obj=dataset_obj,
                select_top_1=True,
                remove_dup_answer=True
            )
            emb1_cluster_labels = None
        else:
            clusterer = Clusterer(method=args.cluster_method, n_clusters=args.n_clusters, use_gpu=args.use_gpu)
            labels = clusterer.fit(emb1)
            labels = LabelEncoder().fit_transform(labels)
            data_partitioner = DataPartitioner(labels, partition_type=args.partition, nonref_clu_choices=args.nonref_clu_choices, overlap_ratio=args.overlap_ratio)
            # Save cluster labels for later use in Procrustes refinement
            emb1_cluster_labels = labels

        ind_emb1_unique = data_partitioner.ind_emb1_unique
        ind_emb2_unique = data_partitioner.ind_emb2_unique
        ind_nonref = data_partitioner.ind_emb1_nonref

        # Always save partition indices to cache (including seed-specific ones)
        save_npy(ind_file_name, "ind1", ind_emb1_unique)
        save_npy(ind_file_name, "ind2", ind_emb2_unique)
        save_npy(ind_file_name, "ind_nonref", ind_nonref)
    
    ref_indices1 = np.array([], dtype=np.int32)
    ref_indices2 = np.array([], dtype=np.int32) 
    
    if args.anchor_mode.startswith("ood"):
        logger.debug("Using OOD anchor generation")
        ref_dataset = args.ref_dataset
        ref_emb1 = load_npy(base_dir, f"corpus_embeddings_{args.emb1}_{ref_dataset}.npy")
        ref_emb2 = load_npy(base_dir, f"corpus_embeddings_{args.emb2}_{ref_dataset}.npy")

        # Calculate overlap size (number of points to be matched)
        overlap_size = len(ind_nonref)

        if args.partition == "cluster_partial":
            ref_ind_file_name = os.path.join(args.cache_dir, f"{args.ref_dataset}_{args.emb1}_{args.partition}{args.n_clusters}_{args.nonref_clu_choices}{seed_suffix}")
        elif args.partition == "random":
            ref_ind_file_name = os.path.join(args.cache_dir, f"{args.ref_dataset}_{args.partition}_{args.overlap_ratio}{seed_suffix}")
        elif args.partition == "la2m":
            ref_ind_file_name = os.path.join(args.cache_dir, f"{args.ref_dataset}_la2m_{args.overlap_ratio}{seed_suffix}")

        # Determine cache key based on whether n_seeds is provided
        if args.n_seeds is not None:
            ref_cache_key = f"ref_ind_n{args.n_seeds}"
        else:
            ref_cache_key = f"ref_ind_{args.ref_ratio}"

        if os.path.exists(os.path.join(ref_ind_file_name, f"{ref_cache_key}.npy")):
            ref_ind = load_npy(ref_ind_file_name, ref_cache_key)
        else:
            if args.partition == "cluster_partial":
                total_candidates = len(ref_emb1)
                if total_candidates == 0:
                    ref_ind = np.array([], dtype=np.int32)
                else:
                    ref_emb1_cluster = Cluster(
                        ref_emb1,
                        np.arange(total_candidates),
                        args.n_clusters_overlap,
                        args.cluster_method,
                        graph_method=args.graph_method,
                        knn_k=args.knn_k,
                        sample=args.sample,
                        use_gpu=getattr(args, "use_gpu", True)
                    )
                    label_list = ref_emb1_cluster.label_list
                    if isinstance(label_list, torch.Tensor):
                        label_list = label_list.cpu().numpy()
                    else:
                        label_list = np.asarray(label_list)

                    cluster_choices = getattr(args, "nonref_clu_choices", None)
                    choices_array = None
                    if cluster_choices is not None:
                        if isinstance(cluster_choices, str):
                            stripped = cluster_choices.strip()
                            if stripped.endswith("]"):
                                stripped = stripped[1:-1]
                            if stripped:
                                try:
                                    choices_array = np.array([int(item.strip()) for item in stripped.split(',') if item.strip()], dtype=np.int32)
                                except ValueError:
                                    choices_array = None
                        else:
                            try:
                                choices_array = np.array(cluster_choices, dtype=np.int32)
                            except (TypeError, ValueError):
                                choices_array = None

                    if choices_array is not None and choices_array.size > 0:
                        available_labels = np.intersect1d(label_list, choices_array)
                        if available_labels.size == 0:
                            available_labels = label_list
                    else:
                        available_labels = label_list

                    available_labels = np.asarray(available_labels, dtype=np.int32)
                    if available_labels.size == 0:
                        ref_ind = np.array([], dtype=np.int32)
                    else:
                        # Determine target count: use n_seeds if provided, otherwise ref_ratio
                        if args.n_seeds is not None:
                            target_total = min(args.n_seeds, total_candidates)
                        else:
                            target_total = max(1, int(round(args.ref_ratio * overlap_size)))
                        selected_mask = np.zeros(total_candidates, dtype=bool)
                        selected_indices = []
                        per_cluster = max(1, int(np.ceil(target_total / len(available_labels))))

                        for label in available_labels:
                            cluster_indices = ref_emb1_cluster.get_ori_ind(int(label))
                            if len(cluster_indices) == 0:
                                continue
                            remaining_slots = target_total - len(selected_indices)
                            if remaining_slots <= 0:
                                break
                            n_select = min(len(cluster_indices), per_cluster, remaining_slots)
                            if n_select <= 0:
                                continue
                            if len(cluster_indices) <= n_select:
                                chosen = cluster_indices
                            else:
                                chosen = np.random.choice(cluster_indices, size=n_select, replace=False)
                            selected_indices.extend(chosen.tolist())
                            selected_mask[chosen] = True

                        if len(selected_indices) < target_total:
                            remaining_slots = target_total - len(selected_indices)
                            if remaining_slots > 0:
                                remaining_indices = np.where(~selected_mask)[0]
                                if len(remaining_indices) > 0:
                                    if len(remaining_indices) <= remaining_slots:
                                        extra = remaining_indices
                                    else:
                                        extra = np.random.choice(remaining_indices, size=remaining_slots, replace=False)
                                    selected_indices.extend(extra.tolist())

                        ref_ind = np.array(selected_indices[:target_total], dtype=np.int32)
            elif args.partition == "random" or args.partition == "la2m":
                # Determine target count: use n_seeds if provided, otherwise ref_ratio
                if args.n_seeds is not None:
                    target_total = args.n_seeds
                else:
                    target_total = max(1, int(round(args.ref_ratio * overlap_size)))
                # Ensure we don't try to select more than available
                target_total = min(target_total, len(ref_emb1))
                ref_ind = np.random.choice(np.arange(len(ref_emb1)), size=target_total, replace=False)
            # Always save ref_ind to cache (including seed-specific ones)
            save_npy(ref_ind_file_name, ref_cache_key, ref_ind)
        
        ori_ref_emb1 = ref_emb1[ref_ind]
        ori_ref_emb2 = ref_emb2[ref_ind]
        union_dataset_size = len(np.union1d(ind_emb1_unique, ind_emb2_unique))
        actual_ref_ratio = len(ref_ind) / union_dataset_size if union_dataset_size > 0 else 0


    else:  # supervised mode (original method)
        logger.debug("Using supervised anchor initialization")

        # Determine cache key based on whether n_seeds is provided
        if args.n_seeds is not None:
            sup_cache_key = f"ref_ind_n{args.n_seeds}"
        else:
            sup_cache_key = f"ref_ind_{args.ref_ratio}"

        # Only load/save from cache for random sampling
        if args.ref_method == "random":
            # Try to load from cache first (seed-specific if seed is set)
            ref_ind = load_npy(ind_file_name, sup_cache_key)
            if ref_ind is None:
                # Generate new ref_ind
                ref_ind = sample_ref_points(
                    method=args.ref_method,
                    embeddings=emb1,
                    candidate_indices=data_partitioner.ind_emb1_nonref,
                    ref_ratio=args.ref_ratio,
                    n_samples=args.n_seeds,
                    use_gpu=args.use_gpu if hasattr(args, 'use_gpu') else False
                )
                # Save to cache for random method
                save_npy(ind_file_name, sup_cache_key, ref_ind)
        else:
            # For non-random methods, always generate fresh (no caching)
            ref_ind = sample_ref_points(
                method=args.ref_method,
                embeddings=emb1,
                candidate_indices=data_partitioner.ind_emb1_nonref,
                ref_ratio=args.ref_ratio,
                n_samples=args.n_seeds,
                use_gpu=args.use_gpu if hasattr(args, 'use_gpu') else False
            )

        # For supervised mode, we use the same anchors for both embeddings
        ori_ref_emb1 = emb1[ref_ind]
        ori_ref_emb2 = emb2[ref_ind]
        
        # Calculate actual ref_ratio for supervised mode
        union_dataset_size = len(np.union1d(ind_emb1_unique, ind_emb2_unique))
        actual_ref_ratio = len(ref_ind) / union_dataset_size if union_dataset_size > 0 else 0

        # Remove supervised references from the partitioned data to avoid data leakage
        ind_emb1_unique = np.setdiff1d(ind_emb1_unique, ref_ind)
        ind_emb2_unique = np.setdiff1d(ind_emb2_unique, ref_ind)
        ind_nonref = np.setdiff1d(ind_nonref, ref_ind)

    # Extract embeddings after all anchor mode processing to ensure supervised refs are excluded
    emb1_unique = emb1[ind_emb1_unique]
    emb2_unique = emb2[ind_emb2_unique]

    logger.debug(f"overlap_ratio: {args.overlap_ratio}")

    initial_ref_size = None

    # Initialize tracking for quality-based annealing
    quality_history = []

    # Initialize convergence tracking
    mutual_nn_history = []
    convergence_threshold = 0.01  # Stop if mutual_nn_ratio change < 1%
    min_convergence_iters = 5 # Need at least 5 stable iterations
    gt_concat_min_ratio = 0.1  # Disable GT concat in ensemble subsets once mutual_nn_ratio drops below this

    # Initialize Bernoulli trials history if enabled
    pair_history = None
    posterior_stats = None
    pair_voting_refs = None  # Track which references voted for each pair

    # Track total ensembles run across all iterations (for adaptive posterior threshold)
    total_ensembles_run = 0
    concat_seed_pairs_enabled = args.concat_seed_pairs

    # Helper to compute automatic ensemble counts when not user-specified
    def compute_auto_n_ensembles(ref_size, subset_size, cap=50):
        subset = max(1, subset_size)
        if ref_size <= 0:
            return 1
        base = max(1, int(ref_size // subset * 2))
        return min(cap, base)

    # Initialize distance cache for reference filtering
    ref_dist_cache = None

    # Iterative refinement loop with anchor updating
    iteration = 0
    while True:
        if iteration == 0:
            ref_emb1 = ori_ref_emb1
            ref_emb2 = ori_ref_emb2
        else:
            ref_emb1 = np.concatenate((ori_ref_emb1, emb1[ref_indices1]))
            ref_emb2 = np.concatenate((ori_ref_emb2, emb2[ref_indices2]))

        iteration += 1

        # Compute ensemble parameters with scaling for both strategies
        current_ref_size = len(ref_emb1)

        # Track initial reference size
        if initial_ref_size is None:
            initial_ref_size = current_ref_size

        # Compute growth ratio and scale factor for both strategies
        growth_ratio = current_ref_size / initial_ref_size if initial_ref_size > 0 else 1.0
        scale_factor = 1 + np.sqrt(0.4 * np.log1p(growth_ratio - 1))

        principled_params = compute_principled_fixed_strategy_params(
            growth_ratio=growth_ratio,
            initial_subset_ratio=args.ensemble_subset_ratio,
            initial_n_ensembles=args.ensemble_n_ensembles,
            initial_ref_size=initial_ref_size if initial_ref_size else 100,
            d_intrinsic=30.0,  # Default assumption for fixed strategy
            gamma=2.0,  # F2 score (recall-weighted)
            confidence=0.95
        )

        adaptive_subset_ratio = principled_params['subset_ratio']
        ensemble_subset_ratio = adaptive_subset_ratio
        adaptive_n_ensembles = principled_params['n_ensembles']

        logger.debug(f"Principled fixed strategy: subset_ratio {args.ensemble_subset_ratio:.3f} -> {adaptive_subset_ratio:.3f}, "
                   f"n_ensembles {args.ensemble_n_ensembles} -> {adaptive_n_ensembles}")
        logger.debug(f"  Derived params: c={principled_params['coefficient_c']:.3f}, β_r={principled_params['beta_r']:.2f}, "
                   f"s={principled_params['scale_factor']:.2f}")
        computed_n_ensembles = adaptive_n_ensembles

        # Set final parameters
        ensemble_n_ensembles = computed_n_ensembles
        logger.debug(f"Iteration {iteration}: ref_size={current_ref_size}, subset_ratio={adaptive_subset_ratio}, n_ensembles={ensemble_n_ensembles}")

        # Accumulate total ensembles run
        total_ensembles_run += ensemble_n_ensembles

        args.concat_seed_pairs = concat_seed_pairs_enabled

        # OOM retry logic: if we get OOM, reduce max_parallel_workers and retry this iteration
        oom_retry_count = 0
        max_oom_retries = 3
        original_max_parallel_workers = args.max_parallel_workers

        while oom_retry_count <= max_oom_retries:
            try:
                if args.use_all_points_mnn:
                    # Use all-points MNN baseline (no ensembles, voting, or Bernoulli trials)
                    mutual_pairs = ensemble_reference_selection_all_points_mnn(
                        ref_emb1, ref_emb2, emb1_unique, emb2_unique, ind_emb1_unique, ind_emb2_unique,
                        args, device,
                        ref_indices1=ref_indices1, ref_indices2=ref_indices2,
                        ori_ref_emb1=ori_ref_emb1, ori_ref_emb2=ori_ref_emb2
                    )
                elif args.use_bernoulli_trials:
                    # Use Bernoulli trial-based ensemble selection
                    mutual_pairs, pair_history, posterior_stats, pair_voting_refs = ensemble_reference_selection_bernoulli(
                        ref_emb1, ref_emb2, emb1_unique, emb2_unique, ind_emb1_unique, ind_emb2_unique,
                        args, device,
                        n_ensembles=ensemble_n_ensembles, subset_ratio=ensemble_subset_ratio,
                        ref_indices1=ref_indices1, ref_indices2=ref_indices2,
                        ori_ref_emb1=ori_ref_emb1, ori_ref_emb2=ori_ref_emb2,
                        pair_history=pair_history, posterior_threshold=args.posterior_threshold,
                        ensemble_strategy=args.ensemble_strategy,  # Use strategy from command line
                        use_distance_weighting=args.use_distance_weighting,
                        distance_filter_percentile=args.distance_filter_percentile,
                        posterior_strategy="iteration_based",
                        current_iteration=iteration,  # Pass current iteration
                        max_iterations=args.max_iter,  # Pass max iterations for normalization
                        total_ensembles_run=total_ensembles_run,  # Pass actual total ensembles run
                        overlap_inference_method=args.overlap_inference_method  # Method to infer overlapping pairs
                    )
                else:
                    # Use voting-based ensemble selection with configurable strategy
                    mutual_pairs = ensemble_reference_selection_voting(
                        ref_emb1, ref_emb2, emb1_unique, emb2_unique, ind_emb1_unique, ind_emb2_unique,
                        args, device, ind_nonref, vote_threshold=args.ensemble_vote_threshold,
                        n_ensembles=ensemble_n_ensembles, subset_ratio=ensemble_subset_ratio,
                        ref_indices1=ref_indices1, ref_indices2=ref_indices2,
                        ori_ref_emb1=ori_ref_emb1, ori_ref_emb2=ori_ref_emb2,
                        ensemble_strategy=args.ensemble_strategy  # Use strategy from command line
                    )
                # Success! Break out of retry loop
                break

            except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
                # Check if it's actually an OOM error
                if isinstance(e, RuntimeError) and "out of memory" not in str(e).lower():
                    raise  # Re-raise if not OOM

                oom_retry_count += 1

                # Clear CUDA cache and run garbage collection
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                gc.collect()

                if oom_retry_count > max_oom_retries:
                    logger.error(f"Iteration {iteration}: OOM retry limit ({max_oom_retries}) reached. Cannot continue.")
                    raise

                # Reduce max_parallel_workers for retry
                if args.max_parallel_workers is None:
                    # If it was None (auto), start with 4 workers
                    args.max_parallel_workers = 4
                else:
                    # Reduce by half, minimum 1
                    args.max_parallel_workers = max(1, args.max_parallel_workers // 2)

                logger.warning(f"Iteration {iteration}: OOM detected (retry {oom_retry_count}/{max_oom_retries}). "
                             f"Reducing max_parallel_workers to {args.max_parallel_workers} and retrying...")

        # Restore original max_parallel_workers for next iteration
        args.max_parallel_workers = original_max_parallel_workers
        
        # Update reference indices and embeddings
        new_ref_indices1 = np.array([ind_emb1_unique[mutual_pairs[i][1]] for i in range(len(mutual_pairs))])
        new_ref_indices2 = np.array([ind_emb2_unique[mutual_pairs[i][0]] for i in range(len(mutual_pairs))])
        
        # Convert ensemble result to expected format
        mutual_nn = len(mutual_pairs)
            
        if mutual_nn == 0:
            logger.warning("No mutual nearest neighbors, break")
            break

        total_points = len(emb1_unique)
        mutual_nn_ratio = mutual_nn / total_points if total_points > 0 else 0.0

        # Apply cluster-wise Procrustes transformation if enabled
        if args.use_procrustes and mutual_nn > 0 and emb1_cluster_labels is not None:
            # Get cluster labels for emb1_unique (subset)
            emb1_unique_cluster_labels = emb1_cluster_labels[ind_emb1_unique]

            # Define wrapper for finding mutual pairs using existing ensemble method
            def find_mutual_pairs_wrapper(emb1_cluster, emb2, ind1_cluster, ind2):
                """GPU-accelerated wrapper to find mutual NNs between cluster and emb2"""
                import torch
                from loguru import logger

                # Get device from args or use cuda if available
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
                use_gpu = torch.cuda.is_available()

                if use_gpu:
                    try:
                        logger.debug(f"GPU mutual NN: cluster={len(emb1_cluster)} pts, emb2={len(emb2)} pts")

                        # Convert to PyTorch tensors and move to GPU
                        if not torch.is_tensor(emb1_cluster):
                            emb1_t = torch.from_numpy(emb1_cluster).to(device).float()
                        else:
                            emb1_t = emb1_cluster.to(device)

                        if not torch.is_tensor(emb2):
                            emb2_t = torch.from_numpy(emb2).to(device).float()
                        else:
                            emb2_t = emb2.to(device)

                        # Normalize vectors on GPU
                        emb1_norm = torch.nn.functional.normalize(emb1_t, p=2, dim=1)
                        emb2_norm = torch.nn.functional.normalize(emb2_t, p=2, dim=1)

                        # Forward pass: cluster -> emb2 (single GPU matrix multiply, no chunking needed)
                        # Cosine similarity via normalized dot product
                        sim_1to2 = emb1_norm @ emb2_norm.T  # (n1, n2)
                        nn_1to2 = torch.argmax(sim_1to2, dim=1)  # (n1,)

                        # Backward pass: emb2 -> cluster (single GPU matrix multiply)
                        sim_2to1 = emb2_norm @ emb1_norm.T  # (n2, n1)
                        nn_2to1 = torch.argmax(sim_2to1, dim=1)  # (n2,)

                        # Vectorized mutual pair detection on GPU (NO Python loop!)
                        n1 = emb1_norm.shape[0]
                        idx = torch.arange(n1, device=device)  # [0, 1, 2, ..., n1-1]

                        # For each cluster point i, check if nn_2to1[nn_1to2[i]] == i
                        # This is the mutual NN condition, fully vectorized
                        is_mutual = nn_2to1[nn_1to2] == idx  # Boolean tensor (n1,)

                        # Extract mutual pairs
                        mutual_i = idx[is_mutual]  # Cluster indices with mutual NNs
                        mutual_j = nn_1to2[is_mutual]  # Their corresponding emb2 indices

                        # Convert to original indices and move to CPU (single transfer!)
                        mutual_i_cpu = mutual_i.cpu().numpy()
                        mutual_j_cpu = mutual_j.cpu().numpy()

                        # Build result list with original indices
                        mutual = [(int(ind1_cluster[i]), int(ind2[j]))
                                  for i, j in zip(mutual_i_cpu, mutual_j_cpu)]

                        # Cleanup GPU memory
                        del emb1_t, emb2_t, emb1_norm, emb2_norm, sim_1to2, sim_2to1
                        del nn_1to2, nn_2to1, is_mutual, mutual_i, mutual_j
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()

                        logger.debug(f"GPU mutual NN complete: found {len(mutual)} pairs")
                        return mutual

                    except Exception as e:
                        logger.warning(f"GPU mutual NN failed, falling back to CPU: {e}")
                        # Fall through to CPU implementation below

                # CPU fallback (original implementation - used when GPU unavailable or fails)
                logger.debug("Using CPU mutual NN finding")
                chunk_size = 500
                n1, n2 = len(emb1_cluster), len(emb2)

                emb1_norm = emb1_cluster / (np.linalg.norm(emb1_cluster, axis=1, keepdims=True) + 1e-8)
                emb2_norm = emb2 / (np.linalg.norm(emb2, axis=1, keepdims=True) + 1e-8)

                nn_1to2 = np.zeros(n1, dtype=np.int32)
                for i in range(0, n1, chunk_size):
                    chunk = emb1_norm[i:i+chunk_size]
                    sim_chunk = chunk @ emb2_norm.T
                    nn_1to2[i:i+chunk_size] = np.argmax(sim_chunk, axis=1)
                    del sim_chunk

                nn_2to1 = np.zeros(n2, dtype=np.int32)
                for j in range(0, n2, chunk_size):
                    chunk = emb2_norm[j:j+chunk_size]
                    sim_chunk = chunk @ emb1_norm.T
                    nn_2to1[j:j+chunk_size] = np.argmax(sim_chunk, axis=1)
                    del sim_chunk

                mutual = []
                for i, j in enumerate(nn_1to2):
                    if nn_2to1[j] == i:
                        mutual.append((ind1_cluster[i], ind2[j]))

                return mutual

            # Apply cluster-wise Procrustes refinement (includes deduplication internally)
            refined_ind1, refined_ind2, _ = cluster_wise_procrustes_refinement(
                emb1_unique,
                emb2_unique,
                ind_emb1_unique,
                ind_emb2_unique,
                new_ref_indices1,  # mutual pairs from emb1
                new_ref_indices2,  # mutual pairs from emb2
                emb1_unique_cluster_labels,
                find_mutual_pairs_wrapper,
                allow_scale=True,
                allow_translation=True,
                min_pairs_per_cluster=3,
                verbose=True,
                use_gpu=args.use_gpu,
                device=device
            )

            # Update reference indices with refined mutual NNs (keep original embeddings)
            if len(refined_ind1) > 0:
                ref_indices1 = refined_ind1
                ref_indices2 = refined_ind2
                logger.debug(f"Iteration {iteration}: Procrustes refinement found {len(refined_ind1)} refined mutual NN pairs")
            else:
                logger.warning(f"Iteration {iteration}: Procrustes refinement found no pairs, keeping original")

        else:
            ref_indices1 = new_ref_indices1
            ref_indices2 = new_ref_indices2

        # Track mutual NN ratio for convergence detection
        mutual_nn_history.append(mutual_nn_ratio)

        if concat_seed_pairs_enabled and mutual_nn_ratio < gt_concat_min_ratio:
            concat_seed_pairs_enabled = False
            logger.debug(f"Disabling seed pair concat in ensemble subsets: mutual_nn_ratio {mutual_nn_ratio:.4f} < {gt_concat_min_ratio:.4f}")
        
        # Check for convergence: stable mutual_nn_ratio over multiple iterations
        if len(mutual_nn_history) >= min_convergence_iters + 1:
            recent_ratios = mutual_nn_history[-(min_convergence_iters + 1):]
            ratio_changes = [abs(recent_ratios[i] - recent_ratios[i-1]) for i in range(1, len(recent_ratios))]
            max_change = max(ratio_changes) if ratio_changes else float('inf')
            
            if max_change < convergence_threshold:
                logger.debug(f"Stopping: converged (mutual_nn_ratio stable at {mutual_nn_ratio:.4f}, max change: {max_change:.4f})")
                break
        
        if len(mutual_pairs) == 0:
            logger.debug("Stopping: no mutual pairs")
            break
            
        if iteration >= args.max_iter:
            logger.debug(f"Stopping: reached maximum iterations ({args.max_iter})")
            break
        
        if(len(ref_indices1) != len(np.unique(ref_indices1))):
            logger.error(f"ref_indices1: {ref_indices1}")
            logger.error(f"ref_indices2: {ref_indices2}")
            raise ValueError("ref_indices1 has duplicates")
        if (len(ref_indices2) != len(np.unique(ref_indices2))):
            logger.error(f"ref_indices1: {ref_indices1}")
            logger.error(f"ref_indices2: {ref_indices2}")
            raise ValueError("ref_indices2 has duplicates")
                
        # Apply improved reference filtering based on pairwise distance quality
        if args.enable_ref_filtering and len(ref_indices1) >= 5:  # Only filter if enabled and we have enough references
            prev_accuracy, prev_recall, prev_correct = compute_accuracy_recall(ref_indices1, ref_indices2, ind_nonref)
            
            # Compute annealed ref_filter_ratio based on quality metrics
            if hasattr(args, 'ref_filter_annealing') and args.ref_filter_annealing != "none":
                initial_ratio = args.ref_filter_ratio
                final_ratio = getattr(args, 'ref_filter_final_ratio', None)
                if final_ratio is None:
                    final_ratio = min(1.0, args.ref_filter_ratio * 1.5)  # Increase filter ratio over time
                current_filter_ratio = compute_annealed_ref_filter_ratio(
                    iteration, args.max_iter, initial_ratio, final_ratio,
                    args.ref_filter_annealing, quality_history
                )
                logger.debug(f"Iteration {iteration}: Using annealed ref_filter_ratio={current_filter_ratio:.3f} (initial={initial_ratio:.3f}, final={final_ratio:.3f})")
            else:
                current_filter_ratio = args.ref_filter_ratio
            
            # Apply filtering and get quality metrics
            # Pass previous mutual pairs for contribution tracking (if available)
            previous_mutual_pairs_for_filtering = mutual_pairs if iteration > 1 else None
            
            filter_result = filter_references_by_pairwise_distance_quality(
                ref_indices1, ref_indices2, emb1, emb2,
                distance_metric=args.distance_metric, top_k_ratio=current_filter_ratio, device=device,
                return_metrics=True, previous_mutual_pairs=previous_mutual_pairs_for_filtering,
                ind_emb1_unique=ind_emb1_unique, ind_emb2_unique=ind_emb2_unique,
                use_multi_gpu=args.use_multi_gpu, gpu_ids=args.gpu_ids,
                multi_gpu_config=args.multi_gpu_config,
                cached_dist_matrices=ref_dist_cache
            )

            if len(filter_result) == 4:  # Got metrics + cache
                ref_indices1, ref_indices2, quality_metrics, ref_dist_cache = filter_result
                mean_quality, kept_quality, min_quality, max_quality = quality_metrics
                quality_history.append((mean_quality, kept_quality, min_quality, max_quality))
                logger.debug(f"Iteration {iteration}: Quality metrics - mean: {mean_quality:.4f}, kept: {kept_quality:.4f}")
            elif len(filter_result) == 3:  # No metrics but has cache
                ref_indices1, ref_indices2, ref_dist_cache = filter_result
                
            current_accuracy, current_recall, current_correct = compute_accuracy_recall(ref_indices1, ref_indices2, ind_nonref)
            logger.debug(f"Iteration {iteration}: prev_accuracy={prev_accuracy:.4f}, prev_recall={prev_recall:.4f}, current_accuracy={current_accuracy:.4f}, current_recall={current_recall:.4f}")
        # EVALUATION ONLY: ground truth used for monitoring, not for algorithm decisions
        accuracy, recall, correct = compute_accuracy_recall(ref_indices1, ref_indices2, ind_nonref)
        logger.info(f"Iteration {iteration}: accuracy={correct}/{len(ref_indices1)}={accuracy:.4f}, recall={correct}/{len(ind_nonref)}={recall:.4f}")

        # Perform distance-based analysis for supervised mode
        if args.anchor_mode == "supervised" and 'ref_ind' in locals():
            try:
                analysis_result = analyze_distance_based_accuracy(
                    ref_indices1=ref_indices1,
                    ref_indices2=ref_indices2,
                    emb1=emb1,
                    emb2=emb2,
                    anchor_indices=ref_ind,
                    distance_metric=args.distance_metric,
                    use_gpu=args.use_gpu,
                    device=device
                )

                avg_corr = analysis_result['avg_distance_correlation']
                min_corr = analysis_result['min_distance_correlation']

                logger.debug(f"DISTANCE-BASED ANALYSIS (Iteration {iteration})")
                if not np.isnan(avg_corr):
                    logger.debug(f"  Average distance to anchors: r={avg_corr:.4f}, p={analysis_result['avg_distance_p_value']:.4e}")
                else:
                    logger.debug(f"  Average distance to anchors: Cannot compute (insufficient variance)")
                if not np.isnan(min_corr):
                    logger.debug(f"  Minimum distance to anchors: r={min_corr:.4f}, p={analysis_result['min_distance_p_value']:.4e}")
                else:
                    logger.debug(f"  Minimum distance to anchors: Cannot compute (insufficient variance)")

                for (min_pct, max_pct), stats in sorted(analysis_result['percentile_breakdown'].items()):
                    range_str = f"{min_pct}-{max_pct}%"
                    acc_str = f"{stats['accuracy']:.4f}" if stats['n_pairs'] > 0 else "N/A"
                    logger.debug(f"  {range_str:<15} acc={acc_str}, pairs={stats['n_pairs']}, correct={stats['n_correct']}")

            except Exception as e:
                logger.warning(f"Distance-based analysis failed: {e}")

            ref_indices1_local = convert_global_to_local_indices(ref_indices1, ind_emb1_unique)
            ref_indices2_local = convert_global_to_local_indices(ref_indices2, ind_emb2_unique)

    accuracy, recall, correct = compute_accuracy_recall(ref_indices1, ref_indices2, ind_nonref)
    logger.debug(f"accuracy: {correct}/{len(ref_indices1)} = {accuracy}, recall: {correct}/{len(ind_nonref)} = {recall}")

    # Compute accuracy at threshold: select top-k pairs where k = total_points * overlap_ratio
    # This estimates accuracy assuming we know the true overlap ratio
    total_points = len(ind_emb1_unique) + len(ind_emb2_unique) - len(ind_nonref)
    expected_k = int(total_points * args.overlap_ratio)
    acc_at_threshold = None

    if expected_k > 0 and len(ref_indices1) > 0:
        # Build reverse mapping: global index -> local index
        global_to_local_idx1 = {global_idx: local_idx for local_idx, global_idx in enumerate(ind_emb1_unique)}
        global_to_local_idx2 = {global_idx: local_idx for local_idx, global_idx in enumerate(ind_emb2_unique)}

        # Get scores for each pair (prefer posterior_mean, fall back to voting count)
        pair_scores = []
        for idx1, idx2 in zip(ref_indices1, ref_indices2):
            local_idx1 = global_to_local_idx1.get(idx1, None)
            local_idx2 = global_to_local_idx2.get(idx2, None)
            score = 0.0

            if local_idx1 is not None and local_idx2 is not None:
                pair_key = (local_idx2, local_idx1)  # (i, nearest_i) format

                # Try posterior_stats first
                if posterior_stats is not None and pair_key in posterior_stats:
                    score = posterior_stats[pair_key].get('posterior_mean', 0.0)
                # Fall back to voting count
                elif pair_voting_refs is not None and pair_key in pair_voting_refs:
                    score = len(pair_voting_refs[pair_key])

            pair_scores.append((idx1, idx2, score))

        # Sort by score descending
        pair_scores.sort(key=lambda x: x[2], reverse=True)

        # Take top-k pairs
        top_k_pairs = pair_scores[:expected_k]

        # Compute accuracy for top-k
        top_k_correct = sum(1 for idx1, idx2, _ in top_k_pairs if idx1 == idx2)
        acc_at_threshold = top_k_correct / expected_k

        logger.debug(f"acc_at_threshold (k={expected_k}): {top_k_correct}/{expected_k} = {acc_at_threshold:.4f}")

    # Save ref_indices1 and ref_indices2 if save_ref_indices is enabled
    if getattr(args, 'save_ref_indices', False):
        # Determine base directory for saving
        if getattr(args, 'save_ref_indices_dir', None):
            # Use custom directory with partition subdirectory structure
            base_dir = args.save_ref_indices_dir
            partition_subdir = f"{args.dataset}_{args.partition}_{args.overlap_ratio}"
            emb_subdir_name = f"{args.emb1}_{args.emb2}_pred_ind"
            ref_indices_dir = os.path.join(base_dir, partition_subdir, emb_subdir_name)
        else:
            # Default: create subdirectory under the same directory where ind1.npy and ind2.npy are stored
            # Directory structure: cache/ind/{dataset}_{partition}_{overlap_ratio}/{emb1}_{emb2}_pred_ind/
            emb_subdir_name = f"{args.emb1}_{args.emb2}_pred_ind"
            ref_indices_dir = os.path.join(ind_file_name, emb_subdir_name)

        os.makedirs(ref_indices_dir, exist_ok=True)

        # Create filename based on method
        if args.use_all_points_mnn:
            method_str = "all_points_mnn"
        elif args.use_bernoulli_trials:
            method_str = "bernoulli"
        else:
            method_str = "voting"
        anchor_str = args.anchor_mode
        # Use n_seeds in filename if specified, otherwise use ref_ratio
        if getattr(args, 'n_seeds', None) is not None:
            filename = f"{anchor_str}_{method_str}_n{args.n_seeds}_pred_ind"
        else:
            filename = f"{anchor_str}_{method_str}_ref{args.ref_ratio}_pred_ind"

        # Save as .npy files
        if args.anchor_mode == "supervised":
            ref_indices1 = np.concatenate([ref_indices1, ref_ind], axis=0)
            ref_indices2 = np.concatenate([ref_indices2, ref_ind], axis=0)

        ref_indices1_path = os.path.join(ref_indices_dir, f"{filename}1.npy")
        ref_indices2_path = os.path.join(ref_indices_dir, f"{filename}2.npy")

        np.save(ref_indices1_path, ref_indices1)
        np.save(ref_indices2_path, ref_indices2)

        logger.debug(f"Saved ref_indices to {ref_indices1_path} and {ref_indices2_path}")
        logger.debug(f"ref_indices1 shape: {ref_indices1.shape}, ref_indices2 shape: {ref_indices2.shape}")

    return accuracy, recall, ref_indices1, ref_indices2


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="scifact")
    parser.add_argument("--ref_dataset", type=str, default="fiqa")
    parser.add_argument("--ref_ratio", type=float, default=0.01)
    parser.add_argument("--n_seeds", type=int, default=None,
                        help="Fixed number of seed pairs. If provided, overrides ref_ratio.")
    parser.add_argument("--base_dir", type=str, default="embeddings")
    parser.add_argument("--cache_dir", type=str, default="cache/ind")
    parser.add_argument("--emb1", type=str, default="mistral")
    parser.add_argument("--emb2", type=str, default="openai")
    parser.add_argument("--emb_dim1", type=int, default=768)
    parser.add_argument("--emb_dim2", type=int, default=768)
    
    parser.add_argument("--partition", type=str, default="random")
    parser.add_argument("--overlap_ratio", type=float, default=0.3)
    parser.add_argument("--init_ratio", type=float, default=0.1)
    parser.add_argument("--nonref_clu_choices", type=int, nargs='+', default=[0])
    parser.add_argument("--n_clusters", type=int, default=10)
    parser.add_argument("--n_clusters_overlap", type=int, default=20)

    parser.add_argument("--csls_neighborhood", type=int, default=50, help="use CSLS for dictionary induction")
    parser.add_argument("--cluster_method", type=str, default="kmeans")
    parser.add_argument("--distance_metric", type=str, default="cosine")
    parser.add_argument("--ref_method", type=str, default="random")
    parser.add_argument("--graph_method", type=str, default="knn")
    parser.add_argument("--knn_k", type=int, default=500)
    parser.add_argument("--topk", type=int, default=5)
    def str2bool(x):
        """Parse boolean CLI arguments."""
        return x.lower() not in ('false', '0', 'no', 'n')
    parser.add_argument("--use_gpu", type=str2bool, default=True)
    parser.add_argument("--multi_gpu_chunk_size", type=int, default=None,
                        help="Rows per GPU chunk when using multi-GPU distance computations (default: auto)")
    parser.add_argument("--sample", type=str2bool, default=False)

    # Anchor generation arguments
    parser.add_argument("--anchor_mode", type=str, default="supervised", choices=["supervised", "ood"],
                        help="Mode for anchor generation: supervised (original) or ood (out-of-distribution)")
    parser.add_argument("--concat_seed_pairs", type=str2bool, default=False,
                        help="Whether to concatenate initial seed pairs to reference embeddings in supervised/OOD mode")

    # Ensemble reference selection parameters
    parser.add_argument("--ensemble_n_ensembles", type=int, default=5,
                        help="Number of ensemble runs for reference selection (default: auto based on ref/subset sizes)")
    parser.add_argument("--ensemble_subset_ratio", type=float, default=0.4, help="Ratio of reference points to use in each ensemble")
    parser.add_argument("--max_parallel_workers", type=int, default=None, help="Max parallel workers for ensemble (None=auto, 2=recommended for large datasets like scidocs/fiqa)")
    parser.add_argument("--ensemble_vote_threshold", type=float, default=0.6, help="Vote threshold for ensemble selection (0.0=all pairs from any ensemble, 0.6=majority, 1.0=unanimous)")
    parser.add_argument("--ensemble_strategy", type=str, default="furthest", choices=['random', 'cluster', 'furthest', 'nearest'], help="Ensemble strategy for reference selection: furthest (default, dispersed anchors), random, cluster (localized anchors), or nearest (local neighborhoods)")
    
    # Training control parameters
    parser.add_argument("--max_iter", type=int, default=100, help="Maximum number of iterations")

    # Procrustes refinement parameters
    parser.add_argument("--use_procrustes", action="store_true", help="Apply orthogonal Procrustes transformation after finding mutual NNs to align embedding spaces")

    # Reference filtering parameters
    parser.add_argument("--enable_ref_filtering", type=str2bool, default=False, help="Enable reference filtering based on distance quality")
    parser.add_argument("--ref_filter_ratio", type=float, default=0.9, help="Keep top fraction of references (0.8 = keep top 80%)")
    
    # Reference filtering annealing parameters
    parser.add_argument("--ref_filter_annealing", type=str, default="quality_adaptive", 
                        choices=["none", "linear", "exponential", "cosine", "quality_adaptive"], 
                        help="Annealing strategy for ref_filter_ratio")
    # Bernoulli trials ensemble selection
    parser.add_argument("--use_bernoulli_trials", action="store_true", help="Use Bernoulli trial-based ensemble selection with posterior distributions")
    parser.add_argument("--posterior_threshold", type=float, default=0.1, help="Posterior threshold for Bernoulli trial-based ensemble selection")
    parser.add_argument("--use_distance_weighting", default=False, action="store_true", help="Use distance-based weighting for Bernoulli trial-based ensemble selection")
    parser.add_argument("--distance_filter_percentile", type=float, default=0.2,
                        help="When use_distance_weighting is enabled, only consider pairs in top X percentile "
                             "closest to anchors (default: 0.2 = top 20%%)")
    parser.add_argument("--overlap_inference_method", type=str, default="otsu",
                        choices=["threshold", "adaptive", "otsu", "gmm", "elbow", "expected", "gap"],
                        help="Method to infer overlapping pairs from posterior distribution: "
                             "'threshold' (default) uses fixed/iteration-based threshold, "
                             "'adaptive' combines multiple methods (recommended for unknown overlap), "
                             "'otsu' uses Otsu's thresholding, 'gmm' uses Gaussian Mixture Model, "
                             "'elbow' uses elbow/knee detection, 'expected' uses sum of posteriors, "
                             "'gap' uses gap statistic")

    # All-points MNN baseline
    parser.add_argument("--use_all_points_mnn", action="store_true", help="Use all-points mutual nearest neighbor baseline (no ensembles, voting, or Bernoulli trials)")



    # Save reference indices
    parser.add_argument("--save_ref_indices", action="store_true", help="Save final reference indices to disk")
    parser.add_argument("--save_ref_indices_dir", type=str, default=None, help="Custom directory to save reference indices (overrides default cache location)")

    # Seed parameter for reproducibility
    parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility. If set, creates seed-specific cache files.")

    # Iteration statistics saving
    parser.add_argument("--save_iteration_stats", action="store_true", help="Save detailed pair statistics for each iteration to CSV files")
    parser.add_argument("--iteration_stats_dir", type=str, default="iteration_stats", help="Directory to save iteration statistics CSV files")

    # Debug logging
    parser.add_argument("--debug", action="store_true", help="Enable verbose debug logging (loguru output)")

    args = parser.parse_args()

    # Configure loguru log level
    logger.remove()  # Remove default handler
    if args.debug:
        logger.add(sys.stderr, level="INFO")
    else:
        # Only show INFO from test_clu (iteration accuracy line), plus WARNING+
        logger.add(
            sys.stderr,
            level="INFO",
            filter=lambda record: record["level"].no >= 30 or record["function"] == "test_clu" or record["function"] == "<module>"
        )

    faulthandler.enable()
    try:
        # Use seed from args if provided
        seed = args.seed
        accuracy, recall, ref_indices1, ref_indices2 = test_clu(args, seed=seed)
        logger.info(f"Final results: accuracy={accuracy:.4f}, recall={recall:.4f}")
    except Exception as e:
        logger.error("Exception during test_clu execution")
        traceback.print_exc()
        sys.exit(1)
