#!/bin/bash

set -e

source /nlp/scr/anonymous/miniconda3/etc/profile.d/conda.sh
conda activate buggen
cd /nlp/scr/anonymous/projects/attacker_solver

# config
TRAIN=true

MODEL_NAME="Qwen/Qwen2.5-Coder-7B-Instruct"
DATASET="bigcodebench-complete"
SPLITS_FILE="data/splits/bigcodebench_complete_splits_100.json"
NUM_MUTATOR_ITERS=10
NUM_SOLVER_ITERS=5

TRAIN_EXAMPLES=100 # instead of 2000 for mini run
VAL_EXAMPLES=10 # instead of 500 for mini run

export HF_ENTITY=anonymous
HF_REPO_NAME="bigcodebench_qwen7b_att_iter0_dpo_att${NUM_MUTATOR_ITERS}_sol${NUM_SOLVER_ITERS}"
DATA_SAVE_DIR="data/mutations/${HF_REPO_NAME}"
MODEL_SAVE_DIR="outputs/${HF_REPO_NAME}"

export VLLM_WORKER_MULTIPROC_METHOD=spawn

# print time 
TOTAL_START_TIME=$(date +%s)

echo "Generating mutations and solver results for ${DATASET}: ${NUM_MUTATOR_ITERS} mutations and ${NUM_SOLVER_ITERS} solves"

PT1_START_TIME=$(date +%s)

python scripts/generate_data.py \
    --mode generate \
    --model_name ${MODEL_NAME} \
    --dataset ${DATASET} \
    --splits_file ${SPLITS_FILE} \
    --max_new_tokens 2048 \
    --max_model_len 2048 \
    --attacker_temperature 0.7 \
    --solver_temperature 0.7 \
    --num_mutator_iters ${NUM_MUTATOR_ITERS} \
    --num_solver_iters ${NUM_SOLVER_ITERS} \
    --gpu_memory_utilization 0.7 \
    --save_dir ${DATA_SAVE_DIR} \
    --enforce_eager \
    --seed 1

PT1_END_TIME=$(date +%s)
GEN_TIME=$((PT1_END_TIME - PT1_START_TIME))
echo "Time taken for generation: ${GEN_TIME} seconds"



echo "Verifying mutations and solutions..."

PT2_START_TIME=$(date +%s)

python scripts/generate_data.py \
    --mode validate \
    --model_name ${MODEL_NAME} \
    --dataset ${DATASET} \
    --splits_file ${SPLITS_FILE} \
    --mutations_path ${DATA_SAVE_DIR}/mutations.json \
    --solver_results_path ${DATA_SAVE_DIR}/solver_results.json \
    --save_path ${DATA_SAVE_DIR}/validated \
    --hf_repo_name ${HF_REPO_NAME} \
    --push_to_hub \
    --seed 1

PT2_END_TIME=$(date +%s)
VAL_TIME=$((PT2_END_TIME - PT2_START_TIME))
echo "Time taken for validation: ${VAL_TIME} seconds"


if [ "$TRAIN" = true ]; then
  PT3_START_TIME=$(date +%s)

  RANDOM_SEED=$$
  PORT=$((56430 + RANDOM_SEED % 10))

  WANDB__SERVICE_WAIT=300 torchrun --master_port=$PORT --nnodes=1 --nproc_per_node=4 scripts/dpo_train.py \
    --model ${MODEL_NAME} \
    --dataset ${DATASET} \
    --splits_file ${SPLITS_FILE} \
    --input_column response \
    --method ours \
    --num_train_examples ${TRAIN_EXAMPLES} \
    --num_val_examples ${VAL_EXAMPLES} \
    --margin_threshold 0.4 \
    --iteration 0 \
    --save_dir ${MODEL_SAVE_DIR} \
    --hf_repo_name ${HF_REPO_NAME} \
    --max_model_len 2048 \
    --max_tokens 2048 \
    --learning_rate 5e-6 \
    --num_train_epochs 3 \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 4 \
    --eval_steps 10 \
    --save_total_limit 3 \
    --beta 0.1 \
    --warmup_ratio 0.1 \
    --weight_decay 0.01 \
    --gradient_checkpointing \
    --fp16 \
    --optim "adamw_torch" \
    --loss_type "sigmoid" \
    --lr_scheduler_type "cosine" \
    --lora_r 16 \
    --lora_alpha 32 \
    --lora_dropout 0.05 \
    --wandb_project dpo \
    --deepspeed \
    --zero_stage 2 \
    --seed 42

  PT3_END_TIME=$(date +%s)
  MUTATOR_TRAIN_TIME=$((PT3_END_TIME - PT3_START_TIME))
  echo "Time taken for mutator DPO training: ${MUTATOR_TRAIN_TIME} seconds"

  MUTATOR_CHECKPOINT=$(find ${MODEL_SAVE_DIR} -type d -name "best" | head -1)
  echo "Mutator checkpoint: ${MUTATOR_CHECKPOINT}"

  # Train solver (defense) starting from mutator checkpoint
  echo "Starting solver DPO training..."
  PT4_START_TIME=$(date +%s)

  RANDOM_SEED=$$
  PORT=$((56430 + RANDOM_SEED % 10))
  SOLVER_HF_REPO_NAME="bigcodebench_qwen7b_solver_iter0_att${NUM_MUTATOR_ITERS}_sol${NUM_SOLVER_ITERS}"
  SOLVER_MODEL_SAVE_DIR="outputs/${SOLVER_HF_REPO_NAME}"

  WANDB__SERVICE_WAIT=300 torchrun --master_port=$PORT --nnodes=1 --nproc_per_node=4 scripts/dpo_train.py \
    --model ${MODEL_NAME} \
    --model_peft_checkpoint ${MUTATOR_CHECKPOINT} \
    --dataset ${DATASET} \
    --splits_file ${SPLITS_FILE} \
    --input_column solutions \
    --method ours \
    --num_train_examples ${TRAIN_EXAMPLES} \
    --num_val_examples ${VAL_EXAMPLES} \
    --margin_threshold 0.4 \
    --iteration 0 \
    --save_dir ${SOLVER_MODEL_SAVE_DIR} \
    --hf_repo_name ${SOLVER_HF_REPO_NAME} \
    --max_model_len 2048 \
    --max_tokens 2048 \
    --learning_rate 5e-6 \
    --num_train_epochs 3 \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 4 \
    --eval_steps 10 \
    --save_total_limit 3 \
    --beta 0.1 \
    --warmup_ratio 0.1 \
    --weight_decay 0.01 \
    --gradient_checkpointing \
    --fp16 \
    --optim "adamw_torch" \
    --loss_type "sigmoid" \
    --lr_scheduler_type "cosine" \
    --lora_r 16 \
    --lora_alpha 32 \
    --lora_dropout 0.05 \
    --wandb_project dpo \
    --deepspeed \
    --zero_stage 2 \
    --seed 42

  PT4_END_TIME=$(date +%s)
  SOLVER_TRAIN_TIME=$((PT4_END_TIME - PT4_START_TIME))
  echo "Time taken for solver DPO training: ${SOLVER_TRAIN_TIME} seconds"
fi

TOTAL_END_TIME=$(date +%s)
TOTAL_TIME=$((TOTAL_END_TIME - TOTAL_START_TIME))
echo "Total time taken: ${TOTAL_TIME} seconds"
