#!/usr/bin/env python3
"""
Script to evaluate all transformed datasets and calculate evaluation metrics.

This script reads original datasets (.fvec) and transformed datasets (.fvecs) and calculates:
- Distance correlation (when original data is available)
- Distance correlation with exp(-d) transformation (when original data is available)
- Energy compaction at 10%, 25%, and 50%
- KNN recall@10 over 10% sampled test points (when original data is available)

Outputs results to a CSV file for analysis.
"""

import os
import numpy as np
import pandas as pd
import struct
import itertools
from pathlib import Path
from typing import Tuple, Dict, List, Optional
import time

# Optional FAISS for fast KNN
try:
    import faiss  # type: ignore
    HAS_FAISS = True
except Exception:
    HAS_FAISS = False

def load_fvec(filename: str) -> np.ndarray:
    """
    Load .fvec file (ANN benchmark format with 8-byte header: num, dim).
    """
    with open(filename, 'rb') as f:
        header = f.read(8)
        if len(header) != 8:
            raise ValueError(f"Invalid file format: header too short")
        num_vectors, dim = struct.unpack('<II', header)
        if dim == 0 or dim > 1000000:
            raise ValueError(f"Invalid dimension: {dim} (file may be corrupted)")
        print(f"  Loading {num_vectors} vectors of dimension {dim}")
        vectors = []
        for i in range(num_vectors):
            vector_bytes = f.read(dim * 4)
            if len(vector_bytes) != dim * 4:
                raise ValueError(f"Incomplete vector {i}: expected {dim * 4} bytes, got {len(vector_bytes)}")
            vector = struct.unpack(f'<{dim}f', vector_bytes)
            vectors.append(vector)
            if (i + 1) % 10000 == 0:
                print(f"    Loaded {i + 1}/{num_vectors} vectors...")
    return np.array(vectors, dtype=np.float32)

def load_fvecs(filename: str) -> np.ndarray:
    """
    Load .fvecs file (per-vector 4-byte dim prefix).
    """
    vectors = []
    with open(filename, 'rb') as f:
        while True:
            dim_bytes = f.read(4)
            if not dim_bytes:
                break
            dim = struct.unpack('<I', dim_bytes)[0]
            vector_bytes = f.read(dim * 4)
            if len(vector_bytes) != dim * 4:
                break
            vector = struct.unpack('<%df' % dim, vector_bytes)
            vectors.append(vector)
    return np.array(vectors, dtype=np.float32)

def load_original_test(dataset: str, original_dir: str) -> Optional[np.ndarray]:
    """
    Robustly load original test vectors for a dataset.
    Tries .fvec with header, then .fvecs-style parsing.
    """
    candidate = os.path.join(original_dir, f"{dataset}_query.fvec")
    if not os.path.exists(candidate):
        return None
    # Try fvec header first
    try:
        print(f"  Trying .fvec header load: {candidate}")
        return load_fvec(candidate)
    except Exception as e1:
        print(f"  Fallback to fvecs-style parsing due to: {e1}")
        try:
            data = load_fvecs(candidate)
            if data.size == 0:
                raise ValueError("Parsed empty data from fvecs loader")
            print(f"  Loaded with fvecs-style: shape={data.shape}")
            return data
        except Exception as e2:
            print(f"  Failed to load original test vectors: {e2}")
            return None

def get_fvecs_shape(filename: str) -> Tuple[int, int]:
    """Get shape of .fvecs file without loading all data."""
    with open(filename, 'rb') as f:
        dim_bytes = f.read(4)
        if not dim_bytes:
            return (0, 0)
        dimensions = struct.unpack('<I', dim_bytes)[0]
        file_size = os.path.getsize(filename)
        bytes_per_vector = 4 + (dimensions * 4)
        n_vectors = file_size // bytes_per_vector
        return (n_vectors, dimensions)

