#!/usr/bin/env python3
"""
Streamed training and transformation for ANN .fvecs datasets.

- Discovers dataset pairs: <name>_base.fvec[s], <name>_query.fvec[s]
- Loads only fractions of base for training/eval to respect memory limits
- Trains specified wavelet models in mini-batches
- Saves models to /mnt/device/models as <dataset>_<method>_<timestamp>.pth
- Transforms FULL base (train) and query (test) in streamed batches
  and saves to /mnt/device/transformed as
  <dataset>_train_<method>_<timestamp>.fvecs and
  <dataset>_test_<method>_<timestamp>.fvecs

Memory budget: default 8GB. Uses dynamic batch sizing to stay within budget.
"""

import os
import sys
import math
import time
import argparse
import logging
from datetime import datetime
from typing import Iterator, Tuple, List, Optional, Dict

import numpy as np
import struct
from tqdm import tqdm

import torch
import gc

from config import get_default_config
from wavelet_trainer import (
    WaveletTransformTrainer,
    LearnableWaveletTransform,
    LearnableNLWaveletTransform,
    LearnableOrthogonalWaveletTransform,
    LossFunctions,
    EvaluationMetrics,
)


# ==========================
# Dataset allowlist
# ==========================
# Only datasets listed here (by base name, e.g., 'gist', 'sift', 'glove100d')
# will be processed. Leave empty to process all discovered datasets.
SELECTED_DATASETS: List[str] = [
    'cifar10',
    'deep',
    'fashionmnist',
    'glove50d',
    'glove100d',
    'glove300d',
    'nytimes',
    'sift',
    'yorck', 
    'gist',
    'glove2m300'
]


# ==========================
# Logging utilities
# ==========================
def setup_logger(log_dir: str, run_name: str) -> logging.Logger:
    os.makedirs(log_dir, exist_ok=True)
    logger = logging.getLogger(run_name)
    logger.setLevel(logging.INFO)
    logger.propagate = False

    # Clear existing handlers (idempotent for re-runs)
    for h in list(logger.handlers):
        logger.removeHandler(h)

    log_path = os.path.join(log_dir, f"{run_name}.log")
    fh = logging.FileHandler(log_path)
    fh.setLevel(logging.INFO)
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(logging.INFO)

    fmt = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    fh.setFormatter(fmt)
    ch.setFormatter(fmt)
    logger.addHandler(fh)
    logger.addHandler(ch)

    logger.info(f"Logging to {log_path}")
    return logger


# ==========================
# .fvecs streaming I/O
# ==========================
def get_fvecs_meta(file_path: str) -> Tuple[int, int]:
    """Return (dim, num_vectors) for an .fvecs file without loading all data."""
    file_size = os.path.getsize(file_path)
    with open(file_path, 'rb') as f:
        dim_bytes = f.read(4)
        if len(dim_bytes) != 4:
            raise ValueError(f"Invalid fvecs file: {file_path}")
        dim = struct.unpack('<I', dim_bytes)[0]
    record_size = 4 + 4 * dim
    if record_size <= 0:
        raise ValueError("Invalid record size computed for fvecs")
    if file_size % record_size != 0:
        # Some files may have trailing data; floor division as best effort
        num = file_size // record_size
    else:
        num = file_size // record_size
    return dim, int(num)


def read_fvecs_range(
    file_path: str,
    start_index: int,
    count: int,
) -> np.ndarray:
    """Read a range [start_index, start_index+count) from .fvecs into a numpy array.

    This allocates memory for count*dim floats. Use for modest ranges only.
    """
    dim, total = get_fvecs_meta(file_path)
    end_index = min(start_index + count, total)
    if start_index >= end_index:
        return np.empty((0, dim), dtype=np.float32)

    record_size = 4 + 4 * dim
    vectors = np.empty((end_index - start_index, dim), dtype=np.float32)
    with open(file_path, 'rb') as f:
        f.seek(start_index * record_size)
        for i in range(end_index - start_index):
            d = struct.unpack('<I', f.read(4))[0]
            if d != dim:
                raise ValueError(f"Dimension mismatch in {file_path}: expected {dim}, got {d}")
            vec = f.read(dim * 4)
            vectors[i] = np.frombuffer(vec, dtype=np.float32)
    return vectors


