import numpy as np
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.load_data import load_npy

def _pairwise_knn_distances(X, k_max, chunk_size=1000):
    X = np.asarray(X, float)
    n = X.shape[0]
    
    # Initialize result array
    knn_distances = np.zeros((n, k_max))
    
    # Process in chunks to avoid memory issues
    for i in range(0, n, chunk_size):
        end_i = min(i + chunk_size, n)
        chunk = X[i:end_i]
        
        # Compute distances from this chunk to all points
        # Using broadcasting: ||x - y||^2 = ||x||^2 + ||y||^2 - 2<x,y>
        chunk_norms = np.sum(chunk**2, axis=1, keepdims=True)  # (chunk_size, 1)
        all_norms = np.sum(X**2, axis=1, keepdims=True).T     # (1, n)
        dot_products = chunk @ X.T                             # (chunk_size, n)
        
        # Squared distances
        distances_sq = chunk_norms + all_norms - 2 * dot_products
        distances_sq = np.maximum(distances_sq, 0)  # Ensure non-negative
        distances = np.sqrt(distances_sq)
        
        # Set self-distances to infinity
        for j, global_idx in enumerate(range(i, end_i)):
            distances[j, global_idx] = np.inf
        
        # Get k_max nearest distances for each point in chunk
        idx = np.argpartition(distances, kth=k_max-1, axis=1)[:, :k_max]
        rows = np.arange(len(chunk))[:, None]
        chunk_knn = distances[rows, idx]
        
        # Sort distances for each point
        chunk_knn.sort(axis=1)
        knn_distances[i:end_i] = chunk_knn
    
    return knn_distances  # shape (n, k_max)

def levina_bickel_mle(X, k1=None, k2=None, pooled=True, eps=1e-15, bias_correction=True):
    n_samples = X.shape[0]
    
    # Improved adaptive k-range selection
    if k1 is None:
        k1 = max(5, min(15, int(0.015 * n_samples)))
    if k2 is None:
        k2 = max(k1 + 15, min(100, int(0.1 * n_samples), n_samples - 1))
    
    k2 = min(k2, n_samples - 1)
    
    r = _pairwise_knn_distances(X, k2)
    r = np.maximum(r, eps)
    
    mks, variances, k_values = [], [], []
    
    for k in range(k1, k2+1):
        r_k = r[:, k-1]
        r_1 = r[:, 0]
        
        # Use first neighbor as reference for more stable ratios
        valid_points = (r_k > eps) & (r_1 > eps)
        if np.sum(valid_points) < max(20, k):
            continue
            
        r_k_valid = r_k[valid_points]
        r_1_valid = r_1[valid_points]
        
        # Compute log-ratios directly for better numerical stability
        log_ratios = np.log(r_k_valid / r_1_valid)
        
        # Remove outliers using robust statistics
        q75, q25 = np.percentile(log_ratios, [75, 25])
        iqr = q75 - q25
        lower_bound = q25 - 1.5 * iqr
        upper_bound = q75 + 1.5 * iqr
        
        valid_logs = log_ratios[(log_ratios >= lower_bound) & (log_ratios <= upper_bound)]
        
        if len(valid_logs) < max(20, k):
            continue
            
        # MLE estimation with improved bias correction
        mean_log = np.mean(valid_logs)
        var_log = np.var(valid_logs, ddof=1)
        
        if abs(mean_log) < eps:
            continue
            
        # Original Levina-Bickel estimator
        m_k_raw = np.log(k) / mean_log
        
        if bias_correction:
            # Improved bias correction based on higher-order terms
            n_valid = len(valid_logs)
            bias_term1 = (k - 1) / (2 * n_valid)
            bias_term2 = var_log / (2 * mean_log**2 * n_valid)
            correction = 1 - bias_term1 - bias_term2
            m_k = m_k_raw * correction
        else:
            m_k = m_k_raw
        
        # Estimate variance for weighted averaging
        variance = var_log / (mean_log**2 * n_valid)
        
        mks.append(m_k)
        variances.append(variance)
        k_values.append(k)
    
    if len(mks) == 0:
        return {
            "m_per_k": np.array([]),
            "m_avg_over_k": float('nan'),
            "m_inv_avg_over_k": float('nan'),
            "m_median": float('nan'),
            "k_range": (k1, k2),
            "pooled": pooled,
            "n_valid_k": 0
        }
    
    mks = np.array(mks)
    variances = np.array(variances)
    
    # Variance-weighted average (inverse variance weighting)
    weights = 1.0 / (variances + eps)
    weights = weights / np.sum(weights)
    weighted_avg = np.sum(mks * weights)
    
    # Median as robust central estimate
    median_estimate = np.median(mks)
    
    # Remove outlier estimates before final averaging
    q75, q25 = np.percentile(mks, [75, 25])
    iqr = q75 - q25
    robust_mask = (mks >= q25 - 1.5 * iqr) & (mks <= q75 + 1.5 * iqr)
    robust_estimates = mks[robust_mask]
    
    if len(robust_estimates) > 0:
        robust_mean = np.mean(robust_estimates)
    else:
        robust_mean = np.mean(mks)
    
    return {
        "m_per_k": mks,
        "m_avg_over_k": float(np.mean(mks)),
        "m_weighted_avg": float(weighted_avg),
        "m_robust_avg": float(robust_mean),
        "m_median": float(median_estimate),
        "k_range": (k1, k2),
        "k_values": k_values,
        "variances": variances,
        "pooled": pooled,
        "n_valid_k": len(mks)
    }

