#!/bin/bash
#SBATCH --job-name=sae-train
#SBATCH --partition=gpu
#SBATCH --gres=gpu:h200:1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8
#SBATCH --mem=192G
#SBATCH --time=24:00:00
#SBATCH --output=Pile-Qwen2.5-1.5B-hook-mlp-out-SAE/slurm_output/sae-%j.out
#SBATCH --error=Pile-Qwen2.5-1.5B-hook-mlp-out-SAE/slurm_output/sae-%j.err
#SBATCH --requeue

echo '================================================================================'
echo 'SAE Training with Efficient Memory Block Loading'
echo '================================================================================'
cd ${SLURM_SUBMIT_DIR}
echo "Submit Directory: ${SLURM_SUBMIT_DIR}"
echo "Running on host: $(hostname)"
echo "Time: $(date)"
echo "SLURM_NODES: ${SLURM_NODELIST}"
echo "Number of GPUs: ${SLURM_GPUS}"
nvidia-smi --query-gpu=name,memory.total --format=csv
echo '================================================================================'
echo ''

# Set the working directory
cd <PROJECT_DIR>/SAE_Dymistified

# Load conda environment
module load miniconda
conda activate <ENV_NAME>

################################################################################
# CONFIGURATION SECTION - Modify these variables as needed
################################################################################

# Checkpoint configuration (defined early for directory creation)
CKPT_FOLDER="<SCRATCH_DIR>/Pile_gemma2_2b_L12_500k_samples/checkpoints"

# Create output directories if they don't exist
mkdir -p <PROJECT_DIR>/SAE_Dymistified/Pile-Qwen2.5-1.5B-hook-mlp-out-SAE/slurm_output
mkdir -p ${CKPT_FOLDER}

# Data configuration
SHARD_PATTERN="<SCRATCH_DIR>/Pile_gemma2_2b_L12_500k_samples/Pile-uncopyrighted_Gemma2-2B_L12_500k_shard_*.h5"
BATCH_SIZE=8192
NUM_WORKERS=2
MAX_EPOCHS=1
SPLIT_RATIO=0.999

# High norm filtering (set to empty to disable)
# For Qwen/Gemma models with activation spikes, use 10.0
# For aggressive filtering, use 5.0
# Leave empty for no filtering
REMOVE_HIGH_NORM=10.0
FILTER_CACHE_SIZE=10000

# Data normalization (set to true to normalize by mean norm)
# This divides all data by the mean norm computed from sampled data
# Helps stabilize training by ensuring consistent input magnitudes
NORMALIZE_BY_MEAN=true

# Model configuration
NUM_NEURONS=65536
ACTIVATION="relu"
LR=1e-4
L1_DECAY=0.0

# Advanced training options
USE_TUNABLE_THRESHOLD=true
TUNE_B_DEC=true
ADJUST_B_ENC=true
FREQ_THRESHOLD_HIGH="mixed"
NUM_GROUPS=20
INIT_FTH=0.3
END_FTH=0.001
FACTOR_UP=0.01  # Matched to L26 script
FACTOR_DOWN=0.01  # Matched to L26 script
# NORMALIZE_BATCH_THRESHOLD=150  # Commented out to use default (no tanh threshold)
DIVIDE_BY=1.0
CLAMP_B_ENC_MAX=0.0

# Experiment tracking
WANDB_PROJECT="<WANDB_PROJECT>"
WANDB_ENTITY="<WANDB_ENTITY>"  # Your Wandb organization/team name
EXP_NAME="SAE-Gemma2-L12"
SEED=42

################################################################################
# BUILD COMMAND
################################################################################

echo "Configuration:"
echo "- Dataset: ${SHARD_PATTERN}"
echo "- Batch Size: ${BATCH_SIZE}"
echo "- Workers: ${NUM_WORKERS}"
echo "- Max Epochs: ${MAX_EPOCHS}"
echo "- Learning Rate: ${LR}"
echo "- Number of Neurons: ${NUM_NEURONS}"
echo "- Checkpoint Folder: ${CKPT_FOLDER}/"
echo "- Wandb Project: ${WANDB_PROJECT}"
echo "- Wandb Entity: ${WANDB_ENTITY}"
echo "- Experiment Name: ${EXP_NAME}"

if [ -n "${REMOVE_HIGH_NORM}" ]; then
    echo "- High Norm Filtering: ${REMOVE_HIGH_NORM} (cache: ${FILTER_CACHE_SIZE} samples)"
else
    echo "- High Norm Filtering: DISABLED"
fi

