#!/usr/bin/env bash
set -euo pipefail

# ===== User-configurable =====
export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0}

model=<model_name>                            # e.g., qwen2.5-math-1.5b-instruct
N1=<n1_for_exploration>                       # e.g., 16
K=<top_k_for_calibration>                     # e.g., 4
MAX_TOKENS=1024
CALIB_EPOCHS=100
CALIB_BATCH_SIZE=32
CALIB_LR=0.005
INIT_TEMP=0.8
WEIGHT_DECAY=1e-2

MODEL_PATH=<path_to_your_model>               # e.g., /path/to/Models/${model}
CONFIG_PATH=<path_to_your_config_yaml>        # e.g., recipes/${model}/calibration.yaml
INPUT_DATASET_PATH=<path_to_calibration_dataset_jsonl>
OUTPUT_DIR=<path_to_output_dir>

# Optional ablations (uncomment to use exactly one of them)
# ABLATE_TEMPERATURE_FLAG="--ablate_temperature=True"   # learn delta only
# ABLATE_DELTA_FLAG="--ablate_delta=True"               # learn temperature only

# ===== Derived paths =====
OUTPUT_DELTA_PATH="${OUTPUT_DIR}/delta.npz"
OUTPUT_TEMPERATURE_PATH="${OUTPUT_DIR}/temperature.npz"
BIAS_PATH="${OUTPUT_DIR}/bias.npz"
mkdir -p "${OUTPUT_DIR}"

# ===== Info =====
echo "=== Joint Training Delta and Temperature (for beam search) ==="
echo "Model: ${model}"
echo "Input dataset: ${INPUT_DATASET_PATH}"
echo "Output delta: ${OUTPUT_DELTA_PATH}"
echo "Output temperature: ${OUTPUT_TEMPERATURE_PATH}"
echo "N1: ${N1}, K: ${K}"
echo "Epochs: ${CALIB_EPOCHS}, LR: ${CALIB_LR}, Init Temp: ${INIT_TEMP}, Weight Decay: ${WEIGHT_DECAY}"
echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}"
echo "========================================"

# ===== Joint training =====
echo "=== Starting Joint Training ==="
python joint_train_delta_temp.py "${CONFIG_PATH}" \
  --model_path="${MODEL_PATH}" \
  --input_dataset_path="${INPUT_DATASET_PATH}" \
  --output_delta_path="${OUTPUT_DELTA_PATH}" \
  --output_temperature_path="${OUTPUT_TEMPERATURE_PATH}" \
  --calib_epochs="${CALIB_EPOCHS}" \
  --calib_batch_size="${CALIB_BATCH_SIZE}" \
  --calib_lr="${CALIB_LR}" \
  --init_temp="${INIT_TEMP}" \
  --weight_decay="${WEIGHT_DECAY}" \
  --max_tokens="${MAX_TOKENS}" \
  --n1="${N1}" \
  --k="${K}" \
  --gradient_accumulation=True \
  --log_epoch_loss=True \
  ${ABLATE_TEMPERATURE_FLAG:-} ${ABLATE_DELTA_FLAG:-} \
  2>&1 | tee -a "${OUTPUT_DIR}/joint_train.log"

# ===== Convert delta to bias (if your generation path expects bias.npz) =====
echo "=== Converting Delta to Bias ==="
python tools/delta_to_bias.py \
  --model_path "${MODEL_PATH}" \
  --delta_path "${OUTPUT_DELTA_PATH}"

# ===== Beam-search generation =====
echo "=== Generating Completions (Beam Search) ==="
N2=${N1}
N=${N2}

python generate_with_temperature_and_bias_beam.py "${CONFIG_PATH}" \
  --model_path="${MODEL_PATH}" \
  --bias_file_path="${BIAS_PATH}" \
  --temperature_file_path="${OUTPUT_TEMPERATURE_PATH}" \
  --calib_output_path="${OUTPUT_DIR}" \
  --n2="${N2}" \
  --n="${N}" \
  2>&1 | tee "${OUTPUT_DIR}/generate_with_delta_and_temp_beam.log"

# ===== Accuracy on generated set =====
echo "=== Computing Accuracy (Generated) ==="
GEN_INPUT_FILE_PATH="${OUTPUT_DIR}/calibration_completions.jsonl"
GEN_OUTPUT_FILE_PATH="${OUTPUT_DIR}/accuracy.json"
python compute_accuracy.py "${GEN_INPUT_FILE_PATH}" --output "${GEN_OUTPUT_FILE_PATH}"

# ===== Merge with original best-of-N and evaluate =====
echo "=== Merging Completions ==="
merge_before_file="${INPUT_DATASET_PATH}"
merge_after_file="${OUTPUT_DIR}/calibration_completions.jsonl"
merge_output_file=<path_to_output_merged_completions_jsonl>

python tools/merge_completions.py \
  --before_file "${merge_before_file}" \
  --after_file "${merge_after_file}" \
  --n "${N1}" \
  --output_file "${merge_output_file}" \
  --aggregation_strategy last

# Evaluate merged results
echo "=== Computing Merge Completions Accuracy ==="
MERGED_INPUT_FILE_PATH="${merge_output_file}"
MERGED_OUTPUT_FILE_PATH=<path_to_output_merged_results_accuracy_json>
python compute_accuracy.py "${MERGED_INPUT_FILE_PATH}" --output "${MERGED_OUTPUT_FILE_PATH}"

echo "=== Done (Beam) ==="
echo "Delta saved to: ${OUTPUT_DELTA_PATH}"
echo "Temperature saved to: ${OUTPUT_TEMPERATURE_PATH}"
echo "Bias saved to: ${BIAS_PATH}"
echo "Generated accuracy: ${GEN_OUTPUT_FILE_PATH}"
echo "Merged accuracy: ${MERGED_OUTPUT_FILE_PATH}"
