#!/usr/bin/env python
"""
Parallel Data Preprocessing for Transformer Model Activations with Sharding Support

This script processes text datasets through transformer models to extract intermediate
activations (hidden states) from specified layers, utilizing multiple GPUs for parallel
processing. The extracted activations are saved in sharded HDF5 files for efficient
storage, retrieval, and distributed processing.

Sharding Structure:
==================
The dataset is split into multiple HDF5 files (shards), each containing a subset of sequences.
For example, with 1M sequences and 100k sequences per shard, you get 10 shard files:
- output_L26_shard_00000.h5 (sequences 0-99,999)
- output_L26_shard_00001.h5 (sequences 100,000-199,999)
- ...
- output_L26_shard_00009.h5 (sequences 900,000-999,999)

Each Shard File Structure:
==========================
1. Main Dataset: "non_padding_cache"
   - Shape: (total_tokens_in_shard, hidden_dim)
   - Type: float32
   - Description: Concatenated activations from all non-padding tokens in this shard

2. Indexing Datasets:
   a. "start": Starting indices within this shard's non_padding_cache
   b. "end": Ending indices within this shard's non_padding_cache
   c. "length": Number of non-padding tokens per sequence
   d. "global_seq_start": Global sequence ID of first sequence in this shard
   e. "global_seq_end": Global sequence ID of last sequence + 1 in this shard

3. Metadata (attributes):
   - "shard_id": Index of this shard
   - "total_shards": Total number of shards
   - "sequences_in_shard": Number of sequences in this shard
   - "hidden_dim": Dimension of the hidden states

Example Usage:
=============
```python
import h5py
import glob

# Load all shards
shard_files = sorted(glob.glob('output_L26_shard_*.h5'))

# Process each shard
for shard_file in shard_files:
    with h5py.File(shard_file, 'r') as f:
        # Access shard metadata
        shard_id = f.attrs['shard_id']
        global_start = f['global_seq_start'][()]
        global_end = f['global_seq_end'][()]

        # Access activations in this shard
        activations = f['non_padding_cache'][:]
        starts = f['start'][:]
        ends = f['end'][:]

        # Process sequences in this shard
        for local_idx in range(len(starts)):
            global_idx = global_start + local_idx
            seq_activations = activations[starts[local_idx]:ends[local_idx]]
```
"""

import os
import argparse
import numpy as np
import torch
import torch.multiprocessing as mp
import h5py
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformers import AutoTokenizer
from datasets import load_dataset, Dataset
import math
from threading import Lock
import time
import pyarrow as pa
import pyarrow.parquet as pq
import json

def autodetect_pad_id(tokenizer):
    """Autodetect the padding token ID from the tokenizer."""
    if tokenizer.pad_token_id is not None:
        return tokenizer.pad_token_id

    if tokenizer.eos_token_id is not None:
        return tokenizer.eos_token_id

    try:
        endoftext_id = tokenizer('<|endoftext|>').input_ids[0]
        return endoftext_id
    except:
        pass

    return 0

def autodetect_bos_id(tokenizer):
    """Autodetect the BOS (beginning of sequence) token ID from the tokenizer."""
    if tokenizer.bos_token_id is not None:
        return tokenizer.bos_token_id

    # Try common BOS tokens
    try:
        bos_id = tokenizer('<|startoftext|>').input_ids[0]
        return bos_id
    except:
        pass

    try:
        bos_id = tokenizer('<s>').input_ids[0]
        return bos_id
    except:
        pass

    return None  # Return None if no BOS token is found

