#!/bin/bash

# Function to get the length of a Hugging Face dataset
get_dataset_length() {
  local dataset_path="$1"
  
  # Check if the dataset path is provided
  if [ -z "$dataset_path" ]; then
    echo "Usage: get_dataset_length <path_to_dataset>"
    return 1
  fi

  # Use Python to load the dataset and print its length
  python -c "
import sys
from datasets import load_from_disk

dataset_path = sys.argv[1]
dataset = load_from_disk(dataset_path)
print(len(dataset))
" "$dataset_path"
}

# Function to create green and red dataset
create_green_red_dataset() {
  echo "Creating green and red dataset $R1_DATASET..."
  echo "Output will be saved to: $DATA_DIR/$R1_DATASET"

  # Construct the command to run the data cleansing script
  cmd="python -m src.train.data_cleansing \
  --sanity_check $SANITY_CHECK \
  --margin_threshold $MARGIN \
  --data_dir $DATA_DIR \
  --dataset $DATASET"
  echo $cmd
  eval $cmd
}

# Function to create green and red dataset
create_no_rejective_dataset() {
  cmd="python -m src.train.data_remove_rejection \
  --sanity_check $SANITY_CHECK \
  --data_dir $DATA_DIR \
  --dataset $DATASET"
  echo $cmd
  eval $cmd
}

# Function to set up a DPO run
run_setup_dpo() {
    local prefix="$1"

    # Set environment variables for the DPO run using the provided prefix        
    export DATA_DIR=data_cache
    export DPO_MODEL_NAME=$(eval echo \${${prefix}_MODEL_NAME})
    export DPO_BASE_MODEL=$(eval echo \${${prefix}_BASE_MODEL})
    export DPO_DATASET=$(eval echo \${${prefix}_DATASET})    
    export DPO_OUTPUT_DIR=$(eval echo \${${prefix}_OUTPUT_DIR})
    export DPO_BATCH_SIZE=$(eval echo \${${prefix}_BATCH_SIZE})
    export DPO_GRAD_ACCUM_STEPS=$(eval echo \${${prefix}_GRAD_ACCUM_STEPS})
    export DPO_LR=$(eval echo \${${prefix}_LR})
    export DPO_NUM_EPOCHS=$(eval echo \${${prefix}_NUM_EPOCHS})
    export DPO_BETA=$(eval echo \${${prefix}_BETA})
    export DPO_DATA_SIZE=-1
}

# Function to run a DPO round
run_dpo() {
  echo "Running DPO round"
  echo "Output will be saved to: logs/train_$DPO_MODEL_NAME.log"
  
  # Construct the command to run the DPO training script
  cmd="accelerate launch --config_file $ACC_CONFIG_FILE \
    -m src.train.dpo \
    --experiment_name $EXP_NAME \
    --run_name $DPO_MODEL_NAME \
    --model_name_or_path $DPO_BASE_MODEL \
    --local_data_path $DATA_DIR/$DPO_DATASET \
    --output_dir $DPO_OUTPUT_DIR \
    --per_device_train_batch_size $DPO_BATCH_SIZE \
    --gradient_accumulation_steps $DPO_GRAD_ACCUM_STEPS \
    --model_dtype 'bfloat16' \
    --learning_rate $DPO_LR \
    --num_train_epochs $DPO_NUM_EPOCHS \
    --max_steps $MAX_STEPS \
    --warmup_ratio 0.03 \
    --beta $DPO_BETA \
    --seed $SEED \
    --save_strategy $SAVE_STRATEGY \
    --save_steps $SAVE_STEPS \
    --sanity_check $SANITY_CHECK \
    --logging_steps 10 \
    --num_train_data $DPO_DATA_SIZE \
    --eval_strategy no"

  echo $cmd
  # Check if the model has already been trained
  if [ -d "$DPO_OUTPUT_DIR" ]; then
    echo "Model $DPO_MODEL_NAME already trained. Skipping this DPO round."
    return
  fi  
  eval $cmd
}