#!/usr/bin/env python3
"""
Streamed training and transformation for ANN .fvecs datasets using linear transforms.

- Discovers dataset pairs: <name>_base.fvec[s], <name>_query.fvec[s]
- Loads only fractions of base for training/eval to respect memory limits
- Trains specified linear transform models in mini-batches
- Saves models to /mnt/device/models as <dataset>_<method>_<timestamp>.pth
- Transforms FULL base (train) and query (test) in streamed batches
  and saves to /mnt/device/transformed as
  <dataset>_train_<method>_<timestamp>.fvecs and
  <dataset>_test_<method>_<timestamp>.fvecs

Memory budget: default 8GB. Uses dynamic batch sizing to stay within budget.
"""

import os
import sys
import math
import time
import argparse
import logging
from datetime import datetime
from typing import Iterator, Tuple, List, Optional, Dict

import numpy as np
import struct
from tqdm import tqdm

import torch
import gc

from config import get_default_config
from linear_transform import (
    LinearTransformTrainer,
    SVDProjectionTransform,
    PenaltyTransform,
    ManifoldTransform,
    ExponentialMapTransform,
    CayleyTransform,
    GivensRotationTransform,
    HouseholderTransform,
    LossFunctions,
    EvaluationMetrics,
)


# ==========================
# Dataset allowlist
# ==========================
# Only datasets listed here (by base name, e.g., 'gist', 'sift', 'glove100d')
# will be processed. Leave empty to process all discovered datasets.
SELECTED_DATASETS: List[str] = [
    # 'deep',
    # 'fashionmnist',
    # 'glove50d',
    # 'glove100d',
    # 'glove300d',
    # 'nytimes',
    # 'sift100m',
    # 'yorck', 
    # 'gist'
    # 'cifar10'
    'glove300d_1_2m'
]


# ==========================
# Logging utilities
# ==========================
def setup_logger(log_dir: str, run_name: str) -> logging.Logger:
    os.makedirs(log_dir, exist_ok=True)
    logger = logging.getLogger(run_name)
    logger.setLevel(logging.INFO)
    logger.propagate = False

    # Clear existing handlers (idempotent for re-runs)
    for h in list(logger.handlers):
        logger.removeHandler(h)

    log_path = os.path.join(log_dir, f"{run_name}.log")
    fh = logging.FileHandler(log_path)
    fh.setLevel(logging.INFO)
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(logging.INFO)

    fmt = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    fh.setFormatter(fmt)
    ch.setFormatter(fmt)
    logger.addHandler(fh)
    logger.addHandler(ch)

    logger.info(f"Logging to {log_path}")
    return logger


# ==========================
# .fvecs streaming I/O
# ==========================
def get_fvecs_meta(file_path: str) -> Tuple[int, int]:
    """Return (dim, num_vectors) for an .fvecs file without loading all data."""
    file_size = os.path.getsize(file_path)
    with open(file_path, 'rb') as f:
        dim_bytes = f.read(4)
        if len(dim_bytes) != 4:
            raise ValueError(f"Invalid fvecs file: {file_path}")
        dim = struct.unpack('<I', dim_bytes)[0]
    record_size = 4 + 4 * dim
    if record_size <= 0:
        raise ValueError("Invalid record size computed for fvecs")
    if file_size % record_size != 0:
        # Some files may have trailing data; floor division as best effort
        num = file_size // record_size
    else:
        num = file_size // record_size
    return dim, int(num)


def read_fvecs_range(
    file_path: str,
    start_index: int,
    count: int,
    ) -> np.ndarray:
    """Read a range [start_index, start_index+count) from .fvecs into a numpy array.

    This allocates memory for count*dim floats. Use for modest ranges only.
    """
    dim, total = get_fvecs_meta(file_path)
    end_index = min(start_index + count, total)
    if start_index >= end_index:
        return np.empty((0, dim), dtype=np.float32)

    record_size = 4 + 4 * dim
    vectors = np.empty((end_index - start_index, dim), dtype=np.float32)
    with open(file_path, 'rb') as f:
        f.seek(start_index * record_size)
        for i in range(end_index - start_index):
            d = struct.unpack('<I', f.read(4))[0]
            if d != dim:
                raise ValueError(f"Dimension mismatch in {file_path}: expected {dim}, got {d}")
            vec = f.read(dim * 4)
            vectors[i] = np.frombuffer(vec, dtype=np.float32)
    return vectors