def process_batch_on_gpu(rank, args, dataset, start_idx, end_idx, output_queue, progress_queue):
    """Process a batch of data on a specific GPU."""
    try:
        # Set device for this process
        device = f"cuda:{rank}"
        torch.cuda.set_device(rank)

        # Load model and tokenizer for this GPU
        tl_model = HookedTransformer.from_pretrained_no_processing(
            args.model_path, device=device
        ).to(device)
        tokenizer = AutoTokenizer.from_pretrained(args.model_path)
        tl_model.eval()

        # Get pad_id
        if args.pad_id is None:
            pad_id = autodetect_pad_id(tokenizer)
        else:
            pad_id = args.pad_id

        # Get bos_id if removal is requested
        bos_id = None
        if args.remove_bos:
            if args.bos_id is None:
                bos_id = autodetect_bos_id(tokenizer)
                if bos_id is None and rank == 0:
                    print("Warning: Could not detect BOS token ID, BOS removal disabled")
            else:
                bos_id = args.bos_id

        # Build hook names
        layers = args.layers
        hook_names = [f"blocks.{L}.hook_resid_post" for L in layers]
        max_layer = max(layers)

        # Process assigned data range
        batch_num = 0
        for i in range(start_idx, end_idx, args.batch_size):
            batch_end = min(i + args.batch_size, end_idx)
            batch_texts = dataset["train"][i:batch_end]["text"]

            if args.truncate_to_max_length:
                toks = tokenizer(
                    batch_texts,
                    return_tensors="pt",
                    padding='max_length',
                    truncation=True,
                    max_length=args.max_length
                ).input_ids.to(device)
            else:
                toks = tokenizer(
                    batch_texts,
                    return_tensors="pt",
                    padding='longest'
                ).input_ids.to(device)

            with torch.no_grad():
                _, cache = tl_model.run_with_cache(
                    toks,
                    stop_at_layer=max_layer+1,
                    names_filter=hook_names
                )

            # Create masks for tokens to exclude
            exclude_mask = toks == pad_id

            # Also exclude BOS tokens if requested
            if bos_id is not None and args.remove_bos:
                bos_mask = toks == bos_id
                exclude_mask = exclude_mask | bos_mask

            keep_mask = ~exclude_mask
            lengths = keep_mask.sum(dim=1)

            # Prepare output data
            batch_data = {}
            for L, hook in zip(layers, hook_names):
                raw = cache[hook]
                flat = raw[keep_mask]
                batch_data[L] = flat.cpu().numpy()

            # Send results to main process
            output_queue.put({
                'batch_idx': i,  # This is the global sample index for this batch
                'batch_data': batch_data,
                'lengths': lengths.cpu().numpy(),
                'truncated': keep_mask[:, -1].sum().item() if keep_mask.shape[1] > 0 else 0
            })

            # Update progress
            progress_queue.put(1)
            batch_num += 1

            # Clear cache
            del cache
            torch.cuda.empty_cache()

    except Exception as e:
        print(f"Error in GPU {rank}: {e}")
        import traceback
        traceback.print_exc()
        output_queue.put(None)

