#!/usr/bin/env python3
"""
Simple benchmark runner for a specific dataset.
This script:
1. Takes a dataset name as input
2. Finds all transformed and original files for that dataset
3. Runs benchmarks with the same metrics as run_all_benchmarks.py
4. Saves results to CSV
"""

import os
import subprocess
import re
import pandas as pd
import argparse
from datetime import datetime

# Simplified index types - using the same as the original
INDEX_TYPES = ['ivfpq', 'ivfpq_pano']

def get_fvec_dimension_and_count(fvec_path):
    """Get dimension and vector count from a fvec file by reading the header."""
    try:
        with open(fvec_path, 'rb') as f:
            # Read first vector's dimension
            dim_bytes = f.read(4)
            if len(dim_bytes) != 4:
                return 0, 0
            dimension = int.from_bytes(dim_bytes, byteorder='little')
            
            # Calculate file size and estimate vector count
            f.seek(0, 2)  # Seek to end
            file_size = f.tell()
            
            # Each vector: 4 bytes (dimension) + dimension * 4 bytes (floats)
            bytes_per_vector = 4 + dimension * 4
            vector_count = file_size // bytes_per_vector
            
            return dimension, vector_count
    except Exception as e:
        print(f"Error reading fvec info from {fvec_path}: {e}")
        return 0, 0

def find_dataset_files(dataset_name, transformed_dir, original_dir):
    """
    Find all files for a specific dataset.
    
    Args:
        dataset_name: Name of the dataset (e.g., "sift")
        transformed_dir: Directory containing transformed files
        original_dir: Directory containing original files
    
    Returns:
        Dict with train_path, test_path, original_base_path, original_query_path
    """
    files = {
        'train_path': None,
        'test_path': None,
        'original_base_path': None,
        'original_query_path': None
    }
    
    # Find transformed files
    train_pattern = f"{dataset_name}_train_Cayley"
    test_pattern = f"{dataset_name}_test_Cayley"
    
    try:
        for f in os.listdir(transformed_dir):
            if f.startswith(train_pattern) and (f.endswith('.fvec') or f.endswith('.fvecs')):
                files['train_path'] = os.path.join(transformed_dir, f)
            elif f.startswith(test_pattern) and (f.endswith('.fvec') or f.endswith('.fvecs')):
                files['test_path'] = os.path.join(transformed_dir, f)
    except OSError as e:
        print(f"Error accessing transformed directory {transformed_dir}: {e}")
        return None
    
    # Find original files
    base_pattern = f"{dataset_name}_base"
    query_pattern = f"{dataset_name}_query"
    
    try:
        for f in os.listdir(original_dir):
            if f.startswith(base_pattern) and f.endswith('.fvec'):
                files['original_base_path'] = os.path.join(original_dir, f)
            elif f.startswith(query_pattern) and f.endswith('.fvec'):
                files['original_query_path'] = os.path.join(original_dir, f)
    except OSError as e:
        print(f"Error accessing original directory {original_dir}: {e}")
        return None
    
    # Check if all required files were found
    missing_files = [key for key, value in files.items() if value is None]
    if missing_files:
        print(f"Missing files for dataset '{dataset_name}': {missing_files}")
        return None
    
    return files

def run_benchmark(train_path, test_path, original_base_path, original_query_path, nb, nq, index_type, nlevels, M, epsilon, efSearch, work_dir):
    """Run a single benchmark test with appropriate file paths for each index type."""
    print(f"Running benchmark for {index_type} with nlevels={nlevels}, M={M}, epsilon={epsilon}")
    
    # For ivfpq_pano: use both original and transformed files
    # For ivfpq: use only original files (no trans paths)
    if 'pano' in index_type:
        print(f"  Original Base:     {original_base_path}")
        print(f"  Original Query:    {original_query_path}")
        print(f"  Transformed Train: {train_path}")
        print(f"  Transformed Test:  {test_path}")
        
        cmd = [
            'python3', 'faiss/perf_tests/bench_ivf.py',
            '--nb', str(nb),
            '--nq', str(nq),
            '--nprobe', '3',
            '--nlist', '10',
            '--num-threads', '1',
            '--nlevels', str(nlevels),
            '--M', str(M),
            '--epsilon', str(epsilon),
            '--csv-path', original_base_path,
            '--csv-test-path', original_query_path,
            '--trans-csv-path', train_path,
            '--trans-csv-test-path', test_path,
            '--index-type', index_type,
        ]
    else:
        print(f"  Original Base:     {original_base_path}")
        print(f"  Original Query:    {original_query_path}")
        print(f"  (No transformed files for non-pano index)")
        
        cmd = [
            'python3', 'faiss/perf_tests/bench_ivf.py',
            '--nb', str(nb),
            '--nq', str(nq),
            '--nprobe', '3',
            '--nlist', '10',
            '--num-threads', '1',
            '--nlevels', str(nlevels),
            '--M', str(M),
            '--epsilon', str(epsilon),
            '--csv-path', original_base_path,
            '--csv-test-path', original_query_path,
            '--index-type', index_type,
        ]
    
    try:
        
        print("---------------------------------------")
        print("Command to run manually:")
        print(" ".join(cmd))
        print("---------------------------------------")
        
        # Set PYTHONPATH to use locally built FAISS and Annoy
        env = os.environ.copy()
        local_faiss_path = work_dir + '/faiss/build/faiss/python/build/lib'
        local_annoy_path = work_dir + '/annoy/build/lib.linux-x86_64-cpython-312'
        env['PYTHONPATH'] = f"{local_annoy_path}:{local_faiss_path}:{env.get('PYTHONPATH', '')}"
        
        result = subprocess.run(cmd, capture_output=True, text=True, cwd=work_dir, env=env)
        
        # Create command string for logging
        cmd_string = " ".join(cmd)
        
        if result.returncode == 0:
            parsed_result = parse_benchmark_output(result.stdout, index_type, train_path, nb, nq, nlevels, epsilon, efSearch)
            if parsed_result:
                # Add command to the results for logging
                parsed_result['command'] = cmd_string
            return parsed_result
        else:
            print(f"Benchmark failed for {index_type} with nlevels={nlevels}: {result.stderr}")
            return None
            
    except Exception as e:
        print(f"Exception during benchmark for {index_type} with nlevels={nlevels}: {e}")
        return None

