ann_file="/data/dataset/dataset_json/data/flickr30k_train.json"
output_file="/data/dataset/dataset_json/data_rewrite/flickr30k_train_lines.json"

# python3 convert_train_file2lines.py \
#     --ann_file ${ann_file} \
#     --out_file ${output_file}

LLAMA_FOLDER="/data/robust_crossmodal-retrieval/llama_rewrite/llama_weights/"
model="llama-2-7b"

sample_mode="bard"
# sample_mode="chatgpt"

rewrite_file=/data/dataset/dataset_json/data_rewrite/flickr30k_train_llama_${sample_mode}_lines.json

torchrun --nproc_per_node 1 --master_port 12388 \
    llama_rewrite.py \
    --ckpt_dir ${LLAMA_FOLDER}/${model} \
    --tokenizer_path ${LLAMA_FOLDER}/tokenizer.model \
    --max_batch_size 100 --max_seq_len 400 \
    --prompt_filename $output_file \
    --output_filename $rewrite_file \
    --sample_mode $sample_mode --temperature 0.9

new_ann_file="/data/dataset/dataset_json/data_rewrite/flickr30k_train_llama_${sample_mode}.json"

python3 convert_lines2train_file.py \
    --lines_file ${rewrite_file} \
    --orig_train_file ${ann_file} \
    --out_train_file ${new_ann_file}