def calculate_cumulative_energy(vectors: np.ndarray) -> np.ndarray:
    squared_vectors = vectors**2
    avg_energy_per_dim = np.mean(squared_vectors, axis=0)
    total_avg_energy = np.sum(avg_energy_per_dim)
    if total_avg_energy == 0:
        return np.zeros(vectors.shape[1])
    cumulative_energy = np.cumsum(avg_energy_per_dim) / total_avg_energy
    return cumulative_energy

def evaluate_distance_preservation(data_orig: np.ndarray, data_transformed: np.ndarray, 
                                 n_pairs_check: int = 1000) -> Dict:
    n_samples = min(data_orig.shape[0], data_transformed.shape[0])
    orig_dim = data_orig.shape[1]
    trans_dim = data_transformed.shape[1]
    if orig_dim != trans_dim:
        if trans_dim > orig_dim:
            padding = np.zeros((data_orig.shape[0], trans_dim - orig_dim), dtype=np.float32)
            data_orig = np.concatenate((data_orig, padding), axis=1)
        else:
            data_orig = data_orig[:, :trans_dim]
    indices_to_check = list(itertools.combinations(range(n_samples), 2))
    if len(indices_to_check) > n_pairs_check:
        sample_indices = np.random.choice(len(indices_to_check), n_pairs_check, replace=False)
        indices_to_check = [indices_to_check[i] for i in sample_indices]
    original_distances = []
    transformed_distances = []
    for i, j in indices_to_check:
        dist_orig = np.linalg.norm(data_orig[i] - data_orig[j])
        dist_transformed = np.linalg.norm(data_transformed[i] - data_transformed[j])
        if dist_orig > 1e-9 or dist_transformed > 1e-9:
            original_distances.append(dist_orig)
            transformed_distances.append(dist_transformed)
    if not original_distances:
        return {'mean_abs_diff': np.nan, 'max_abs_diff': np.nan, 'corr_coef': np.nan}
    original_distances = np.array(original_distances)
    transformed_distances = np.array(transformed_distances)
    dist_diff = np.abs(original_distances - transformed_distances)
    corr_coef = np.corrcoef(original_distances, transformed_distances)[0, 1] if len(original_distances) > 1 else 1.0
    return {
        'mean_abs_diff': np.mean(dist_diff),
        'max_abs_diff': np.max(dist_diff),
        'corr_coef': corr_coef
    }

def evaluate_distance_preservation_exp(data_orig: np.ndarray, data_transformed: np.ndarray, 
                                     n_pairs_check: int = 1000) -> Dict:
    n_samples = min(data_orig.shape[0], data_transformed.shape[0])
    orig_dim = data_orig.shape[1]
    trans_dim = data_transformed.shape[1]
    if orig_dim != trans_dim:
        if trans_dim > orig_dim:
            padding = np.zeros((data_orig.shape[0], trans_dim - orig_dim), dtype=np.float32)
            data_orig = np.concatenate((data_orig, padding), axis=1)
        else:
            data_orig = data_orig[:, :trans_dim]
    indices_to_check = list(itertools.combinations(range(n_samples), 2))
    if len(indices_to_check) > n_pairs_check:
        sample_indices = np.random.choice(len(indices_to_check), n_pairs_check, replace=False)
        indices_to_check = [indices_to_check[i] for i in sample_indices]
    exp_original_distances = []
    exp_transformed_distances = []
    for i, j in indices_to_check:
        dist_orig = np.linalg.norm(data_orig[i] - data_orig[j])
        dist_transformed = np.linalg.norm(data_transformed[i] - data_transformed[j])
        if dist_orig > 1e-9 or dist_transformed > 1e-9:
            exp_original_distances.append(np.exp(-dist_orig))
            exp_transformed_distances.append(np.exp(-dist_transformed))
    if not exp_original_distances:
        return {'mean_abs_diff': np.nan, 'max_abs_diff': np.nan, 'corr_coef': np.nan}
    exp_original_distances = np.array(exp_original_distances)
    exp_transformed_distances = np.array(exp_transformed_distances)
    dist_diff = np.abs(exp_original_distances - exp_transformed_distances)
    # Zero-variance guardrail: if either vector is (nearly) constant, avoid NaN from corrcoef
    std_orig = float(np.std(exp_original_distances))
    std_trans = float(np.std(exp_transformed_distances))
    if len(exp_original_distances) <= 1:
        corr_coef = 1.0
    elif std_orig < 1e-12 or std_trans < 1e-12:
        # If they are (nearly) identical after transform, treat as perfect correlation; else 0.0
        corr_coef = 1.0 if float(np.mean(dist_diff)) < 1e-9 else 0.0
    else:
        corr_coef = float(np.corrcoef(exp_original_distances, exp_transformed_distances)[0, 1])
    return {
        'mean_abs_diff': float(np.mean(dist_diff)),
        'max_abs_diff': float(np.max(dist_diff)),
        'corr_coef': corr_coef
    }