def parse_benchmark_output(output, index_type, file_path, nb, nq, nlevels, epsilon, efSearch):
    """Parse the benchmark output to extract key metrics (same as original)."""
    try:
        results = {
            'index_type': index_type,
            'file_path': file_path,
            'nb': nb,
            'nq': nq,
            'nlevels': nlevels,
            'epsilon': epsilon,
            'wall_time_ms': None,
            'qps': None,
            'queries': None,
            'recall': None,
            'avg_level_percent': None,
            'M': 32,
            'efSearch': efSearch,
        }
        
        # Parse wall time
        wall_time_match = re.search(r'Wall time: ([\d.]+) ms', output)
        if wall_time_match:
            results['wall_time_ms'] = float(wall_time_match.group(1))
        
        # Parse QPS
        qps_match = re.search(r'QPS: ([\d.]+)', output)
        if qps_match:
            results['qps'] = float(qps_match.group(1))
        
        # Parse queries
        queries_match = re.search(r'Queries: (\d+)', output)
        if queries_match:
            results['queries'] = int(queries_match.group(1))
        
        # Parse recall
        recall_match = re.search(r'Recall: ([\d.]+)', output)
        if recall_match:
            results['recall'] = float(recall_match.group(1))
        
        # Parse average level metrics for pano indexes
        if 'pano' in index_type:
            avg_level_match = re.search(r'avg_level %: ([\d.]+)%', output)
            if avg_level_match:
                results['avg_level_percent'] = float(avg_level_match.group(1))
            
            total_active_match = re.search(r'total_active: ([\d.]+)', output)
            if total_active_match:
                results['avg_level_percent'] = float(total_active_match.group(1)) * 100
            
            avg_level_active_match = re.search(r'Avg level active: ([\d.]+)', output)
            if avg_level_active_match:
                results['avg_level_percent'] = float(avg_level_active_match.group(1)) * 100
        
        return results
        
    except Exception as e:
        print(f"Error parsing benchmark output for {index_type} with nlevels={nlevels}: {e}")
        return None

def write_result_incremental(results_csv: str, result_row: dict) -> None:
    """Append a single result row to CSV, writing header if file does not exist."""
    try:
        row_df = pd.DataFrame([result_row])
        file_exists = os.path.exists(results_csv)
        row_df.to_csv(results_csv, mode='a', header=not file_exists, index=False)
    except Exception as exc:
        print(f"Warning: failed to append results to {results_csv}: {exc}")

def get_valid_m_values(dimension):
    """Get M values that are factors of dimension AND multiples of 8."""
    valid_m_values = []
    # Try multiples of 8 that also divide the dimension evenly
    for m in range(8, dimension + 1, 8):  # Start at 8, increment by 8
        if dimension % m == 0:
            valid_m_values.append(m)
    return valid_m_values

