#!/bin/bash
export PYTHONPATH="./src:$PYTHONPATH"

# Parameters
PROMPT_TEMPLATE="nl" # Options: "ascii", "nl"
MODEL_NAME="meta-llama/Meta-Llama-3-8B"
DATASET_PATH="./data/xlogomini-dataset-train.json"
NUM_EPOCHS=8    # Total number of fine-tuning epochs
RESAMPLE_FREQ=3 # The frequency of resampling
LORA_R=32
LORA_ALPHA=128
BATCHING_STRATEGY="padding"
SEED=42

source activate llama_recipes # Replace `llama_recipes` with the name of the environment installed with `requirements_ft.txt`

# Calculate total rounds, including the initial fine-tuning round
TOTAL_ROUNDS=$((NUM_EPOCHS / RESAMPLE_FREQ + 1))
LAST_ROUND_EPOCHS=$((NUM_EPOCHS % RESAMPLE_FREQ))

# Initial fine-tuning
EPOCHS=$RESAMPLE_FREQ
OUTPUT_DIR="./results/checkpoints/${PROMPT_TEMPLATE}/Meta-Llama-3-8B-Emu/R1"

echo "Starting initial fine-tuning for ${EPOCHS} epochs..."
python -m torch.distributed.launch \
  --nnodes 1 \
  --nproc_per_node 2 \
  src/xlogominiprog/finetuning.py \
  --dataset "custom_dataset" \
  --custom_dataset.file "./src/xlogominiprog/custom_dataset.py" \
  --custom_dataset.prompt_template ${PROMPT_TEMPLATE} \
  --model_name "${MODEL_NAME}" \
  --use_peft \
  --peft_method lora \
  --enable_fsdp \
  --fsdp_config.pure_bf16 \
  --output_dir "${OUTPUT_DIR}" \
  --use_fast_kernels \
  --train_config.num_epochs ${EPOCHS} \
  --lora_config.r ${LORA_R} \
  --lora_config.lora_alpha ${LORA_ALPHA} \
  --train_config.batching_strategy ${BATCHING_STRATEGY} \
  --train_config.seed ${SEED}
echo "Initial fine-tuning completed."

for ROUND in $(seq 1 $((TOTAL_ROUNDS - 1))); do
  if [[ $ROUND -eq $((TOTAL_ROUNDS - 1)) ]]; then
    EPOCHS=$LAST_ROUND_EPOCHS
  else
    EPOCHS=$RESAMPLE_FREQ
  fi

  PREV_OUTPUT_DIR="./results/checkpoints/${PROMPT_TEMPLATE}/Meta-Llama-3-8B-Emu/R${ROUND}"
  NEW_OUTPUT_DIR="./results/checkpoints/${PROMPT_TEMPLATE}/Meta-Llama-3-8B-Emu/R$((ROUND + 1))"

  # Inference and calculate emulator weights
  echo "Starting inference and emulator weight calculation for Round ${ROUND}..."
  source activate vllmenv

  python src/xlogominiprog/inference_ft_vllm.py \
    --model_name "${MODEL_NAME}" \
    --peft_model "${PREV_OUTPUT_DIR}/epoch_${EPOCHS}" \
    --top_p 1 \
    --temperature 0 \
    --dataset_path "${DATASET_PATH}" \
    --use_emulator_sample \
    --emulator_weight_save_path "./results/emu_weight/${PROMPT_TEMPLATE}/Meta-Llama-3-8B-Emu/r${ROUND}.json"
  echo "Inference and emulator weight calculation for Round ${ROUND} completed."

  # Next round fine-tuning
  echo "Starting fine-tuning for Round $((ROUND + 1)) for ${EPOCHS} epochs..."
  source activate llama_recipes

  python -m torch.distributed.launch \
    --nnodes 1 \
    --nproc_per_node 2 \
    src/xlogominiprog/finetuning.py \
    --dataset "custom_dataset" \
    --custom_dataset.file "./src/xlogominiprog/custom_dataset.py" \
    --custom_dataset.prompt_template ${PROMPT_TEMPLATE} \
    --model_name "${MODEL_NAME}" \
    --use_peft \
    --peft_method lora \
    --enable_fsdp \
    --fsdp_config.pure_bf16 \
    --output_dir "${NEW_OUTPUT_DIR}" \
    --use_fast_kernels \
    --train_config.num_epochs ${EPOCHS} \
    --lora_config.r ${LORA_R} \
    --lora_config.lora_alpha ${LORA_ALPHA} \
    --train_config.batching_strategy ${BATCHING_STRATEGY} \
    --train_config.seed ${SEED} \
    --train_config.peft_path "${PREV_OUTPUT_DIR}/epoch_${EPOCHS}" \
    --train_config.sample_weight_path "./results/emu_weight/${PROMPT_TEMPLATE}/Meta-Llama-3-8B-Emu/r${ROUND}.json"
  echo "Fine-tuning for Round $((ROUND + 1)) completed."
done

echo "All rounds completed."
