#!/bin/bash
#SBATCH --job-name=process-500k-sharded # Job name
#SBATCH --partition=gpu # Partition name
#SBATCH --gres=gpu:a40:4  # Request 4 GPUs for parallel processing
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=32  # More CPUs for parallel processing
#SBATCH --time=24:00:00  # Longer time for 500k samples
#SBATCH --output=Pile-Qwen2.5-1.5B-hook-mlp-out-preprocessing/slurm_output/process-500k-sharded-%j.out
#SBATCH --error=Pile-Qwen2.5-1.5B-hook-mlp-out-preprocessing/slurm_output/process-500k-sharded-%j.err
#SBATCH --requeue

echo '-------------------------------'
cd ${SLURM_SUBMIT_DIR}
echo ${SLURM_SUBMIT_DIR}
echo Running on host $(hostname)
echo Time is $(date)
echo SLURM_NODES are $(echo ${SLURM_NODELIST})
echo "Number of GPUs: ${SLURM_GPUS}"
nvidia-smi --query-gpu=name,memory.total --format=csv
echo '-------------------------------'
echo -e '\n\n'

export PROCS=${SLURM_CPUS_ON_NODE}

# Set the working directory
cd /path/to/your/project/SAE_Dymistified/Pile-Qwen2.5-1.5B-hook-mlp-out-preprocessing

module load miniconda
conda activate INTERP

# Configuration variables
MODEL_NAME="gemma-2-2b"
SAMPLE_SIZE=500000
LAYER=12

# Create output directories if they don't exist
OUTPUT_DIR="/path/to/your/scratch/Pile_gemma2_2b_L12_500k_samples"
mkdir -p ${OUTPUT_DIR}
mkdir -p slurm_output

echo "Starting parallel sharded processing..."
echo "Processing ${SAMPLE_SIZE} samples with ${SLURM_GPUS} GPUs"
echo "Output directory: ${OUTPUT_DIR}"

# Process 500,000 samples
# - Use 5,000 sequences per shard (will create 100 shards)
# - Process layer 12
# - Use 2 GPUs for parallel processing

python core/data_preprocess_parallel_sharded.py \
    --model_path "google/gemma-2-2b" \
    --arrow_dir "/path/to/your/.cache/huggingface/datasets/monology___pile-uncopyrighted/default/0.0.0/3be90335b66f24456a5d6659d9c8d208c0357119.incomplete" \
    --max_length 1024 \
    --batch_size 64 \
    --layers ${LAYER} \
    --output_files \
        "${OUTPUT_DIR}/Pile-uncopyrighted_Gemma2-2B_L${LAYER}_500k.h5" \
    --truncate_to_max_length \
    --remove_bos \
    --num_gpus 4 \
    --sequences_per_shard 5000 \
    --max_samples ${SAMPLE_SIZE}

echo "Processing completed"

# List the generated shard files
echo -e "\nGenerated shard files:"
ls -lh ${OUTPUT_DIR}/*L${LAYER}*.h5 2>/dev/null | head -20

# Show file count
echo -e "\nTotal number of shard files created:"
ls ${OUTPUT_DIR}/*L${LAYER}*.h5 2>/dev/null | wc -l

# Quick validation - check one shard file structure
echo -e "\nChecking structure of first shard file:"
python -c "
import h5py
import glob

files = sorted(glob.glob('${OUTPUT_DIR}/Pile-uncopyrighted_Gemma2-2B_L${LAYER}_500k_shard_*.h5'))
if files:
    with h5py.File(files[0], 'r') as f:
        print(f'File: {files[0]}')
        print(f'Attributes:')
        for key, val in f.attrs.items():
            print(f'  {key}: {val}')
        print(f'Datasets:')
        for key in f.keys():
            if hasattr(f[key], 'shape'):
                print(f'  {key}: shape={f[key].shape}, dtype={f[key].dtype}')
            else:
                print(f'  {key}: {f[key][()]}')
else:
    print('No shard files found')
"

echo '-------------------------------'
echo "Processing completed at $(date)"