#!/bin/bash
# This is a script for generation test case or code generation.
##################################################################################################
### **Default parameters**
DATASETS=(
  #"humaneval"
  #"mbpp"
  "total"
)

# use hyper parameters
BATCH_SIZE=4
NUM_SAMPLES=1
MAX_LENGTH=4096
USE_QLORA=False
##################################################################################################
### **Custom parameters**
export CUDA_VISIBLE_DEVICES=${1:-1} # This means which GPU to use.
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"

TASK='code_generation' # This parameter is used to determine the task to be executed.
# Componsnets: 'test_case_generation', 'code_generation'

MODEL_NAMES=(
    ${2}
    #"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
    #"Qwen/Qwen3-14B"
    #"microsoft/Phi-4-reasoning-plus"
    #"meta-llama/Llama-3.1-8B-Instruct"
    
    #"microsoft/Phi-4-reasoning"
    #"mistralai/Mistral-Nemo-Base-2407"
)
 
CODE_GEN_MODE="SFT" # This parameter is what used to filter the test cases. We only use SFT, and RL model for test case generation. 
# Components: "SFT", "RL"

SKIP_COMPLETED=True # This parameter is used to determine whether to skip the completed test cases.
# Components: "True", "False"

SAMPLE_N="False" # This parameter is used to determine whether to sample the test cases.
# Components: "True", "False"
##################################################################################################
if [ "$TASK" = "test_case_generation" ]; then
  USE_INSTRUCTIONS=(
    #"FUNCTIONALITY_SPECIFICATION" # Directly generate test cases
    
    #"ASSERT_SPECIFICATION" # Directly generate test cases
    
    "MULTI_ASSERT_SPECIFICATION" # Directly generate test cases
    #"MULTI_ASSERT_SPECIFICATION_humaneval" # One shot
    #"MULTI_ASSERT_SPECIFICATION_mbpp" # One shot

    #"GRAMMAR_ASSERT_SPECIFICATION" # Generate test cases based on grammar specification
    #"GRAMMAR_ASSERT_SPECIFICATION_humaneval" # One shot
    #"GRAMMAR_ASSERT_SPECIFICATION_mbpp" # One shot
  
  )

elif [ "$TASK" = "code_generation" ]; then
  USE_INSTRUCTIONS=(
    #"CODE_GENERATION_CS"
    "CODE_GENERATION_CT"
    # "MAKE_CODE_BLOCK_FS_CS" # Function Specification + Contracts Specification
    # "MAKE_CODE_BLOCK_FT_CT" # Function Specification + Contracts Test Cases
    # "CODE_REFINEMENT_WITH_INSTRUCTIONS_FC_CS" # Function Code + Contracts Specification 
    # "CODE_REFINEMENT_WITH_INSTRUCTIONS_FC_CT" # Function Code + Contracts Test Case
  )
fi

[ -d log ] || mkdir log

for DATASET in "${DATASETS[@]}"; do
  echo "==============================="
  echo "Using dataset: $DATASET"
  for USE_INSTRUCTION in "${USE_INSTRUCTIONS[@]}"; do
    echo "==============================="
    echo "Using instruction: $USE_INSTRUCTION"

    if [ "$USE_INSTRUCTION" = "ASSERT_SPECIFICATION" ] || [ "$USE_INSTRUCTION" = "MULTI_ASSERT_SPECIFICATION" ] || [ "$USE_INSTRUCTION" = "GRAMMAR_ASSERT_SPECIFICATION" ]; then
      MAX_NEW_TOKENS=0 # If 0, it means max_new_Tokens = max_length - input_length. This is automatically set by the input length.
      # 2048
    else
      MAX_NEW_TOKENS=1024 # If 0, it means max_new_Tokens = max_length - input_length. This is automatically set by the input length.
      # 1024
    fi

    for MODEL_NAME in "${MODEL_NAMES[@]}"; do
      MODEL_SHORT_NAME=$(basename "$MODEL_NAME")
      echo "Running model: $MODEL_NAME"
      mkdir -p log/${DATASET}_${MODEL_SHORT_NAME}
      HF_HOME=/scratch/greghahn/huggingface python ../../code/TG_CG_main.py \
        --dataset "$DATASET" \
        --model_name "$MODEL_NAME" \
        --batch_size $BATCH_SIZE \
        --num_samples $NUM_SAMPLES \
        --max_length $MAX_LENGTH \
        --max_new_tokens $MAX_NEW_TOKENS \
        --use_instruction $USE_INSTRUCTION \
        --use_qlora $USE_QLORA \
        --skip_completed $SKIP_COMPLETED \
        --code_gen_mode $CODE_GEN_MODE \
        --sample_n $SAMPLE_N \
        > log/${DATASET}_${MODEL_SHORT_NAME}/${USE_INSTRUCTION}.log 
    done
  done
done
