#!/bin/bash
#SBATCH --job-name=topk600-jumprelu
#SBATCH --partition=gpu
#SBATCH --gres=gpu:h200:1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8
#SBATCH --mem=192G
#SBATCH --time=24:00:00
#SBATCH --output=Pile-Qwen2.5-1.5B-hook-mlp-out-SAE/slurm_output/topk600-jumprelu-%j.out
#SBATCH --error=Pile-Qwen2.5-1.5B-hook-mlp-out-SAE/slurm_output/topk600-jumprelu-%j.err
#SBATCH --requeue

echo '================================================================================'
echo 'SAE Training with TopK=600 + TunableJumpReLU'
echo '================================================================================'
cd ${SLURM_SUBMIT_DIR}
echo "Submit Directory: ${SLURM_SUBMIT_DIR}"
echo "Running on host: $(hostname)"
echo "Time: $(date)"
echo "SLURM_NODES: ${SLURM_NODELIST}"
echo "Number of GPUs: ${SLURM_GPUS}"
nvidia-smi --query-gpu=name,memory.total --format=csv
echo '================================================================================'
echo ''

# Set the working directory
cd <PROJECT_DIR>/SAE_Dymistified

# Load conda environment
module load miniconda
conda activate <ENV_NAME>

################################################################################
# CONFIGURATION SECTION - TopK=600 with TunableJumpReLU
################################################################################

# Checkpoint configuration
CKPT_FOLDER="<SCRATCH_DIR>/Pile_gemma2_2b_L12_topk600_jumprelu/checkpoints"

# Create output directories if they don't exist
mkdir -p <PROJECT_DIR>/SAE_Dymistified/Pile-Qwen2.5-1.5B-hook-mlp-out-SAE/slurm_output
mkdir -p ${CKPT_FOLDER}

# Data configuration
SHARD_PATTERN="<SCRATCH_DIR>/Pile_gemma2_2b_L12_500k_samples/Pile-uncopyrighted_Gemma2-2B_L12_500k_shard_*.h5"
BATCH_SIZE=8192
NUM_WORKERS=2
MAX_EPOCHS=1
SPLIT_RATIO=0.999

# High norm filtering
REMOVE_HIGH_NORM=10.0
FILTER_CACHE_SIZE=10000

# Data normalization
NORMALIZE_BY_MEAN=true

# Model configuration
NUM_NEURONS=65536
ACTIVATION="relu"
LR=1e-4
L1_DECAY=0.0

# TopK specific settings
TOPK=600

# Advanced training options for TopK + TunableJumpReLU
USE_TUNABLE_THRESHOLD=true
TUNE_B_ENC=true  # Enable gradient for b_enc (important for topk)
TUNE_B_DEC=true
ADJUST_B_ENC=false  # Disable adjust_b_enc for topk
FREQ_THRESHOLD_HIGH=0.0  # Set to 0.0 for topk
NUM_GROUPS=1  # Single group for topk
INIT_FTH=0.1
END_FTH=0.001
FACTOR_UP=0.01
FACTOR_DOWN=0.01
DIVIDE_BY=1.0

# No clamping for topk
# CLAMP_B_ENC_MAX and CLAMP_B_ENC_MIN are not set (None in Python)

# Experiment tracking
WANDB_PROJECT="<WANDB_PROJECT>"
WANDB_ENTITY="<WANDB_ENTITY>"
EXP_NAME="SAE-Gemma2-L12-topk600-tunablejumprelu"
SEED=42

################################################################################
# BUILD COMMAND
################################################################################

echo "Configuration:"
echo "- Dataset: ${SHARD_PATTERN}"
echo "- Batch Size: ${BATCH_SIZE}"
echo "- Workers: ${NUM_WORKERS}"
echo "- Max Epochs: ${MAX_EPOCHS}"
echo "- Learning Rate: ${LR}"
echo "- Number of Neurons: ${NUM_NEURONS}"
echo "- TopK: ${TOPK}"
echo "- Tunable Threshold (JumpReLU): ENABLED"
echo "- Checkpoint Folder: ${CKPT_FOLDER}/"
echo "- Wandb Project: ${WANDB_PROJECT}"
echo "- Wandb Entity: ${WANDB_ENTITY}"
echo "- Experiment Name: ${EXP_NAME}"

