#!/usr/bin/env python3
"""
Transform datasets using pre-trained linear transform models without training.

Takes a saved model from transform_linear_all_data.py and applies it to transform
dataset_base.fvec/fvecs and dataset_query.fvec/fvecs files.

Supports chunked processing to avoid OOM issues by processing a fraction at a time.
"""

import os
import sys
import argparse
import logging
from datetime import datetime
from typing import Iterator, Tuple, Optional

import numpy as np
import struct
from tqdm import tqdm
import torch
import gc

from linear_transform import (
    SVDProjectionTransform,
    PenaltyTransform,
    ManifoldTransform,
    ExponentialMapTransform,
    CayleyTransform,
    GivensRotationTransform,
    HouseholderTransform,
)


def setup_logger() -> logging.Logger:
    """Setup basic logging."""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[logging.StreamHandler(sys.stdout)]
    )
    return logging.getLogger(__name__)


def get_fvecs_meta(file_path: str) -> Tuple[int, int]:
    """Return (dim, num_vectors) for an .fvecs file without loading all data."""
    file_size = os.path.getsize(file_path)
    with open(file_path, 'rb') as f:
        dim_bytes = f.read(4)
        if len(dim_bytes) != 4:
            raise ValueError(f"Invalid fvecs file: {file_path}")
        dim = struct.unpack('<I', dim_bytes)[0]
    record_size = 4 + 4 * dim
    if record_size <= 0:
        raise ValueError("Invalid record size computed for fvecs")
    if file_size % record_size != 0:
        # Some files may have trailing data; floor division as best effort
        num = file_size // record_size
    else:
        num = file_size // record_size
    return dim, int(num)


def read_fvecs_chunk(
    file_path: str,
    start_index: int,
    count: int,
) -> np.ndarray:
    """Read a chunk [start_index, start_index+count) from .fvecs into a numpy array."""
    dim, total = get_fvecs_meta(file_path)
    end_index = min(start_index + count, total)
    if start_index >= end_index:
        return np.empty((0, dim), dtype=np.float32)

    record_size = 4 + 4 * dim
    vectors = np.empty((end_index - start_index, dim), dtype=np.float32)
    with open(file_path, 'rb') as f:
        f.seek(start_index * record_size)
        for i in range(end_index - start_index):
            d = struct.unpack('<I', f.read(4))[0]
            if d != dim:
                raise ValueError(f"Dimension mismatch in {file_path}: expected {dim}, got {d}")
            vec = f.read(dim * 4)
            vectors[i] = np.frombuffer(vec, dtype=np.float32)
    return vectors


def write_fvecs_chunk(file_path: str, vectors: np.ndarray, append: bool = False) -> None:
    """Write vectors to .fvecs file."""
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    mode = 'ab' if append else 'wb'
    
    # Ensure contiguous float32
    if vectors.dtype != np.float32 or not vectors.flags['C_CONTIGUOUS']:
        vectors = np.ascontiguousarray(vectors, dtype=np.float32)
    
    dim = vectors.shape[1]
    dim_hdr = struct.pack('<I', dim)
    
    with open(file_path, mode) as f:
        for i in range(vectors.shape[0]):
            f.write(dim_hdr)
            f.write(vectors[i].tobytes(order='C'))


def load_model(model_path: str, device: torch.device) -> torch.nn.Module:
    """Load a trained linear transform model."""
    logger = logging.getLogger(__name__)
    logger.info(f"Loading model from {model_path}")
    
    model_data = torch.load(model_path, map_location='cpu')
    model_class_name = model_data['model_class']
    input_dim = model_data['input_dim']
    
    # Create model instance based on class name
    if model_class_name == 'SVDProjectionTransform':
        model = SVDProjectionTransform(input_dim)
    elif model_class_name == 'PenaltyTransform':
        model = PenaltyTransform(input_dim)
    elif model_class_name == 'ManifoldTransform':
        model = ManifoldTransform(input_dim)
    elif model_class_name == 'ExponentialMapTransform':
        model = ExponentialMapTransform(input_dim)
    elif model_class_name == 'CayleyTransform':
        model = CayleyTransform(input_dim)
    elif model_class_name == 'GivensRotationTransform':
        model = GivensRotationTransform(input_dim)
    elif model_class_name == 'HouseholderTransform':
        model = HouseholderTransform(input_dim)
    else:
        raise ValueError(f"Unknown model class: {model_class_name}")
    
    # Load state dict
    model.load_state_dict(model_data['model_state_dict'])
    model.to(device)
    model.eval()
    
    logger.info(f"Loaded {model_class_name} model with input_dim={input_dim}")
    return model


