echo $RESULT
rm -r $RESULT
mkdir $RESULT

# hyperparameters
base_model_name=meta-llama/Llama-3.2-3B-Instruct

for i in $(seq 0 9);
do

  # launch the server
  if [ $i -eq 0 ]; then
    model_name=$base_model_name
  else
    model_name=$RESULT/checkpoint
    old_folder=$RESULT/checkpoint-"$((i*20))"
    rm -r $RESULT/checkpoint
    mv $old_folder $model_name
  fi
  CUDA_VISIBLE_DEVICES=0,1 vllm serve $model_name --tensor-parallel-size 2 --port 8000 --dtype bfloat16 &
  SERVER_PID=$!

  rm -r $RESULT/data
  # mv $RESULT/data $RESULT/data_{$i}
  mkdir $RESULT/data

  # wait for the server to be ready
  cd $Home/incontext_RL/collect
  python waiting.py --model_name $model_name

  # collect data
  cd $Home/incontext_RL/BALROG
  python eval.py \
    agent.type=naive \
    agent.max_image_history=0 \
    agent.max_text_history=2 \
    eval.num_workers=8 \
    client.client_name=vllm \
    client.model_id=$model_name \
    client.base_url=http://0.0.0.0:8000/v1 \
    client.generate_kwargs.temperature=$TEMPERATURE \
    tasks.babyai_tasks=$TASKNAME \
    eval.data_dir=$RESULT/tmp/ \
    eval.output_dir=$RESULT/visualization/ \
    eval.num_episodes.babyai=8 \
    eval.batch_size=$BATCH_SIZE \
    eval.save_images=True \
    envs.babyai_kwargs.max_steps=$MAX_ENV_STEPS &
  python eval.py \
    agent.type=naive \
    agent.max_image_history=0 \
    agent.max_text_history=2 \
    eval.num_workers=32 \
    client.client_name=vllm \
    client.model_id=$model_name \
    client.base_url=http://0.0.0.0:8000/v1 \
    client.generate_kwargs.temperature=$TEMPERATURE \
    tasks.babyai_tasks=$TASKNAME \
    eval.prompt_dir=$PROMPT_DIR \
    eval.data_dir=$RESULT/data/ \
    eval.output_dir=$RESULT/history/ \
    eval.num_episodes.babyai=128 \
    eval.batch_size=$BATCH_SIZE \
    envs.babyai_kwargs.max_steps=$MAX_ENV_STEPS 

  # plot win rate curve
  cd $Home/incontext_RL/collect
  python plot.py --file_path $RESULT

  # if [ $i -gt 0 ]; then
  #   # kill the server
  #   kill $SERVER_PID
  #   while kill -0 $SERVER_PID; do
  #     sleep 3
  #   done

  #   # launch the server
  #   vllm serve $base_model_name --tensor-parallel-size 2 --port 8000 --dtype bfloat16 &
  #   SERVER_PID=$!

  #   # collect verbal feedback
  #   cd $Home/incontext_RL/collect
  #   python waiting.py --model_name $base_model_name
  # fi

  cd $Home/incontext_RL/collect
  python collect_feedback.py \
    --model_name $base_model_name \
    --file_path $RESULT/data/ \
    --batch_size $BATCH_SIZE 

  # kill the server
  kill $SERVER_PID
  while kill -0 $SERVER_PID; do
    sleep 3
  done
  
  # collect log probability
  cd $Home/incontext_RL/collect
  python collect_prob.py \
    --model_name $base_model_name \
    --file_path $RESULT/data/ \
    --temperature $TEMPERATURE \
    --batch_size $BATCH_SIZE \
    --alpha $ALPHA

  # print entropy
  cd $Home/incontext_RL/collect
  python print_entropy.py --file_path $RESULT

  # set training config
  old_config_path=$Home/incontext_RL/LLaMA-Factory/examples/train_full/llama3_base.yaml
  config_file_path=$RESULT/data/llama3_base.yaml
  cp $old_config_path $config_file_path
  sed -i "s|model_name_or_path: .*|model_name_or_path: $model_name|" "$config_file_path" # set base model name
  sed -i -E "s|^(output_dir: ).*|\1$RESULT|" "$config_file_path" # set output dir
  sed -i '/^### dataset/a dataset_file_dir: '$RESULT"/data/result.json" "$config_file_path" # set dataset
  if [ $i -ne 0 ]; then
    sed -i '/^### output/a resume_from_checkpoint: '"$model_name" "$config_file_path" # set base model name
  fi

  per_device_batch_size=8
  num_gpus=2
  mini_batch_size=2
  gradient_accumulation_steps=$(( BATCH_SIZE / per_device_batch_size / num_gpus / mini_batch_size ))

  sed -i -E "s|^(per_device_train_batch_size: ).*|\1$per_device_batch_size|" "$config_file_path" # set per_device_train_batch_size
  sed -i -E "s|^(gradient_accumulation_steps: ).*|\1$gradient_accumulation_steps|" "$config_file_path" # set gradient_accumulation_steps
  sed -i -E "s|^(max_samples: ).*|\1$((BATCH_SIZE))|" "$config_file_path" # set max_samples
  sed -i -E "s|^(learning_rate: ).*|\1$LEARNING_RATE|" "$config_file_path" # set learning_rate
  sed -i -E "s|^(num_train_epochs: ).*|\1200.0|" "$config_file_path" # set num_train_epochs

  # train
  cd $Home/incontext_RL/LLaMA-Factory
  FORCE_TORCHRUN=1 CUDA_VISIBLE_DEVICES=0,1 llamafactory-cli train $config_file_path

done

# rm -r $RESULT/checkpoint*