if [ -n "${REMOVE_HIGH_NORM}" ]; then
    echo "- High Norm Filtering: ${REMOVE_HIGH_NORM} (cache: ${FILTER_CACHE_SIZE} samples)"
else
    echo "- High Norm Filtering: DISABLED"
fi

if [ "${NORMALIZE_BY_MEAN}" = true ]; then
    echo "- Mean Normalization: ENABLED (dividing by mean norm)"
else
    echo "- Mean Normalization: DISABLED"
fi

echo ""
echo "Starting training with TopK=600 + TunableJumpReLU..."
echo "================================================================================"

# Build the command
cd <PROJECT_DIR>/SAE_Dymistified
export PYTHONPATH="${PYTHONPATH}:<PROJECT_DIR>/SAE_Dymistified/Pile-Qwen2.5-1.5B-hook-mlp-out-SAE"
CMD="python Pile-Qwen2.5-1.5B-hook-mlp-out-SAE/core/train_entry_sharded.py"
CMD="${CMD} --shard_pattern \"${SHARD_PATTERN}\""
CMD="${CMD} --batch_size ${BATCH_SIZE}"
CMD="${CMD} --num_workers ${NUM_WORKERS}"
CMD="${CMD} --max_epochs ${MAX_EPOCHS}"
CMD="${CMD} --split_ratio ${SPLIT_RATIO}"
CMD="${CMD} --num_neurons ${NUM_NEURONS}"
CMD="${CMD} --activation ${ACTIVATION}"
CMD="${CMD} --lr ${LR}"
CMD="${CMD} --L1_decay ${L1_DECAY}"
CMD="${CMD} --topk ${TOPK}"
CMD="${CMD} --wandb_project ${WANDB_PROJECT}"
CMD="${CMD} --wandb_entity ${WANDB_ENTITY}"
CMD="${CMD} --exp_name ${EXP_NAME}"
CMD="${CMD} --seed ${SEED}"
CMD="${CMD} --ckpt_folder ${CKPT_FOLDER}"

# Add optional high norm filtering
if [ -n "${REMOVE_HIGH_NORM}" ]; then
    CMD="${CMD} --remove_high_norm ${REMOVE_HIGH_NORM}"
    CMD="${CMD} --filter_cache_size ${FILTER_CACHE_SIZE}"
fi

# Add optional mean normalization
if [ "${NORMALIZE_BY_MEAN}" = true ]; then
    CMD="${CMD} --normalize_by_mean"
fi

# Add advanced training options
if [ "${USE_TUNABLE_THRESHOLD}" = true ]; then
    CMD="${CMD} --use_tunable_threshold_activation"
fi

if [ "${TUNE_B_ENC}" = true ]; then
    CMD="${CMD} --tune_b_enc"
fi

if [ "${TUNE_B_DEC}" = true ]; then
    CMD="${CMD} --tune_b_dec"
fi

if [ "${ADJUST_B_ENC}" = true ]; then
    CMD="${CMD} --adjust_b_enc"
fi

CMD="${CMD} --freq_threshold_high ${FREQ_THRESHOLD_HIGH}"
CMD="${CMD} --num_groups ${NUM_GROUPS}"
CMD="${CMD} --init_FTH ${INIT_FTH}"
CMD="${CMD} --end_FTH ${END_FTH}"
CMD="${CMD} --factor_up ${FACTOR_UP}"
CMD="${CMD} --factor_down ${FACTOR_DOWN}"
CMD="${CMD} --divide_by ${DIVIDE_BY}"

# Execute the command
eval ${CMD}

echo ""
echo "================================================================================"
echo "Training completed at $(date)"
echo "================================================================================"

# List checkpoint files
echo ""
echo "Generated checkpoint files in ${CKPT_FOLDER}:"
ls -lh ${CKPT_FOLDER}/${EXP_NAME}-*/ 2>/dev/null | head -10 || echo "No checkpoints found"

echo ""
echo "================================================================================"
echo "PERFORMANCE NOTES:"
echo "================================================================================"
echo "TopK=600 with TunableJumpReLU configuration:"
echo "- Using TopK sparsity constraint with k=600"
echo "- TunableJumpReLU activation enabled for adaptive thresholding"
echo "- tune_b_enc=true to allow encoder bias optimization"
echo "- No L1 regularization (using TopK instead)"
echo "================================================================================"