def compute_recall_at_k(original: np.ndarray, transformed: np.ndarray, sample_fraction: float = 0.1, k: int = 10) -> float:
    """
    Compute average recall@k comparing kNN in original vs transformed spaces on the test set.
    Queries are a random 10% subset of test points; search set is the full test set.
    Self-matches are excluded from both neighbor lists.
    """
    n = min(original.shape[0], transformed.shape[0])
    original = original[:n]
    transformed = transformed[:n]
    num_queries = max(1, int(n * sample_fraction))
    rng = np.random.default_rng(42)
    query_idx = rng.choice(n, size=num_queries, replace=False)

    # Build indices
    if HAS_FAISS:
        d_orig = original.shape[1]
        d_trans = transformed.shape[1]
        index_orig = faiss.IndexFlatL2(d_orig)
        index_trans = faiss.IndexFlatL2(d_trans)
        index_orig.add(original.astype(np.float32))
        index_trans.add(transformed.astype(np.float32))
        D_orig, I_orig = index_orig.search(original[query_idx].astype(np.float32), k + 1)
        D_trans, I_trans = index_trans.search(transformed[query_idx].astype(np.float32), k + 1)
        # Remove self if present (first neighbor should be query itself)
        def strip_self(I):
            out = []
            for qi, row in enumerate(I):
                filt = [idx for idx in row if idx != query_idx[qi]]
                out.append(filt[:k])
            return np.array(out, dtype=np.int64)
        I_orig_k = strip_self(I_orig)
        I_trans_k = strip_self(I_trans)
    else:
        # Numpy fallback (compute distances in batches)
        def knn_numpy(data: np.ndarray, queries: np.ndarray, topk: int) -> np.ndarray:
            batch = 128
            all_idx = []
            for start in range(0, queries.shape[0], batch):
                q = queries[start:start+batch]
                # squared distances: (q - x)^2 = q^2 + x^2 - 2 q·x
                q2 = np.sum(q*q, axis=1, keepdims=True)
                x2 = np.sum(data*data, axis=1, keepdims=True).T
                dists = q2 + x2 - 2*np.dot(q, data.T)
                idx = np.argpartition(dists, kth=topk, axis=1)[:, :topk]
                # sort the topk
                part = np.take_along_axis(dists, idx, axis=1)
                order = np.argsort(part, axis=1)
                idx = np.take_along_axis(idx, order, axis=1)
                all_idx.append(idx)
            return np.vstack(all_idx)
        I_orig_k = knn_numpy(original, original[query_idx], k + 1)
        I_trans_k = knn_numpy(transformed, transformed[query_idx], k + 1)
        # remove self
        cleaned = []
        for qi, row in enumerate(I_orig_k):
            filt = [idx for idx in row if idx != query_idx[qi]]
            cleaned.append(filt[:k])
        I_orig_k = np.array(cleaned, dtype=np.int64)
        cleaned = []
        for qi, row in enumerate(I_trans_k):
            filt = [idx for idx in row if idx != query_idx[qi]]
            cleaned.append(filt[:k])
        I_trans_k = np.array(cleaned, dtype=np.int64)

    # Compute recall per query
    recalls = []
    for qi in range(I_orig_k.shape[0]):
        s_true = set(I_orig_k[qi].tolist())
        s_pred = set(I_trans_k[qi].tolist())
        inter = len(s_true & s_pred)
        recalls.append(inter / float(k))
    return float(np.mean(recalls)) if recalls else np.nan