def main():
    # Generate timestamp for default filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    parser = argparse.ArgumentParser(description="Generate QPS vs Recall data by varying M parameter for IVFPQ")
    parser.add_argument("--dataset-name", required=True,
                       help="Name of the dataset to benchmark (e.g., 'sift', 'glove')")
    parser.add_argument("--transformed-dir", default="/mnt/device/linear", 
                       help="Directory containing transformed fvec files")
    parser.add_argument("--original-dir", default="/mnt/device/datasets", 
                       help="Directory containing original fvec files")
    parser.add_argument("--output-dir", default="/mnt/device/transformed/temp", 
                       help="Directory for output files")
    parser.add_argument("--results-csv", default=f"qps_recall_data_{timestamp}.csv", 
                       help="Output CSV file for QPS vs Recall data")
    parser.add_argument("--nq", type=int, default=1000, 
                       help="Number of query points for benchmarks")
    parser.add_argument("--work-dir", default="/home/name/panorama",
                       help="Working directory for running benchmark commands")
    
    args = parser.parse_args()
    
    print(f"Generating QPS vs Recall data for dataset: {args.dataset_name}")
    print(f"Results will be saved to: {args.results_csv}")
    
    # Find all files for the specified dataset
    dataset_files = find_dataset_files(args.dataset_name, args.transformed_dir, args.original_dir)
    if dataset_files is None:
        print(f"Could not find all required files for dataset '{args.dataset_name}'")
        return
    
    print(f"\nFound dataset files:")
    for key, path in dataset_files.items():
        print(f"  {key}: {path}")
    
    # Get dimension and vector count
    dimension, nb = get_fvec_dimension_and_count(dataset_files['train_path'])
    if dimension == 0 or nb == 0:
        print(f"Failed to read dimension/count from {dataset_files['train_path']}")
        return
    
    print(f"\nDataset info: {nb} vectors with dimension {dimension}")
    
    # Get valid M values that are factors of dimension AND multiples of 8
    valid_m_values = get_valid_m_values(dimension)
    print(f"Valid M values for dimension {dimension} (multiples of 8): {valid_m_values}")
    
    if not valid_m_values:
        print(f"No valid M values found for dimension {dimension}")
        return
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    all_results = []
    
    # Fixed parameters for QPS vs Recall study
    epsilon = 1.0  # Fixed  
    efSearch = 1  # Fixed
    
    # Run benchmarks
    print(f"\n{'='*80}")
    print(f"Generating QPS vs Recall data by varying M parameter")
    print(f"Fixed parameters: epsilon={epsilon}, nlist=10, nprobe=10")
    print(f"For ivfpq_pano: testing both 1 level and 8 levels")
    print(f"For ivfpq: testing 8 levels (fixed)")
    print(f"{'='*80}")
    
    for index_type in INDEX_TYPES:
        print(f"\n--- Testing {index_type} ---")
        
        # Determine which nlevels to test
        if 'pano' in index_type:
            nlevels_list = [1, 8]  # Test both 1 and 8 levels for panorama
        else:
            nlevels_list = [8]     # Only 8 levels for regular ivfpq
        
        for nlevels in nlevels_list:
            print(f"\n--- {index_type} with {nlevels} level(s) ---")
            
            for M in valid_m_values:
                print(f"\nBenchmarking {index_type} with nlevels={nlevels}, M={M} (bits per sub-vector: {8 if M <= 64 else 'N/A'})")
                
                result = run_benchmark(
                    dataset_files['train_path'], 
                    dataset_files['test_path'], 
                    dataset_files['original_base_path'], 
                    dataset_files['original_query_path'], 
                    nb, 
                    args.nq, 
                    index_type, 
                    nlevels, 
                    M, 
                    epsilon, 
                    efSearch, 
                    args.work_dir
                )
                
                if result:
                    # Add compression info for analysis
                    result['compression_ratio'] = dimension / M  # Higher M = more compression
                    result['bits_per_vector'] = M * 8 if M <= 64 else M * 16  # Estimate
                    
                    all_results.append(result)
                    write_result_incremental(args.results_csv, result)
                    print(f"✓ Success - nlevels={nlevels}, M={M}, QPS: {result.get('qps'):.1f}, Recall: {result.get('recall'):.3f}, Wall time: {result.get('wall_time_ms'):.1f}ms")
                else:
                    print(f"✗ Failed - nlevels={nlevels}, M={M}")
    
    # Print final summary
    print(f"\n{'='*80}")
    print(f"QPS vs Recall data generation complete for dataset: {args.dataset_name}")
    print(f"Total benchmark runs: {len(all_results)}")
    print(f"Results saved to: {args.results_csv}")
    print(f"{'='*80}")
    
    if all_results:
        results_df = pd.DataFrame(all_results)
        print("\nQPS vs Recall Summary:")
        print("=" * 60)
        
        for index_type in INDEX_TYPES:
            type_results = results_df[results_df['index_type'] == index_type]
            if len(type_results) > 0:
                print(f"\n{index_type.upper()}:")
                print(f"{'M':<4} {'QPS':<8} {'Recall':<8} {'Compression':<12} {'Bits/Vector':<12}")
                print("-" * 50)
                
                for _, row in type_results.iterrows():
                    print(f"{row['M']:<4} {row['qps']:<8.1f} {row['recall']:<8.3f} "
                          f"{row['compression_ratio']:<12.1f} {row['bits_per_vector']:<12.0f}")
        
        print(f"\nData for plotting saved to: {args.results_csv}")
        print("Columns: index_type, M, qps, recall, compression_ratio, bits_per_vector")
    else:
        print("No successful benchmark runs!")

if __name__ == "__main__":
    main()
