#!/usr/bin/env bash

# ======================================================================================
# Training Script for RO-GRPO
#
# Description:
# This script launches the training process for the experiments described in the paper
# "Balancing the Experts: Unlocking LoRA-MoE for GRPO via Mechanism-Aware Rewards".
# It supports training for both unimodal and multimodal mathematical reasoning tasks,
# and allows selection of different reward strategies, including the baselines and
# our proposed RO-GRPO methods.
#
# All paths are relative or placeholders. Please ensure that the models and datasets
# are downloaded and placed in the specified locations before running.
# ======================================================================================

# --- Configuration ---
# Set the number of GPUs to use
NPROC_PER_NODE=8

# --- Helper Functions ---
function usage() {
    echo "Usage: $0 [task_type] [reward_type]"
    echo ""
    echo "Arguments:"
    echo "  task_type:      Specifies the task for training."
    echo "                  Options: 'unimodal', 'multimodal'"
    echo "  reward_type:    Specifies the reward strategy to use."
    echo "                  Options: 'lora_baseline', 'lora_moe_baseline', 'ro_grpo_smooth', 'ro_grpo_relative'"
    echo ""
    echo "Example:"
    echo "  # Run unimodal training with the RO-GRPO (Smooth) reward"
    echo "  bash $0 unimodal ro_grpo_smooth"
    echo ""
    echo "  # Run multimodal training for the LoRA-MoE baseline (no routing reward)"
    echo "  bash $0 multimodal lora_moe_baseline"
}

# --- Argument Validation ---
if [ "$#" -ne 2 ]; then
    echo "Error: Invalid number of arguments."
    usage
    exit 1
fi

TASK_TYPE=$1
REWARD_TYPE=$2
TIMESTAMP=$(date +%Y%m%d_%H%M%S)

# --- Set Task-Specific Parameters ---
MODEL_PATH=""
DATASET_PATH=""
NUM_EPOCHS=1

case "$TASK_TYPE" in
    unimodal)
        # As described in the paper, for unimodal tasks we use Qwen2.5-7B-Instruct
        # and fine-tune on a subset of the NuminaMath-TIR dataset.
        MODEL_PATH="path/to/your/Qwen2.5-7B-Instruct" # Placeholder: e.g., ./models/Qwen2.5-7B-Instruct
        DATASET_PATH="path/to/your/NuminaMath-TIR"     # Using Hugging Face dataset identifier
        NUM_EPOCHS=1
        ;;
    multimodal)
        # For multimodal tasks, we use Qwen2.5-VL-7B-Instruct and fine-tune on Geometry3k.
        MODEL_PATH="path/to/your/Qwen2.5-VL-7B-Instruct" # Placeholder: e.g., ./models/Qwen2.5-VL-7B-Instruct
        DATASET_PATH="path/to/your/Geometry3k"           # Placeholder: e.g., ./data/geometry3k
        NUM_EPOCHS=3
        ;;
    *)
        echo "Error: Invalid task_type '$TASK_TYPE'."
        usage
        exit 1
        ;;
esac

# --- Set Reward-Specific and Model-Specific Parameters ---
ROUTING_REWARD_FUNC=""
ROUTING_REWARD_WEIGHT=0
LORA_RANK=8 # Default for MoE models (E experts * rank 8)
TARGET_MODULES='.*model\.layers\.[0-9]+\.mlp\.(gate_proj|down_proj|up_proj)$'

case "$REWARD_TYPE" in
    lora_baseline)
        # Standard LoRA baseline (GRPO (LoRA))
        # As per the paper, rank is 16 to maintain parameter parity. No routing reward is used.
        LORA_RANK=16
        ;;
    lora_moe_baseline)
        # LoRA-MoE baseline (GRPO (LoRA-MoE))
        # No routing reward is applied.
        ;;
    ro_grpo_smooth)
        # Our proposed RO-GRPO (Smooth) method.
        # The custom reward function 'routing_reward_smooth' will be defined in routing_reward.py
        ROUTING_REWARD_FUNC="routing_reward_smooth"
        ROUTING_REWARD_WEIGHT=0.2 # Global scaling coefficient w_route, as mentioned in the paper
        ;;
    ro_grpo_relative)
        # Our proposed RO-GRPO (Relative) method.
        # The custom reward function 'routing_reward_relative' will be defined in routing_reward.py
        ROUTING_REWARD_FUNC="routing_reward_relative"
        ROUTING_REWARD_WEIGHT=0.2 # Global scaling coefficient w_route
        ;;
    *)
        echo "Error: Invalid reward_type '$REWARD_TYPE'."
        usage
        exit 1
        ;;
esac

# --- Define Output and Logging Paths ---
OUTPUT_DIR="./output/${TASK_TYPE}/${REWARD_TYPE}/${TIMESTAMP}"
LOG_FILE="./logs/train_${TASK_TYPE}_${REWARD_TYPE}_${TIMESTAMP}.log"
mkdir -p ./logs

# --- Display Run Configuration ---
echo "==================================================="
echo "Starting RO-GRPO Training"
echo "==================================================="
echo "Task Type:              $TASK_TYPE"
echo "Reward Strategy:        $REWARD_TYPE"
echo "Model Path:             $MODEL_PATH"
echo "Dataset Path:           $DATASET_PATH"
echo "LoRA Rank:              $LORA_RANK"
echo "Num Train Epochs:       $NUM_EPOCHS"
echo "Output Directory:       $OUTPUT_DIR"
echo "Log File:               $LOG_FILE"
echo "==================================================="

# --- Execute Training Command ---
# The command is based on the swift framework.
# - Paths are now variables.
# - Reward functions and weights are dynamically set.
# - Output and logging are directed to local, structured directories.
# - Personal/system-specific paths for plugins and prompts are replaced with relative paths.
NPROC_PER_NODE=${NPROC_PER_NODE} \
swift rlhf \
    --rlhf_type grpo \
    --model_type ${MODEL_PATH} \
    --dataset ${DATASET_PATH} \
    --output_dir ${OUTPUT_DIR} \
    --log_file ${LOG_FILE} \
    \
    --reward_funcs format math_acc ${ROUTING_REWARD_FUNC} \
    --reward_weights 1 1 ${ROUTING_REWARD_WEIGHT} \
    \
    --train_type lora \
    --lora_rank ${LORA_RANK} \
    --lora_alpha 32 \
    --lora_target_modules ${TARGET_MODULES} \
    \
    --num_train_epochs ${NUM_EPOCHS} \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 1 \
    --learning_rate 1e-5 \
    --warmup_ratio 0.01 \
    \
    --max_length 8192 \
    --max_new_tokens 1024 \
    --temperature 1.0 \
    --num_beams 8 \
    \
    --torch_dtype bfloat16 \
    --attn_impl flash_attn \
    --deepspeed_config default-zero2 \
    \
    --save_strategy 'steps' \
    --eval_strategy 'steps' \
    --save_steps 100 \
    --eval_steps 100 \
    --save_only_model true \
    \
    --logging_steps 1 \
    --report_to none \
    --log_completions true \
    \
    --beta 0.001 \
    --system_prompt /path/to/system_prompt.txt