#!/bin/bash
set -e
set -x
BEFORE_EPOCH_NAME="1epoch"
EPOCH_NAME="2epoch"
: <<'BLOCK_COMMENT'
CUDA_VISIBLE_DEVICES=0,1 python utils/vllm_generate.py \
  --model ./model/student_adpa_$BEFORE_EPOCH_NAME \
  --data argilla/dpo-mix-7k \
  --dataset_split train \
  --prompt_key chosen \
  --out_dir ./data/llama3.2-1b-deita-dpomix/student_init_self_generation_$EPOCH_NAME \
  --apply_template True



CUDA_VISIBLE_DEVICES=0,1 python utils/vllm_generate.py \
  --model ./model/student_adpa_$BEFORE_EPOCH_NAME \
  --data argilla/dpo-mix-7k \
  --dataset_split test \
  --prompt_key chosen \
  --out_dir ./data/llama3.2-1b-deita-dpomix/student_init_self_generation_$EPOCH_NAME \
  --apply_template True



python utils/form_preference_dataset.py \
  --original-dataset argilla/dpo-mix-7k \
  --rejected-train ./data/llama3.2-1b-deita-dpomix/student_init_self_generation_$EPOCH_NAME/dpo-mix-7k-train.jsonl \
  --rejected-test ./data/llama3.2-1b-deita-dpomix/student_init_self_generation_$EPOCH_NAME/dpo-mix-7k-test.jsonl \
  --output-dir ./data/llama3.2-1b-deita-dpomix/student_adpa_dataset_original_$EPOCH_NAME
BLOCK_COMMENT


CUDA_VISIBLE_DEVICES=0,1,2,3 python -m accelerate.commands.launch \
  --num_processes=2 \
  --main_process_port 29501 \
  utils/precompute_logits_for_train.py \
  --data ./data/llama3.2-1b-deita-dpomix/student_adpa_dataset_original_$EPOCH_NAME \
  --split train \
  --model ./model/dpo_teacher \
  --conversation-key rejected \
  --user-begin '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\\n\\n' \
  --user-end '<|eot_id|>' \
  --assistant-begin '<|start_header_id|>assistant<|end_header_id|>\\n\\n' \
  --assistant-end '<|eot_id|>' \
  --save-to ./data/llama3.2-1b-deita-dpomix/dpomix7k-dpoteacher-train-student_$EPOCH_NAME \
  --pad-token-id 128001 \
  --max-tokens-per-batch 2048

rm data/llama3.2-1b-deita-dpomix/dpomix7k-dpoteacher-train-student_$EPOCH_NAME/results_rank_*.jsonl


CUDA_VISIBLE_DEVICES=0,1,2,3 python -m accelerate.commands.launch \
  --num_processes=4 \
  --main_process_port 29501 \
  utils/precompute_logits.py \
  --data ./data/llama3.2-1b-deita-dpomix/student_adpa_dataset_original_$EPOCH_NAME \
  --split test \
  --model ./model/dpo_teacher \
  --conversation-key rejected \
  --user-begin '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\\n\\n' \
  --user-end '<|eot_id|>' \
  --assistant-begin '<|start_header_id|>assistant<|end_header_id|>\\n\\n' \
  --assistant-end '<|eot_id|>' \
  --save-to ./data/llama3.2-1b-deita-dpomix/dpomix7k-dpoteacher-test-student_$EPOCH_NAME \
  --pad-token-id 128001 \
  --max-tokens-per-batch 2048



CUDA_VISIBLE_DEVICES=0,1,2,3 python -m accelerate.commands.launch \
  --num_processes=2 \
  --main_process_port 29501 \
  utils/precompute_logits_for_train.py \
  --data ./data/llama3.2-1b-deita-dpomix/student_adpa_dataset_original_$EPOCH_NAME \
  --split train \
  --model ./model/ref_teacher \
  --conversation-key rejected \
  --user-begin '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\\n\\n' \
  --user-end '<|eot_id|>' \
  --assistant-begin '<|start_header_id|>assistant<|end_header_id|>\\n\\n' \
  --assistant-end '<|eot_id|>' \
  --save-to ./data/llama3.2-1b-deita-dpomix/dpomix7k-refteacher-train-student_$EPOCH_NAME \
  --pad-token-id 128001 \
  --max-tokens-per-batch 2048



CUDA_VISIBLE_DEVICES=0,1,2,3 python -m accelerate.commands.launch \
  --num_processes=4 \
  --main_process_port 29501 \
  utils/precompute_logits.py \
  --data ./data/llama3.2-1b-deita-dpomix/student_adpa_dataset_original_$EPOCH_NAME \
  --split test \
  --model ./model/ref_teacher \
  --conversation-key rejected \
  --user-begin '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\\n\\n' \
  --user-end '<|eot_id|>' \
  --assistant-begin '<|start_header_id|>assistant<|end_header_id|>\\n\\n' \
  --assistant-end '<|eot_id|>' \
  --save-to ./data/llama3.2-1b-deita-dpomix/dpomix7k-refteacher-test-student_$EPOCH_NAME \
  --pad-token-id 128001 \
  --max-tokens-per-batch 2048




python utils/merge_logits_adpa_dataset.py \
  --input-dataset-dict argilla/dpo-mix-7k \
  --dpo-teacher-logp-train ./data/llama3.2-1b-deita-dpomix/dpomix7k-dpoteacher-train-student_$EPOCH_NAME \
  --ref-teacher-logp-train ./data/llama3.2-1b-deita-dpomix/dpomix7k-refteacher-train-student_$EPOCH_NAME \
  --dpo-teacher-logp-test ./data/llama3.2-1b-deita-dpomix/dpomix7k-dpoteacher-test-student_$EPOCH_NAME \
  --ref-teacher-logp-test ./data/llama3.2-1b-deita-dpomix/dpomix7k-refteacher-test-student_$EPOCH_NAME \
  --save-to ./data/llama3.2-1b-deita-dpomix/adpa_dataset_$EPOCH_NAME \
  --logits-key rejected_compressed_probs \
  --label-key rejected_labels \
  --output-key rejected_margin_logp_every



CUDA_VISIBLE_DEVICES=0,1,2,3 \
ACCELERATE_LOG_LEVEL=info \
DS_SKIP_CUDA_CHECK=1 \
python -m accelerate.commands.launch \
  --config_file recipes/accelerate_config/deepspeed_zero3.yaml \
  scripts/run_distill_dpo.py \
  recipes/llama3.2-1b-deita-dpomix/student_adpa_1epoch.yaml
