import numpy as np
import torch
import os
import sys
from pathlib import Path
from loguru import logger
from scipy.optimize import linear_sum_assignment

# Add parent directory to path for imports when running as script
if __name__ == "__main__":
    sys.path.insert(0, str(Path(__file__).parent.parent))

from utils.retrieval_util import topk_mean
from sklearn.preprocessing import normalize

# Try to import CuPy for GPU acceleration
try:
    import cupy as cp
    CUPY_AVAILABLE = True
    # Note: CuML CCA is not available in current versions and has dependency issues
    # Using CPU CCA with GPU distance computation instead
    CUML_CCA_AVAILABLE = False
except ImportError:
    CUPY_AVAILABLE = False
    CUML_CCA_AVAILABLE = False


def _l2_normalize_rows(m):
    if isinstance(m, torch.Tensor):
        norms = torch.norm(m, dim=1, keepdim=True)
        norms = torch.clamp(norms, min=1e-12)
        return m / norms
    else:
        # Handle numpy arrays
        norms = np.linalg.norm(m, axis=1, keepdims=True)
        norms = np.maximum(norms, 1e-12)  # Prevent division by zero
        return m / norms


def pca_gpu(X: np.ndarray, n_components: int, device: torch.device = None):
    """
    GPU-accelerated PCA using torch.pca_lowrank.

    Args:
        X: Input data (n_samples, n_features) as numpy array
        n_components: Number of principal components to compute
        device: GPU device to use (auto-detected if None)

    Returns:
        Tuple of (X_transformed, explained_variance_ratio):
            - X_transformed: Projected data (n_samples, n_components) as numpy array
            - explained_variance_ratio: Variance explained by each component as numpy array
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    X_tensor = torch.tensor(X, dtype=torch.float32, device=device)

    # Center the data
    X_mean = X_tensor.mean(dim=0)
    X_centered = X_tensor - X_mean

    # Compute PCA using randomized SVD (efficient for large matrices)
    # pca_lowrank returns U, S, V where X_centered ≈ U @ diag(S) @ V.T
    U, S, V = torch.pca_lowrank(X_centered, q=n_components, center=False, niter=2)

    # Transform data: projection = U * S (equivalent to X_centered @ V)
    X_transformed = U * S

    # Compute explained variance ratio
    # Total variance = sum of all squared singular values of full matrix
    # We approximate using the variance of centered data
    total_var = (X_centered ** 2).sum()
    explained_var = S ** 2
    explained_variance_ratio = (explained_var / total_var).cpu().numpy()

    return X_transformed.cpu().numpy(), explained_variance_ratio


def init_anchors_pca(emb1_unique, emb2_unique, ind_emb1_unique, ind_emb2_unique, min_anchors, init_subset=None, n_components=None, cache_dir=None, cache_key=None, init_ratio=0.1, use_gpu=True):
    """
    Initialize anchors using PCA with nearest neighbor matching.

    Algorithm:
    1. Normalize embeddings (L2 normalization)
    2. Apply PCA separately to each embedding space to reduce dimensionality
    3. Normalize PCA-transformed embeddings
    4. Use nearest neighbor matching to find closest pairs based on cosine similarity
    5. Select top-k pairs, ensuring one-to-one matching

    Args:
        emb1_unique: Embeddings from first space
        emb2_unique: Embeddings from second space
        ind_emb1_unique: Original indices for emb1
        ind_emb2_unique: Original indices for emb2
        min_anchors: Minimum number of anchors
        init_subset: Maximum subset size for computational efficiency
        n_components: Number of PCA components (default: min(dimensions, samples)/2)
        cache_dir: Directory to cache PCA results
        cache_key: Key to identify this PCA configuration
        init_ratio: Ratio of ind_emb1_unique size to use for keep_num calculation (default: 0.1)
        use_gpu: Use GPU-accelerated PCA via torch.pca_lowrank (default: True)

    Returns:
        ref_indices1, ref_indices2: Selected anchor indices
    """
    # Import PCA from sklearn
    try:
        from sklearn.decomposition import PCA
    except ImportError:
        raise RuntimeError("scikit-learn is required for PCA initialization")
    
    n1 = emb1_unique.shape[0]
    n2 = emb2_unique.shape[0]
    sim_size = min(n1, n2)
    
    if init_subset is not None:
        sim_size = min(sim_size, init_subset)
    
    X = emb1_unique[:sim_size].copy()
    Y = emb2_unique[:sim_size].copy()

    # Step 1: Normalize the data (L2 normalization)
    X_norm = _l2_normalize_rows(X)
    Y_norm = _l2_normalize_rows(Y)
    logger.debug(f"Normalized embeddings before PCA")

    # Determine number of PCA components
    if n_components is None:
        # Use at most half the minimum of dimensions or samples
        n_components = min(X_norm.shape[1], Y_norm.shape[1], sim_size) // 2
        n_components = max(1, min(n_components, 50))  # Cap at 50 for efficiency

    logger.debug(f"PCA initialization with {n_components} components on {sim_size} samples")

    # Step 2: Apply PCA separately to each normalized embedding space
    pca_success = False

    # Try GPU PCA first if use_gpu is enabled
    if use_gpu and torch.cuda.is_available():
        try:
            device = torch.device('cuda')
            logger.debug(f"Using GPU-accelerated PCA on device: {device}")

            X_c, var_ratio_x = pca_gpu(X_norm, n_components, device)
            Y_c, var_ratio_y = pca_gpu(Y_norm, n_components, device)

            pca_success = True
            logger.debug(f"GPU PCA explained variance - X: {var_ratio_x.sum():.3f}, Y: {var_ratio_y.sum():.3f}")
        except Exception as e:
            logger.warning(f"GPU PCA failed: {e}. Falling back to CPU sklearn PCA.")
            use_gpu = False  # Fall through to CPU path

    # CPU path using sklearn PCA
    if not pca_success:
        try:
            # PCA for first embedding space
            pca_x = PCA(n_components=n_components)
            X_c = pca_x.fit_transform(X_norm)

            # PCA for second embedding space
            pca_y = PCA(n_components=n_components)
            Y_c = pca_y.fit_transform(Y_norm)

            pca_success = True
            logger.debug(f"PCA explained variance - X: {pca_x.explained_variance_ratio_.sum():.3f}, Y: {pca_y.explained_variance_ratio_.sum():.3f}")
        except Exception as e:
            logger.warning(f"PCA fitting failed: {e}. Trying reduced components.")
            # Try with fewer components
            try:
                reduced_components = max(1, n_components // 2)
                logger.debug(f"Retrying PCA with {reduced_components} components")

                pca_x = PCA(n_components=reduced_components)
                X_c = pca_x.fit_transform(X_norm)

                pca_y = PCA(n_components=reduced_components)
                Y_c = pca_y.fit_transform(Y_norm)

                n_components = reduced_components
                pca_success = True
            except Exception as e2:
                logger.warning(f"Reduced PCA also failed: {e2}. Using normalized embeddings.")
                # Fallback: use normalized embeddings (no PCA transformation)
                X_c = X_norm.copy()
                Y_c = Y_norm.copy()
                if X_c.shape[1] != Y_c.shape[1]:
                    # Truncate to same dimensionality
                    min_dim = min(X_c.shape[1], Y_c.shape[1])
                    X_c = X_c[:, :min_dim]
                    Y_c = Y_c[:, :min_dim]

    # Step 3: Normalize PCA-transformed embeddings
    X_c_norm = _l2_normalize_rows(X_c)
    Y_c_norm = _l2_normalize_rows(Y_c)

    # Step 4: Use nearest neighbor matching - find closest pairs
    # Compute cosine similarity (since normalized, this is just dot product)
    similarity_matrix = X_c_norm @ Y_c_norm.T

    # For each point in X, find its nearest neighbor in Y
    nearest_neighbors_y = np.argmax(similarity_matrix, axis=1)
    nn_scores = similarity_matrix[np.arange(sim_size), nearest_neighbors_y]

    # Create pairs (row_index, nearest_neighbor_index, score)
    pairs = [(i, nearest_neighbors_y[i], nn_scores[i]) for i in range(sim_size)]

    # Sort pairs by similarity score (highest first)
    pairs.sort(key=lambda x: x[2], reverse=True)

    # Step 5: Select top-k pairs based on init_ratio
    keep_num = max(min_anchors, int(len(ind_emb1_unique) * init_ratio))
    keep_num = min(keep_num, len(pairs))  # Don't exceed available pairs

    # Keep only unique matches (deduplicate if same Y is matched to multiple X's)
    # We take the best scoring pair for each unique Y index
    seen_y = set()
    unique_pairs = []
    for src_idx, trg_idx, score in pairs:
        if trg_idx not in seen_y:
            unique_pairs.append((src_idx, trg_idx, score))
            seen_y.add(trg_idx)
        if len(unique_pairs) >= keep_num:
            break

    # Extract indices and scores
    sel_src = np.array([p[0] for p in unique_pairs])
    sel_trg = np.array([p[1] for p in unique_pairs])
    pair_scores = np.array([p[2] for p in unique_pairs])
    
    # Map back to original indices
    ref_indices1 = ind_emb1_unique[sel_src]
    ref_indices2 = ind_emb2_unique[sel_trg]
    
    correct = (ref_indices1 == ref_indices2).sum().item()
    precision = correct / len(ref_indices1) if len(ref_indices1) > 0 else 0
    logger.debug(f"PCA+NN anchors: {len(ref_indices1)} pairs, "
               f"correct={correct}, "
               f"precision={precision:.3f}")

    # Log PCA statistics
    logger.debug(f"PCA+NN anchors: {len(ref_indices1)} pairs, "
               f"avg_score={np.mean(pair_scores):.3f}, "
               f"min_score={np.min(pair_scores):.3f}, "
               f"max_score={np.max(pair_scores):.3f}, "
               f"init_ratio={init_ratio:.3f}")

    # Save computed PCA anchors to cache
    if cache_dir is not None and cache_key is not None:
        try:
            path = os.path.join(cache_dir, cache_key)
            os.makedirs(path, exist_ok=True)
            pca_ref_file1 = os.path.join(path, "pca_ref_indices1.npy")
            pca_ref_file2 = os.path.join(path, "pca_ref_indices2.npy")
            np.save(pca_ref_file1, ref_indices1)
            np.save(pca_ref_file2, ref_indices2)
            logger.debug(f"Cached PCA anchors")
        except Exception as e:
            logger.warning(f"Failed to cache PCA anchors: {e}")
    
    return ref_indices1, ref_indices2


def init_anchors_procrustes(emb1_unique, emb2_unique, ind_emb1_unique, ind_emb2_unique, min_anchors, init_subset=None, n_components=None, cache_dir=None, cache_key=None, init_ratio=0.1, use_pca_first=True):
    """
    Initialize anchors using Orthogonal Procrustes Analysis.

    Orthogonal Procrustes finds the optimal rotation matrix to align two sets of points
    by minimizing the Frobenius norm. This method handles different dimensional embeddings
    by optionally using PCA to project to a common dimensionality first.

    Algorithm:
    1. Normalize embeddings (L2 normalization)
    2. Handle dimension mismatch:
       - Option A: Use PCA to project both to min(d1, d2, n_components) dimensions
       - Option B: Pad smaller dimension with zeros
    3. Apply Orthogonal Procrustes to find rotation matrix R: X @ R ≈ Y
    4. Transform X using R: X_aligned = X @ R
    5. Use nearest neighbor to find closest pairs
    6. Select top-k pairs based on similarity scores

    Args:
        emb1_unique: Embeddings from first space (n_samples, d1)
        emb2_unique: Embeddings from second space (n_samples, d2)
        ind_emb1_unique: Original indices for emb1
        ind_emb2_unique: Original indices for emb2
        min_anchors: Minimum number of anchors
        init_subset: Maximum subset size for computational efficiency
        n_components: Target dimensionality (default: min(d1, d2, samples))
        cache_dir: Directory to cache results
        cache_key: Key to identify this configuration
        init_ratio: Ratio of ind_emb1_unique size to use for keep_num calculation (default: 0.1)
        use_pca_first: If True, use PCA to handle dimension mismatch; otherwise pad with zeros

    Returns:
        ref_indices1, ref_indices2: Selected anchor indices
    """
    try:
        from scipy.linalg import orthogonal_procrustes
        from sklearn.decomposition import PCA
    except ImportError:
        raise RuntimeError("scipy and scikit-learn are required for Procrustes initialization")

    n1 = emb1_unique.shape[0]
    n2 = emb2_unique.shape[0]
    sim_size = min(n1, n2)

    if init_subset is not None:
        sim_size = min(sim_size, init_subset)

    X = emb1_unique[:sim_size].copy()
    Y = emb2_unique[:sim_size].copy()

    # Step 1: Normalize the data (L2 normalization)
    X_norm = _l2_normalize_rows(X)
    Y_norm = _l2_normalize_rows(Y)
    logger.debug(f"Procrustes: Normalized embeddings - X: {X_norm.shape}, Y: {Y_norm.shape}")

    d1, d2 = X_norm.shape[1], Y_norm.shape[1]

    # Step 2: Handle dimension mismatch
    if d1 != d2:
        logger.debug(f"Procrustes: Dimension mismatch detected (d1={d1}, d2={d2})")

        if use_pca_first:
            # Option A: Use PCA to project to common dimensionality
            if n_components is None:
                n_components = min(d1, d2, sim_size) // 2
                n_components = max(1, min(n_components, 50))

            logger.debug(f"Procrustes: Using PCA to project to {n_components} dimensions")

            pca_x = PCA(n_components=n_components)
            X_projected = pca_x.fit_transform(X_norm)

            pca_y = PCA(n_components=n_components)
            Y_projected = pca_y.fit_transform(Y_norm)

            logger.debug(f"Procrustes: PCA explained variance - X: {pca_x.explained_variance_ratio_.sum():.3f}, Y: {pca_y.explained_variance_ratio_.sum():.3f}")
        else:
            # Option B: Pad with zeros to match dimensions
            target_dim = max(d1, d2)
            logger.debug(f"Procrustes: Padding to {target_dim} dimensions")

            if d1 < target_dim:
                padding = np.zeros((sim_size, target_dim - d1))
                X_projected = np.hstack([X_norm, padding])
                Y_projected = Y_norm
            else:
                padding = np.zeros((sim_size, target_dim - d2))
                Y_projected = np.hstack([Y_norm, padding])
                X_projected = X_norm
    else:
        # Dimensions match, no projection needed
        logger.debug(f"Procrustes: Dimensions match ({d1}), no projection needed")
        X_projected = X_norm
        Y_projected = Y_norm

    # Step 3: Apply Orthogonal Procrustes
    # Find rotation matrix R such that X @ R ≈ Y
    try:
        R, scale = orthogonal_procrustes(X_projected, Y_projected)
        logger.debug(f"Procrustes: Found rotation matrix with scale={scale:.4f}")
    except Exception as e:
        logger.warning(f"Procrustes: Optimization failed: {e}. Using identity transform.")
        R = np.eye(X_projected.shape[1])
        scale = np.inf

    # Step 4: Transform X using the rotation matrix
    X_aligned = X_projected @ R

    # Normalize the aligned embeddings
    X_aligned_norm = _l2_normalize_rows(X_aligned)
    Y_projected_norm = _l2_normalize_rows(Y_projected)

    # Step 5: Use nearest neighbor matching
    similarity_matrix = X_aligned_norm @ Y_projected_norm.T

    # For each point in X, find its nearest neighbor in Y
    nearest_neighbors_y = np.argmax(similarity_matrix, axis=1)
    nn_scores = similarity_matrix[np.arange(sim_size), nearest_neighbors_y]

    # Create pairs (row_index, nearest_neighbor_index, score)
    pairs = [(i, nearest_neighbors_y[i], nn_scores[i]) for i in range(sim_size)]

    # Sort pairs by similarity score (highest first)
    pairs.sort(key=lambda x: x[2], reverse=True)

    # Step 6: Select top-k pairs based on init_ratio
    keep_num = max(min_anchors, int(len(ind_emb1_unique) * init_ratio))
    keep_num = min(keep_num, len(pairs))

    # Keep only unique matches (deduplicate if same Y is matched to multiple X's)
    seen_y = set()
    unique_pairs = []
    for src_idx, trg_idx, score in pairs:
        if trg_idx not in seen_y:
            unique_pairs.append((src_idx, trg_idx, score))
            seen_y.add(trg_idx)
        if len(unique_pairs) >= keep_num:
            break

    # Extract indices and scores
    sel_src = np.array([p[0] for p in unique_pairs])
    sel_trg = np.array([p[1] for p in unique_pairs])
    pair_scores = np.array([p[2] for p in unique_pairs])

    # Map back to original indices
    ref_indices1 = ind_emb1_unique[sel_src]
    ref_indices2 = ind_emb2_unique[sel_trg]

    correct = (ref_indices1 == ref_indices2).sum().item()
    precision = correct / len(ref_indices1) if len(ref_indices1) > 0 else 0
    logger.debug(f"Procrustes anchors: {len(ref_indices1)} pairs, "
               f"correct={correct}, "
               f"precision={precision:.3f}")

    # Log statistics
    logger.debug(f"Procrustes anchors: avg_score={np.mean(pair_scores):.3f}, "
               f"min_score={np.min(pair_scores):.3f}, "
               f"max_score={np.max(pair_scores):.3f}, "
               f"procrustes_scale={scale:.4f}")

    # Save computed anchors to cache
    if cache_dir is not None and cache_key is not None:
        try:
            path = os.path.join(cache_dir, cache_key)
            os.makedirs(path, exist_ok=True)
            proc_ref_file1 = os.path.join(path, "procrustes_ref_indices1.npy")
            proc_ref_file2 = os.path.join(path, "procrustes_ref_indices2.npy")
            np.save(proc_ref_file1, ref_indices1)
            np.save(proc_ref_file2, ref_indices2)
            logger.debug(f"Cached Procrustes anchors")
        except Exception as e:
            logger.warning(f"Failed to cache Procrustes anchors: {e}")

    return ref_indices1, ref_indices2


def init_anchors_vecmap_struct(emb1_unique, emb2_unique, ind_emb1_unique, ind_emb2_unique, 
                               min_anchors, csls_k=10, init_subset=None, init_ratio=0.1):
    """
    Initialize anchors using VecMap's structural similarity approach.
    This uses intra-space similarity distribution mapping like VecMap does for zero ref_ratio.
    
    The key insight is to compute structural similarity within each space using SVD,
    then match the sorted similarity distributions between spaces.
    
    Args:
        emb1_unique, emb2_unique: Unique embeddings from both spaces
        ind_emb1_unique, ind_emb2_unique: Original indices for emb1 and emb2
        min_anchors: Minimum number of anchors
        csls_k: Neighborhood size for CSLS scoring (default: 10)
        init_subset: Maximum subset size for computational efficiency
        init_ratio: Ratio of ind_emb1_unique size to use for keep_num calculation (default: 0.1)
    
    Returns:
        ref_indices1, ref_indices2: Selected anchor indices for both spaces
    """
    logger.debug(f"VecMap structural initialization with CSLS k={csls_k}")
    
    n1 = emb1_unique.shape[0]
    n2 = emb2_unique.shape[0]
    sim_size = min(n1, n2)
    
    if init_subset is not None:
        sim_size = min(sim_size, init_subset)
    
    # Take subset for computation
    x = emb1_unique[:sim_size].copy()
    z = emb2_unique[:sim_size].copy()
    
    # Compute structural similarity matrices (VecMap approach)
    # For each space, compute intra-space similarity using SVD decomposition
    try:
        u, s, _ = np.linalg.svd(x, full_matrices=False)
        xsim = (u * s) @ u.T  # Structural similarity matrix for space 1
    except np.linalg.LinAlgError as e:
        logger.warning(f"SVD failed for space 1: {e}. Using direct similarity.")
        # Fallback: use cosine similarity
        x_norm = _l2_normalize_rows(x)
        xsim = x_norm @ x_norm.T

    try:
        u, s, _ = np.linalg.svd(z, full_matrices=False)
        zsim = (u * s) @ u.T  # Structural similarity matrix for space 2
    except np.linalg.LinAlgError as e:
        logger.warning(f"SVD failed for space 2: {e}. Using direct similarity.")
        # Fallback: use cosine similarity
        z_norm = _l2_normalize_rows(z)
        zsim = z_norm @ z_norm.T
    
    # Sort similarity values within each row (VecMap's key step)
    xsim.sort(axis=1)
    zsim.sort(axis=1)
    
    # Normalize the sorted similarity distributions
    xsim = _l2_normalize_rows(xsim)
    zsim = _l2_normalize_rows(zsim)
    
    # Match the normalized similarity distributions
    sim = xsim @ zsim.T
    
    # Apply CSLS correction if specified
    if csls_k > 0:
        knn_sim_fwd = topk_mean(sim, k=csls_k)
        knn_sim_bwd = topk_mean(sim.T, k=csls_k)
        sim = sim - knn_sim_fwd[:, np.newaxis] / 2 - knn_sim_bwd / 2
    
    # Use Hungarian algorithm for one-to-one matching (maximize similarity)
    row_ind, col_ind = linear_sum_assignment(-sim)
    pair_scores = sim[row_ind, col_ind]
    
    # Select top performing pairs based on significant performance gap
    order = np.argsort(-pair_scores)
    sorted_scores = pair_scores[order]
    
    # Find significant performance gap using statistical analysis
    if len(sorted_scores) > min_anchors:
        # Calculate score differences
        score_diff = -np.diff(sorted_scores)  # Negative because we want drops
        
        if len(score_diff) > 0:
            # Find the largest gap that's significantly larger than others
            mean_diff = np.mean(score_diff)
            std_diff = np.std(score_diff) if len(score_diff) > 1 else 0
            
            # Look for first gap that's more than 2.5 std devs above mean (stricter)
            significant_drop_idx = None
            for j in range(len(score_diff)):
                if score_diff[j] > (mean_diff + 2.5 * std_diff):
                    significant_drop_idx = j + 1  # +1 because diff is between consecutive elements
                    break
            
            if significant_drop_idx is not None and significant_drop_idx >= min_anchors:
                keep_num = significant_drop_idx
            else:
                # Use init_ratio based on ind_emb1_unique size
                keep_num = max(min_anchors, int(len(ind_emb1_unique) * init_ratio))
        else:
            keep_num = min_anchors
    else:
        keep_num = len(order)
    
    # Don't exceed available pairs
    keep_num = min(keep_num, len(order))
    
    keep = order[:keep_num]
    
    sel_src = row_ind[keep]
    sel_trg = col_ind[keep]
    
    # Map back to original indices
    ref_indices1 = ind_emb1_unique[sel_src]
    ref_indices2 = ind_emb2_unique[sel_trg]
    
    # Log statistics
    correct = (ref_indices1 == ref_indices2).sum()
    avg_score = pair_scores[keep].mean()
    logger.debug(f"VecMap structural anchors: {len(ref_indices1)} pairs, "
               f"correct={correct}, precision={correct/len(ref_indices1):.3f}, "
               f"avg_score={avg_score:.3f}")

    return ref_indices1, ref_indices2


def init_anchors_distance_profile(
    emb1_unique: np.ndarray,
    emb2_unique: np.ndarray,
    ind_emb1_unique: np.ndarray,
    ind_emb2_unique: np.ndarray,
    k: int = None,
    top_pairs_ratio: float = 0.1,
    distance_metric: str = "cosine",
    profile_metric: str = "soft_jaccard",
    use_gpu: bool = False,
    batch_size: int = 1000,
    init_subset: int = None,
    verbose: bool = True,
    csls_k: int = 10,
    soft_jaccard_epsilon: float = 0.01,
) -> tuple:
    """
    Initialize anchors using distance profile matching with soft Jaccard similarity.

    **RECOMMENDED METHOD**: This is the default and recommended initialization method
    as it matches points based on their local neighborhood structure (distance signatures)
    using soft Jaccard similarity for robust cross-space comparison.

    Algorithm:
    1. Compute distance profiles (k-NN distances) for each point within its own space
    2. Compare profiles across spaces using soft Jaccard similarity:
       - Each distance profile is a signature of the local neighborhood
       - Soft Jaccard bins distances with epsilon tolerance and computes set overlap
       - Higher Jaccard score = more similar neighborhood structure
    3. Find mutual nearest neighbors based on profile similarity
    4. Select top confidence pairs (lowest profile distance) as anchors

    Args:
        emb1_unique: Embeddings from first space (N1, D1)
        emb2_unique: Embeddings from second space (N2, D2)
        ind_emb1_unique: Original indices for emb1
        ind_emb2_unique: Original indices for emb2
        k: Number of nearest neighbors for distance profiles (None = use all distances,
           recommended: 50-100 for efficiency)
        top_pairs_ratio: Ratio of top confidence pairs to keep as anchors (default: 0.1)
        distance_metric: Distance metric for computing profiles within each space
                        ('cosine', 'euclidean', 'csls'). Default: 'cosine'
        profile_metric: Metric for comparing profiles between spaces. Default: 'soft_jaccard'
                       Options: 'euclidean', 'cosine', 'correlation', 'csls', 'soft_jaccard'
        use_gpu: Use GPU acceleration if available (recommended for large datasets)
        batch_size: Batch size for distance computation (default: 1000)
        init_subset: Maximum subset size for computational efficiency (default: None = use all)
        verbose: Print progress information (default: True)
        csls_k: Neighborhood size for CSLS adjustment (used when metric='csls', default: 10)
        soft_jaccard_epsilon: Bin size for soft Jaccard similarity (default: 0.01)
                            Smaller values = finer granularity, larger values = more tolerance
                            Recommended: 0.005-0.02 depending on normalized distance range

    Returns:
        ref_indices1, ref_indices2: Selected anchor indices representing matched pairs

    Example:
        >>> ref_idx1, ref_idx2 = init_anchors_distance_profile(
        ...     emb1, emb2, ind1, ind2,
        ...     k=50, profile_metric='soft_jaccard',
        ...     soft_jaccard_epsilon=0.01, top_pairs_ratio=0.1
        ... )
    """
    k_desc = "all" if k is None else str(k)
    if verbose:
        logger.debug(f"Distance profile initialization with k={k_desc}, top_pairs_ratio={top_pairs_ratio}")

    n1 = emb1_unique.shape[0]
    n2 = emb2_unique.shape[0]
    sim_size = min(n1, n2)

    if init_subset is not None and init_subset > 0:
        sim_size = min(sim_size, init_subset)

    # Use subset if specified
    if sim_size < n1 or sim_size < n2:
        if verbose:
            logger.debug(f"Using subset of {sim_size} points (full: {n1}, {n2})")
        emb1_part = emb1_unique[:sim_size]
        emb2_part = emb2_unique[:sim_size]
        ind1_part = ind_emb1_unique[:sim_size]
        ind2_part = ind_emb2_unique[:sim_size]
    else:
        emb1_part = emb1_unique
        emb2_part = emb2_unique
        ind1_part = ind_emb1_unique
        ind2_part = ind_emb2_unique

    # Step 1: Compute distance profiles for both embeddings
    if verbose:
        k_desc = "all" if k is None else f"{k}-NN"
        logger.debug(f"Computing {k_desc} distance profiles for emb1...")

    # Auto-determine whether to use chunked computation based on dataset size
    from utils.memory_util import warn_if_memory_insufficient
    device = torch.device("cuda" if (use_gpu and torch.cuda.is_available()) else "cpu")

    # Memory threshold: if pairwise matrix would exceed 5GB, use chunked computation
    memory_threshold_gb = 5.0
    n_total = max(len(emb1_part), len(emb2_part))
    estimated_memory_gb = n_total * n_total * 4 / (1024**3)  # float32
    use_chunked = estimated_memory_gb > memory_threshold_gb

    if use_chunked:
        logger.debug(f"Dataset size {n_total} would require {estimated_memory_gb:.2f} GB, using chunked computation")
        profiles1 = _compute_knn_distance_profiles_chunked(
            emb1_part, k=k, metric=distance_metric,
            use_gpu=use_gpu, device=device, chunk_size=None, use_faiss=True, csls_k=csls_k
        )
    else:
        # Use original method for small datasets
        warn_if_memory_insufficient(
            n_total, operation_name="Distance profile computation",
            use_gpu=use_gpu, device=device, auto_chunk=False
        )
        profiles1 = _compute_knn_distance_profiles(
            emb1_part, k=k, metric=distance_metric,
            use_gpu=use_gpu, batch_size=batch_size, sort=True, csls_k=csls_k
        )

    if verbose:
        logger.debug(f"Computing {k_desc} distance profiles for emb2...")

    if use_chunked:
        profiles2 = _compute_knn_distance_profiles_chunked(
            emb2_part, k=k, metric=distance_metric,
            use_gpu=use_gpu, device=device, chunk_size=None, use_faiss=True, csls_k=csls_k
        )
    else:
        profiles2 = _compute_knn_distance_profiles(
            emb2_part, k=k, metric=distance_metric,
            use_gpu=use_gpu, batch_size=batch_size, sort=True, csls_k=csls_k
        )

    # Step 2: Find nearest neighbors by comparing profiles
    if verbose:
        logger.debug(f"Finding nearest neighbors using {profile_metric} metric...")
    nn_indices_1to2, nn_distances_1to2 = _find_nearest_neighbors_by_profile(
        profiles1, profiles2,
        metric=profile_metric,
        top_n=1,
        use_gpu=use_gpu,
        batch_size=batch_size,
        csls_k=csls_k,
        soft_jaccard_epsilon=soft_jaccard_epsilon
    )

    nn_indices_2to1, nn_distances_2to1 = _find_nearest_neighbors_by_profile(
        profiles2, profiles1,
        metric=profile_metric,
        top_n=1,
        use_gpu=use_gpu,
        batch_size=batch_size,
        csls_k=csls_k,
        soft_jaccard_epsilon=soft_jaccard_epsilon
    )

    # Step 3: Find mutual nearest neighbors and select top confidence pairs
    n_part1 = len(ind1_part)

    mutual_pairs = []
    mutual_distances = []

    for i in range(n_part1):
        j = nn_indices_1to2[i, 0]
        if nn_indices_2to1[j, 0] == i:
            # Mutual nearest neighbor found
            dist_1to2 = nn_distances_1to2[i, 0]
            dist_2to1 = nn_distances_2to1[j, 0]
            avg_dist = (dist_1to2 + dist_2to1) / 2.0
            mutual_pairs.append((i, j))
            mutual_distances.append(avg_dist)

    if len(mutual_pairs) == 0:
        logger.warning("No mutual nearest neighbors found!")
        return np.array([], dtype=np.int32), np.array([], dtype=np.int32)

    # Step 4: Select top confidence pairs (lowest average distance)
    mutual_distances = np.array(mutual_distances)
    sorted_indices = np.argsort(mutual_distances)

    n_pairs_to_keep = max(1, int(top_pairs_ratio * min(n_part1, len(ind_emb1_unique))))
    n_pairs_to_keep = min(n_pairs_to_keep, len(mutual_pairs))

    top_indices = sorted_indices[:n_pairs_to_keep]

    sel_src = np.array([mutual_pairs[i][0] for i in top_indices])
    sel_trg = np.array([mutual_pairs[i][1] for i in top_indices])

    # Map back to original indices
    ref_indices1 = ind1_part[sel_src]
    ref_indices2 = ind2_part[sel_trg]

    # Log statistics
    correct = (ref_indices1 == ref_indices2).sum()
    precision = correct / len(ref_indices1) if len(ref_indices1) > 0 else 0

    if verbose:
        logger.debug(f"Distance profile anchors: {len(ref_indices1)} pairs")
        logger.debug(f"Mutual NN pairs found: {len(mutual_pairs)}")
        logger.debug(f"Correct matches: {correct}, precision: {precision:.3f}")
        logger.debug(f"Distance range: [{mutual_distances[top_indices].min():.6f}, {mutual_distances[top_indices].max():.6f}]")
        logger.debug(f"Mean distance: {mutual_distances[top_indices].mean():.6f}")

    return ref_indices1, ref_indices2


def _compute_knn_distance_profiles_chunked(
    emb: np.ndarray,
    k: int = None,
    metric: str = "cosine",
    use_gpu: bool = False,
    device: torch.device = None,
    chunk_size: int = None,
    use_faiss: bool = True,
    csls_k: int = 10,
) -> np.ndarray:
    """
    Memory-efficient chunked computation of k-NN distance profiles.

    This function avoids computing the full N×N distance matrix by processing
    the data in chunks. It uses FAISS when available for efficient k-NN search.

    Args:
        emb: Embeddings array (N, D)
        k: Number of nearest neighbors (None = use all distances, slower)
        metric: Distance metric ('cosine', 'euclidean', 'csls')
        use_gpu: Use GPU acceleration if available
        device: PyTorch device for GPU computation
        chunk_size: Number of rows to process at once (auto-determined if None)
        use_faiss: Use FAISS for k-NN search (highly recommended)
        csls_k: Neighborhood size for CSLS metric

    Returns:
        Distance profiles (N, k_actual) where k_actual = k if k is provided, else n-1
    """
    from utils.memory_util import compute_optimal_chunk_size, log_memory_usage

    n = emb.shape[0]

    # If k is None, use all distances (excluding self) - this will be slow and memory-intensive
    if k is None:
        k_actual = n - 1
        logger.warning(
            f"Computing full distance profiles (k=None) for {n} points. "
            f"This requires {n * (n-1) * 4 / 1e9:.2f} GB memory. "
            f"Consider setting k to a smaller value (e.g., k=50)."
        )
    else:
        k_actual = min(k, n - 1)

    # Determine chunk size if not provided
    if chunk_size is None:
        chunk_size = compute_optimal_chunk_size(
            n=n, m=n, use_gpu=use_gpu, device=device,
            min_chunk_size=100, max_chunk_size=5000
        )
        logger.debug(f"Auto-determined chunk size: {chunk_size} for {n} points")

    log_memory_usage("Before k-NN computation", use_gpu=use_gpu, device=device)

    # Try FAISS first if available and metric is supported
    if use_faiss and k is not None and metric in ['cosine', 'euclidean']:
        try:
            import faiss
            logger.debug("FAISS successfully imported for k-NN computation")

            emb_np = emb.astype(np.float32) if emb.dtype != np.float32 else emb

            if metric == 'cosine':
                # Normalize for cosine similarity
                emb_norm = normalize(emb_np, norm='l2', axis=1).astype(np.float32)
                # Use inner product index for cosine similarity
                if use_gpu and device is not None and device.type == 'cuda':
                    try:
                        # Try GPU FAISS
                        index = faiss.IndexFlatIP(emb_norm.shape[1])
                        gpu_res = faiss.StandardGpuResources()
                        index = faiss.index_cpu_to_gpu(gpu_res, device.index, index)
                        index.add(emb_norm)
                        # k+1 to include self, then we'll exclude it
                        knn_dists, knn_indices = index.search(emb_norm, k_actual + 1)

                        # Convert to distances (1 - cosine_sim)
                        knn_dists = 1.0 - knn_dists
                        logger.debug("Using GPU FAISS for k-NN search")
                    except Exception as e:
                        logger.warning(f"GPU FAISS failed: {e}, falling back to CPU FAISS")
                        # Fallback to CPU FAISS
                        index = faiss.IndexFlatIP(emb_norm.shape[1])
                        index.add(emb_norm)
                        knn_dists, knn_indices = index.search(emb_norm, k_actual + 1)
                        knn_dists = 1.0 - knn_dists
                        logger.debug("Using CPU FAISS for k-NN search")
                else:
                    # CPU FAISS
                    index = faiss.IndexFlatIP(emb_norm.shape[1])
                    index.add(emb_norm)
                    knn_dists, knn_indices = index.search(emb_norm, k_actual + 1)
                    knn_dists = 1.0 - knn_dists
                    logger.debug("Using CPU FAISS for k-NN search")
            else:  # euclidean
                if use_gpu and device is not None and device.type == 'cuda':
                    try:
                        # Try GPU FAISS
                        index = faiss.IndexFlatL2(emb_np.shape[1])
                        gpu_res = faiss.StandardGpuResources()
                        index = faiss.index_cpu_to_gpu(gpu_res, device.index, index)
                        index.add(emb_np)
                        knn_dists, knn_indices = index.search(emb_np, k_actual + 1)
                        logger.debug("Using GPU FAISS for k-NN search")
                    except Exception as e:
                        logger.warning(f"GPU FAISS failed: {e}, falling back to CPU FAISS")
                        # Fallback to CPU FAISS
                        index = faiss.IndexFlatL2(emb_np.shape[1])
                        index.add(emb_np)
                        knn_dists, knn_indices = index.search(emb_np, k_actual + 1)
                        logger.debug("Using CPU FAISS for k-NN search")
                else:
                    # CPU FAISS
                    index = faiss.IndexFlatL2(emb_np.shape[1])
                    index.add(emb_np)
                    knn_dists, knn_indices = index.search(emb_np, k_actual + 1)
                    logger.debug("Using CPU FAISS for k-NN search")

            # Remove self from k-NN results (first column is always self with distance 0)
            knn_distances = knn_dists[:, 1:k_actual + 1].astype(np.float32)

            log_memory_usage("After FAISS k-NN computation", use_gpu=use_gpu, device=device)
            return knn_distances

        except ImportError as e:
            logger.warning(f"FAISS not available (ImportError: {e}), falling back to chunked computation")
            logger.debug("Tip: Install FAISS with: pip install faiss-cpu (or faiss-gpu for GPU support)")
        except AttributeError as e:
            # Common with NumPy 2.x incompatibility
            logger.warning(f"FAISS import failed due to compatibility issue: {e}")
            logger.warning("This is likely due to NumPy 2.x incompatibility with FAISS")
            logger.warning("Options: (1) downgrade NumPy: pip install 'numpy<2.0' or (2) wait for FAISS update")
            logger.debug("Falling back to chunked computation (slower but memory-efficient)")
        except Exception as e:
            logger.warning(f"FAISS failed with unexpected error: {e}, falling back to chunked computation")

    # Fallback: Chunked computation without FAISS
    logger.debug(f"Using chunked k-NN computation with chunk_size={chunk_size}")

    knn_distances = np.zeros((n, k_actual), dtype=np.float32)

    # Preprocess embeddings based on metric
    if metric == "cosine":
        emb_proc = normalize(emb, norm='l2', axis=1).astype(np.float32)
        compute_cosine = True
    elif metric == "euclidean":
        emb_proc = emb.astype(np.float32)
        compute_cosine = False
    else:
        # CSLS not supported in chunked mode without FAISS
        logger.warning(f"Metric '{metric}' not fully optimized for chunked computation, using cosine")
        emb_proc = normalize(emb, norm='l2', axis=1).astype(np.float32)
        compute_cosine = True

    # Process in chunks
    for chunk_start in range(0, n, chunk_size):
        chunk_end = min(chunk_start + chunk_size, n)
        chunk_emb = emb_proc[chunk_start:chunk_end]

        # Compute distances for this chunk against all points
        if use_gpu and device is not None and device.type == 'cuda':
            try:
                chunk_t = torch.from_numpy(chunk_emb).to(device)
                emb_t = torch.from_numpy(emb_proc).to(device)

                if compute_cosine:
                    # Cosine distance = 1 - cosine similarity
                    sim_matrix = torch.mm(chunk_t, emb_t.T)
                    dist_chunk = (1.0 - sim_matrix).cpu().numpy()
                else:
                    # Euclidean distance
                    dist_chunk = torch.cdist(chunk_t, emb_t, p=2).cpu().numpy()
            except Exception as e:
                logger.warning(f"GPU computation failed in chunk {chunk_start}-{chunk_end}: {e}, using CPU")
                # Fallback to CPU
                if compute_cosine:
                    sim_matrix = chunk_emb @ emb_proc.T
                    dist_chunk = 1.0 - sim_matrix
                else:
                    from scipy.spatial.distance import cdist
                    dist_chunk = cdist(chunk_emb, emb_proc, metric='euclidean')
        else:
            # CPU computation
            if compute_cosine:
                sim_matrix = chunk_emb @ emb_proc.T
                dist_chunk = 1.0 - sim_matrix
            else:
                from scipy.spatial.distance import cdist
                dist_chunk = cdist(chunk_emb, emb_proc, metric='euclidean')

        # Extract k-NN for each row in chunk
        for local_i in range(chunk_end - chunk_start):
            global_i = chunk_start + local_i
            dists = dist_chunk[local_i]

            if k_actual == n - 1:
                # Use all distances (excluding self)
                knn_vals = np.sort(dists)
                # Remove self (should be at position global_i with distance ~0)
                knn_vals = np.delete(knn_vals, np.argmin(np.abs(dists)))[:k_actual]
            else:
                # Get k+1 smallest (including self), then exclude self
                knn_idx = np.argpartition(dists, min(k_actual + 1, n - 1))[:k_actual + 1]
                knn_vals = dists[knn_idx]
                knn_vals_sorted = np.sort(knn_vals)
                # First element should be self with ~0 distance, exclude it
                knn_distances[global_i] = knn_vals_sorted[1:k_actual + 1]

        if (chunk_end - chunk_start) >= 1000 or chunk_end == n:
            logger.debug(f"Processed k-NN for points {chunk_start} to {chunk_end} / {n}")

    log_memory_usage("After chunked k-NN computation", use_gpu=use_gpu, device=device)
    return knn_distances


def _compute_knn_distance_profiles(
    emb: np.ndarray,
    k: int = None,
    metric: str = "cosine",
    use_gpu: bool = False,
    batch_size: int = 1000,
    sort: bool = True,
    csls_k: int = 10,
) -> np.ndarray:
    """
    Compute distance profiles for each point.

    Args:
        emb: Embeddings array (N, D)
        k: Number of nearest neighbors (None = use all distances)
        metric: Distance metric ('cosine', 'euclidean', 'csls')
        use_gpu: Use GPU acceleration if available
        batch_size: Batch size for GPU computation
        sort: Sort distances in ascending order
        csls_k: Neighborhood size for CSLS metric

    Returns:
        Distance profiles (N, k_actual) where k_actual = k if k is provided, else n-1
    """
    n = emb.shape[0]

    # If k is None, use all distances (excluding self)
    if k is None:
        k_actual = n - 1
    else:
        k_actual = min(k, n - 1)

    # Compute pairwise distance matrix
    if metric == "csls":
        # CSLS operates on cosine similarity
        emb_norm = normalize(emb, norm='l2', axis=1)
        k_csls = max(1, min(csls_k, n - 1))

        if use_gpu:
            try:
                emb_t = torch.from_numpy(emb_norm).float().cuda()
                sim_matrix = torch.mm(emb_t, emb_t.T)

                # Row-wise mean of top-k similarities (excluding self)
                row_topk_vals = torch.topk(sim_matrix, k=k_csls + 1, dim=1, largest=True).values[:, 1:]
                r_x = row_topk_vals.mean(dim=1, keepdim=True)

                # Column-wise mean of top-k similarities (excluding self)
                col_topk_vals = torch.topk(sim_matrix, k=k_csls + 1, dim=0, largest=True).values[1:, :]
                r_y = col_topk_vals.mean(dim=0, keepdim=True)

                # CSLS adjustment
                csls_scores = 2.0 * sim_matrix - r_x - r_y
                dist_matrix = -csls_scores.cpu().numpy()
            except:
                sim_matrix = emb_norm @ emb_norm.T

                # Row-wise top-k means (excluding self)
                row_idx = np.argpartition(sim_matrix, -(k_csls + 1), axis=1)[:, -(k_csls + 1):]
                row_topk = np.take_along_axis(sim_matrix, row_idx, axis=1)
                row_topk = np.sort(row_topk, axis=1)[:, :-1][:, -k_csls:]
                r_x = row_topk.mean(axis=1, keepdims=True)

                # Column-wise top-k means (excluding self)
                col_idx = np.argpartition(sim_matrix, -(k_csls + 1), axis=0)[-(k_csls + 1):, :]
                rows = col_idx
                cols = np.broadcast_to(np.arange(sim_matrix.shape[1]), rows.shape)
                col_topk = sim_matrix[rows, cols]
                col_topk = np.sort(col_topk, axis=0)[:-1, :][-k_csls:, :]
                r_y = col_topk.mean(axis=0, keepdims=True)

                # CSLS adjustment
                csls_matrix = 2.0 * sim_matrix - r_x - r_y
                dist_matrix = -csls_matrix
        else:
            sim_matrix = emb_norm @ emb_norm.T

            # Row-wise top-k means (excluding self)
            row_idx = np.argpartition(sim_matrix, -(k_csls + 1), axis=1)[:, -(k_csls + 1):]
            row_topk = np.take_along_axis(sim_matrix, row_idx, axis=1)
            row_topk = np.sort(row_topk, axis=1)[:, :-1][:, -k_csls:]
            r_x = row_topk.mean(axis=1, keepdims=True)

            # Column-wise top-k means (excluding self)
            col_idx = np.argpartition(sim_matrix, -(k_csls + 1), axis=0)[-(k_csls + 1):, :]
            rows = col_idx
            cols = np.broadcast_to(np.arange(sim_matrix.shape[1]), rows.shape)
            col_topk = sim_matrix[rows, cols]
            col_topk = np.sort(col_topk, axis=0)[:-1, :][-k_csls:, :]
            r_y = col_topk.mean(axis=0, keepdims=True)

            # CSLS adjustment
            csls_matrix = 2.0 * sim_matrix - r_x - r_y
            dist_matrix = -csls_matrix
    elif metric == "cosine":
        emb_norm = normalize(emb, norm='l2', axis=1)
        if use_gpu:
            try:
                emb_t = torch.from_numpy(emb_norm).cuda()
                sim_matrix = torch.mm(emb_t, emb_t.T)
                dist_matrix = (1.0 - sim_matrix).cpu().numpy()
            except:
                sim_matrix = emb_norm @ emb_norm.T
                dist_matrix = 1.0 - sim_matrix
        else:
            sim_matrix = emb_norm @ emb_norm.T
            dist_matrix = 1.0 - sim_matrix
    elif metric == "euclidean":
        if use_gpu:
            try:
                emb_t = torch.from_numpy(emb).float().cuda()
                dist_matrix = torch.cdist(emb_t, emb_t, p=2).cpu().numpy()
            except:
                from scipy.spatial.distance import cdist
                dist_matrix = cdist(emb, emb, metric='euclidean')
        else:
            from scipy.spatial.distance import cdist
            dist_matrix = cdist(emb, emb, metric='euclidean')
    else:
        raise ValueError(f"Unknown metric: {metric}")

    # Extract distance profiles
    knn_distances = np.zeros((n, k_actual), dtype=np.float32)

    for i in range(n):
        dists = dist_matrix[i]

        if k_actual == n - 1:
            # Use all distances (excluding self)
            knn_vals = np.sort(dists)[1:]  # Exclude self (first element after sorting)
        elif k_actual + 1 < n:
            # Get k+1 smallest (including self), then exclude self
            knn_idx = np.argpartition(dists, k_actual + 1)[:k_actual + 1]
            knn_vals = dists[knn_idx]
            knn_vals = np.sort(knn_vals)[1:k_actual + 1]
        else:
            # Fallback: sort all and take k_actual after excluding self
            knn_vals = np.sort(dists)[1:k_actual + 1]

        knn_distances[i] = knn_vals

    return knn_distances


def _extract_and_bin_profiles(
    profiles: np.ndarray,
    global_min: float,
    global_max: float,
    epsilon: float,
) -> np.ndarray:
    """
    Extract and bin distance profiles into histogram representation.

    This converts profiles from shape (N, K) to histogram matrix (N, num_bins)
    where each row contains the frequency count for each bin.

    Args:
        profiles: Distance profiles (N, K)
        global_min: Global minimum value for consistent binning
        global_max: Global maximum value for consistent binning
        epsilon: Bin size

    Returns:
        Histogram matrix (N, num_bins) with frequency counts
    """
    n = profiles.shape[0]

    # Determine number of bins
    n_bins = max(10, int((global_max - global_min) / epsilon) + 1)

    # Bin the profiles
    bins = np.floor((profiles - global_min) / epsilon).astype(np.int32)
    bins = np.clip(bins, 0, n_bins - 1)

    # Build histogram matrix
    histogram_matrix = np.zeros((n, n_bins), dtype=np.float32)

    for i in range(n):
        for bin_id in bins[i]:
            histogram_matrix[i, bin_id] += 1.0

    return histogram_matrix


def _compute_soft_jaccard_similarity(
    profiles1: np.ndarray,
    profiles2: np.ndarray,
    epsilon: float = 0.1,
    use_gpu: bool = False,
    batch_size: int = 1000,
) -> np.ndarray:
    """
    Compute soft jaccard similarity between distance profiles using histogram-based approach.

    Treats each profile as a set of values (not ordered vector), bins them with epsilon tolerance,
    and computes jaccard index based on bin overlap using efficient histogram operations.

    Args:
        profiles1: First set of profiles (N1, K)
        profiles2: Second set of profiles (N2, K)
        epsilon: Tolerance for binning values
        use_gpu: Use GPU acceleration if available
        batch_size: Batch size for GPU computation

    Returns:
        Similarity matrix (N1, N2) with jaccard indices
    """
    n1, k1 = profiles1.shape
    n2, k2 = profiles2.shape

    if k1 != k2:
        raise ValueError(f"Profile dimensions must match: {k1} != {k2}")

    # Determine global min/max for consistent binning
    global_min = min(profiles1.min(), profiles2.min())
    global_max = max(profiles1.max(), profiles2.max())

    n_bins = max(10, int((global_max - global_min) / epsilon) + 1)

    logger.debug(f"Soft Jaccard: binning profiles with epsilon={epsilon}, n_bins={n_bins}")

    # Extract and bin all profiles into histograms (optimized approach)
    H1 = _extract_and_bin_profiles(profiles1, global_min, global_max, epsilon)
    H2 = _extract_and_bin_profiles(profiles2, global_min, global_max, epsilon)

    logger.debug(f"Soft Jaccard: histogram shapes H1={H1.shape}, H2={H2.shape}")

    # Determine if GPU is available and beneficial
    if use_gpu:
        try:
            if torch is not None and torch.cuda.is_available():
                # GPU implementation using histogram operations
                device = torch.device("cuda")

                # Auto-determine chunk size based on available memory
                try:
                    total_mem = torch.cuda.get_device_properties(0).total_memory
                    free_mem, _ = torch.cuda.mem_get_info()
                    # Use at most 30% of free memory for safety
                    max_chunk_mem = min(free_mem * 0.3, 10e9)
                    bytes_per_element = 4  # float32
                    # Estimate: H1_batch (batch1 × n_bins) + H2_batch (batch2 × n_bins) + expanded + results
                    estimated_chunk_size = int(max_chunk_mem / (n_bins * bytes_per_element * 6))
                    chunk_size = max(1, min(estimated_chunk_size, batch_size, 100))
                    logger.debug(f"GPU chunk size: {chunk_size} (n_bins={n_bins})")
                except Exception:
                    chunk_size = 1

                similarity_matrix = np.zeros((n1, n2), dtype=np.float32)

                # Process in batches to avoid memory issues
                for start_idx in range(0, n2, chunk_size):
                    end_idx = min(start_idx + chunk_size, n2)

                    if chunk_size == 1:
                        # Row-by-row processing for very large histograms
                        H2_row = torch.from_numpy(H2[start_idx]).float().to(device)
                        for i in range(n1):
                            H1_row = torch.from_numpy(H1[i]).float().to(device)
                            intersection = torch.minimum(H1_row, H2_row).sum()
                            union = torch.maximum(H1_row, H2_row).sum()
                            if union > 0:
                                similarity_matrix[i, start_idx] = (intersection / union).cpu().item()
                            del H1_row
                        del H2_row
                        torch.cuda.empty_cache()
                    else:
                        # Batch processing
                        try:
                            H1_gpu = torch.from_numpy(H1).float().to(device)
                            H2_chunk = torch.from_numpy(H2[start_idx:end_idx]).float().to(device)

                            # Compute intersection and union using broadcasting
                            H1_expanded = H1_gpu.unsqueeze(1)  # (n1, 1, n_bins)
                            H2_expanded = H2_chunk.unsqueeze(0)  # (1, chunk, n_bins)

                            # Vectorized min/max operations
                            intersection = torch.minimum(H1_expanded, H2_expanded).sum(dim=2)
                            union = torch.maximum(H1_expanded, H2_expanded).sum(dim=2)

                            # Jaccard similarity
                            batch_sim = torch.where(union > 0, intersection / union, torch.zeros_like(intersection))

                            similarity_matrix[:, start_idx:end_idx] = batch_sim.cpu().numpy()

                            # Clean up
                            del H1_gpu, H2_chunk, H1_expanded, H2_expanded, intersection, union, batch_sim
                            torch.cuda.empty_cache()

                        except RuntimeError as e:
                            if "out of memory" in str(e):
                                logger.warning(f"GPU OOM at chunk {start_idx}. Falling back to CPU for this chunk...")
                                torch.cuda.empty_cache()
                                # Fallback to CPU for this chunk
                                for j in range(start_idx, end_idx):
                                    for i in range(n1):
                                        intersection = np.minimum(H1[i], H2[j]).sum()
                                        union = np.maximum(H1[i], H2[j]).sum()
                                        if union > 0:
                                            similarity_matrix[i, j] = intersection / union
                            else:
                                raise

                return similarity_matrix
            else:
                logger.warning("GPU requested but not available, falling back to CPU")
                use_gpu = False
        except Exception as e:
            logger.warning(f"GPU soft jaccard failed: {e}, falling back to CPU")
            use_gpu = False

    # CPU implementation using histogram operations
    similarity_matrix = np.zeros((n1, n2), dtype=np.float32)

    for i in range(n1):
        for j in range(n2):
            # Use vectorized numpy operations on histograms
            intersection = np.minimum(H1[i], H2[j]).sum()
            union = np.maximum(H1[i], H2[j]).sum()

            # Jaccard index
            if union > 0:
                similarity_matrix[i, j] = intersection / union
            else:
                similarity_matrix[i, j] = 0.0

    return similarity_matrix


def _find_nearest_neighbors_by_profile(
    profiles1: np.ndarray,
    profiles2: np.ndarray,
    metric: str = "euclidean",
    top_n: int = 1,
    use_gpu: bool = False,
    batch_size: int = 1000,
    csls_k: int = 10,
    soft_jaccard_epsilon: float = 0.1,
) -> tuple:
    """Find nearest neighbors between two sets based on their distance profiles."""
    n1 = profiles1.shape[0]
    n2 = profiles2.shape[0]

    if metric == "csls":
        # CSLS for profile comparison
        p1_norm = normalize(profiles1, norm='l2', axis=1)
        p2_norm = normalize(profiles2, norm='l2', axis=1)

        k_row = max(1, min(csls_k, n2))
        k_col = max(1, min(csls_k, n1))

        if use_gpu:
            try:
                p1_t = torch.from_numpy(p1_norm).float().cuda()
                p2_t = torch.from_numpy(p2_norm).float().cuda()

                # Full similarity matrix (N1 x N2)
                sim = torch.mm(p1_t, p2_t.T)

                # Row-wise mean of top-k similarities
                row_topk_vals = torch.topk(sim, k=k_row, dim=1, largest=True).values
                r_x = row_topk_vals.mean(dim=1, keepdim=True)

                # Column-wise mean of top-k similarities
                col_topk_vals = torch.topk(sim, k=k_col, dim=0, largest=True).values
                r_y = col_topk_vals.mean(dim=0, keepdim=True)

                csls_scores = 2.0 * sim - r_x - r_y
                dist = -csls_scores

                if top_n == 1:
                    dist_vals, idx = torch.min(dist, dim=1)
                    indices = idx.cpu().numpy().reshape(-1, 1)
                    distances = dist_vals.cpu().numpy().reshape(-1, 1)
                else:
                    dist_vals, idx = torch.topk(dist, k=top_n, dim=1, largest=False)
                    indices = idx.cpu().numpy()
                    distances = dist_vals.cpu().numpy()
            except:
                sim_matrix = p1_norm @ p2_norm.T

                # Row-wise top-k means
                row_idx = np.argpartition(sim_matrix, -k_row, axis=1)[:, -k_row:]
                row_topk = np.take_along_axis(sim_matrix, row_idx, axis=1)
                r_x = row_topk.mean(axis=1, keepdims=True)

                # Column-wise top-k means
                col_idx = np.argpartition(sim_matrix, -k_col, axis=0)[-k_col:, :]
                rows = col_idx
                cols = np.broadcast_to(np.arange(sim_matrix.shape[1]), rows.shape)
                col_topk = sim_matrix[rows, cols]
                r_y = col_topk.mean(axis=0, keepdims=True)

                csls_matrix = 2.0 * sim_matrix - r_x - r_y
                dist_matrix = -csls_matrix

                indices = np.zeros((n1, top_n), dtype=np.int64)
                distances = np.zeros((n1, top_n), dtype=np.float32)

                for i in range(n1):
                    if top_n == 1:
                        idx = np.argmin(dist_matrix[i])
                        indices[i, 0] = idx
                        distances[i, 0] = dist_matrix[i, idx]
                    else:
                        idx = np.argpartition(dist_matrix[i], top_n)[:top_n]
                        idx_sorted = idx[np.argsort(dist_matrix[i, idx])]
                        indices[i] = idx_sorted
                        distances[i] = dist_matrix[i, idx_sorted]
        else:
            sim_matrix = p1_norm @ p2_norm.T

            # Row-wise top-k means
            row_idx = np.argpartition(sim_matrix, -k_row, axis=1)[:, -k_row:]
            row_topk = np.take_along_axis(sim_matrix, row_idx, axis=1)
            r_x = row_topk.mean(axis=1, keepdims=True)

            # Column-wise top-k means
            col_idx = np.argpartition(sim_matrix, -k_col, axis=0)[-k_col:, :]
            rows = col_idx
            cols = np.broadcast_to(np.arange(sim_matrix.shape[1]), rows.shape)
            col_topk = sim_matrix[rows, cols]
            r_y = col_topk.mean(axis=0, keepdims=True)

            csls_matrix = 2.0 * sim_matrix - r_x - r_y
            dist_matrix = -csls_matrix

            indices = np.zeros((n1, top_n), dtype=np.int64)
            distances = np.zeros((n1, top_n), dtype=np.float32)

            for i in range(n1):
                if top_n == 1:
                    idx = np.argmin(dist_matrix[i])
                    indices[i, 0] = idx
                    distances[i, 0] = dist_matrix[i, idx]
                else:
                    idx = np.argpartition(dist_matrix[i], top_n)[:top_n]
                    idx_sorted = idx[np.argsort(dist_matrix[i, idx])]
                    indices[i] = idx_sorted
                    distances[i] = dist_matrix[i, idx_sorted]

        return indices, distances

    elif metric == "euclidean":
        if use_gpu:
            try:
                p1_t = torch.from_numpy(profiles1).float().cuda()
                p2_t = torch.from_numpy(profiles2).float().cuda()
                dist = torch.cdist(p1_t, p2_t, p=2)

                if top_n == 1:
                    dist_vals, idx = torch.min(dist, dim=1)
                    indices = idx.cpu().numpy().reshape(-1, 1)
                    distances = dist_vals.cpu().numpy().reshape(-1, 1)
                else:
                    dist_vals, idx = torch.topk(dist, k=top_n, dim=1, largest=False)
                    indices = idx.cpu().numpy()
                    distances = dist_vals.cpu().numpy()
            except:
                from scipy.spatial.distance import cdist
                dist_matrix = cdist(profiles1, profiles2, metric='euclidean')
                indices = np.zeros((n1, top_n), dtype=np.int64)
                distances = np.zeros((n1, top_n), dtype=np.float32)
                for i in range(n1):
                    if top_n == 1:
                        idx = np.argmin(dist_matrix[i])
                        indices[i, 0] = idx
                        distances[i, 0] = dist_matrix[i, idx]
                    else:
                        idx = np.argpartition(dist_matrix[i], top_n)[:top_n]
                        idx_sorted = idx[np.argsort(dist_matrix[i, idx])]
                        indices[i] = idx_sorted
                        distances[i] = dist_matrix[i, idx_sorted]
        else:
            from scipy.spatial.distance import cdist
            dist_matrix = cdist(profiles1, profiles2, metric='euclidean')
            indices = np.zeros((n1, top_n), dtype=np.int64)
            distances = np.zeros((n1, top_n), dtype=np.float32)
            for i in range(n1):
                if top_n == 1:
                    idx = np.argmin(dist_matrix[i])
                    indices[i, 0] = idx
                    distances[i, 0] = dist_matrix[i, idx]
                else:
                    idx = np.argpartition(dist_matrix[i], top_n)[:top_n]
                    idx_sorted = idx[np.argsort(dist_matrix[i, idx])]
                    indices[i] = idx_sorted
                    distances[i] = dist_matrix[i, idx_sorted]

    elif metric == "cosine":
        p1_norm = normalize(profiles1, norm='l2', axis=1)
        p2_norm = normalize(profiles2, norm='l2', axis=1)

        if use_gpu:
            try:
                p1_t = torch.from_numpy(p1_norm).float().cuda()
                p2_t = torch.from_numpy(p2_norm).float().cuda()
                sim = torch.mm(p1_t, p2_t.T)
                dist = 1.0 - sim

                if top_n == 1:
                    dist_vals, idx = torch.min(dist, dim=1)
                    indices = idx.cpu().numpy().reshape(-1, 1)
                    distances = dist_vals.cpu().numpy().reshape(-1, 1)
                else:
                    dist_vals, idx = torch.topk(dist, k=top_n, dim=1, largest=False)
                    indices = idx.cpu().numpy()
                    distances = dist_vals.cpu().numpy()
            except:
                sim_matrix = p1_norm @ p2_norm.T
                dist_matrix = 1.0 - sim_matrix
                indices = np.zeros((n1, top_n), dtype=np.int64)
                distances = np.zeros((n1, top_n), dtype=np.float32)
                for i in range(n1):
                    if top_n == 1:
                        idx = np.argmin(dist_matrix[i])
                        indices[i, 0] = idx
                        distances[i, 0] = dist_matrix[i, idx]
                    else:
                        idx = np.argpartition(dist_matrix[i], top_n)[:top_n]
                        idx_sorted = idx[np.argsort(dist_matrix[i, idx])]
                        indices[i] = idx_sorted
                        distances[i] = dist_matrix[i, idx_sorted]
        else:
            sim_matrix = p1_norm @ p2_norm.T
            dist_matrix = 1.0 - sim_matrix
            indices = np.zeros((n1, top_n), dtype=np.int64)
            distances = np.zeros((n1, top_n), dtype=np.float32)
            for i in range(n1):
                if top_n == 1:
                    idx = np.argmin(dist_matrix[i])
                    indices[i, 0] = idx
                    distances[i, 0] = dist_matrix[i, idx]
                else:
                    idx = np.argpartition(dist_matrix[i], top_n)[:top_n]
                    idx_sorted = idx[np.argsort(dist_matrix[i, idx])]
                    indices[i] = idx_sorted
                    distances[i] = dist_matrix[i, idx_sorted]

    elif metric == "correlation":
        # Pearson correlation distance
        from scipy.spatial.distance import cdist
        dist_matrix = cdist(profiles1, profiles2, metric='correlation')

        indices = np.zeros((n1, top_n), dtype=np.int64)
        distances = np.zeros((n1, top_n), dtype=np.float32)

        for i in range(n1):
            if top_n == 1:
                idx = np.argmin(dist_matrix[i])
                indices[i, 0] = idx
                distances[i, 0] = dist_matrix[i, idx]
            else:
                idx = np.argpartition(dist_matrix[i], top_n)[:top_n]
                idx_sorted = idx[np.argsort(dist_matrix[i, idx])]
                indices[i] = idx_sorted
                distances[i] = dist_matrix[i, idx_sorted]

        return indices, distances

    elif metric == "soft_jaccard":
        # Soft Jaccard similarity using binning
        similarity_matrix = _compute_soft_jaccard_similarity(
            profiles1, profiles2,
            epsilon=soft_jaccard_epsilon,
            use_gpu=use_gpu,
            batch_size=batch_size
        )

        # Convert similarity to distance (1 - similarity)
        dist_matrix = 1.0 - similarity_matrix

        indices = np.zeros((n1, top_n), dtype=np.int64)
        distances = np.zeros((n1, top_n), dtype=np.float32)

        for i in range(n1):
            if top_n == 1:
                idx = np.argmin(dist_matrix[i])
                indices[i, 0] = idx
                distances[i, 0] = dist_matrix[i, idx]
            else:
                idx = np.argpartition(dist_matrix[i], top_n)[:top_n]
                idx_sorted = idx[np.argsort(dist_matrix[i, idx])]
                indices[i] = idx_sorted
                distances[i] = dist_matrix[i, idx_sorted]

        return indices, distances

    else:
        raise ValueError(f"Metric {metric} not implemented")

    return indices, distances


def main():
    """Test anchor initialization methods"""
    import argparse
    from sklearn.preprocessing import LabelEncoder
    from utils.load_data import load_npy, DataPartitioner
    from utils.clustering import Clusterer

    parser = argparse.ArgumentParser(description="Test anchor initialization methods")
    parser.add_argument("--dataset", type=str, default="scifact")
    parser.add_argument("--base_dir", type=str, default="embeddings")
    parser.add_argument("--cache_dir", type=str, default="cache/ind")
    parser.add_argument("--emb1", type=str, default="mistral")
    parser.add_argument("--emb2", type=str, default="openai")
    parser.add_argument("--partition", type=str, default="random")
    parser.add_argument("--overlap_ratio", type=float, default=0.3)
    parser.add_argument("--n_clusters", type=int, default=5)
    parser.add_argument("--cluster_method", type=str, default="kmeans")
    parser.add_argument("--nonref_clu_choices", type=list, default=[0])
    parser.add_argument("--use_gpu", type=bool, default=True)

    # Anchor initialization parameters
    parser.add_argument("--init_method", type=str, default="distance_profile",
                       choices=["pca", "procrustes", "vecmap_struct", "distance_profile"])
    parser.add_argument("--min_anchors", type=int, default=50)
    parser.add_argument("--init_subset", type=int, default=5000)
    parser.add_argument("--init_ratio", type=float, default=0.1)
    parser.add_argument("--n_components", type=int, default=100)
    parser.add_argument("--csls_neighborhood", type=int, default=10)
    parser.add_argument("--use_pca_first", action="store_true", help="For Procrustes: use PCA first (default: True)")

    # Distance profile specific parameters
    parser.add_argument("--distance_profile_k", type=int, default=50,
                       help="Number of nearest neighbors for distance profile (default: 50)")
    parser.add_argument("--distance_profile_metric", type=str, default="cosine",
                       choices=["cosine", "euclidean", "csls"],
                       help="Distance metric for computing profiles (default: cosine)")
    parser.add_argument("--profile_comparison_metric", type=str, default="soft_jaccard",
                       choices=["euclidean", "cosine", "correlation", "csls", "soft_jaccard"],
                       help="Metric for comparing profiles between sets (default: soft_jaccard)")
    parser.add_argument("--soft_jaccard_epsilon", type=float, default=0.01,
                       help="Epsilon for soft Jaccard similarity binning (default: 0.01)")
    parser.add_argument("--top_pairs_ratio", type=float, default=0.1,
                       help="Ratio of top confidence pairs to keep (default: 0.1)")

    args = parser.parse_args()

    # Load embeddings
    base_dir = args.base_dir
    if args.dataset in ["scifact", "scidocs", "fiqa", "nfcorpus", "arguana"]:
        emb1 = load_npy(base_dir, f"corpus_embeddings_{args.emb1}_{args.dataset}.npy")
        emb2 = load_npy(base_dir, f"corpus_embeddings_{args.emb2}_{args.dataset}.npy")
    elif args.dataset in ["Citeseer", "Cora", "PubMed"]:
        emb1 = load_npy(base_dir, f"emb_{args.dataset}/{args.emb1}.npy")
        emb2 = load_npy(base_dir, f"emb_{args.dataset}/{args.emb2}.npy")
    elif args.dataset in ["biorxiv", "alloprof", "big_patent", "arxivp2p"]:
        emb1 = load_npy(base_dir, f"{args.emb1}_{args.dataset}/{args.emb1}_{args.dataset}_embeddings.npy")
        emb2 = load_npy(base_dir, f"{args.emb2}_{args.dataset}/{args.emb2}_{args.dataset}_embeddings.npy")
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")

    logger.debug(f"Loaded embeddings: emb1.shape={emb1.shape}, emb2.shape={emb2.shape}")

    # Create data partition
    if args.partition == "cluster_partial":
        ind_file_name = os.path.join(args.cache_dir, f"{args.dataset}_{args.emb1}_{args.partition}{args.n_clusters}_{args.nonref_clu_choices}")
    elif args.partition == "random":
        ind_file_name = os.path.join(args.cache_dir, f"{args.dataset}_{args.partition}_{args.overlap_ratio}")

    if os.path.exists(os.path.join(ind_file_name, "ind1.npy")):
        ind_emb1_unique = load_npy(ind_file_name, "ind1")
        ind_emb2_unique = load_npy(ind_file_name, "ind2")
        ind_nonref = np.intersect1d(ind_emb1_unique, ind_emb2_unique)
        logger.debug("Loaded cached partition indices")
    else:
        clusterer = Clusterer(method=args.cluster_method, n_clusters=args.n_clusters, use_gpu=args.use_gpu)
        labels = clusterer.fit(emb1)
        labels = LabelEncoder().fit_transform(labels)
        data_partitioner = DataPartitioner(labels, partition_type=args.partition,
                                          nonref_clu_choices=args.nonref_clu_choices,
                                          overlap_ratio=args.overlap_ratio)
        ind_emb1_unique = data_partitioner.ind_emb1_unique
        ind_emb2_unique = data_partitioner.ind_emb2_unique
        ind_nonref = data_partitioner.ind_emb1_nonref
        logger.debug("Created new partition")

    logger.debug(f"Partition stats: ind_emb1_unique={len(ind_emb1_unique)}, "
               f"ind_emb2_unique={len(ind_emb2_unique)}, "
               f"ind_nonref={len(ind_nonref)}")

    # Extract unique embeddings
    np.random.shuffle(ind_emb1_unique)
    np.random.shuffle(ind_emb2_unique)

    emb1_unique = emb1[ind_emb1_unique]
    emb2_unique = emb2[ind_emb2_unique]

    # Test anchor initialization
    logger.debug(f"\n{'='*60}")
    logger.debug(f"Testing {args.init_method.upper()} anchor initialization")
    logger.debug(f"{'='*60}")

    if args.init_method == "pca":
        cache_key = f"{args.dataset}_{args.partition}_{args.overlap_ratio}"
        ref_indices1, ref_indices2 = init_anchors_pca(
            emb1_unique, emb2_unique,
            ind_emb1_unique, ind_emb2_unique,
            min_anchors=args.min_anchors,
            init_subset=args.init_subset,
            n_components=args.n_components,
            cache_dir=args.cache_dir,
            cache_key=cache_key,
            init_ratio=args.init_ratio,
            use_gpu=args.use_gpu
        )
    elif args.init_method == "procrustes":
        cache_key = f"{args.dataset}_{args.partition}_{args.overlap_ratio}_procrustes"
        ref_indices1, ref_indices2 = init_anchors_procrustes(
            emb1_unique, emb2_unique,
            ind_emb1_unique, ind_emb2_unique,
            min_anchors=args.min_anchors,
            init_subset=args.init_subset,
            n_components=args.n_components,
            cache_dir=args.cache_dir,
            cache_key=cache_key,
            init_ratio=args.init_ratio,
            use_pca_first=args.use_pca_first if hasattr(args, 'use_pca_first') else False
        )
    elif args.init_method == "vecmap_struct":
        ref_indices1, ref_indices2 = init_anchors_vecmap_struct(
            emb1_unique, emb2_unique,
            ind_emb1_unique, ind_emb2_unique,
            min_anchors=args.min_anchors,
            csls_k=args.csls_neighborhood,
            init_subset=args.init_subset,
            init_ratio=args.init_ratio
        )
    elif args.init_method == "distance_profile":
        ref_indices1, ref_indices2 = init_anchors_distance_profile(
            emb1_unique, emb2_unique,
            ind_emb1_unique, ind_emb2_unique,
            k=args.distance_profile_k,
            top_pairs_ratio=args.top_pairs_ratio,
            distance_metric=args.distance_profile_metric,
            profile_metric=args.profile_comparison_metric,
            use_gpu=args.use_gpu,
            batch_size=1000,
            init_subset=args.init_subset,
            verbose=True,
            csls_k=args.csls_neighborhood,
            soft_jaccard_epsilon=args.soft_jaccard_epsilon,
        )

    # Compute accuracy and recall
    correct = (ref_indices1 == ref_indices2).sum()
    accuracy = correct / len(ref_indices1) if len(ref_indices1) > 0 else 0
    recall = correct / len(ind_nonref) if len(ind_nonref) > 0 else 0

    logger.debug(f"\n{'='*60}")
    logger.debug(f"RESULTS")
    logger.debug(f"{'='*60}")
    logger.debug(f"Number of anchors: {len(ref_indices1)}")
    logger.debug(f"Correct matches: {correct}")
    logger.debug(f"Accuracy: {accuracy:.4f} ({correct}/{len(ref_indices1)})")
    logger.debug(f"Recall: {recall:.4f} ({correct}/{len(ind_nonref)})")
    logger.debug(f"Ground truth size: {len(ind_nonref)}")
    logger.debug(f"{'='*60}")


if __name__ == "__main__":
    main()