class ShardWriter:
    """Manages writing to multiple HDF5 shard files."""

    def __init__(self, args, dataset_len, hidden_dim):
        self.args = args
        self.dataset_len = dataset_len
        self.hidden_dim = hidden_dim
        self.layers = args.layers

        # Calculate sharding info
        self.sequences_per_shard = args.sequences_per_shard
        self.num_shards = math.ceil(dataset_len / self.sequences_per_shard)

        # Current shard info
        self.current_shard = 0
        self.sequences_in_current_shard = 0
        self.current_shard_handles = None
        self.current_shard_dsets = None
        self.current_shard_parts = None
        self.current_shard_start_seq = 0
        self.current_end = 0  # Running total within current shard

        print(f"Will create {self.num_shards} shards with up to {self.sequences_per_shard} sequences each")

        # Open first shard
        self._open_new_shard()

    def _get_shard_filename(self, layer, shard_id):
        """Generate shard filename for a given layer and shard."""
        base = self.args.output_files[self.layers.index(layer)]
        # Handle filenames with dots in them (e.g., Gemma2.2-2B)
        # Only treat the last .h5, .hdf5, .arrow, or .parquet as extension
        if base.endswith(('.h5', '.hdf5', '.arrow', '.parquet')):
            for ext in ['.h5', '.hdf5', '.arrow', '.parquet']:
                if base.endswith(ext):
                    base_name = base[:-len(ext)]
                    break
        else:
            base_name = base
            ext = '.h5'
        return f"{base_name}_shard_{shard_id:05d}{ext}"

    def _open_new_shard(self):
        """Open a new shard for writing."""
        if self.current_shard_handles:
            self._close_current_shard()

        self.current_shard_handles = {}
        self.current_shard_dsets = {}
        self.current_shard_parts = {}

        # Calculate sequences in this shard
        remaining = self.dataset_len - self.current_shard_start_seq
        seqs_in_shard = min(self.sequences_per_shard, remaining)

        for L in self.layers:
            filename = self._get_shard_filename(L, self.current_shard)
            os.makedirs(os.path.dirname(filename) if os.path.dirname(filename) else '.', exist_ok=True)

            f = h5py.File(filename, "w")

            # Store metadata
            f.attrs['shard_id'] = self.current_shard
            f.attrs['total_shards'] = self.num_shards
            f.attrs['sequences_in_shard'] = seqs_in_shard
            f.attrs['hidden_dim'] = self.hidden_dim
            f.attrs['layer'] = L

            # Create main dataset with optimized chunking for write performance
            # Larger chunks reduce write overhead
            chunk_size = min(self.args.batch_size * 10, 1000)
            self.current_shard_dsets[L] = f.create_dataset(
                "non_padding_cache",
                shape=(0, self.hidden_dim),
                maxshape=(None, self.hidden_dim),
                dtype=np.float32,
                chunks=(chunk_size, self.hidden_dim)
                # No compression for faster writes
            )

            # Create indexing datasets
            self.current_shard_parts[L] = {
                "start": f.create_dataset("start", shape=(seqs_in_shard,), dtype=np.int64),
                "end": f.create_dataset("end", shape=(seqs_in_shard,), dtype=np.int64),
                "length": f.create_dataset("length", shape=(seqs_in_shard,), dtype=np.int64)
            }

            # Store global sequence range for this shard
            f.create_dataset("global_seq_start", data=self.current_shard_start_seq)
            f.create_dataset("global_seq_end", data=self.current_shard_start_seq + seqs_in_shard)

            self.current_shard_handles[L] = f

        self.sequences_in_current_shard = 0
        self.current_end = 0
        print(f"Opened shard {self.current_shard} for sequences {self.current_shard_start_seq}-{self.current_shard_start_seq + seqs_in_shard - 1}")

    def _close_current_shard(self):
        """Close the current shard files."""
        if self.current_shard_handles:
            for L in self.layers:
                filename = self._get_shard_filename(L, self.current_shard)
                self.current_shard_handles[L].close()
                print(f"Closed shard {self.current_shard}, layer {L}: {filename}")

    def write_batch(self, batch_idx, batch_data, lengths):
        """Write a batch of data, potentially spanning multiple shards."""
        batch_size = len(lengths)

        # Buffer data for batch writing
        buffer_data = {L: [] for L in self.layers}
        buffer_indices = []

        for batch_offset in range(batch_size):
            # Check if we need a new shard
            if self.sequences_in_current_shard >= self.sequences_per_shard:
                # Write any buffered data before switching shards
                if buffer_indices:
                    self._flush_buffer(buffer_data, buffer_indices)
                    buffer_data = {L: [] for L in self.layers}
                    buffer_indices = []

                self.current_shard += 1
                self.current_shard_start_seq += self.sequences_in_current_shard
                self._open_new_shard()

            # Calculate indices for this sequence
            seq_length = lengths[batch_offset]
            start = self.current_end
            self.current_end += seq_length
            end = self.current_end

            # Buffer indexing info
            local_idx = self.sequences_in_current_shard
            buffer_indices.append({
                'local_idx': local_idx,
                'start': start,
                'end': end,
                'length': seq_length
            })

            # Buffer activation data for this sequence
            for L in self.layers:
                # Calculate the slice for this sequence in the batch data
                batch_start = sum(lengths[:batch_offset])
                batch_end = batch_start + seq_length
                seq_data = batch_data[L][batch_start:batch_end]
                buffer_data[L].append(seq_data)

            self.sequences_in_current_shard += 1

        # Write remaining buffered data
        if buffer_indices:
            self._flush_buffer(buffer_data, buffer_indices)

        return batch_idx // self.args.batch_size

    def _flush_buffer(self, buffer_data, buffer_indices):
        """Write buffered data to HDF5 files efficiently."""
        if not buffer_indices:
            return

        # Write indexing info
        for idx_info in buffer_indices:
            local_idx = idx_info['local_idx']
            for L in self.layers:
                self.current_shard_parts[L]["start"][local_idx] = idx_info['start']
                self.current_shard_parts[L]["end"][local_idx] = idx_info['end']
                self.current_shard_parts[L]["length"][local_idx] = idx_info['length']

        # Batch write activation data
        for L in self.layers:
            if buffer_data[L]:
                # Concatenate all sequences in buffer
                concat_data = np.concatenate(buffer_data[L], axis=0)

                # Single resize and write operation
                ds = self.current_shard_dsets[L]
                old_size = ds.shape[0]
                new_size = old_size + concat_data.shape[0]
                ds.resize((new_size, self.hidden_dim))
                ds[old_size:new_size] = concat_data

    def close_all(self):
        """Close all open shard files."""
        self._close_current_shard()
        print(f"All {self.current_shard + 1} shards have been written")


