#!/bin/bash

# Example training script for Likert-Scale Preference Learning
# This demonstrates how to train different algorithms on HelpSteer datasets

# Configuration
numberOfGpus=8
gradient_accumulation_steps=4
batch_size=$((2*$numberOfGpus*$gradient_accumulation_steps))
eval_batch_size=$((3*$numberOfGpus))
n_epochs=8
maxLength=2048
maxPromptLength=1600
lr=1e-6
ordinal_lr=1e-3
eval_every=19230
saveCheckpoints=True

# Example 1: Ordinal Symmetric on HelpSteer2
python3 -u ../src/trainReward.py \
    model=llama \
    datasets=[hs2] \
    loss=ordinal \
    loss.levels=4 \
    loss.ordinalLr=$ordinal_lr \
    loss.schedulerGamma=0.995 \
    loss.beta=0. \
    loss.symmetrize=true \
    loss.ordinal_update_interval=1 \
    loss.ordinal_l2_weight=5 \
    exp_name=example_ordinal_symmetric \
    gradient_accumulation_steps=$gradient_accumulation_steps \
    batch_size=$batch_size \
    eval_batch_size=$eval_batch_size \
    trainer=FSDPTrainer \
    sample_during_eval=false \
    minimum_log_interval_secs=10 \
    model.fsdp_policy_mp=bfloat16 \
    saveCheckpoints=$saveCheckpoints \
    eval_every=$eval_every \
    max_length=$maxLength \
    max_prompt_length=$maxPromptLength \
    n_epochs=$n_epochs \
    lr=$lr

# Example 2: All-Threshold on HelpSteer3
python3 -u ../src/trainReward.py \
    model=zephyr \
    datasets=[hs3] \
    loss=allThreshold \
    loss.levels=7 \
    loss.offset=3 \
    loss.ordinal_update_interval=1 \
    loss.ordinal_l2_weight=0.1 \
    loss.ordinalLr=$ordinal_lr \
    loss.schedulerGamma=0.995 \
    loss.beta=0.0 \
    loss.symmetrize=false \
    loss.makeScoresPositive=false \
    loss.symmetrizeDataset=false \
    exp_name=example_all_threshold \
    gradient_accumulation_steps=$gradient_accumulation_steps \
    batch_size=$batch_size \
    eval_batch_size=$eval_batch_size \
    trainer=FSDPTrainer \
    sample_during_eval=false \
    minimum_log_interval_secs=10 \
    model.fsdp_policy_mp=bfloat16 \
    saveCheckpoints=$saveCheckpoints \
    eval_every=$eval_every \
    max_length=$maxLength \
    max_prompt_length=$maxPromptLength \
    n_epochs=$n_epochs \
    lr=$lr

# Example 3: MarginBT baseline
python3 -u ../src/trainReward.py \
    model=mistral \
    datasets=[hs3] \
    loss=marginBT \
    loss.beta=0 \
    exp_name=example_marginBT \
    gradient_accumulation_steps=$gradient_accumulation_steps \
    batch_size=$batch_size \
    eval_batch_size=$eval_batch_size \
    trainer=FSDPTrainer \
    sample_during_eval=false \
    minimum_log_interval_secs=10 \
    model.fsdp_policy_mp=bfloat16 \
    saveCheckpoints=$saveCheckpoints \
    eval_every=$eval_every \
    max_length=$maxLength \
    max_prompt_length=$maxPromptLength \
    n_epochs=$n_epochs \
    lr=$lr