#!/usr/bin/env python3
"""
Script to analyze datasets and compute statistics on fvecs files.
Collects train/test pairs, computes L2 norms, distances, and statistical measures.
Generates histograms of distance distributions for each dataset.
"""

import os
import numpy as np
import pandas as pd
from scipy import stats
import struct
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

# Set style for better-looking plots
plt.style.use('default')
sns.set_palette("husl")

def count_vectors_in_fvecs(filename):
    """
    Count the total number of vectors in an fvecs file without loading them into memory.
    """
    count = 0
    with open(filename, 'rb') as f:
        while True:
            # Read dimension
            dim_bytes = f.read(4)
            if not dim_bytes:
                break
            dim = struct.unpack('<i', dim_bytes)[0]
            
            # Skip vector data
            vector_bytes = f.read(dim * 4)
            if len(vector_bytes) != dim * 4:
                break
            count += 1
    
    return count

def read_fvecs(filename, max_vectors=None):
    """
    Read fvecs file format and return numpy array.
    Only reads the first 20% of vectors or up to 5,000,000 vectors, whichever comes first.
    fvecs format: each vector is preceded by its dimension (4 bytes int) then the vector data (4 bytes float each)
    """
    # First, count total vectors to determine how many to read
    print(f"    Counting vectors in {os.path.basename(filename)}...")
    total_vectors = count_vectors_in_fvecs(filename)
    print(f"    Total vectors in file: {total_vectors}")
    
    # Determine how many vectors to read
    twenty_percent = int(total_vectors * 0.2)
    max_limit = 5_000_000
    vectors_to_read = min(twenty_percent, max_limit, total_vectors)
    
    if max_vectors is not None:
        vectors_to_read = min(vectors_to_read, max_vectors)
    
    print(f"    Reading first {vectors_to_read} vectors ({vectors_to_read/total_vectors*100:.1f}% of total)")
    
    vectors = []
    with open(filename, 'rb') as f:
        for i in range(vectors_to_read):
            # Read dimension
            dim_bytes = f.read(4)
            if not dim_bytes:
                break
            dim = struct.unpack('<i', dim_bytes)[0]
            
            # Read vector data
            vector_bytes = f.read(dim * 4)
            if len(vector_bytes) != dim * 4:
                break
            vector = struct.unpack('<' + 'f' * dim, vector_bytes)
            vectors.append(vector)
            
            # Progress indicator for large files
            if (i + 1) % 100000 == 0:
                print(f"      Loaded {i + 1}/{vectors_to_read} vectors")
    
    return np.array(vectors, dtype=np.float32)

def find_train_test_pairs(datasets_dir):
    """
    Find all train/test pairs in the datasets directory.
    """
    train_test_pairs = []
    datasets_dir = Path(datasets_dir)
    
    # Walk through all subdirectories
    for root, dirs, files in os.walk(datasets_dir):
        for f in files:
            if 'base' in f and (f.endswith('.fvec') or f.endswith('.fvecs')):
                train_path = os.path.join(root, f)
                
                # Create test filename by replacing '_train_' with '_test_'
                # Also try other common patterns
                test_patterns = [
                    f.replace('base', 'query'),
                ]
                
                test_path = None
                for pattern in test_patterns:
                    potential_test_path = os.path.join(root, pattern)
                    if os.path.exists(potential_test_path):
                        test_path = potential_test_path
                        break
                
                if test_path:
                    # Extract dataset name from the path
                    dataset_name = f.split('_')[0]
                    if not dataset_name:  # If root is the datasets dir itself
                        dataset_name = f.split('_')[0]
                    
                    train_test_pairs.append((dataset_name, train_path, test_path))
                    print(f"Found train/test pair for {dataset_name}:")
                    print(f"  Train: {train_path}")
                    print(f"  Test: {test_path}")
                else:
                    print(f"Warning: Train file {train_path} found but no corresponding test file")
    
    print(f"\nFound {len(train_test_pairs)} train/test pairs")
    return train_test_pairs