class ArrowShardWriter:
    """Manages writing to Arrow/Parquet format for HuggingFace compatibility."""

    def __init__(self, args, dataset_len, hidden_dim):
        self.args = args
        self.dataset_len = dataset_len
        self.hidden_dim = hidden_dim
        self.layers = args.layers

        # Calculate sharding info
        self.sequences_per_shard = args.sequences_per_shard
        self.num_shards = math.ceil(dataset_len / self.sequences_per_shard)

        # Current shard info
        self.current_shard = 0
        self.sequences_in_current_shard = 0
        self.current_shard_data = {L: [] for L in self.layers}
        self.current_shard_start_seq = 0

        # Output format
        self.use_parquet = args.arrow_format == 'parquet'

        print(f"Will create {self.num_shards} Arrow/Parquet shards with up to {self.sequences_per_shard} sequences each")

        # Create output directories
        for L in self.layers:
            base_dir = self._get_output_dir(L)
            os.makedirs(base_dir, exist_ok=True)

    def _get_output_dir(self, layer):
        """Get output directory for a given layer."""
        base = self.args.output_files[self.layers.index(layer)]
        # Handle filenames with dots in them (e.g., Gemma2.2-2B)
        # Only treat the last .h5, .hdf5, .arrow, or .parquet as extension
        if base.endswith(('.h5', '.hdf5', '.arrow', '.parquet')):
            for ext in ['.h5', '.hdf5', '.arrow', '.parquet']:
                if base.endswith(ext):
                    base_name = base[:-len(ext)]
                    break
        else:
            base_name = base
        return base_name + "_arrow"

    def _get_shard_filename(self, layer, shard_id):
        """Generate shard filename for a given layer and shard."""
        output_dir = self._get_output_dir(layer)
        ext = 'parquet' if self.use_parquet else 'arrow'
        return os.path.join(output_dir, f"data-{shard_id:05d}.{ext}")

    def _write_current_shard(self):
        """Write current shard data to Arrow/Parquet files."""
        if not self.current_shard_data[self.layers[0]]:
            return

        for L in self.layers:
            # Create dataset from accumulated data
            dataset = Dataset.from_list(self.current_shard_data[L])

            # Get output file
            output_file = self._get_shard_filename(L, self.current_shard)

            # Write to file
            if self.use_parquet:
                dataset.to_parquet(output_file)
            else:
                # Save as Arrow format - use to_parquet with Arrow writer
                # Note: save_to_disk creates a directory, we want a single file
                import pyarrow as pa
                import pyarrow.feather as feather

                # Convert to Arrow table and save as single file
                arrow_table = pa.Table.from_pydict({
                    'activations': [item['activations'] for item in self.current_shard_data[L]],
                    'sequence_length': [item['sequence_length'] for item in self.current_shard_data[L]],
                    'layer': [item['layer'] for item in self.current_shard_data[L]],
                    'global_idx': [item['global_idx'] for item in self.current_shard_data[L]],
                    'shard_id': [item['shard_id'] for item in self.current_shard_data[L]]
                })

                # Write as Arrow IPC file (single file)
                with pa.OSFile(output_file, 'wb') as sink:
                    with pa.ipc.new_file(sink, arrow_table.schema) as writer:
                        writer.write_table(arrow_table)

            print(f"Wrote shard {self.current_shard}, layer {L}: {output_file}")

            # Clear data for next shard
            self.current_shard_data[L] = []

    def write_batch(self, batch_idx, batch_data, lengths):
        """Write a batch of data, potentially spanning multiple shards."""
        batch_size = len(lengths)

        for batch_offset in range(batch_size):
            # Check if we need a new shard
            if self.sequences_in_current_shard >= self.sequences_per_shard:
                # Write current shard
                self._write_current_shard()

                # Reset for new shard
                self.current_shard += 1
                self.current_shard_start_seq += self.sequences_in_current_shard
                self.sequences_in_current_shard = 0

            # Get sequence data
            seq_length = lengths[batch_offset]
            global_idx = self.current_shard_start_seq + self.sequences_in_current_shard

            # Add data for each layer
            for L in self.layers:
                # Calculate the slice for this sequence in the batch data
                batch_start = sum(lengths[:batch_offset])
                batch_end = batch_start + seq_length
                seq_activations = batch_data[L][batch_start:batch_end]

                # Add to current shard data
                self.current_shard_data[L].append({
                    'activations': seq_activations.tolist() if isinstance(seq_activations, np.ndarray) else seq_activations,
                    'sequence_length': int(seq_length),
                    'layer': int(L),
                    'global_idx': int(global_idx),
                    'shard_id': int(self.current_shard)
                })

            self.sequences_in_current_shard += 1

        return batch_idx // self.args.batch_size

    def close_all(self):
        """Write any remaining data and create metadata files."""
        # Write final shard if there's data
        if self.current_shard_data[self.layers[0]]:
            self._write_current_shard()

        # Create metadata files for each layer
        for L in self.layers:
            output_dir = self._get_output_dir(L)

            # Create dataset_info.json
            info = {
                "dataset_name": f"activations_layer_{L}",
                "hidden_dim": self.hidden_dim,
                "layer": L,
                "num_shards": self.current_shard + 1,
                "total_sequences": self.current_shard_start_seq + self.sequences_in_current_shard,
                "format": "parquet" if self.use_parquet else "arrow",
                "features": {
                    "activations": {
                        "dtype": "float32",
                        "shape": ["sequence_length", self.hidden_dim]
                    },
                    "sequence_length": {"dtype": "int64"},
                    "layer": {"dtype": "int64"},
                    "global_idx": {"dtype": "int64"},
                    "shard_id": {"dtype": "int64"}
                }
            }

            with open(os.path.join(output_dir, "dataset_info.json"), "w") as f:
                json.dump(info, f, indent=2)

            # Create README
            ext = 'parquet' if self.use_parquet else 'arrow'
            readme = f"""# Activation Dataset - Layer {L}

This dataset contains transformer activations from layer {L}.

## Loading the dataset

```python
from datasets import load_dataset

# Load all shards
dataset = load_dataset('{ext}', data_files='{output_dir}/data-*.{ext}')

# Or load specific shards
dataset = load_dataset('{ext}', data_files=[
    '{output_dir}/data-00000.{ext}',
    '{output_dir}/data-00001.{ext}'
])
```

## Features
- `activations`: Tensor of shape [sequence_length, {self.hidden_dim}]
- `sequence_length`: Number of tokens in the sequence
- `layer`: Layer number ({L})
- `global_idx`: Global sequence index
- `shard_id`: Shard file index

## Metadata
- Total sequences: {self.current_shard_start_seq + self.sequences_in_current_shard}
- Number of shards: {self.current_shard + 1}
- Hidden dimension: {self.hidden_dim}
"""

            with open(os.path.join(output_dir, "README.md"), "w") as f:
                f.write(readme)

        print(f"All {self.current_shard + 1} Arrow/Parquet shards have been written")

