NAME=comal_qwen
OLD_CKPT=Qwen/Qwen2.5-7B-Instruct
REF_CKPT=Qwen/Qwen2.5-7B-Instruct
GPU_IDS=0,1,2,3,4,5,6,7
GPU_LIST="0 1 2 3 4 5 6 7"
NUM_GPUS=8
SKIP=1
NUM_ITERS=18
ITER_START=0
RATIO=0.3333333333333333
UPDATE_INTERVAL=6
export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_USE_V1=1

echo "GPU_IDS: $GPU_IDS"

for ((ITER=$ITER_START; ITER<$ITER_START+$NUM_ITERS; ITER++))
    do
        echo "$NAME ITER $ITER"
        FDIR="exps/$NAME/$ITER"
        mkdir -p $FDIR
        echo "FDIR: $FDIR"
        DATA_ITER=$(( ITER % 6 ))
        if [[ "$OLD_CKPT" == *"meta-llama"* ]]; then
            OLD_MODEL_PT=$OLD_CKPT
        else
            OLD_MODEL_PT=$OLD_CKPT/model
        fi
        # check if skip or the first iteration
        if [ $SKIP -ne 0 ] || [ $ITER -ne $ITER_START ];
        then
            # generate samples
            echo "Generating samples"

            CUDA_VISIBLE_DEVICES=$GPU_IDS python sampling.py \
                    --num_gpus $NUM_GPUS \
                    --model_type qwen \
                    --model_pt $OLD_MODEL_PT \
                    --tokenizer_pt Qwen/Qwen2.5-7B \
                    --num_samples 5  \
                    --top_p 0.95 \
                    --input_dir data/prompts/test.jsonl \
                    --output_dir $FDIR/test.samples.jsonl \
                    --gpuids $GPU_LIST \
                    --num_workers 8 \
                    --batch_size 64


            CUDA_VISIBLE_DEVICES=$GPU_IDS python sampling.py \
                    --num_gpus $NUM_GPUS \
                    --model_type qwen \
                    --model_pt $OLD_MODEL_PT \
                    --tokenizer_pt Qwen/Qwen2.5-7B \
                    --num_samples 5  \
                    --top_p 0.95 \
                    --input_dir data/prompts/train_${DATA_ITER}.jsonl \
                    --output_dir $FDIR/train.samples.jsonl  \
                    --gpuids $GPU_LIST \
                    --num_workers 8 \
                    --batch_size 64

            # score samples
            echo "Scoring samples"
            CUDA_VISIBLE_DEVICES=$GPU_IDS python scoring.py \
                --src_dir data/prompts/train_${DATA_ITER}.jsonl \
                --input_dir $FDIR/train.samples.jsonl \
                --output_dir  $FDIR/train.samples.scores.armorm.jsonl \
                --gpuids $GPU_LIST \
                --model_pt RLHFlow/ArmoRM-Llama3-8B-v0.1 \
                --batch_size 32 \
                --score_mode pointwise \
                --model_type armorm \
                --num_workers 8

            CUDA_VISIBLE_DEVICES=$GPU_IDS python scoring.py \
                --src_dir data/prompts/test.jsonl \
                --input_dir $FDIR/test.samples.jsonl \
                --output_dir  $FDIR/test.samples.scores.armorm.jsonl \
                --gpuids $GPU_LIST \
                --model_pt RLHFlow/ArmoRM-Llama3-8B-v0.1 \
                --batch_size 32 \
                --score_mode pointwise \
                --model_type armorm \
                --num_workers 8  

            CUDA_VISIBLE_DEVICES=$GPU_IDS python scoring.py \
                --src_dir data/prompts/train_${DATA_ITER}.jsonl \
                --input_dir $FDIR/train.samples.jsonl \
                --output_dir  $FDIR/train.samples.scores.skyworkrm.jsonl \
                --gpuids $GPU_LIST \
                --model_pt Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 \
                --batch_size 32 \
                --score_mode pointwise \
                --model_type skyworkrm \
                --num_workers 8

            CUDA_VISIBLE_DEVICES=$GPU_IDS python scoring.py \
                --src_dir data/prompts/test.jsonl \
                --input_dir $FDIR/test.samples.jsonl \
                --output_dir  $FDIR/test.samples.scores.skyworkrm.jsonl \
                --gpuids $GPU_LIST \
                --model_pt Skywork/Skywork-Reward-Llama-3.1-8B-v0.2 \
                --batch_size 32 \
                --score_mode pointwise \
                --model_type skyworkrm \
                --num_workers 8  
        

            # processing
            echo "Processing"

            python data_processing.py \
                --task make_output_pair_from_mixrms \
                --input_dirs $FDIR/train.samples.scores.skyworkrm.jsonl $FDIR/train.samples.scores.armorm.jsonl \
                --output_dir $FDIR/train.pairs.jsonl  \
                --num_workers 32 \
                --tokenizer_pt Qwen/Qwen2.5-7B \
                --model_type qwen

            python data_processing.py \
                --task make_output_pair_from_mixrms \
                --input_dirs $FDIR/test.samples.scores.skyworkrm.jsonl $FDIR/test.samples.scores.armorm.jsonl \
                --output_dir $FDIR/test.pairs.jsonl  \
                --num_workers 32 \
                --tokenizer_pt Qwen/Qwen2.5-7B \
                --model_type qwen

            mkdir -p $FDIR/data
            # get logprobs
            echo "Getting logprobs using the latest model"
            CUDA_VISIBLE_DEVICES=$GPU_IDS python get_logprobs.py \
                --input_dir $FDIR/train.pairs.jsonl \
                --gpuids $GPU_LIST \
                --output_dir $FDIR/data/train.jsonl \
                --model_type qwen \
                --model_pt $OLD_MODEL_PT \
                --tokenizer_pt Qwen/Qwen2.5-7B \
                --batch_size 8

            CUDA_VISIBLE_DEVICES=$GPU_IDS python get_logprobs.py \
                --input_dir $FDIR/test.pairs.jsonl \
                --gpuids $GPU_LIST \
                --output_dir $FDIR/data/test.jsonl \
                --model_type qwen \
                --model_pt $OLD_MODEL_PT \
                --tokenizer_pt Qwen/Qwen2.5-7B \
                --batch_size 8

            if [[ "$REF_CKPT" == *"meta-llama"* ]]; then
                REF_MODEL_PT=$REF_CKPT
            else
                REF_MODEL_PT=$REF_CKPT/model
            fi

            echo "Getting logprobs using the ref model"
            CUDA_VISIBLE_DEVICES=$GPU_IDS python get_logprobs.py \
                --input_dir $FDIR/data/train.jsonl \
                --gpuids $GPU_LIST \
                --output_dir $FDIR/data/train.jsonl \
                --model_type qwen \
                --model_pt $REF_MODEL_PT \
                --tokenizer_pt Qwen/Qwen2.5-7B \
                --mode nash \
                --batch_size 8

            CUDA_VISIBLE_DEVICES=$GPU_IDS python get_logprobs.py \
                --input_dir $FDIR/data/test.jsonl \
                --gpuids $GPU_LIST \
                --output_dir $FDIR/data/test.jsonl \
                --model_type qwen \
                --model_pt $REF_MODEL_PT \
                --tokenizer_pt Qwen/Qwen2.5-7B \
                --mode nash \
                --batch_size 8
        fi

        # training
        echo "Training COMAL"

        CUDA_VISIBLE_DEVICES=$GPU_IDS accelerate launch --mixed_precision bf16 \
          --num_machines 1 \
          --num_processes 8 \
          --use_deepspeed \
          --deepspeed_config_file deepspeed.conf \
          --main_process_port 29770 \
          inpo.py \
          --epoch 1 \
          --eta 0.002 \
          --tau_eta_ratio $RATIO \
          --dataset $FDIR/data \
          --model_type Qwen/Qwen2.5-7B \
          --exp_name $FDIR/ckpts  \
          --pretrained $OLD_MODEL_PT \
          --gradient_checkpointing \
          --lr_schedule cosine \
          --accumulate_step 4 \
          --eval_interval 400 \
          --max_lr 0.0000005 \
          --batch_size 2 \
          -l

        CKPT=$FDIR/ckpts
        python downcast.py $CKPT

        # evaluate
        echo "generating samples"
        CUDA_VISIBLE_DEVICES=$GPU_IDS python eval.py \
            --num_gpus 8 \
            --model_type qwen \
            --model_pt $CKPT/model \
            --tokenizer_pt Qwen/Qwen2.5-7B \
            --num_samples 1  \
            --temperature 0.9 \
            --top_p 0.95 \
            --output_dir $CKPT/alpacaeval_output.jsonl \
            --task gen_alpaca \
            --num_workers 8 \
            --gpuids $GPU_LIST


        echo "Comparing with sft on alpaca"
        CUDA_VISIBLE_DEVICES=$GPU_IDS python eval.py \
            --output_dir $CKPT/alpacaeval_vs_sft.armorm.json \
            --sys1_dir exps/tulu_qwen_sft/alpacaeval_output.jsonl \
            --sys2_dir $CKPT/alpacaeval_output.jsonl \
            --num_gpus $NUM_GPUS \
            --batch_size 8 \
            --task eval_alpaca \
            --evaluator rm \
            --model_type armorm \
            --model_pt RLHFlow/ArmoRM-Llama3-8B-v0.1

        CUDA_VISIBLE_DEVICES=$GPU_IDS python eval.py \
            --output_dir $CKPT/alpacaeval_vs_sft.skyworkrm.json \
            --sys1_dir exps/tulu_qwen_sft/alpacaeval_output.jsonl \
            --sys2_dir $CKPT/alpacaeval_output.jsonl \
            --num_gpus $NUM_GPUS \
            --batch_size 8 \
            --task eval_alpaca \
            --evaluator rm \
            --model_type skyworkrm \
            --model_pt Skywork/Skywork-Reward-Llama-3.1-8B-v0.2

        CUDA_VISIBLE_DEVICES=$GPU_IDS python eval.py \
            --output_dir $CKPT/alpacaeval_vs_sft.json \
            --sys1_dir exps/tulu_qwen_sft/alpacaeval_output.jsonl \
            --sys2_dir $CKPT/alpacaeval_output.jsonl \
            --task eval_alpaca \
            --evaluator mixrms \
            --skyworkrm_dir $CKPT/alpacaeval_vs_sft.skyworkrm.json \
            --armorm_dir $CKPT/alpacaeval_vs_sft.armorm.json

        if [[ "$OLD_CKPT" != *"meta-llama"* ]]; then
            echo "Comparing with previous ckpt on alpaca"
            CUDA_VISIBLE_DEVICES=$GPU_IDS python eval.py \
                --output_dir $CKPT/alpacaeval_vs_previous.armorm.json \
                --sys1_dir $OLD_CKPT/alpacaeval_output.jsonl \
                --sys2_dir $CKPT/alpacaeval_output.jsonl \
                --num_gpus $NUM_GPUS \
                --batch_size 8 \
                --task eval_alpaca \
                --evaluator rm \
                --model_type armorm \
                --model_pt RLHFlow/ArmoRM-Llama3-8B-v0.1

            CUDA_VISIBLE_DEVICES=$GPU_IDS python eval.py \
                --output_dir $CKPT/alpacaeval_vs_previous.skyworkrm.json \
                --sys1_dir $OLD_CKPT/alpacaeval_output.jsonl \
                --sys2_dir $CKPT/alpacaeval_output.jsonl \
                --num_gpus $NUM_GPUS \
                --batch_size 8 \
                --task eval_alpaca \
                --evaluator rm \
                --model_type skyworkrm \
                --model_pt Skywork/Skywork-Reward-Llama-3.1-8B-v0.2

            CUDA_VISIBLE_DEVICES=$GPU_IDS python eval.py \
                --output_dir $CKPT/alpacaeval_vs_previous.json \
                --sys1_dir $OLD_CKPT/alpacaeval_output.jsonl \
                --sys2_dir $CKPT/alpacaeval_output.jsonl \
                --task eval_alpaca \
                --evaluator mixrms \
                --skyworkrm_dir $CKPT/alpacaeval_vs_previous.skyworkrm.json \
                --armorm_dir $CKPT/alpacaeval_vs_previous.armorm.json
        fi
        
        OLD_CKPT=$CKPT

        if (( (ITER + 1) % UPDATE_INTERVAL == 0 )); then
            echo "Updating the reference model"
            REF_CKPT=$CKPT
        fi
    done