def stream_fvecs(
    file_path: str,
    batch_vectors: int,
    start_index: int = 0,
    max_vectors: int = -1,
) -> Iterator[np.ndarray]:
    """Yield numpy arrays of shape (n_batch, dim) from .fvecs without loading all at once."""
    dim, total = get_fvecs_meta(file_path)
    remaining = total - start_index
    if max_vectors > 0:
        remaining = min(remaining, max_vectors)

    record_size = 4 + 4 * dim
    with open(file_path, 'rb') as f:
        f.seek(start_index * record_size)
        produced = 0
        while produced < remaining:
            n = int(min(batch_vectors, remaining - produced))
            batch = np.empty((n, dim), dtype=np.float32)
            for i in range(n):
                d = struct.unpack('<I', f.read(4))[0]
                if d != dim:
                    raise ValueError(f"Dimension mismatch in {file_path}: expected {dim}, got {d}")
                vec = f.read(dim * 4)
                batch[i] = np.frombuffer(vec, dtype=np.float32)
            produced += n
            yield batch


def write_fvecs_stream(
    file_path: str,
    arrays_iter: Iterator[np.ndarray],
    expected_dim: Optional[int] = None,
) -> None:
    """Write arrays to .fvecs file, streaming per vector."""
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'wb') as f:
        for arr in arrays_iter:
            if arr.size == 0:
                continue
            if expected_dim is None:
                expected_dim = arr.shape[1]
            if arr.shape[1] != expected_dim:
                raise ValueError(f"Output dim mismatch: got {arr.shape[1]}, expected {expected_dim}")
            # Ensure contiguous float32 without copying if possible
            if arr.dtype != np.float32 or not arr.flags['C_CONTIGUOUS']:
                arr = np.ascontiguousarray(arr, dtype=np.float32)
            dim = expected_dim
            dim_hdr = struct.pack('<I', dim)
            for i in range(arr.shape[0]):
                f.write(dim_hdr)
                # Avoid redundant astype to reduce per-row allocations
                f.write(arr[i].tobytes(order='C'))
            # Encourage freeing the just-processed batch
            try:
                del arr
                gc.collect()
            except Exception:
                pass


# ==========================
# Utility helpers
# ==========================
def bytes_for_vectors(num_vectors: int, dim: int) -> int:
    return int(num_vectors) * int(dim) * 4