def twonn(X, trim=(0.1, 0.9), eps=1e-12):
    r = _pairwise_knn_distances(X, 2)
    r = np.maximum(r, eps)
    
    mu = r[:,1] / np.maximum(r[:,0], eps)
    
    # More robust filtering
    valid_mask = np.isfinite(mu) & (mu > 1 + eps) & (mu < 1e6)
    valid_mu = mu[valid_mask]
    
    if len(valid_mu) < 10:
        return {"m_twonn_mle": float('nan'), "m_twonn_cdf_slope": float('nan'), "mu": mu}
    
    # Remove outliers using IQR method
    q75, q25 = np.percentile(valid_mu, [75, 25])
    iqr = q75 - q25
    outlier_mask = (valid_mu >= q25 - 1.5 * iqr) & (valid_mu <= q75 + 1.5 * iqr)
    robust_mu = valid_mu[outlier_mask]
    
    if len(robust_mu) < 10:
        robust_mu = valid_mu
    
    log_mu = np.log(robust_mu)
    valid_log_mu = log_mu[np.isfinite(log_mu)]
    
    if len(valid_log_mu) < 10:
        m_mle = float('nan')
    else:
        # Bias-corrected MLE
        n_valid = len(valid_log_mu)
        sum_log = np.sum(valid_log_mu)
        m_mle_raw = n_valid / sum_log
        
        # Bias correction for finite sample size
        bias_correction = 1 - 1.0 / (2 * n_valid)
        m_mle = m_mle_raw * bias_correction
    
    # Improved CDF-based estimation
    mu_sorted = np.sort(robust_mu)
    n = len(mu_sorted)
    
    if n < 20:
        return {"m_twonn_mle": float(m_mle), "m_twonn_cdf_slope": float('nan'), "mu": mu}
    
    # Use Weibull plotting position for better tail behavior
    F = (np.arange(1, n+1) - 0.3) / (n + 0.4)
    x_all = np.log(mu_sorted)
    y_all = -np.log(np.maximum(1 - F, eps))
    
    # Adaptive trimming based on data distribution
    i_lo = max(1, int(np.floor(trim[0] * n)))
    i_hi = min(n-1, int(np.ceil(trim[1] * n)))
    
    x, y = x_all[i_lo:i_hi], y_all[i_lo:i_hi]
    
    if len(x) < 10:
        slope = float('nan')
    else:
        # Robust linear regression using least squares
        X_matrix = np.vstack([x, np.ones(len(x))]).T
        try:
            coeffs = np.linalg.lstsq(X_matrix, y, rcond=None)[0]
            slope = float(coeffs[0])
        except:
            # Fallback to simple ratio
            slope = float(np.dot(x, y) / np.dot(x, x)) if np.dot(x, x) > 0 else float('nan')
    
    return {
        "m_twonn_mle": float(m_mle), 
        "m_twonn_cdf_slope": slope, 
        "mu": mu,
        "n_valid": len(robust_mu)
    }