def get_dataset_info() -> List[Tuple[str, str, str]]:
    datasets = [
        'yorck', 'sift100m', 'nytimes', 'glove50d', 'glove300d', 
        'glove2m300', 'glove100d', 'fashionmnist', 'cifar10', 'gist1m'
    ]
    methods = ['Cayley_Transform']
    # , 'Wavelet_Triplet', 'Orthogonal_Wavelet_Triplet']
    combinations = []
    for dataset in datasets:
        for method in methods:
            combinations.append((dataset, method, f"{dataset}_test_{method}_*.fvecs"))
    return combinations

def find_transformed_file(dataset: str, method: str, transformed_dir: str) -> str:
    pattern = f"{dataset}_test_{method}_*.fvecs"
    matching_files = []
    for file in os.listdir(transformed_dir):
        if file.startswith(f"{dataset}_test_{method}_") and file.endswith('.fvecs'):
            matching_files.append(file)
    if not matching_files:
        raise FileNotFoundError(f"No transformed file found for {dataset} with method {method}")
    matching_files.sort()
    return os.path.join(transformed_dir, matching_files[-1])

def evaluate_single_dataset(dataset: str, method: str, 
                          original_dir: str, transformed_dir: str) -> Dict:
    print(f"\nEvaluating {dataset} with {method}...")
    try:
        transformed_file = find_transformed_file(dataset, method, transformed_dir)
        print(f"  Loading transformed data from {transformed_file}")
        transformed_data = load_fvecs(transformed_file)
        original_data = load_original_test(dataset, original_dir)
        original_dim = None if original_data is None else original_data.shape[1]
        n_samples = transformed_data.shape[0]
        if original_data is not None:
            n_samples = min(original_data.shape[0], n_samples)
            original_data = original_data[:n_samples]
        transformed_data = transformed_data[:n_samples]
        print(f"  Using {n_samples} samples for evaluation")
        results = {
            'dataset': dataset,
            'method': method,
            'original_dim': original_dim,
            'transformed_dim': transformed_data.shape[1],
            'n_samples': n_samples
        }
        if original_data is not None:
            print("  Calculating distance preservation...")
            dist_metrics = evaluate_distance_preservation(original_data, transformed_data)
            results['distance_correlation'] = dist_metrics['corr_coef']
            print("  Calculating distance preservation (exp(-d))...")
            dist_exp_metrics = evaluate_distance_preservation_exp(original_data, transformed_data)
            results['distance_correlation_exp'] = dist_exp_metrics['corr_coef']
            print("  Calculating KNN recall@10 on 10% test samples...")
            recall_at_10 = compute_recall_at_k(original_data, transformed_data, sample_fraction=0.10, k=10)
            results['knn_recall_k10'] = recall_at_10
        else:
            results['distance_correlation'] = np.nan
            results['distance_correlation_exp'] = np.nan
            results['knn_recall_k10'] = np.nan
        print("  Calculating energy compaction...")
        energy = calculate_cumulative_energy(transformed_data)
        n_dims = len(energy)
        energy_10pct = energy[int(0.10 * n_dims) - 1] if n_dims >= 10 else np.nan
        energy_25pct = energy[int(0.25 * n_dims) - 1] if n_dims >= 4 else np.nan
        energy_50pct = energy[int(0.50 * n_dims) - 1] if n_dims >= 2 else np.nan
        results.update({
            'energy_10pct': energy_10pct,
            'energy_25pct': energy_25pct,
            'energy_50pct': energy_50pct
        })
        print(f"  ✓ Completed evaluation")
        return results
    except Exception as e:
        print(f"  ✗ Error evaluating {dataset} with {method}: {str(e)}")
        return None