def choose_stream_batch_size(dim: int, target_mb: int = 256) -> int:
    """Choose number of vectors per batch to approximately use target_mb of RAM."""
    bytes_per_vec = dim * 4
    batch = max(1, (target_mb * 1024 * 1024) // bytes_per_vec)
    return int(batch)


def ensure_writable_paths(models_dir: str, transformed_dir: str, logger: logging.Logger) -> None:
    for d in [models_dir, transformed_dir]:
        if not os.path.exists(d):
            try:
                os.makedirs(d, exist_ok=True)
            except PermissionError:
                logger.info(f"Attempting sudo mkdir/chown for {d}")
                os.system(f"echo 'password' | sudo -S mkdir -p {d}")
                os.system(f"echo 'password' | sudo -S chown -R {os.getenv('USER','name')}:{os.getenv('USER','name')} {d}")


def dataset_pairs_from_dir(data_dir: str) -> List[Tuple[str, str, str]]:
    """Return list of (dataset_name, base_path, query_path) for available pairs."""
    files = os.listdir(data_dir)
    def match_ext(name: str) -> bool:
        return name.endswith('.fvec') or name.endswith('.fvecs')

    base_suffixes = ['_base.fvecs', '_base.fvec']
    query_suffixes = ['_query.fvecs', '_query.fvec']

    bases: Dict[str, str] = {}
    queries: Dict[str, str] = {}
    for f in files:
        if not match_ext(f):
            continue
        for s in base_suffixes:
            if f.endswith(s):
                bases[f[:-len(s)]] = os.path.join(data_dir, f)
                break
        for s in query_suffixes:
            if f.endswith(s):
                queries[f[:-len(s)]] = os.path.join(data_dir, f)
                break

    pairs = []
    keys = sorted(set(bases.keys()) & set(queries.keys()))
    for k in keys:
        dataset = k
        pairs.append((dataset, bases[k], queries[k]))
    return pairs


# ==========================
# Core pipeline
# ==========================
def pad_for_wavelet(data: np.ndarray, n_levels: int) -> np.ndarray:
    dim = data.shape[1]
    factor = 2 ** n_levels
    target_dim = ((dim + factor - 1) // factor) * factor
    if target_dim == dim:
        return data
    pad = np.zeros((data.shape[0], target_dim - dim), dtype=np.float32)
    return np.concatenate([data, pad], axis=1)


def padded_length(dim: int, n_levels: int) -> int:
    """Return the dimension padded up to a multiple of 2**n_levels."""
    factor = 2 ** n_levels
    return ((dim + factor - 1) // factor) * factor


def train_method(
    method: str,
    config: Dict,
    train_np: np.ndarray,
    eval_np: np.ndarray,
    logger: logging.Logger,
) -> Tuple[torch.nn.Module, Dict, Dict, np.ndarray, np.ndarray, int]:
    """Train specified method and return (model, history, transformed_train, transformed_eval, output_dim)."""
    trainer = WaveletTransformTrainer(config)

    # Determine method specifics
    if method == 'Wavelet Corr Exp':
        n_levels = config['wavelet'].N_LEVELS
        filter_len = config['wavelet'].FILTER_LEN
        train_np_p = pad_for_wavelet(train_np, n_levels)
        eval_np_p = pad_for_wavelet(eval_np, n_levels)
        dim = train_np_p.shape[1]
        model = LearnableWaveletTransform(dim, dim, filter_len, n_levels)
        loss_shape = lambda v: LossFunctions.cumulative_energy_shape(v, config['training'].TARGET_A)
        loss_dist = lambda v1, v2: LossFunctions.exp_distance_correlation(v1, v2, config['exp_correlation'].MAX_PAIRS)
        gamma, beta = config['exp_correlation'].GAMMA, config['exp_correlation'].BETA
        epochs = config['wavelet'].EPOCHS
        lr = config['wavelet'].LEARNING_RATE
        ortho_lambda = None
        data_train = torch.from_numpy(train_np_p).float()
        data_eval = torch.from_numpy(eval_np_p).float()
        output_dim = dim

    elif method == 'Wavelet Triplet':
        n_levels = config['wavelet'].N_LEVELS
        filter_len = config['wavelet'].FILTER_LEN
        train_np_p = pad_for_wavelet(train_np, n_levels)
        eval_np_p = pad_for_wavelet(eval_np, n_levels)
        dim = train_np_p.shape[1]
        model = LearnableWaveletTransform(dim, dim, filter_len, n_levels)
        loss_shape = lambda v: LossFunctions.cumulative_energy_shape(v, config['training'].TARGET_A)
        loss_dist = lambda v1, v2: LossFunctions.ordering_preservation(v1, v2, config['force_order'])
        gamma, beta = config['force_order'].GAMMA, config['force_order'].BETA
        epochs = config['wavelet'].EPOCHS
        lr = config['wavelet'].LEARNING_RATE
        ortho_lambda = None
        data_train = torch.from_numpy(train_np_p).float()
        data_eval = torch.from_numpy(eval_np_p).float()
        output_dim = dim

    elif method == 'Wavelet Triplet Var':
        n_levels = config['wavelet'].N_LEVELS
        filter_len = config['wavelet'].FILTER_LEN
        train_np_p = pad_for_wavelet(train_np, n_levels)
        eval_np_p = pad_for_wavelet(eval_np, n_levels)
        dim = train_np_p.shape[1]
        model = LearnableWaveletTransform(dim, dim, filter_len, n_levels)
        loss_shape = lambda v: LossFunctions.variance_shape_loss(v, config['training'].TARGET_A)
        loss_dist = lambda v1, v2: LossFunctions.ordering_preservation(v1, v2, config['force_order'])
        gamma, beta = config['force_order'].GAMMA, config['force_order'].BETA
        epochs = config['wavelet'].EPOCHS
        lr = config['wavelet'].LEARNING_RATE
        ortho_lambda = None
        data_train = torch.from_numpy(train_np_p).float()
        data_eval = torch.from_numpy(eval_np_p).float()
        output_dim = dim

    elif method == 'Wavelet Triplet NL':
        n_levels = config['wavelet'].N_LEVELS
        filter_len = config['wavelet'].FILTER_LEN
        train_np_p = pad_for_wavelet(train_np, n_levels)
        eval_np_p = pad_for_wavelet(eval_np, n_levels)
        dim = train_np_p.shape[1]
        model = LearnableNLWaveletTransform(dim, dim, filter_len, n_levels)
        loss_shape = lambda v: LossFunctions.cumulative_energy_shape(v, config['training'].TARGET_A)
        loss_dist = lambda v1, v2: LossFunctions.ordering_preservation(v1, v2, config['force_order'])
        gamma, beta = config['force_order'].GAMMA, config['force_order'].BETA
        epochs = config['wavelet'].EPOCHS
        lr = config['wavelet'].LEARNING_RATE
        ortho_lambda = None
        data_train = torch.from_numpy(train_np_p).float()
        data_eval = torch.from_numpy(eval_np_p).float()
        output_dim = dim

    elif method == 'Orthogonal Wavelet Triplet':
        n_levels = config['wavelet'].N_LEVELS
        filter_len = config['wavelet'].FILTER_LEN
        train_np_p = pad_for_wavelet(train_np, n_levels)
        eval_np_p = pad_for_wavelet(eval_np, n_levels)
        dim = train_np_p.shape[1]
        model = LearnableOrthogonalWaveletTransform(dim, dim, filter_len, n_levels)
        loss_shape = lambda v: LossFunctions.cumulative_energy_shape(v, config['training'].TARGET_A)
        loss_dist = lambda v1, v2: LossFunctions.ordering_preservation(v1, v2, config['force_order'])
        gamma, beta = config['force_order'].GAMMA, config['force_order'].BETA
        epochs = config['ortho_wavelet'].EPOCHS
        lr = config['ortho_wavelet'].LEARNING_RATE
        ortho_lambda = config['ortho_wavelet'].LAMBDA
        data_train = torch.from_numpy(train_np_p).float()
        data_eval = torch.from_numpy(eval_np_p).float()
        output_dim = dim

    else:
        raise ValueError(f"Unsupported method: {method}")

    logger.info(f"Training {method}: train {data_train.shape}, eval {data_eval.shape}, epochs={epochs}, lr={lr}")

    transformed_train, transformed_eval, train_time, history, trained_model = trainer.train_model(
        method, model, data_train, data_eval, loss_shape, loss_dist, gamma, beta, epochs, lr, ortho_lambda
    )

    logger.info(f"{method} training time: {train_time:.2f}s")

    # Basic evaluation metrics on eval split
    try:
        eval_orig_np = data_eval.numpy()
        metrics: Dict = {}
        metrics['energy'] = EvaluationMetrics.calculate_cumulative_energy(transformed_eval)
        metrics['geometric_ratio'] = EvaluationMetrics.calculate_geometric_ratio(
            metrics.get('energy', []), config['training'].TARGET_A
        )
        metrics['distance'] = EvaluationMetrics.evaluate_distance_preservation(
            eval_orig_np, transformed_eval, config['training'].N_PAIRS_CHECK
        )
    except Exception as e:
        logger.warning(f"Evaluation metrics failed for {method}: {e}")
        metrics = {}

    # Free large transformed arrays before returning to reduce peak memory
    try:
        del transformed_train
        del transformed_eval
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    except Exception:
        pass

    return trained_model, history, metrics, None, None, output_dim


def transform_streamed(
    model: torch.nn.Module,
    input_path: str,
    output_path: str,
    n_levels: int,
    logger: logging.Logger,
    show_tqdm: bool = True,
    device: Optional[torch.device] = None,
    target_mb: int = 64,
    force_cpu: bool = False,
) -> Tuple[int, int]:
    """Stream-transform a full .fvecs file and write .fvecs output. Returns (num_vectors, out_dim)."""
    model.eval()
    if device is None:
        if force_cpu:
            device = torch.device('cpu')
        else:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    dim_in, total = get_fvecs_meta(input_path)
    factor = 2 ** n_levels
    out_dim = ((dim_in + factor - 1) // factor) * factor
    # Choose batch size based on padded output dimension, not input
    batch_size_vecs = choose_stream_batch_size(out_dim, target_mb=target_mb)

    def gen_transformed() -> Iterator[np.ndarray]:
        iterator = stream_fvecs(input_path, batch_size_vecs)
        it = iterator
        pbar = tqdm(it, desc=f"Transform {os.path.basename(input_path)}", disable=not show_tqdm)
        for batch in pbar:
            # Pad for wavelet if needed
            if batch.shape[1] != out_dim:
                pad = np.zeros((batch.shape[0], out_dim - batch.shape[1]), dtype=np.float32)
                batch_p = np.concatenate([batch, pad], axis=1)
            else:
                batch_p = batch
            with torch.no_grad():
                t = torch.from_numpy(batch_p).float().to(device)
                y = model(t).detach().cpu().numpy()
                # Free tensor promptly
                del t
                if torch.cuda.is_available() and device.type == 'cuda':
                    torch.cuda.empty_cache()
            yield y

    write_fvecs_stream(output_path, gen_transformed(), expected_dim=out_dim)
    try:
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    except Exception:
        pass
    return total, out_dim


def process_dataset(
    dataset: str,
    base_path: str,
    query_path: str,
    args: argparse.Namespace,
    logger: logging.Logger,
) -> None:
    logger.info(f"\n===== Dataset: {dataset} =====")
    logger.info(f"Base:  {base_path}")
    logger.info(f"Query: {query_path}")

    dim_base, num_base = get_fvecs_meta(base_path)
    logger.info(f"Base meta: {num_base} vectors, dim={dim_base}")

    # Prepare config early to know padding levels
    config = get_default_config()
    n_levels_cfg = config['wavelet'].N_LEVELS

    # Decide train/eval sizes
    if getattr(args, 'auto_fractions', False):
        # If 2% is less than 20k points => small dataset: use 10% eval, rest train
        two_percent = int(0.02 * num_base)
        if two_percent < 20000:
            eval_count = max(1, int(0.10 * num_base))
            train_count = max(1, num_base - eval_count)
            logger.info(f"Auto fractions (SMALL): eval=10% ({eval_count}), train=rest ({train_count})")
        else:
            # Large dataset: 2% train, 0.05% eval
            train_count = max(1, int(0.02 * num_base))
            eval_count = max(1, int(0.0005 * num_base))
            logger.info(f"Auto fractions (LARGE): train=2% ({train_count}), eval=0.05% ({eval_count})")
    else:
        # Use provided fractions
        f_train = args.f_train
        f_eval = args.f_eval
        train_count = int(num_base * f_train)
        eval_count = int(num_base * f_eval)
        logger.info(f"Manual fractions: train={f_train*100:.2f}% ({train_count}), eval={f_eval*100:.2f}% ({eval_count})")
    if train_count < 1 or eval_count < 1:
        logger.warning(f"Very small train/eval sizes (train={train_count}, eval={eval_count}). Consider increasing fractions.")

    # Memory guard: cap total training memory. Account for padding and duplication (numpy+tensor+grads)
    max_train_bytes = int(args.train_mem_gb * (1024**3))
    padded_dim = padded_length(dim_base, n_levels_cfg)
    duplication_factor = float(getattr(args, 'train_safety_factor', 3.5))
    bytes_needed = int(duplication_factor * bytes_for_vectors(train_count + eval_count, padded_dim))
    if bytes_needed > max_train_bytes:
        scale = max_train_bytes / max(1, bytes_needed)
        new_train = max(1, int(train_count * scale))
        new_eval = max(1, int(eval_count * scale))
        logger.info(f"Capping train/eval due to memory: {train_count}->{new_train}, {eval_count}->{new_eval}")
        train_count, eval_count = new_train, new_eval

    # Hard cap by absolute number of vectors if provided
    if getattr(args, 'max_train_eval', 0) and (train_count + eval_count) > args.max_train_eval:
        total = train_count + eval_count
        scale = args.max_train_eval / float(total)
        new_train = max(1, int(train_count * scale))
        new_eval = max(1, int(eval_count * scale))
        logger.info(f"Hard cap applied: total {total}->{args.max_train_eval} vectors => train {train_count}->{new_train}, eval {eval_count}->{new_eval}")
        train_count, eval_count = new_train, new_eval

    config['training'].BATCH_SIZE = args.batch_size
    config['wavelet'].EPOCHS = args.epochs
    config['wavelet'].LEARNING_RATE = args.learning_rate

    # Methods to train
    methods = []
    if 'triplet' in args.methods:
        methods.append('Wavelet Triplet')
    if 'corr_exp' in args.methods:
        methods.append('Wavelet Corr Exp')
    if 'triplet_var' in args.methods:
        methods.append('Wavelet Triplet Var')
    if 'wavelet_nl_triplet' in args.methods:
        methods.append('Wavelet Triplet NL')
    if 'ortho_wavelet_triplet' in args.methods:
        methods.append('Orthogonal Wavelet Triplet')

    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

    for method in methods:
        logger.info(f"\n--- Training method: {method} ---")
        # Load train + eval slices per method to keep memory low
        logger.info(f"Loading training slice: first {train_count} vectors")
        train_np = read_fvecs_range(base_path, 0, train_count)
        logger.info(f"Loading eval slice: next {eval_count} vectors")
        eval_np = read_fvecs_range(base_path, train_count, eval_count)

        model, history, metrics, tr_out, ev_out, out_dim = train_method(method, config, train_np, eval_np, logger)

        # Save model
        model_fname = f"{dataset}_{method.replace(' ', '_')}_{timestamp}.pth"
        model_path = os.path.join(args.models_dir, model_fname)
        model_data = {
            'model_state_dict': model.state_dict(),
            'model_class': type(model).__name__,
            'method_name': method,
            'timestamp': timestamp,
            'original_file': base_path,
            'config': config,
            'input_dim': out_dim,
            'target_dim': out_dim,
            'filter_len': config['wavelet'].FILTER_LEN,
            'n_levels': config['wavelet'].N_LEVELS,
        }
        torch.save(model_data, model_path)
        logger.info(f"Saved model -> {model_path}")

        # Free training/eval arrays before full transforms to reduce peak RAM
        try:
            del train_np
            del eval_np
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except Exception:
            pass

        # Transform full base (train) and query (test)
        base_out = os.path.join(args.transformed_dir, f"{dataset}_train_{method.replace(' ', '_')}_{timestamp}.fvecs")
        query_out = os.path.join(args.transformed_dir, f"{dataset}_test_{method.replace(' ', '_')}_{timestamp}.fvecs")

        n_levels = config['wavelet'].N_LEVELS
        logger.info(f"Transforming FULL base -> {base_out}")
        total_base, out_dim_b = transform_streamed(
            model, base_path, base_out, n_levels, logger, show_tqdm=True, target_mb=int(args.transform_mb), force_cpu=bool(args.cpu_transform)
        )
        logger.info(f"Base transformed: {total_base} vectors to dim={out_dim_b}")

        logger.info(f"Transforming FULL query -> {query_out}")
        total_query, out_dim_q = transform_streamed(
            model, query_path, query_out, n_levels, logger, show_tqdm=True, target_mb=int(args.transform_mb), force_cpu=bool(args.cpu_transform)
        )
        logger.info(f"Query transformed: {total_query} vectors to dim={out_dim_q}")

        # Log training trajectory and evaluation values
        if history:
            epochs = len(history.get('epoch', []))
            if epochs:
                for ep in range(epochs):
                    logger.info(
                        f"epoch={history['epoch'][ep]} total={history['total_loss'][ep]:.6f} "
                        f"shape={history['loss_shape'][ep]:.6f} dist={history['loss_dist'][ep]:.6f}"
                    )
        if metrics:
            try:
                corr = metrics.get('distance', {}).get('corr_coef', None)
                geo = metrics.get('geometric_ratio', {}).get('mean_ratio', None)
                logger.info(f"Eval metrics [{method}]: dist_corr={corr} geom_ratio={geo}")
            except Exception:
                logger.info(f"Eval metrics [{method}] available keys: {list(metrics.keys())}")


def main():
    parser = argparse.ArgumentParser(description='Train and transform all .fvecs datasets in a directory (streamed)')
    parser.add_argument('--data-dir', required=True, help='Directory containing *_base.fvec[s] and *_query.fvec[s] files')
    parser.add_argument('--models-dir', default='/mnt/device/models', help='Output directory for models')
    parser.add_argument('--transformed-dir', default='/mnt/device/transformed', help='Output directory for transformed fvecs')
    parser.add_argument('--log-dir', default='/mnt/device/logs', help='Log directory')
    parser.add_argument('--auto-fractions', action='store_true', default=True, help='Automatically choose train/eval fractions based on dataset size')
    parser.add_argument('--no-auto-fractions', dest='auto_fractions', action='store_false', help='Disable auto fractions and use manual fractions')
    parser.add_argument('--f-train', dest='f_train', type=float, default=0.05, help='Manual fraction of base used for training (used if --no-auto-fractions)')
    parser.add_argument('--f-eval', dest='f_eval', type=float, default=0.01, help='Manual fraction of base used for eval (used if --no-auto-fractions)')
    parser.add_argument('--batch-size', type=int, default=256, help='Training mini-batch size')
    parser.add_argument('--epochs', type=int, default=50, help='Training epochs for wavelet models')
    parser.add_argument('--learning-rate', type=float, default=0.005, help='Learning rate for wavelet models')
    parser.add_argument('--train-mem-gb', type=float, default=3.0, help='Max RAM (GB) for train+eval slice')
    parser.add_argument('--transform-mb', type=int, default=64, help='Approx MB of RAM for streamed transform batches')
    parser.add_argument('--train-safety-factor', type=float, default=3.5, help='Multiplier for RAM estimate during training sizing')
    parser.add_argument('--max-train-eval', type=int, default=0, help='Hard cap on total (train+eval) vectors; 0 disables')
    parser.add_argument('--cpu-transform', action='store_true', help='Run streamed transforms on CPU to reduce GPU memory spikes')

    args = parser.parse_args()

    run_name = f"transform_all_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    logger = setup_logger(args.log_dir, run_name)

    ensure_writable_paths(args.models_dir, args.transformed_dir, logger)

    pairs = dataset_pairs_from_dir(args.data_dir)
    # Apply allowlist filter if provided
    if SELECTED_DATASETS:
        pairs = [p for p in pairs if p[0] in SELECTED_DATASETS]
    if not pairs:
        logger.error(f"No dataset pairs found in {args.data_dir}")
        sys.exit(1)

    logger.info(f"Found {len(pairs)} dataset pairs")
    for dataset, base_path, query_path in tqdm(pairs, desc='Datasets'):
        try:
            process_dataset(dataset, base_path, query_path, args, logger)
        except Exception as e:
            logger.exception(f"Failed on dataset {dataset}: {e}")


if __name__ == '__main__':
    main()


'''
python3 transform_all_data.py \
  --data-dir /mnt/device/datasets \
  --models-dir /mnt/device/models \
  --transformed-dir /mnt/device/transformed \ 
  --log-dir /mnt/device/logs
  --train-mem-gb 7.0
'''