#!/bin/bash

while [[ $# -gt 0 ]]; do
  case $1 in
    --use-deepspeed)
      DEEPSPEED_ENABLED="$2"
      shift 2
      ;;
    --wandb-key)
      WANDB_API_KEY="$2" 
      shift 2
      ;;
    --config-path)
      EXPERIMENT_SETTINGS_PATH="$2"
      shift 2
      ;;
    *)
      echo "Unknown argument: $1"
      exit 1
      ;;
  esac
done

# Validate required arguments
if [ -z "$WANDB_API_KEY" ] || [ -z "$EXPERIMENT_SETTINGS_PATH" ]; then
    echo "Usage: $0 --use-deepspeed <true|false> --wandb-key <key> --config-path <path>"
    exit 1
fi

pip install -U flash-attn
export TOKENIZERS_PARALLELISM="false"
export WANDB_API_KEY="$WANDB_API_KEY"


NUM_GPUS=$(python -c "import torch; print(torch.cuda.device_count())")


if [ "$DEEPSPEED_ENABLED" = true ]; then
    deepspeed --num_gpus=$NUM_GPUS --no_local_rank --module src train_sft --experiment_settings_path "$EXPERIMENT_SETTINGS_PATH"
else
    python -m src train_sft --experiment_settings_path "$EXPERIMENT_SETTINGS_PATH"
fi

ls train_output/trainer && echo "Done!"