#!/bin/bash

source /opt/rh/devtoolset-7/enable
export WANDB_MODE=disabled

export TRANSFORMERS_CACHE='/data/private_models/xx_models/huggingface'
export NUM_NODES=$SLURM_JOB_NUM_NODES

function makehostfile() {
    perl -e '$slots=split /,/, $ENV{"SLURM_STEP_GPUS"};
    @nodes = split /\n/, qx[scontrol show hostnames $ENV{"SLURM_JOB_NODELIST"}];
    foreach $node (@nodes) {
        $gpus = qx[ssh $node nvidia-smi --list-gpus | wc -l];
        chomp($gpus);
        print "$node slots=$gpus\n";
    }'
}
makehostfile > configs/hostfile

export MAIN_PORT=$1
export NUM_TRAIN=$2
export POLICY=$3

export S=$4
export MODEL=microsoft/deberta-v3-${S}
export REWARD_PATH=../../proxy/models/$5
export GOLD_PATH=../../proxy/models/gold_ensemble/
export SEED=$6
export NUM_GPU=$7

WANDB__SERVICE_WAIT=3000 accelerate launch --num_processes $NUM_GPU --num_machines $NUM_NODES --main_process_port $MAIN_PORT --config_file configs/default_accelerate_config.yaml \
    train_merged_chatbot.py \
    --reward_model $MODEL \
    --gold_model microsoft/deberta-v3-large \
    --reward_checkpoint_path $REWARD_PATH \
    --gold_checkpoint_path $GOLD_PATH \
    --policy_model $POLICY \
    --num_train $NUM_TRAIN \
    --seed $SEED