def writer_process(args, dataset_len, output_queue, num_gpus, total_batches):
    """Process to handle all writing with sharding (HDF5 or Arrow)."""
    layers = args.layers

    # Get first result to determine hidden_dim
    first_result = output_queue.get()
    if first_result is None:
        print("Error: Failed to get first batch")
        return

    hidden_dim = first_result['batch_data'][layers[0]].shape[1]

    # Initialize appropriate writer based on output format
    if args.output_format in ['arrow', 'parquet']:
        shard_writer = ArrowShardWriter(args, dataset_len, hidden_dim)
    else:
        shard_writer = ShardWriter(args, dataset_len, hidden_dim)

    # Generate list of expected batch indices
    expected_indices = []
    for idx in range(0, dataset_len, args.batch_size):
        expected_indices.append(idx)

    # Process results
    results_buffer = {first_result['batch_idx']: first_result}
    processed_indices = set()
    total_received = 1  # We already got first_result
    total_trunc = 0
    samples_processed = 0

    pbar = tqdm(total=total_batches, desc="Writing batches")

    # Process all batches
    import queue
    while total_received < total_batches:
        # Try to get more results
        try:
            result = output_queue.get(timeout=60)  # 60 second timeout
        except queue.Empty:
            print(f"Timeout waiting for batch {total_received}/{total_batches}")
            print(f"Buffered batch indices: {sorted(results_buffer.keys())}")
            print(f"Processed indices: {sorted(list(processed_indices))[:10]}...")
            print(f"Missing indices: {[idx for idx in expected_indices if idx not in processed_indices][:10]}...")
            break

        if result is None:
            print(f"Error: Received None result at batch {total_received}/{total_batches}")
            break

        results_buffer[result['batch_idx']] = result
        total_received += 1

        # Try to write any buffered results that we can process now
        for idx in sorted(results_buffer.keys()):
            if idx in processed_indices:
                continue  # Already processed

            res = results_buffer.pop(idx)

            # Count truncated sequences and processed samples for this batch
            batch_size = len(res['lengths'])
            samples_processed += batch_size
            total_trunc += res['truncated']

            shard_writer.write_batch(
                res['batch_idx'],
                res['batch_data'],
                res['lengths']
            )

            processed_indices.add(idx)
            pbar.update(1)

            # Periodic status update
            if len(processed_indices) % 10 == 0:
                pbar.set_postfix({
                    'shard': shard_writer.current_shard,
                    'truncated': f"{total_trunc/samples_processed:.2%}" if samples_processed > 0 else "0.00%"
                })

    pbar.close()

    # Handle any remaining buffered results
    for idx in sorted(results_buffer.keys()):
        if idx not in processed_indices:
            res = results_buffer[idx]
            batch_size = len(res['lengths'])
            samples_processed += batch_size
            total_trunc += res['truncated']

            shard_writer.write_batch(
                res['batch_idx'],
                res['batch_data'],
                res['lengths']
            )
            processed_indices.add(idx)

    # Print final truncation stats
    if samples_processed > 0:
        print(f"Final truncation rate: {total_trunc/samples_processed:.2%} ({total_trunc}/{samples_processed} sequences)")

    # Close all shard files
    shard_writer.close_all()