if [ "${NORMALIZE_BY_MEAN}" = true ]; then
    echo "- Mean Normalization: ENABLED (dividing by mean norm)"
else
    echo "- Mean Normalization: DISABLED"
fi

echo ""
echo "Starting training with Memory Block loading (optimized for best performance)..."
echo "================================================================================"

# Build the command (add parent directory to PYTHONPATH for imports)
cd <PROJECT_DIR>/SAE_Dymistified
export PYTHONPATH="${PYTHONPATH}:<PROJECT_DIR>/SAE_Dymistified/Pile-Qwen2.5-1.5B-hook-mlp-out-SAE"
CMD="python Pile-Qwen2.5-1.5B-hook-mlp-out-SAE/core/train_entry_sharded.py"
CMD="${CMD} --shard_pattern \"${SHARD_PATTERN}\""
CMD="${CMD} --batch_size ${BATCH_SIZE}"
CMD="${CMD} --num_workers ${NUM_WORKERS}"
CMD="${CMD} --max_epochs ${MAX_EPOCHS}"
CMD="${CMD} --split_ratio ${SPLIT_RATIO}"
CMD="${CMD} --num_neurons ${NUM_NEURONS}"
CMD="${CMD} --activation ${ACTIVATION}"
CMD="${CMD} --lr ${LR}"
CMD="${CMD} --L1_decay ${L1_DECAY}"
CMD="${CMD} --wandb_project ${WANDB_PROJECT}"
CMD="${CMD} --wandb_entity ${WANDB_ENTITY}"
CMD="${CMD} --exp_name ${EXP_NAME}"
CMD="${CMD} --seed ${SEED}"
CMD="${CMD} --ckpt_folder ${CKPT_FOLDER}"

# Add optional high norm filtering
if [ -n "${REMOVE_HIGH_NORM}" ]; then
    CMD="${CMD} --remove_high_norm ${REMOVE_HIGH_NORM}"
    CMD="${CMD} --filter_cache_size ${FILTER_CACHE_SIZE}"
fi

# Add optional mean normalization
if [ "${NORMALIZE_BY_MEAN}" = true ]; then
    CMD="${CMD} --normalize_by_mean"
fi

# Add advanced training options
if [ "${USE_TUNABLE_THRESHOLD}" = true ]; then
    CMD="${CMD} --use_tunable_threshold_activation"
fi

if [ "${TUNE_B_DEC}" = true ]; then
    CMD="${CMD} --tune_b_dec"
fi

if [ "${ADJUST_B_ENC}" = true ]; then
    CMD="${CMD} --adjust_b_enc"
fi

CMD="${CMD} --freq_threshold_high ${FREQ_THRESHOLD_HIGH}"
CMD="${CMD} --num_groups ${NUM_GROUPS}"
CMD="${CMD} --init_FTH ${INIT_FTH}"
CMD="${CMD} --end_FTH ${END_FTH}"
CMD="${CMD} --factor_up ${FACTOR_UP}"
CMD="${CMD} --factor_down ${FACTOR_DOWN}"

# Add normalize_batch_with_tanh_threshold only if set
if [ -n "${NORMALIZE_BATCH_THRESHOLD}" ]; then
    CMD="${CMD} --normalize_batch_with_tanh_threshold ${NORMALIZE_BATCH_THRESHOLD}"
fi

CMD="${CMD} --divide_by ${DIVIDE_BY}"

if [ -n "${CLAMP_B_ENC_MAX}" ]; then
    CMD="${CMD} --clamp_b_enc_max ${CLAMP_B_ENC_MAX}"
fi

# Execute the command
eval ${CMD}

echo ""
echo "================================================================================"
echo "Training completed at $(date)"
echo "================================================================================"

# List checkpoint files
echo ""
echo "Generated checkpoint files in ${CKPT_FOLDER}:"
ls -lh ${CKPT_FOLDER}/${EXP_NAME}-*/ 2>/dev/null | head -10 || echo "No checkpoints found"

echo ""
echo "================================================================================"
echo "PERFORMANCE NOTES:"
echo "================================================================================"
echo "The Memory Block DataModule is now the default and provides:"
echo "- Loads entire shards into memory for optimal I/O"
echo "- In-memory shuffling for better training dynamics"
echo "- Efficient high norm filtering when enabled"
echo "- Expected speed: 1-2M samples/sec with batch_size=8192"
echo ""
echo "To use alternative loading methods:"
echo "- Add --no_memory_block for direct loading (supports runtime shuffling)"
echo "- Add --use_buffered for legacy buffered loading (deprecated)"
echo "================================================================================"