def pca_intrinsic_dim(X, threshold=0.95):
    """Estimate intrinsic dimension using PCA variance explained threshold."""
    X_centered = X - np.mean(X, axis=0)
    
    # Use SVD for better numerical stability
    _, s, _ = np.linalg.svd(X_centered, full_matrices=False)
    
    # Compute explained variance ratio
    explained_variance = s**2 / np.sum(s**2)
    cumulative_variance = np.cumsum(explained_variance)
    
    # Find dimension where cumulative variance exceeds threshold
    intrinsic_dim = np.argmax(cumulative_variance >= threshold) + 1
    
    # Also compute effective dimension (participation ratio)
    effective_dim = (np.sum(explained_variance))**2 / np.sum(explained_variance**2)
    
    return {
        "m_pca_threshold": intrinsic_dim,
        "m_pca_effective": effective_dim,
        "explained_variance_ratio": explained_variance,
        "threshold": threshold
    }

def correlation_dimension(X, k_max=50, eps_range=None):
    """Estimate correlation dimension using box-counting method."""
    n_samples = X.shape[0]
    k_max = min(k_max, n_samples - 1)
    
    # Get k-nearest neighbor distances
    r = _pairwise_knn_distances(X, k_max)
    
    # Use range of epsilon values based on distance distribution
    if eps_range is None:
        r_flat = r.flatten()
        r_valid = r_flat[r_flat > 0]
        if len(r_valid) == 0:
            return {"m_correlation_dim": float('nan')}
        
        eps_min = np.percentile(r_valid, 1)
        eps_max = np.percentile(r_valid, 90)
        eps_range = np.logspace(np.log10(eps_min), np.log10(eps_max), 20)
    
    log_eps = []
    log_c = []
    
    for eps in eps_range:
        # Count pairs within distance eps
        count = np.sum(r <= eps)
        if count > 0:
            log_eps.append(np.log(eps))
            # Normalize by total possible pairs
            log_c.append(np.log(count / (n_samples * k_max)))
    
    if len(log_eps) < 5:
        return {"m_correlation_dim": float('nan')}
    
    log_eps = np.array(log_eps)
    log_c = np.array(log_c)
    
    # Linear regression to estimate slope
    X_matrix = np.vstack([log_eps, np.ones(len(log_eps))]).T
    try:
        coeffs = np.linalg.lstsq(X_matrix, log_c, rcond=None)[0]
        correlation_dim = float(coeffs[0])
    except:
        correlation_dim = float('nan')
    
    return {"m_correlation_dim": correlation_dim}