def stream_fvecs(
    file_path: str,
    batch_vectors: int,
    start_index: int = 0,
    max_vectors: int = -1,
    ) -> Iterator[np.ndarray]:
    """Yield numpy arrays of shape (n_batch, dim) from .fvecs without loading all at once."""
    dim, total = get_fvecs_meta(file_path)
    remaining = total - start_index
    if max_vectors > 0:
        remaining = min(remaining, max_vectors)

    record_size = 4 + 4 * dim
    with open(file_path, 'rb') as f:
        f.seek(start_index * record_size)
        produced = 0
        while produced < remaining:
            n = int(min(batch_vectors, remaining - produced))
            batch = np.empty((n, dim), dtype=np.float32)
            for i in range(n):
                d = struct.unpack('<I', f.read(4))[0]
                if d != dim:
                    raise ValueError(f"Dimension mismatch in {file_path}: expected {dim}, got {d}")
                vec = f.read(dim * 4)
                batch[i] = np.frombuffer(vec, dtype=np.float32)
            produced += n
            yield batch


def write_fvecs_stream(
    file_path: str,
    arrays_iter: Iterator[np.ndarray],
    expected_dim: Optional[int] = None,
    ) -> None:
    """Write arrays to .fvecs file, streaming per vector."""
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'wb') as f:
        for arr in arrays_iter:
            if arr.size == 0:
                continue
            if expected_dim is None:
                expected_dim = arr.shape[1]
            if arr.shape[1] != expected_dim:
                raise ValueError(f"Output dim mismatch: got {arr.shape[1]}, expected {expected_dim}")
            # Ensure contiguous float32 without copying if possible
            if arr.dtype != np.float32 or not arr.flags['C_CONTIGUOUS']:
                arr = np.ascontiguousarray(arr, dtype=np.float32)
            dim = expected_dim
            dim_hdr = struct.pack('<I', dim)
            for i in range(arr.shape[0]):
                f.write(dim_hdr)
                # Avoid redundant astype to reduce per-row allocations
                f.write(arr[i].tobytes(order='C'))
            # Encourage freeing the just-processed batch
            try:
                del arr
                gc.collect()
            except Exception:
                pass


# ==========================
# Utility helpers
# ==========================
def bytes_for_vectors(num_vectors: int, dim: int) -> int:
    return int(num_vectors) * int(dim) * 4


