#!/bin/bash

set -e

source /nlp/scr/anonymous/miniconda3/etc/profile.d/conda.sh
conda activate buggen
cd /nlp/scr/anonymous/projects/attacker_solver

# config 
TRAIN=false

MODEL_NAME="Qwen/Qwen2.5-Coder-7B-Instruct"
DATASET="kodcode-complete"
SPLITS_FILE="data/splits/kodcode_complete_splits_1000.json" # mini run: "data/splits/kodcode_complete_splits_100.json"
NUM_MUTATOR_ITERS=10
NUM_SOLVER_ITERS=5
 
TRAIN_EXAMPLES=50000 
VAL_EXAMPLES=10000 

export HF_ENTITY=anonymous
HF_REPO_NAME="kodcode_qwen7b_att_iter0_dpo_att${NUM_MUTATOR_ITERS}_sol${NUM_SOLVER_ITERS}"
DATA_SAVE_DIR="data/mutations/${HF_REPO_NAME}"
MODEL_SAVE_DIR="outputs/${HF_REPO_NAME}"

export VLLM_WORKER_MULTIPROC_METHOD=spawn

# print time 
TOTAL_START_TIME=$(date +%s)

echo "Generating mutations and solver results for ${DATASET}: ${NUM_MUTATOR_ITERS} mutations and ${NUM_SOLVER_ITERS} solves"

PT1_START_TIME=$(date +%s)

python scripts/generate_data.py \
    --mode generate \
    --model_name ${MODEL_NAME} \
    --dataset ${DATASET} \
    --splits_file ${SPLITS_FILE} \
    --max_new_tokens 2048 \
    --max_model_len 2048 \
    --attacker_temperature 0.7 \
    --solver_temperature 0.7 \
    --num_mutator_iters ${NUM_MUTATOR_ITERS} \
    --num_solver_iters ${NUM_SOLVER_ITERS} \
    --gpu_memory_utilization 0.7 \
    --save_dir ${DATA_SAVE_DIR} \
    --enforce_eager \
    --seed 1

PT1_END_TIME=$(date +%s)
GEN_TIME=$((PT1_END_TIME - PT1_START_TIME))
echo "Time taken for generation: ${GEN_TIME} seconds"