def mle_id_estimator(X, k=20, n_subsamples=10):
    """Hill estimator for intrinsic dimension (MLE-based)."""
    n_samples = X.shape[0]
    k = min(k, n_samples - 1)
    
    # Subsample for robustness
    subsample_size = min(n_samples, max(1000, n_samples // 5))
    estimates = []
    
    for _ in range(n_subsamples):
        if n_samples > subsample_size:
            idx = np.random.choice(n_samples, subsample_size, replace=False)
            X_sub = X[idx]
        else:
            X_sub = X
        
        # Get k-NN distances
        r = _pairwise_knn_distances(X_sub, k + 1)
        
        # Use distances to k-th neighbor normalized by distance to nearest neighbor
        r_k = r[:, k]
        r_1 = r[:, 0]
        
        # Filter valid points
        valid_mask = (r_k > 0) & (r_1 > 0) & np.isfinite(r_k / r_1)
        if np.sum(valid_mask) < 10:
            continue
            
        ratios = r_k[valid_mask] / r_1[valid_mask]
        log_ratios = np.log(ratios)
        
        # Remove outliers
        q75, q25 = np.percentile(log_ratios, [75, 25])
        iqr = q75 - q25
        outlier_mask = (log_ratios >= q25 - 1.5 * iqr) & (log_ratios <= q75 + 1.5 * iqr)
        robust_log_ratios = log_ratios[outlier_mask]
        
        if len(robust_log_ratios) < 10:
            continue
            
        # Hill estimator
        mean_log = np.mean(robust_log_ratios)
        if abs(mean_log) > 1e-12:
            estimate = np.log(k) / mean_log
            estimates.append(estimate)
    
    if len(estimates) == 0:
        return {"m_mle_hill": float('nan'), "estimates": []}
    
    estimates = np.array(estimates)
    
    # Remove outlier estimates
    q75, q25 = np.percentile(estimates, [75, 25])
    iqr = q75 - q25
    robust_mask = (estimates >= q25 - 1.5 * iqr) & (estimates <= q75 + 1.5 * iqr)
    robust_estimates = estimates[robust_mask]
    
    if len(robust_estimates) == 0:
        robust_estimates = estimates
    
    return {
        "m_mle_hill": float(np.median(robust_estimates)),
        "m_mle_hill_mean": float(np.mean(robust_estimates)),
        "estimates": robust_estimates.tolist()
    }

def load_embeddings_from_dist2vec():
    base_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "embeddings")
    
    # Load the main embeddings
    emb1 = load_npy(base_dir, "corpus_embeddings_nv-embed_scifact.npy")
    emb2 = load_npy(base_dir, "corpus_embeddings_openai_scifact.npy")
    
    # Load indices for unique embeddings
    ind_file_name = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    try:
        ind_emb1_unique = load_npy(ind_file_name, "ind1")
        ind_emb2_unique = load_npy(ind_file_name, "ind2")
        if ind_emb1_unique is not None:
            emb1_unique = emb1[ind_emb1_unique]
        else:
            emb1_unique = emb1
        if ind_emb2_unique is not None:
            emb2_unique = emb2[ind_emb2_unique]
        else:
            emb2_unique = emb2
    except:
        print("Could not load indices, using all embeddings")
        emb1_unique = emb1
        emb2_unique = emb2
    
    return emb1_unique, emb2_unique

def calculate_intrinsic_dimensions():
    print("Loading embeddings...")
    emb1_unique, emb2_unique = load_embeddings_from_dist2vec()
    
    print(f"Embedding 1 shape: {emb1_unique.shape}")
    print(f"Embedding 2 shape: {emb2_unique.shape}")
    
    # Calculate all methods for embedding 1
    print("\nCalculating intrinsic dimension for embedding 1...")
    result1_lb = levina_bickel_mle(emb1_unique)
    result1_twonn = twonn(emb1_unique)
    result1_pca = pca_intrinsic_dim(emb1_unique)
    result1_corr = correlation_dimension(emb1_unique)
    result1_hill = mle_id_estimator(emb1_unique)
    
    print("Results for Embedding 1:")
    print(f"  Levina-Bickel robust avg: {result1_lb.get('m_robust_avg', 'nan'):.2f}")
    print(f"  Levina-Bickel weighted avg: {result1_lb.get('m_weighted_avg', 'nan'):.2f}")
    print(f"  Levina-Bickel median: {result1_lb.get('m_median', 'nan'):.2f}")
    print(f"  TwoNN MLE: {result1_twonn['m_twonn_mle']:.2f}")
    print(f"  TwoNN CDF slope: {result1_twonn['m_twonn_cdf_slope']:.2f}")
    print(f"  PCA threshold (95%): {result1_pca['m_pca_threshold']:.2f}")
    print(f"  PCA effective dim: {result1_pca['m_pca_effective']:.2f}")
    print(f"  Correlation dimension: {result1_corr['m_correlation_dim']:.2f}")
    print(f"  Hill estimator: {result1_hill['m_mle_hill']:.2f}")
    
    # Calculate all methods for embedding 2
    print("\nCalculating intrinsic dimension for embedding 2...")
    result2_lb = levina_bickel_mle(emb2_unique)
    result2_twonn = twonn(emb2_unique)
    result2_pca = pca_intrinsic_dim(emb2_unique)
    result2_corr = correlation_dimension(emb2_unique)
    result2_hill = mle_id_estimator(emb2_unique)
    
    print("Results for Embedding 2:")
    print(f"  Levina-Bickel robust avg: {result2_lb.get('m_robust_avg', 'nan'):.2f}")
    print(f"  Levina-Bickel weighted avg: {result2_lb.get('m_weighted_avg', 'nan'):.2f}")
    print(f"  Levina-Bickel median: {result2_lb.get('m_median', 'nan'):.2f}")
    print(f"  TwoNN MLE: {result2_twonn['m_twonn_mle']:.2f}")
    print(f"  TwoNN CDF slope: {result2_twonn['m_twonn_cdf_slope']:.2f}")
    print(f"  PCA threshold (95%): {result2_pca['m_pca_threshold']:.2f}")
    print(f"  PCA effective dim: {result2_pca['m_pca_effective']:.2f}")
    print(f"  Correlation dimension: {result2_corr['m_correlation_dim']:.2f}")
    print(f"  Hill estimator: {result2_hill['m_mle_hill']:.2f}")
    
    # Consensus estimates
    def get_consensus_estimate(results_dict, methods_to_use=None):
        if methods_to_use is None:
            methods_to_use = ['m_robust_avg', 'm_twonn_mle', 'm_pca_effective', 'm_mle_hill']
        
        estimates = []
        for method_key in methods_to_use:
            for result_dict in results_dict.values():
                if method_key in result_dict and not np.isnan(result_dict[method_key]):
                    estimates.append(result_dict[method_key])
        
        if len(estimates) == 0:
            return float('nan')
        
        estimates = np.array(estimates)
        # Remove outliers
        q75, q25 = np.percentile(estimates, [75, 25])
        iqr = q75 - q25
        if iqr > 0:
            robust_mask = (estimates >= q25 - 1.5 * iqr) & (estimates <= q75 + 1.5 * iqr)
            robust_estimates = estimates[robust_mask]
            if len(robust_estimates) > 0:
                return float(np.median(robust_estimates))
        
        return float(np.median(estimates))
    
    # Calculate consensus estimates
    emb1_results = {'lb': result1_lb, 'twonn': result1_twonn, 'pca': result1_pca, 'hill': result1_hill}
    emb2_results = {'lb': result2_lb, 'twonn': result2_twonn, 'pca': result2_pca, 'hill': result2_hill}
    
    consensus1 = get_consensus_estimate(emb1_results)
    consensus2 = get_consensus_estimate(emb2_results)
    
    print(f"\nConsensus Estimates:")
    print(f"  Embedding 1: {consensus1:.2f}")
    print(f"  Embedding 2: {consensus2:.2f}")
    
    return {
        'emb1': {
            'levina_bickel': result1_lb, 
            'twonn': result1_twonn,
            'pca': result1_pca,
            'correlation': result1_corr,
            'hill': result1_hill,
            'consensus': consensus1
        },
        'emb2': {
            'levina_bickel': result2_lb, 
            'twonn': result2_twonn,
            'pca': result2_pca,
            'correlation': result2_corr,
            'hill': result2_hill,
            'consensus': consensus2
        }
    }

def calculate_single_embedding_intrinsic_dim(X, methods=['levina_bickel', 'twonn', 'pca', 'hill']):
    """Calculate intrinsic dimension of a single embedding matrix using multiple methods."""
    print(f"Input embedding shape: {X.shape}")
    results = {}
    
    if 'levina_bickel' in methods:
        print("Computing Levina-Bickel MLE...")
        try:
            lb_result = levina_bickel_mle(X)
            results['levina_bickel'] = lb_result
            print(f"  Robust avg: {lb_result.get('m_robust_avg', 'nan'):.2f}")
            print(f"  Weighted avg: {lb_result.get('m_weighted_avg', 'nan'):.2f}")
            print(f"  Median: {lb_result.get('m_median', 'nan'):.2f}")
        except Exception as e:
            print(f"  Error in Levina-Bickel: {e}")
            results['levina_bickel'] = {'error': str(e)}
    
    if 'twonn' in methods:
        print("Computing TwoNN...")
        try:
            twonn_result = twonn(X)
            results['twonn'] = twonn_result
            print(f"  MLE: {twonn_result['m_twonn_mle']:.2f}")
            print(f"  CDF slope: {twonn_result['m_twonn_cdf_slope']:.2f}")
        except Exception as e:
            print(f"  Error in TwoNN: {e}")
            results['twonn'] = {'error': str(e)}
    
    if 'pca' in methods:
        print("Computing PCA intrinsic dimension...")
        try:
            pca_result = pca_intrinsic_dim(X)
            results['pca'] = pca_result
            print(f"  PCA threshold (95%): {pca_result['m_pca_threshold']:.2f}")
            print(f"  PCA effective dim: {pca_result['m_pca_effective']:.2f}")
        except Exception as e:
            print(f"  Error in PCA: {e}")
            results['pca'] = {'error': str(e)}
    
    if 'hill' in methods:
        print("Computing Hill estimator...")
        try:
            hill_result = mle_id_estimator(X)
            results['hill'] = hill_result
            print(f"  Hill estimator: {hill_result['m_mle_hill']:.2f}")
        except Exception as e:
            print(f"  Error in Hill: {e}")
            results['hill'] = {'error': str(e)}
    
    # Compute consensus from successful methods
    estimates = []
    method_names = []
    
    for method_name, result_dict in results.items():
        if 'error' not in result_dict:
            if method_name == 'levina_bickel' and 'm_robust_avg' in result_dict:
                estimates.append(result_dict['m_robust_avg'])
                method_names.append('levina_bickel_robust')
            elif method_name == 'twonn' and 'm_twonn_mle' in result_dict:
                estimates.append(result_dict['m_twonn_mle'])
                method_names.append('twonn_mle')
            elif method_name == 'pca' and 'm_pca_effective' in result_dict:
                estimates.append(result_dict['m_pca_effective'])
                method_names.append('pca_effective')
            elif method_name == 'hill' and 'm_mle_hill' in result_dict:
                estimates.append(result_dict['m_mle_hill'])
                method_names.append('hill_mle')
    
    if len(estimates) > 0:
        estimates = np.array(estimates)
        # Remove outliers
        q75, q25 = np.percentile(estimates, [75, 25])
        iqr = q75 - q25
        if iqr > 0:
            robust_mask = (estimates >= q25 - 1.5 * iqr) & (estimates <= q75 + 1.5 * iqr)
            robust_estimates = estimates[robust_mask]
            if len(robust_estimates) > 0:
                consensus = float(np.median(robust_estimates))
            else:
                consensus = float(np.median(estimates))
        else:
            consensus = float(np.median(estimates))
        
        print(f"\nConsensus estimate: {consensus:.2f}")
        print(f"Used methods: {method_names}")
    else:
        consensus = float('nan')
        print("\nNo valid estimates obtained")
    
    results['consensus'] = consensus
    return results

if __name__ == "__main__":
    # Test with individual embeddings to avoid dimension mismatch
    print("Loading embeddings...")
    emb1_unique, emb2_unique = load_embeddings_from_dist2vec()
    
    print("\n" + "="*60)
    print("EMBEDDING 1 (NV-Embed)")
    print("="*60)
    result1 = calculate_single_embedding_intrinsic_dim(emb1_unique)
    
    print("\n" + "="*60)
    print("EMBEDDING 2 (OpenAI)")
    print("="*60)
    result2 = calculate_single_embedding_intrinsic_dim(emb2_unique)
    
    print("\n" + "="*60)
    print("SUMMARY")
    print("="*60)
    print(f"Embedding 1 consensus: {result1['consensus']:.2f}")
    print(f"Embedding 2 consensus: {result2['consensus']:.2f}")
    
    results = {'emb1': result1, 'emb2': result2}