def transform_dataset_chunked(
    model: torch.nn.Module,
    input_path: str,
    output_path: str,
    chunk_fraction: float,
    device: torch.device,
    logger: logging.Logger,
) -> None:
    """Transform a dataset in chunks to avoid OOM."""
    logger.info(f"Transforming {input_path} -> {output_path} in {chunk_fraction*100:.1f}% chunks")
    
    dim, total_vectors = get_fvecs_meta(input_path)
    chunk_size = max(1, int(total_vectors * chunk_fraction))
    
    logger.info(f"Dataset: {total_vectors} vectors, dim={dim}")
    logger.info(f"Chunk size: {chunk_size} vectors ({chunk_fraction*100:.1f}%)")
    
    processed = 0
    chunk_idx = 0
    
    with tqdm(total=total_vectors, desc=f"Transforming {os.path.basename(input_path)}") as pbar:
        while processed < total_vectors:
            # Determine current chunk size
            current_chunk_size = min(chunk_size, total_vectors - processed)
            
            logger.info(f"Processing chunk {chunk_idx+1}: vectors {processed} to {processed + current_chunk_size - 1}")
            
            # Read chunk
            chunk_data = read_fvecs_chunk(input_path, processed, current_chunk_size)
            
            # Transform chunk
            with torch.no_grad():
                chunk_tensor = torch.from_numpy(chunk_data).float().to(device)
                transformed_tensor = model(chunk_tensor)
                transformed_data = transformed_tensor.cpu().numpy()
                
                # Free GPU memory immediately
                del chunk_tensor, transformed_tensor
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
            
            # Write chunk (append after first chunk)
            write_fvecs_chunk(output_path, transformed_data, append=(chunk_idx > 0))
            
            # Free memory
            del chunk_data, transformed_data
            gc.collect()
            
            processed += current_chunk_size
            chunk_idx += 1
            pbar.update(current_chunk_size)
    
    logger.info(f"Transformation complete: {processed} vectors processed")


def main():
    parser = argparse.ArgumentParser(description='Transform datasets using pre-trained linear transform models')
    parser.add_argument('--model-path', required=True, help='Path to saved model (.pth file)')
    parser.add_argument('--dataset-name', required=True, help='Name of dataset (e.g., "gist", "sift")')
    parser.add_argument('--data-dir', required=True, help='Directory containing dataset_base.fvec[s] and dataset_query.fvec[s]')
    parser.add_argument('--output-dir', required=True, help='Output directory for transformed files')
    parser.add_argument('--frac', type=float, default=0.1, help='Fraction of dataset to process at once (0.0-1.0, default: 0.1)')
    parser.add_argument('--device', choices=['auto', 'cpu', 'cuda'], default='auto', help='Device to use for inference')
    parser.add_argument('--verbose', '-v', action='store_true', help='Enable verbose logging')
    
    args = parser.parse_args()
    
    # Setup logging
    logger = setup_logger()
    if args.verbose:
        logging.getLogger().setLevel(logging.DEBUG)
    
    # Validate arguments
    if not (0.0 < args.frac <= 1.0):
        logger.error(f"frac must be between 0.0 and 1.0, got {args.frac}")
        sys.exit(1)
    
    if not os.path.exists(args.model_path):
        logger.error(f"Model path does not exist: {args.model_path}")
        sys.exit(1)
    
    if not os.path.exists(args.data_dir):
        logger.error(f"Data directory does not exist: {args.data_dir}")
        sys.exit(1)
    
    # Setup device
    if args.device == 'auto':
        if torch.cuda.is_available():
            device = torch.device('cuda')
            logger.info(f"Using CUDA: {torch.cuda.get_device_name()}")
        else:
            device = torch.device('cpu')
            logger.info("Using CPU")
    else:
        device = torch.device(args.device)
        logger.info(f"Using device: {device}")
    
    # Find dataset files
    base_extensions = ['.fvecs', '.fvec']
    base_path = None
    query_path = None
    
    for ext in base_extensions:
        base_candidate = os.path.join(args.data_dir, f"{args.dataset_name}_base{ext}")
        query_candidate = os.path.join(args.data_dir, f"{args.dataset_name}_query{ext}")
        
        if os.path.exists(base_candidate):
            base_path = base_candidate
        if os.path.exists(query_candidate):
            query_path = query_candidate
    
    if base_path is None:
        logger.error(f"Could not find base file for dataset '{args.dataset_name}' in {args.data_dir}")
        logger.error(f"Expected files: {args.dataset_name}_base.fvecs or {args.dataset_name}_base.fvec")
        sys.exit(1)
    
    if query_path is None:
        logger.error(f"Could not find query file for dataset '{args.dataset_name}' in {args.data_dir}")
        logger.error(f"Expected files: {args.dataset_name}_query.fvecs or {args.dataset_name}_query.fvec")
        sys.exit(1)
    
    logger.info(f"Found base file: {base_path}")
    logger.info(f"Found query file: {query_path}")
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Output file paths
    base_output = os.path.join(args.output_dir, f"{args.dataset_name}_train.fvecs")
    query_output = os.path.join(args.output_dir, f"{args.dataset_name}_test.fvecs")
    
    try:
        # Load model
        model = load_model(args.model_path, device)
        
        # Transform base (train) file
        logger.info("\n=== Transforming Base (Train) File ===")
        transform_dataset_chunked(model, base_path, base_output, args.frac, device, logger)
        
        # Transform query (test) file
        logger.info("\n=== Transforming Query (Test) File ===")
        transform_dataset_chunked(model, query_path, query_output, args.frac, device, logger)
        
        logger.info(f"\n=== Transformation Complete ===")
        logger.info(f"Base transformed: {base_path} -> {base_output}")
        logger.info(f"Query transformed: {query_path} -> {query_output}")
        
    except Exception as e:
        logger.error(f"Error during transformation: {e}")
        sys.exit(1)


if __name__ == '__main__':
    main()


'''
Example usage:

python3 transform_without_train_linear.py \
  --model-path /mnt/device/models/gist_SVD_Projection_20250903_025145.pth \
  --dataset-name gist \
  --data-dir /mnt/device/datasets \
  --output-dir /mnt/device/transformed_inference \
  --frac 0.1 \
  --device auto \
  --verbose
'''