# echo "Verifying mutations and solutions..."
# 
# PT2_START_TIME=$(date +%s)
# 
# python scripts/generate_data.py \
#     --mode validate \
#     --model_name ${MODEL_NAME} \
#     --dataset ${DATASET} \
#     --splits_file ${SPLITS_FILE} \
#     --mutations_path ${DATA_SAVE_DIR}/mutations.json \
#     --solver_results_path ${DATA_SAVE_DIR}/solver_results.json \
#     --save_path ${DATA_SAVE_DIR}/validated \
#     --hf_repo_name ${HF_REPO_NAME} \
#     --push_to_hub \
#     --seed 1
# 
# PT2_END_TIME=$(date +%s)
# VAL_TIME=$((PT2_END_TIME - PT2_START_TIME))
# echo "Time taken for validation: ${VAL_TIME} seconds"
# 
# 
# if [ "$TRAIN" = true ]; then
#   PT3_START_TIME=$(date +%s)
# 
#   RANDOM_SEED=$$
#   PORT=$((56430 + RANDOM_SEED % 10))
# 
#   WANDB__SERVICE_WAIT=300 torchrun --master_port=$PORT --nnodes=1 --nproc_per_node=4 scripts/dpo_train.py \
#     --model ${MODEL_NAME} \
#     --dataset ${DATASET} \
#     --splits_file ${SPLITS_FILE} \
#     --input_column response \
#     --method ours \
#     --num_train_examples ${TRAIN_EXAMPLES} \
#     --num_val_examples ${VAL_EXAMPLES} \
#     --margin_threshold 1.0 \
#     --iteration 0 \
#     --save_dir ${MODEL_SAVE_DIR} \
#     --hf_repo_name ${HF_REPO_NAME} \
#     --max_model_len 1024 \
#     --max_tokens 1024 \
#     --learning_rate 1e-5 \
#     --num_train_epochs 3 \
#     --per_device_train_batch_size 16 \
#     --gradient_accumulation_steps 1 \
#     --eval_steps 200 \
#     --save_total_limit 3 \
#     --beta 0.1 \
#     --warmup_ratio 0.1 \
#     --weight_decay 0.01 \
#     --gradient_checkpointing \
#     --fp16 \
#     --optim "adamw_torch" \
#     --loss_type "sigmoid" \
#     --lr_scheduler_type "cosine" \
#     --lora_r 16 \
#     --lora_alpha 32 \
#     --lora_dropout 0.05 \
#     --wandb_project dpo \
#     --deepspeed \
#     --zero_stage 2 \
#     --seed 42
# 
#   PT3_END_TIME=$(date +%s)
#   MUTATOR_TRAIN_TIME=$((PT3_END_TIME - PT3_START_TIME))
#   echo "Time taken for mutator DPO training: ${MUTATOR_TRAIN_TIME} seconds"
# 
#   MUTATOR_CHECKPOINT=$(find ${MODEL_SAVE_DIR} -type d -name "best" | head -1)
#   echo "Mutator checkpoint: ${MUTATOR_CHECKPOINT}"
# 
#   # Train solver (defense) starting from mutator checkpoint
#   echo "Starting solver DPO training..."
#   PT4_START_TIME=$(date +%s)
# 
#   RANDOM_SEED=$$
#   PORT=$((56430 + RANDOM_SEED % 10))
#   SOLVER_HF_REPO_NAME="kodcode_qwen7b_solver_iter0_att${NUM_MUTATOR_ITERS}_sol${NUM_SOLVER_ITERS}"
#   SOLVER_MODEL_SAVE_DIR="outputs/${SOLVER_HF_REPO_NAME}"
# 
#   WANDB__SERVICE_WAIT=300 torchrun --master_port=$PORT --nnodes=1 --nproc_per_node=4 scripts/dpo_train.py \
#     --model ${MODEL_NAME} \
#     --model_peft_checkpoint ${MUTATOR_CHECKPOINT} \
#     --dataset ${DATASET} \
#     --splits_file ${SPLITS_FILE} \
#     --input_column solutions \
#     --method ours \
#     --num_train_examples ${TRAIN_EXAMPLES} \
#     --num_val_examples ${VAL_EXAMPLES} \
#     --margin_threshold 1.0 \
#     --iteration 0 \
#     --save_dir ${SOLVER_MODEL_SAVE_DIR} \
#     --hf_repo_name ${SOLVER_HF_REPO_NAME} \
#     --max_model_len 1024 \
#     --max_tokens 1024 \
#     --learning_rate 1e-5 \
#     --num_train_epochs 3 \
#     --per_device_train_batch_size 16 \
#     --gradient_accumulation_steps 1 \
#     --eval_steps 200 \
#     --save_total_limit 3 \
#     --beta 0.1 \
#     --warmup_ratio 0.1 \
#     --weight_decay 0.01 \
#     --gradient_checkpointing \
#     --fp16 \
#     --optim "adamw_torch" \
#     --loss_type "sigmoid" \
#     --lr_scheduler_type "cosine" \
#     --lora_r 16 \
#     --lora_alpha 32 \
#     --lora_dropout 0.05 \
#     --wandb_project dpo \
#     --deepspeed \
#     --zero_stage 2 \
#     --seed 42
# 
#   PT4_END_TIME=$(date +%s)
#   SOLVER_TRAIN_TIME=$((PT4_END_TIME - PT4_START_TIME))
#   echo "Time taken for solver DPO training: ${SOLVER_TRAIN_TIME} seconds"
# fi

TOTAL_END_TIME=$(date +%s)
TOTAL_TIME=$((TOTAL_END_TIME - TOTAL_START_TIME))
echo "Total time taken: ${TOTAL_TIME} seconds"
