exp_name=""

n_epochs='1'

# accelerator config
num_processes='8'
main_process_port='8897'
config_file="./ds_config/default_config_deepspeed_ga2.yaml"

# training arguments
train_file=""
inference_file=""
model_train_path=""
model_save_path="sft_outputs/${exp_name}/"

batch_size="2"
eval_batch_size="1"
gradient_accumulation_steps="2"
max_input_length="4096"
num_workers="8"
learning_rate="1e-5"
weight_decay="0"
warmup_step="-100"
clip_grad_norm="1"
seed="42"

logging_epoch_freq="1"
evaluating_epoch_freq="100"
saving_epoch_freq="1"
logging_step_freq="5"

# wandb config
wandb_log="True"
wandb_project="agentenv"
wandb_run_name="${exp_name}"

# environment parameters
data_len="200"
timeout="2400"

# eval
task_list=("webshop" "alfworld" "textcraft" "sciworld")
# eval parameters
test_file_list=("./data/test/webshop_test.json" "./data/test/alfworld_test.json" "./data/test/textcraft_test.json" "./data/test/sciworld_test_small.json")
do_sample="False"
temperature="1.0"
sample_num="1"
max_round_list=("10" "30" "20" "30")
env_server_base_list=("ENV_URL_HERE")

mkdir -p "${model_save_path}"

# step1: train
accelerate launch \
        --config_file "${config_file}" \
        --num_processes=${num_processes} \
        --main_process_port=${main_process_port} \
train_sft.py \
        --train_file "${train_file}" \
        --model_train_path "${model_train_path}" \
        --model_save_path "${model_save_path}" \
        --task_name "${task_list[1]}" \
        --batch_size "${batch_size}" \
        --eval_batch_size "${eval_batch_size}" \
        --n_epochs "${n_epochs}" \
        --num_workers "${num_workers}" \
        --learning_rate "${learning_rate}" \
        --weight_decay "${weight_decay}" \
        --warmup_step "${warmup_step}" \
        --clip_grad_norm "${clip_grad_norm}" \
        --evaluating_epoch_freq "${evaluating_epoch_freq}" \
        --logging_epoch_freq "${logging_epoch_freq}" \
        --saving_epoch_freq "${saving_epoch_freq}" \
        --logging_step_freq "${logging_step_freq}" \
        --seed "${seed}" \
        --max_input_length "${max_input_length}" \
        --max_round "${max_round_list[1]}" \
        --gradient_accumulation_steps "${gradient_accumulation_steps}" \
        --wandb_log "${wandb_log}" \
        --wandb_project "${wandb_project}" \
        --wandb_run_name "${wandb_run_name}" \
        --env_server_base "${env_server_base_list[1]}" \
        --data_len "${data_len}" \
        --timeout "${timeout}" \
        > "${model_save_path}/train.log" 2>&1


# step2: eval on test dataset
for index in "${!task_list[@]}";
do
        cur_task=${task_list[$index]}
        cur_test_file="${test_file_list[$index]}"
        cur_max_round=${max_round_list[$index]}
        cur_env_server_base=${env_server_base_list[$index]}
        cur_eval_output_file="${model_save_path}/eval_${cur_task}.jsonl"


        accelerate launch \
                --config_file "${config_file}" \
                --num_processes=${num_processes} \
                --main_process_port=${main_process_port} \
        ./distributed_eval_scripts/distributed_eval_task.py \
                --model_path "${model_save_path}/train_epoch_${n_epochs}" \
                --output_file "${cur_eval_output_file}" \
                --inference_file "${cur_test_file}" \
                --task_name "${cur_task}" \
                --eval_batch_size "${eval_batch_size}" \
                --num_workers "${num_workers}" \
                --seed "${seed}" \
                --do_sample "${do_sample}" \
                --max_round "${cur_max_round}" \
                --env_server_base "${cur_env_server_base}" \
                --data_len "${data_len}" \
                --timeout "${timeout}" \
                > "${model_save_path}/eval_${cur_task}.log" 2>&1
done