#!/bin/bash
set -e
model_config="gpt2"
dataset="tldr"
instance="ml.p4d.24xlarge"
step=1
run_ensemble=0

while getopts 'grpaec:s:w:d:i:h' opt; do
  case "$opt" in
    c)
      model_config="$OPTARG"
      echo "Setting model config to '${model_config}'"
      ;;
    
    w)
      wandb_run_id="$OPTARG"
      echo "Setting WandB PPO run id to '${wandb_run_id}'"
      ;;
    
    s)
      step="$OPTARG"
      echo "Starting from STEP ${step}"
      ;;
   
    d)
      dataset="$OPTARG"
      echo "Setting the dataset to '${dataset}'"
      ;;
   
    i)
      instance="$OPTARG"
      echo "Setting the AWS compute instance to '${instance}'"
      ;;
   
    g)
      step=2
      echo "Skipping fine-tuning and starting with generating new outputs from an already fine-tuned model"
      ;;

    r) 
      step=7
      echo "Skipping fine-tuning and starting with reward model training on an existing Claude preference dataset"
      ;;

    p)
      step=8
      echo "Skipping SFT and reward model training and running PPO only"
      ;;
    
    a)
      step=11
      echo "Producing only the alignment graph"
      ;;
    
    e)
      run_ensemble=1
      echo "Setting ensemble reward model for the PPO run"
      ;;

    ?|h)
      echo "Usage: $(basename $0) [-m model_config_name]"
      exit 1
      ;;
  esac
done
shift "$(($OPTIND -1))"

sft_output_dir="data/datasets/${dataset}/$(yq '.output_directory' configs/$model_config/sft.yaml)"
ppo_output_dir="data/datasets/${dataset}/$(yq '.output_directory' configs/$model_config/ppo.yaml)"
wandb_run_id=${wandb_run_id:=$(date +%s)}

if [ $step -le 1 ]
then
  # STEP 1
  # Run supervised fine-tuning with the specified configuration (configs/$model_config/sft.yaml)
  python run_sagemaker.py src/training/sft.py -i $instance -c $model_config DATASET=$dataset
fi

if [ $step -le 2 ]
then
  # STEP 2
  # Generate the outputs from the fine-tuned model
  python run_sagemaker.py src/scripts/generate.py -i $instance -c $model_config -d f DATASET=$dataset
fi

if [ $step -le 3 ]
then
  # STEP 3
  # Download the generated outputs from S3
  python src/scripts/sync_data.py download $sft_output_dir
fi

if [ $step -le 4 ]
then
  # STEP 4
  cp $sft_output_dir/outputs.json $sft_output_dir/annotated_outputs.json
fi

if [ $step -le 5 ]
then
  # STEP 5
  # Annotate the generated samples against the human labels with Gold RM (or Claude)
  python run_sagemaker.py src/scripts/annotate.py -d f -c $model_config DATASET=$dataset
fi

if [ $step -le 6 ]
then  
  # STEP 6
  # Upload the annotated data to S3, only when running annotation locally
  # python src/scripts/sync_data.py upload $sft_output_dir
fi

if [ $step -le 7 ]
then
  # STEP 7
  # Learn the reward model from the annotated data
  python run_sagemaker.py src/training/reward_model.py -i $instance -c $model_config DATASET=$dataset
fi

if [ $step -le 8 ]
then
  # STEP 8
  # Run PPO with the learned SFT and reward models
  python run_sagemaker.py src/training/ppo.py -d f -c $model_config -w $wandb_run_id DATASET=$dataset
fi

if [ $step -le 9 ]
then
  # STEP 9
  # Download evaluation data from various PPO policy iterations 
  python src/scripts/ppo_outputs.py "uncertainty-rlhf/Uncertainty RLHF/$wandb_run_id" -o $ppo_output_dir
fi

if [ $step -le 10 ]
then
  # STEP 10
  cp $ppo_output_dir/outputs.json $ppo_output_dir/annotated_outputs.json
fi

# # STEP 11
# # Produce the Alignment graph
# python src/scripts/alignment.py $ppo_output_dir/annotated_outputs.json