def main():
    print("=" * 80)
    print("EVALUATING ALL TRANSFORMED DATASETS")
    print("=" * 80)
    original_dir = "/home/ubuntu/temp_datasets/fvecs"
    transformed_dir = "/home/ubuntu/new_linear"
    if not os.path.exists(original_dir):
        print(f"Error: Original datasets directory not found: {original_dir}")
        return
    if not os.path.exists(transformed_dir):
        print(f"Error: Transformed datasets directory not found: {transformed_dir}")
        return
    combinations = get_dataset_info()
    print(f"Found {len(combinations)} dataset-method combinations to evaluate")
    all_results = []
    successful_evaluations = 0
    for dataset, method, pattern in combinations:
        result = evaluate_single_dataset(dataset, method, original_dir, transformed_dir)
        if result is not None:
            all_results.append(result)
            successful_evaluations += 1
    print(f"\n" + "=" * 80)
    print(f"EVALUATION COMPLETED")
    print(f"Successfully evaluated: {successful_evaluations}/{len(combinations)} combinations")
    print("=" * 80)
    if not all_results:
        print("No successful evaluations. Exiting.")
        return
    df = pd.DataFrame(all_results)
    column_order = [
        'dataset', 'method', 'distance_correlation', 'distance_correlation_exp', 'knn_recall_k10',
        'energy_10pct', 'energy_25pct', 'energy_50pct',
        'original_dim', 'transformed_dim', 'n_samples'
    ]
    df = df[column_order]
    output_file = "transformed_datasets_evaluation_results_nonlinear.csv"
    df.to_csv(output_file, index=False)
    print(f"\nResults saved to: {output_file}")
    print("\n" + "=" * 80)
    print("EVALUATION SUMMARY")
    print("=" * 80)
    print("\nResults by Method:")
    method_summary = df.groupby('method').agg({
        'distance_correlation': ['mean', 'std'],
        'distance_correlation_exp': ['mean', 'std'],
        'knn_recall_k10': ['mean', 'std'],
        'energy_10pct': ['mean', 'std'],
        'energy_25pct': ['mean', 'std'],
        'energy_50pct': ['mean', 'std']
    }).round(4)
    print(method_summary)
    print("\nResults by Dataset:")
    dataset_summary = df.groupby('dataset').agg({
        'distance_correlation': ['mean', 'std'],
        'distance_correlation_exp': ['mean', 'std'],
        'knn_recall_k10': ['mean', 'std'],
        'energy_10pct': ['mean', 'std'],
        'energy_25pct': ['mean', 'std'],
        'energy_50pct': ['mean', 'std']
    }).round(4)
    print(dataset_summary)
    print("\nBest Performing Methods:")
    metrics = ['distance_correlation', 'distance_correlation_exp', 'knn_recall_k10', 'energy_10pct', 'energy_25pct', 'energy_50pct']
    for metric in metrics:
        valid_data = df[df[metric].notna()]
        if len(valid_data) > 0:
            best_idx = valid_data[metric].idxmax()
            best_row = df.loc[best_idx]
            print(f"  {metric}: {best_row['method']} on {best_row['dataset']} = {best_row[metric]:.4f}")
        else:
            print(f"  {metric}: No valid data available")
    print(f"\nDetailed results available in: {output_file}")
    print("\n" + "=" * 80)
    print("ENERGY COMPACTION ANALYSIS")
    print("=" * 80)
    energy_metrics = ['energy_10pct', 'energy_25pct', 'energy_50pct']
    for metric in energy_metrics:
        valid_data = df[df[metric].notna()]
        if len(valid_data) > 0:
            print(f"\n{metric} comparison:")
            method_comparison = valid_data.groupby('method')[metric].agg(['mean', 'std']).round(4)
            print(method_comparison)
            best_method = valid_data.loc[valid_data[metric].idxmax()]
            print(f"  Best: {best_method['method']} on {best_method['dataset']} = {best_method[metric]:.4f}")

if __name__ == "__main__":
    main()