def choose_stream_batch_size(dim: int, target_mb: int = 256) -> int:
    """Choose number of vectors per batch to approximately use target_mb of RAM."""
    bytes_per_vec = dim * 4
    batch = max(1, (target_mb * 1024 * 1024) // bytes_per_vec)
    return int(batch)


def ensure_writable_paths(models_dir: str, transformed_dir: str, logger: logging.Logger) -> None:
    for d in [models_dir, transformed_dir]:
        if not os.path.exists(d):
            try:
                os.makedirs(d, exist_ok=True)
            except PermissionError:
                logger.info(f"Attempting sudo mkdir/chown for {d}")
                os.system(f"echo 'password' | sudo -S mkdir -p {d}")
                os.system(f"echo 'password' | sudo -S chown -R {os.getenv('USER','name')}:{os.getenv('USER','name')} {d}")


def dataset_pairs_from_dir(data_dir: str) -> List[Tuple[str, str, str]]:
    """Return list of (dataset_name, base_path, query_path) for available pairs."""
    files = os.listdir(data_dir)
    def match_ext(name: str) -> bool:
        return name.endswith('.fvec') or name.endswith('.fvecs')

    base_suffixes = ['_base.fvecs', '_base.fvec']
    query_suffixes = ['_query.fvecs', '_query.fvec']

    bases: Dict[str, str] = {}
    queries: Dict[str, str] = {}
    for f in files:
        if not match_ext(f):
            continue
        for s in base_suffixes:
            if f.endswith(s):
                bases[f[:-len(s)]] = os.path.join(data_dir, f)
                break
        for s in query_suffixes:
            if f.endswith(s):
                queries[f[:-len(s)]] = os.path.join(data_dir, f)
                break

    pairs = []
    keys = sorted(set(bases.keys()) & set(queries.keys()))
    for k in keys:
        dataset = k
        pairs.append((dataset, bases[k], queries[k]))
    return pairs


# ==========================
# Core pipeline
# ==========================
def train_method(
    method: str,
    config: Dict,
    train_np: np.ndarray,
    eval_np: np.ndarray,
    logger: logging.Logger,
    ) -> Tuple[torch.nn.Module, Dict, Dict, np.ndarray, np.ndarray, int]:
    """Train specified linear transform method and return (model, history, transformed_train, transformed_eval, output_dim)."""
    trainer = LinearTransformTrainer(config)

    # Determine method specifics
    if method == 'SVD Projection':
        model = SVDProjectionTransform(train_np.shape[1])
        ortho_lambda = None
        data_train = torch.from_numpy(train_np).float()
        data_eval = torch.from_numpy(eval_np).float()
        output_dim = train_np.shape[1]

    elif method == 'Penalty Method':
        model = PenaltyTransform(train_np.shape[1])
        ortho_lambda = config['penalty']['LAMBDA']
        data_train = torch.from_numpy(train_np).float()
        data_eval = torch.from_numpy(eval_np).float()
        output_dim = train_np.shape[1]

    elif method == 'Manifold Optimization':
        model = ManifoldTransform(train_np.shape[1])
        ortho_lambda = None
        data_train = torch.from_numpy(train_np).float()
        data_eval = torch.from_numpy(eval_np).float()
        output_dim = train_np.shape[1]

    elif method == 'Exponential Map':
        model = ExponentialMapTransform(train_np.shape[1])
        ortho_lambda = None
        data_train = torch.from_numpy(train_np).float()
        data_eval = torch.from_numpy(eval_np).float()
        output_dim = train_np.shape[1]

    elif method == 'Cayley Transform':
        model = CayleyTransform(train_np.shape[1])
        ortho_lambda = None
        data_train = torch.from_numpy(train_np).float()
        data_eval = torch.from_numpy(eval_np).float()
        output_dim = train_np.shape[1]

    elif method == 'Givens Rotations':
        model = GivensRotationTransform(train_np.shape[1])
        ortho_lambda = None
        data_train = torch.from_numpy(train_np).float()
        data_eval = torch.from_numpy(eval_np).float()
        output_dim = train_np.shape[1]

    elif method == 'Householder Reflections':
        model = HouseholderTransform(train_np.shape[1])
        ortho_lambda = None
        data_train = torch.from_numpy(train_np).float()
        data_eval = torch.from_numpy(eval_np).float()
        output_dim = train_np.shape[1]

    else:
        raise ValueError(f"Unsupported method: {method}")

    # Choose loss function
    if config['training'].LOSS_TYPE == 'cumulative_energy':
        loss_fn = lambda v: LossFunctions.cumulative_energy_shape(v, config['training'].TARGET_A)
    elif config['training'].LOSS_TYPE == 'exponential':
        loss_fn = lambda v: LossFunctions.exponential_shape_loss(v, config['training'].TARGET_A)
    else:
        raise ValueError(f"Unknown loss type: {config['training'].LOSS_TYPE}")

    logger.info(f"Training {method}: train {data_train.shape}, eval {data_eval.shape}, epochs={config['training'].EPOCHS}, lr={config['training'].LEARNING_RATE}")

    transformed_train, transformed_eval, train_time, history, trained_model = trainer.train_model(
        method, model, data_train, data_eval, loss_fn, config['training'].EPOCHS, config['training'].LEARNING_RATE, ortho_lambda
    )

    logger.info(f"{method} training time: {train_time:.2f}s")

    # Basic evaluation metrics on eval split
    try:
        eval_orig_np = data_eval.numpy()
        metrics: Dict = {}
        metrics['energy'] = EvaluationMetrics.calculate_cumulative_energy(transformed_eval)
        metrics['distance'] = EvaluationMetrics.evaluate_distance_preservation(
            eval_orig_np, transformed_eval, config['training'].N_PAIRS_CHECK
        )
    except Exception as e:
        logger.warning(f"Evaluation metrics failed for {method}: {e}")
        metrics = {}

    # Free large transformed arrays before returning to reduce peak memory
    try:
        del transformed_train
        del transformed_eval
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    except Exception:
        pass

    return trained_model, history, metrics, None, None, output_dim


def transform_streamed(
    model: torch.nn.Module,
    input_path: str,
    output_path: str,
    logger: logging.Logger,
    show_tqdm: bool = True,
    device: Optional[torch.device] = None,
    target_mb: int = 64,
    force_cpu: bool = False,
    transform_frac: float = 1.0,
    ) -> Tuple[int, int]:
    """Stream-transform a full .fvecs file and write .fvecs output. Returns (num_vectors, out_dim)."""
    model.eval()
    if device is None:
        if force_cpu:
            device = torch.device('cpu')
        else:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    dim_in, total = get_fvecs_meta(input_path)
    out_dim = dim_in  # Linear transforms preserve dimension
    
    # Use chunked processing if transform_frac < 1.0
    if transform_frac < 1.0:
        return transform_chunked(model, input_path, output_path, transform_frac, device, logger, show_tqdm)
    
    # Original streaming approach for backward compatibility
    # Choose batch size based on output dimension
    batch_size_vecs = choose_stream_batch_size(out_dim, target_mb=target_mb)

    def gen_transformed() -> Iterator[np.ndarray]:
        iterator = stream_fvecs(input_path, batch_size_vecs)
        it = iterator
        pbar = tqdm(it, desc=f"Transform {os.path.basename(input_path)}", disable=not show_tqdm)
        for batch in pbar:
            with torch.no_grad():
                t = torch.from_numpy(batch).float().to(device)
                y = model(t).detach().cpu().numpy()
                # Free tensor promptly
                del t
                if torch.cuda.is_available() and device.type == 'cuda':
                    torch.cuda.empty_cache()
            yield y

    write_fvecs_stream(output_path, gen_transformed(), expected_dim=out_dim)
    try:
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    except Exception:
        pass
    return total, out_dim


def transform_chunked(
    model: torch.nn.Module,
    input_path: str,
    output_path: str,
    chunk_fraction: float,
    device: torch.device,
    logger: logging.Logger,
    show_tqdm: bool = True,
    ) -> Tuple[int, int]:
    """Transform a dataset in chunks to avoid OOM."""
    logger.info(f"Using chunked transformation with {chunk_fraction*100:.1f}% chunks")
    
    dim, total_vectors = get_fvecs_meta(input_path)
    chunk_size = max(1, int(total_vectors * chunk_fraction))
    
    logger.info(f"Dataset: {total_vectors} vectors, dim={dim}, chunk_size={chunk_size}")
    
    processed = 0
    chunk_idx = 0
    
    with tqdm(total=total_vectors, desc=f"Transform {os.path.basename(input_path)}", disable=not show_tqdm) as pbar:
        while processed < total_vectors:
            # Determine current chunk size
            current_chunk_size = min(chunk_size, total_vectors - processed)
            
            # Read chunk
            chunk_data = read_fvecs_range(input_path, processed, current_chunk_size)
            
            # Transform chunk
            with torch.no_grad():
                chunk_tensor = torch.from_numpy(chunk_data).float().to(device)
                transformed_tensor = model(chunk_tensor)
                transformed_data = transformed_tensor.cpu().numpy()
                
                # Free GPU memory immediately
                del chunk_tensor, transformed_tensor
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
            
            # Write chunk (append after first chunk)
            write_fvecs_chunk(output_path, transformed_data, append=(chunk_idx > 0))
            
            # Free memory
            del chunk_data, transformed_data
            gc.collect()
            
            processed += current_chunk_size
            chunk_idx += 1
            pbar.update(current_chunk_size)
    
    return total_vectors, dim


def write_fvecs_chunk(file_path: str, vectors: np.ndarray, append: bool = False) -> None:
    """Write vectors to .fvecs file in chunk mode."""
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    mode = 'ab' if append else 'wb'
    
    # Ensure contiguous float32
    if vectors.dtype != np.float32 or not vectors.flags['C_CONTIGUOUS']:
        vectors = np.ascontiguousarray(vectors, dtype=np.float32)
    
    dim = vectors.shape[1]
    dim_hdr = struct.pack('<I', dim)
    
    with open(file_path, mode) as f:
        for i in range(vectors.shape[0]):
            f.write(dim_hdr)
            f.write(vectors[i].tobytes(order='C'))


def process_dataset(
    dataset: str,
    base_path: str,
    query_path: str,
    args: argparse.Namespace,
    logger: logging.Logger,
    ) -> None:
    logger.info(f"\n===== Dataset: {dataset} =====")
    logger.info(f"Base:  {base_path}")
    logger.info(f"Query: {query_path}")

    dim_base, num_base = get_fvecs_meta(base_path)
    logger.info(f"Base meta: {num_base} vectors, dim={dim_base}")

    # Prepare config early
    config = get_default_config()
    
    # Add linear transform specific configs
    config['training'].LOSS_TYPE = args.loss_type
    config['training'].EPOCHS = args.epochs
    config['training'].LEARNING_RATE = args.learning_rate
    config['training'].BATCH_SIZE = args.batch_size
    
    # Add penalty config if not exists
    if 'penalty' not in config:
        # Use a simple dict instead of a local class to avoid pickle issues
        config['penalty'] = {'LAMBDA': args.penalty_lambda}

    # Decide train/eval sizes
    if getattr(args, 'auto_fractions', False):
        if num_base > 45000:
            # For large datasets: 20k train, 5k eval
            train_count = 40000
            eval_count = 5000
            logger.info(f"Auto fractions (LARGE): train=20000, eval=5000")
        else:
            # For small datasets: 80% train, 20% eval
            train_count = max(1, int(0.8 * num_base))
            eval_count = max(1, num_base - train_count)
            logger.info(f"Auto fractions (SMALL): train=80% ({train_count}), eval=20% ({eval_count})")
    else:
        # Use provided fractions
        f_train = args.f_train
        f_eval = args.f_eval
        train_count = int(num_base * f_train)
        eval_count = int(num_base * f_eval)
        logger.info(f"Manual fractions: train={f_train*100:.2f}% ({train_count}), eval={f_eval*100:.2f}% ({eval_count})")
    if train_count < 1 or eval_count < 1:
        logger.warning(f"Very small train/eval sizes (train={train_count}, eval={eval_count}). Consider increasing fractions.")

    # Memory guard: cap total training memory. Account for duplication (numpy+tensor+grads)
    max_train_bytes = int(args.train_mem_gb * (1024**3))
    duplication_factor = float(getattr(args, 'train_safety_factor', 3.5))
    bytes_needed = int(duplication_factor * bytes_for_vectors(train_count + eval_count, dim_base))
    if bytes_needed > max_train_bytes:
        scale = max_train_bytes / max(1, bytes_needed)
        new_train = max(1, int(train_count * scale))
        new_eval = max(1, int(eval_count * scale))
        logger.info(f"Capping train/eval due to memory: {train_count}->{new_train}, {eval_count}->{new_eval}")
        train_count, eval_count = new_train, new_eval

    # Hard cap by absolute number of vectors if provided
    if getattr(args, 'max_train_eval', 0) and (train_count + eval_count) > args.max_train_eval:
        total = train_count + eval_count
        scale = args.max_train_eval / float(total)
        new_train = max(1, int(train_count * scale))
        new_eval = max(1, int(eval_count * scale))
        logger.info(f"Hard cap applied: total {total}->{args.max_train_eval} vectors => train {train_count}->{new_train}, eval {eval_count}->{new_eval}")
        train_count, eval_count = new_train, new_eval

    # Methods to train
    methods = []
    if 'svd' in args.methods:
        methods.append('SVD Projection')
    if 'penalty' in args.methods:
        methods.append('Penalty Method')
    if 'manifold' in args.methods:
        methods.append('Manifold Optimization')
    if 'exponential' in args.methods:
        methods.append('Exponential Map')
    if 'cayley' in args.methods:
        methods.append('Cayley Transform')
    if 'givens' in args.methods:
        methods.append('Givens Rotations')
    if 'householder' in args.methods:
        methods.append('Householder Reflections')

    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

    for method in methods:
        logger.info(f"\n--- Training method: {method} ---")
        # Load train + eval slices per method to keep memory low
        logger.info(f"Loading training slice: first {train_count} vectors")
        train_np = read_fvecs_range(base_path, 0, train_count)
        logger.info(f"Loading eval slice: next {eval_count} vectors")
        eval_np = read_fvecs_range(base_path, train_count, eval_count)

        model, history, metrics, tr_out, ev_out, out_dim = train_method(method, config, train_np, eval_np, logger)

        # Save model
        model_fname = f"{dataset}_{method.replace(' ', '_')}_{timestamp}.pth"
        model_path = os.path.join(args.models_dir, model_fname)
        # Create a simplified config for saving (avoid pickle issues with local classes)
        save_config = {
            'loss_type': config['training'].LOSS_TYPE,
            'target_a': config['training'].TARGET_A,
            'epochs': config['training'].EPOCHS,
            'learning_rate': config['training'].LEARNING_RATE,
            'batch_size': config['training'].BATCH_SIZE,
            'penalty_lambda': config.get('penalty', {}).get('LAMBDA', None),
        }
        
        model_data = {
            'model_state_dict': model.state_dict(),
            'model_class': type(model).__name__,
            'method_name': method,
            'timestamp': timestamp,
            'original_file': base_path,
            'config': save_config,
            'input_dim': out_dim,
            'target_dim': out_dim,
        }
        torch.save(model_data, model_path)
        logger.info(f"Saved model -> {model_path}")

        # Free training/eval arrays before full transforms to reduce peak RAM
        try:
            del train_np
            del eval_np
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except Exception:
            pass

        # Transform full base (train) and query (test)
        base_out = os.path.join(args.transformed_dir, f"{dataset}_train_{method.replace(' ', '_')}_{timestamp}.fvecs")
        query_out = os.path.join(args.transformed_dir, f"{dataset}_test_{method.replace(' ', '_')}_{timestamp}.fvecs")

        logger.info(f"Transforming FULL base -> {base_out}")
        total_base, out_dim_b = transform_streamed(
            model, base_path, base_out, logger, show_tqdm=True, target_mb=int(args.transform_mb), force_cpu=bool(args.cpu_transform), transform_frac=args.transform_frac
        )
        logger.info(f"Base transformed: {total_base} vectors to dim={out_dim_b}")

        logger.info(f"Transforming FULL query -> {query_out}")
        total_query, out_dim_q = transform_streamed(
            model, query_path, query_out, logger, show_tqdm=True, target_mb=int(args.transform_mb), force_cpu=bool(args.cpu_transform), transform_frac=args.transform_frac
        )
        logger.info(f"Query transformed: {total_query} vectors to dim={out_dim_q}")

        # Log training trajectory and evaluation values
        if history:
            epochs = len(history.get('epoch', []))
            if epochs:
                for ep in range(epochs):
                    logger.info(
                        f"epoch={history['epoch'][ep]} loss={history['loss'][ep]:.6f} "
                        f"ortho_penalty={history['ortho_penalty'][ep]:.6f}"
                    )
        if metrics:
            try:
                corr = metrics.get('distance', {}).get('corr_coef', None)
                logger.info(f"Eval metrics [{method}]: dist_corr={corr}")
            except Exception:
                logger.info(f"Eval metrics [{method}] available keys: {list(metrics.keys())}")


def main():
    parser = argparse.ArgumentParser(description='Train and transform all .fvecs datasets using linear transforms (streamed)')
    parser.add_argument('--data-dir', required=True, help='Directory containing *_base.fvec[s] and *_query.fvec[s] files')
    parser.add_argument('--models-dir', default='/mnt/device/models', help='Output directory for models')
    parser.add_argument('--transformed-dir', default='/mnt/device/transformed', help='Output directory for transformed fvecs')
    parser.add_argument('--log-dir', default='/mnt/device/logs', help='Log directory')
    parser.add_argument('--auto-fractions', action='store_true', default=True, help='Automatically choose train/eval fractions based on dataset size')
    parser.add_argument('--no-auto-fractions', dest='auto_fractions', action='store_false', help='Disable auto fractions and use manual fractions')
    parser.add_argument('--f-train', dest='f_train', type=float, default=0.005, help='Manual fraction of base used for training (used if --no-auto-fractions)')
    parser.add_argument('--f-eval', dest='f_eval', type=float, default=0.001, help='Manual fraction of base used for eval (used if --no-auto-fractions)')
    parser.add_argument('--batch-size', type=int, default=256, help='Training mini-batch size')
    parser.add_argument('--epochs', type=int, default=50, help='Training epochs for linear transform models')
    parser.add_argument('--learning-rate', type=float, default=0.005, help='Learning rate for linear transform models')
    parser.add_argument('--loss-type', choices=['cumulative_energy', 'exponential'], default='cumulative_energy', help='Loss function type')
    parser.add_argument('--penalty-lambda', type=float, default=1.0, help='Orthogonality penalty weight for penalty method')
    parser.add_argument('--methods', nargs='+', default=['svd', 'penalty'], choices=['svd', 'penalty', 'manifold', 'exponential', 'cayley', 'givens', 'householder'], help='Linear transform methods to train')
    parser.add_argument('--train-mem-gb', type=float, default=3.0, help='Max RAM (GB) for train+eval slice')
    parser.add_argument('--transform-mb', type=int, default=64, help='Approx MB of RAM for streamed transform batches')
    parser.add_argument('--train-safety-factor', type=float, default=3.5, help='Multiplier for RAM estimate during training sizing')
    parser.add_argument('--max-train-eval', type=int, default=0, help='Hard cap on total (train+eval) vectors; 0 disables')
    parser.add_argument('--cpu-transform', action='store_true', help='Run streamed transforms on CPU to reduce GPU memory spikes')
    parser.add_argument('--transform-frac', type=float, default=1.0, help='Fraction of dataset to process at once during transformation (0.0-1.0, default: 1.0 for full streaming)')

    args = parser.parse_args()

    # Validate arguments
    if not (0.0 < args.transform_frac <= 1.0):
        print(f"Error: --transform-frac must be between 0.0 and 1.0, got {args.transform_frac}")
        sys.exit(1)

    run_name = f"transform_linear_all_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    logger = setup_logger(args.log_dir, run_name)

    ensure_writable_paths(args.models_dir, args.transformed_dir, logger)

    pairs = dataset_pairs_from_dir(args.data_dir)
    # Apply allowlist filter if provided
    if SELECTED_DATASETS:
        pairs = [p for p in pairs if p[0] in SELECTED_DATASETS]
    if not pairs:
        logger.error(f"No dataset pairs found in {args.data_dir}")
        sys.exit(1)

    logger.info(f"Found {len(pairs)} dataset pairs")
    for dataset, base_path, query_path in tqdm(pairs, desc='Datasets'):
        try:
            process_dataset(dataset, base_path, query_path, args, logger)
        except Exception as e:
            logger.exception(f"Failed on dataset {dataset}: {e}")


if __name__ == '__main__':
    main()



#python3 transform_linear_all_data.py \
#  --data-dir /mnt/device/datasets \
# --models-dir /mnt/device/models \
#  --transformed-dir /mnt/device/transformed \ 
#  --log-dir /mnt/device/logs \
#  --train-mem-gb 7.0 \
#  --methods svd penalty manifold exponential cayley givens householder \
#  --loss-type cumulative_energy \
#  --epochs 100 \
#  --learning-rate 0.001 \
#  --transform-frac 0.1

