#!/bin/bash

# Script for training grouter with checkpoint support
# This script demonstrates how to use the checkpoint functionality

# Set checkpoint directory
CHECKPOINT_DIR="checkpoints/dilute_grouter"
export PYTHONPATH="Megatron-LM:grouter_ep_optimizer/grouter/structure"

# Function to prompt user for yes/no input
confirm() {
    read -p "$1 (y/n) " -n 1 -r
    echo
    [[ $REPLY =~ ^[Yy]$ ]]
}


C4_HOME="/workspace/Megatron-LM-router/dataset"
DATA_BLEND=""
for i in {0000..0030}; do # 1/25
    DATA_BLEND="${DATA_BLEND} 0.04 ${C4_HOME}/dsv2-c4-${i}_text_document"
done

COMMON_ARGS=(
    # dstill layer
    --moe-layer-start 1 \
    --moe-layer-end 2 \
    # teacher model
    --model-name deepseekv2 \
    --grouter-type mla \
    --model-config-path grouter_ep_optimizer/grouter/structure/mla_config.json \
    # dstill training setting 
    --warmup 3000 \
    --lr 0.0005 \
    --total-steps 10000 \
    --batch-size 4 \
    --max-length 4096 \
    --transpose \
    # gradient accumulation
    --gradient-accumulation-steps 2 \
    # logging
    --log-csv \
    --log-interval 1 \
    # ckpt
    --checkpoint-dir $CHECKPOINT_DIR \
    --checkpoint-interval 5000 \
    # For megatron dataloader
    --use-megatron \
    --tokenizer-type HuggingFaceTokenizer \
    --tokenizer-model model_home/deepseek-v2-lite \
    --random-seed 42 \
    --data-prefix $DATA_BLEND
)


    #--data-path /path/to/c4_data/en \
# Option 1: Start training from scratch with checkpoint saving
if confirm "Option 1: Start training from scratch with checkpoint saving. Execute?"; then
    echo "Starting training from scratch with checkpoint saving..."
    torchrun --nproc-per-node 8 dilute_grouter.py \
        --save-latest \
        ${COMMON_ARGS[@]}
    exit 0
fi

# Option 2: Resume training from a specific checkpoint
if confirm "Option 2: Resuming training from specific checkpoint Execute?"; then
    echo "Resuming training from specific checkpoint..."
    torchrun --nproc-per-node 8 dilute_grouter.py \
        --resume-from $CHECKPOINT_DIR/checkpoint_step_25000.pt \
        ${COMMON_ARGS[@]}
    exit 0
fi

# Option 3: Resume training from latest checkpoint automatically
if confirm "Option 3: Resume training from latest checkpoint. Execute?"; then
    echo "Resuming training from latest checkpoint..."
    torchrun --nproc-per-node 8 dilute_grouter.py \
        --resume-from-last-ckpt \
        ${COMMON_ARGS[@]}

    exit 0
fi
# If no options were selected
echo "No option was selected for execution."
exit 1