def progress_monitor(progress_queue, total_batches):
    """Monitor progress from all GPU processes."""
    pbar = tqdm(total=total_batches, desc="Processing batches")
    count = 0
    while count < total_batches:
        progress_queue.get()
        count += 1
        pbar.update(1)
    pbar.close()

def main():
    parser = argparse.ArgumentParser(description="Parallel sharded dataset processing with transformer model")
    parser.add_argument("--max_length", type=int, default=2048, help="Maximum sequence length")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size per GPU")
    parser.add_argument("--dataset", type=str, default=None, help="Hugging Face dataset to use (e.g., 'timaeus/pile-github')")
    parser.add_argument("--arrow_dir", type=str, default=None, help="Directory containing arrow files to load")
    parser.add_argument("--truncate_to_max_length", action="store_true", help="Truncate sequences to max length.")
    parser.add_argument("--layers", type=int, nargs='+', default=[26], help="List of layers to hook into")
    parser.add_argument("--output_files", type=str, nargs='+', default=["Pile-github_Qwen2.5-1.5B_L26.h5"],
                        help="Output HDF5 file paths (will be sharded)")
    parser.add_argument("--model_path", type=str, default="Qwen/Qwen2.5-1.5B", help="Path to the model")
    parser.add_argument("--pad_id", type=int, default=None, help="Padding token ID (default: autodetect from tokenizer)")
    parser.add_argument("--remove_bos", action="store_true", help="Remove BOS (beginning of sequence) tokens from activations")
    parser.add_argument("--bos_id", type=int, default=None, help="BOS token ID (default: autodetect from tokenizer if --remove_bos is set)")
    parser.add_argument("--num_gpus", type=int, default=None, help="Number of GPUs to use (default: all available)")
    parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to process (for testing)")
    parser.add_argument("--sequences_per_shard", type=int, default=100000,
                        help="Number of sequences per shard file (default: 100000)")
    parser.add_argument("--output_format", type=str, default="h5",
                        choices=["h5", "arrow", "parquet"],
                        help="Output format: h5 (HDF5), arrow, or parquet (default: h5)")
    args = parser.parse_args()

    # Set arrow_format for ArrowShardWriter compatibility
    if args.output_format == 'parquet':
        args.arrow_format = 'parquet'
    elif args.output_format == 'arrow':
        args.arrow_format = 'arrow'

    # Ensure one output file per layer
    assert len(args.layers) == len(args.output_files), \
        f"Must provide one output file per layer. Got {len(args.layers)} layers and {len(args.output_files)} files."

    # Ensure either dataset or arrow_dir is specified
    if args.dataset is None and args.arrow_dir is None:
        raise ValueError("Must specify either --dataset for Hugging Face datasets or --arrow_dir for local arrow files")

    if args.dataset is not None and args.arrow_dir is not None:
        raise ValueError("Cannot specify both --dataset and --arrow_dir. Choose one.")

    # Determine number of GPUs
    if args.num_gpus is None:
        num_gpus = torch.cuda.device_count()
    else:
        num_gpus = min(args.num_gpus, torch.cuda.device_count())

    if num_gpus == 0:
        print("No GPUs available. Please use the single-GPU version.")
        return

    print(f"Using {num_gpus} GPUs for parallel processing")
    print(f"Output format: {args.output_format}")
    print(f"Sequences per shard: {args.sequences_per_shard}")

    # Load dataset
    if args.dataset:
        print(f"Loading dataset from Hugging Face: {args.dataset}")
        ds = load_dataset(args.dataset)
    else:
        print(f"Loading dataset from arrow files in: {args.arrow_dir}")
        import glob

        arrow_files = sorted(glob.glob(os.path.join(args.arrow_dir, "*.arrow")))
        if not arrow_files:
            raise ValueError(f"No arrow files found in {args.arrow_dir}")

        print(f"Found {len(arrow_files)} arrow files")

        # Use HuggingFace's load_dataset which efficiently handles arrow files with memory mapping
        # This avoids loading everything into memory at once
        ds = load_dataset(
            "arrow",
            data_files={"train": arrow_files},
            split={"train": "train"},
            keep_in_memory=False  # Use memory mapping instead of loading all into RAM
        )

        # Wrap in dictionary to match expected format
        if not isinstance(ds, dict):
            ds = {"train": ds}

        print(f"Loaded dataset with {len(ds['train'])} samples from arrow files (memory-mapped)")

    if args.max_samples:
        dataset_len = min(len(ds["train"]), args.max_samples)
    else:
        dataset_len = len(ds["train"])

    print(f"Processing {dataset_len} samples")
    print(f"Will create {math.ceil(dataset_len / args.sequences_per_shard)} shard files")

    # Calculate work distribution
    samples_per_gpu = dataset_len // num_gpus
    remainder = dataset_len % num_gpus

    gpu_ranges = []
    start = 0
    for i in range(num_gpus):
        size = samples_per_gpu + (1 if i < remainder else 0)
        gpu_ranges.append((start, start + size))
        start += size

    print("GPU work distribution:")
    for i, (start, end) in enumerate(gpu_ranges):
        print(f"  GPU {i}: samples {start}-{end} ({end-start} samples)")

    # Set up multiprocessing
    mp.set_start_method('spawn', force=True)

    # Create queues for communication with larger buffer
    # This allows GPUs to continue processing while writer is busy
    output_queue = mp.Queue(maxsize=num_gpus * 4)  # Allow buffering of results
    progress_queue = mp.Queue()

    # Calculate total number of batches
    total_batches = math.ceil(dataset_len / args.batch_size)

    # Start writer process
    writer = mp.Process(
        target=writer_process,
        args=(args, dataset_len, output_queue, num_gpus, total_batches)
    )
    writer.start()

    # Start GPU processing processes
    processes = []
    for rank in range(num_gpus):
        start_idx, end_idx = gpu_ranges[rank]
        p = mp.Process(
            target=process_batch_on_gpu,
            args=(rank, args, ds, start_idx, end_idx, output_queue, progress_queue)
        )
        p.start()
        processes.append(p)
        # Small delay to avoid simultaneous model loading
        time.sleep(2)

    # Monitor progress
    progress_monitor(progress_queue, total_batches)

    # Wait for all GPU processes to complete
    for p in processes:
        p.join()

    # Signal writer to finish
    writer.join()

    print("Processing complete!")

if __name__ == "__main__":
    main()