def plot_distance_histogram(distances, dataset_name, output_dir="plots"):
    """
    Create and save a histogram of the distance distribution.
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Create figure with subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot 1: Regular histogram
    ax1.hist(distances, bins=50, density=True, alpha=0.7, color='skyblue', edgecolor='black')
    ax1.set_xlabel('L2 Distance')
    ax1.set_ylabel('Density')
    ax1.set_title(f'{dataset_name} - Distance Distribution')
    ax1.grid(True, alpha=0.3)
    
    # Add statistics text
    mean_dist = np.mean(distances)
    std_dist = np.std(distances)
    median_dist = np.median(distances)
    ax1.axvline(mean_dist, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_dist:.2f}')
    ax1.axvline(median_dist, color='green', linestyle='--', linewidth=2, label=f'Median: {median_dist:.2f}')
    ax1.legend()
    
    # Plot 2: Q-Q plot for normality check
    stats.probplot(distances, dist="norm", plot=ax2)
    ax2.set_title(f'{dataset_name} - Q-Q Plot (Normal Distribution)')
    ax2.grid(True, alpha=0.3)
    
    # Add overall statistics text
    fig.suptitle(f'Dataset: {dataset_name}\n'
                 f'Mean: {mean_dist:.3f}, Std: {std_dist:.3f}, '
                 f'Min: {np.min(distances):.3f}, Max: {np.max(distances):.3f}\n'
                 f'Sample size: {len(distances):,} distances', 
                 fontsize=12, y=0.98)
    
    plt.tight_layout()
    
    # Save the plot
    filename = f"{dataset_name}_distance_distribution.png"
    filepath = os.path.join(output_dir, filename)
    plt.savefig(filepath, dpi=300, bbox_inches='tight')
    print(f"    Saved histogram: {filepath}")
    
    # Close the figure to free memory
    plt.close(fig)
    
    return filepath

def compute_dataset_statistics(dataset_name, train_path, test_path, n_query_samples=100):
    """
    Compute statistics for a single dataset and generate histogram.
    """
    print(f"\nProcessing dataset: {dataset_name}")
    
    try:
        # Read the fvecs files
        print("  Reading train (base) file...")
        base_vectors = read_fvecs(train_path)
        print(f"  Base vectors shape: {base_vectors.shape}")
        
        print("  Reading test (query) file...")
        query_vectors = read_fvecs(test_path)
        print(f"  Query vectors shape: {query_vectors.shape}")
        
        # Compute maximum L2 norm in base set
        base_norms = np.linalg.norm(base_vectors, axis=1)
        max_base_norm = np.max(base_norms)
        print(f"  Maximum L2 norm in base set: {max_base_norm:.6f}")
        
        # Sample query points
        n_queries = min(n_query_samples, len(query_vectors))
        if len(query_vectors) > n_query_samples:
            query_indices = np.random.choice(len(query_vectors), n_query_samples, replace=False)
            sampled_queries = query_vectors[query_indices]
        else:
            sampled_queries = query_vectors
        print(f"  Using {len(sampled_queries)} query points")
        
        # Compute L2 distances from each query to all base points
        all_distances = []
        query_variances = []
        
        print("  Computing distances...")
        for i, query in enumerate(sampled_queries):
            # Compute L2 distances to all base points
            distances = np.linalg.norm(base_vectors - query, axis=1)
            all_distances.extend(distances)
            
            # Compute variance of distances for this query
            query_variance = np.var(distances)
            query_variances.append(query_variance)
            
            if (i + 1) % 10 == 0:
                print(f"    Processed {i + 1}/{len(sampled_queries)} queries")
        
        all_distances = np.array(all_distances)
        
        # # Plot histogram of distances
        # print("  Creating histogram...")
        # plot_path = plot_distance_histogram(all_distances, dataset_name)
        
        # Run normality test (Shapiro-Wilk for smaller samples, Anderson-Darling for larger)
        print("  Running normality test...")
        # Step 1: Normalize distances to [0, 1] for Beta distribution
        # (Beta requires support strictly on (0, 1))
        eps = 1e-9
        scaled = (all_distances - all_distances.min()) / (all_distances.max() - all_distances.min())
        scaled = np.clip(scaled, eps, 1 - eps)

        # Step 2: Fit Beta distribution (returns shape params a, b, loc, scale)
        a, b, loc, scale = stats.beta.fit(scaled, floc=0, fscale=1)

        # Step 3: Run KS test against the fitted Beta
        stat, p_value = stats.kstest(scaled, 'beta', args=(a, b, loc, scale))

        test_name = "Kolmogorov-Smirnov (Beta fit)"
        
        print(f"  {test_name} test p-value: {p_value:.2e}")
        
        # Compute average variance across queries
        avg_query_variance = np.mean(query_variances)
        print(f"  Average query variance: {avg_query_variance:.6f}")
        
        # Additional statistics for the results
        distance_stats = {
            'mean_distance': np.mean(all_distances),
            'std_distance': np.std(all_distances),
            'median_distance': np.median(all_distances),
            'min_distance': np.min(all_distances),
            'max_distance': np.max(all_distances),
            'total_distances': len(all_distances)
        }
        
        result = {
            'dataset_name': dataset_name,
            'max_base_l2_norm': max_base_norm,
            'normality_test_p_value': p_value,
            'avg_query_variance': avg_query_variance,
            # 'histogram_path': plot_path,
            **distance_stats
        }
        
        return result
        
    except Exception as e:
        print(f"  Error processing {dataset_name}: {str(e)}")
        import traceback
        traceback.print_exc()
        return None

def main():
    # Set random seed for reproducibility
    np.random.seed(42)
    
    # Find all train/test pairs
    datasets_dir = "/home/ubuntu/datasets"
    if not os.path.exists(datasets_dir):
        print(f"Error: Directory {datasets_dir} not found")
        return
    
    train_test_pairs = find_train_test_pairs(datasets_dir)
    
    if not train_test_pairs:
        print("No train/test pairs found!")
        return
    
    # Process each dataset
    results = []
    for dataset_name, train_path, test_path in train_test_pairs:
        result = compute_dataset_statistics(dataset_name, train_path, test_path)
        if result:
            results.append(result)
    
    # Create results DataFrame and save to CSV
    if results:
        df = pd.DataFrame(results)
        output_file = "dataset_analysis_results.csv"
        df.to_csv(output_file, index=False)
        print(f"\nResults saved to {output_file}")
        print("\nSummary:")
        # Display subset of columns for readability
        display_cols = ['dataset_name', 'max_base_l2_norm', 'normality_test_p_value', 
                       'avg_query_variance', 'mean_distance', 'std_distance']
        print(df[display_cols].to_string(index=False, float_format='%.6f'))
        
        print(f"\nHistograms saved in 'plots/' directory:")
        for result in results:
            if 'histogram_path' in result:
                print(f"  - {result['histogram_path']}")
    else:
        print("No results to save!")

if __name__ == "__main__":
    main()
