#!/bin/bash

# set -e
ANSWER_MODELS="gpt-4o-2024-08-06"
MODEL_PROBS="1"

TEMPERATURE="0.6"
NUM_SIMULATIONS=5
NUM_EXPAND_SAMPLES=1
INITIAL_EXPAND_SAMPLES=1
REWARD_MODEL_NAME="Qwen/Qwen2.5-Math-RM-72B"
N_JOBS=64  # 同時に実行するジョブ数
MCTS_ALGO="standard"
# sample 128 problems
START_IDX=0
END_IDX=127

PROMPT_TYPE="omni_math_kou_v1"


EXPERIMENT_NAME="debug_omni_math_${ANSWER_MODELS}"

start_time=$(date +%s)
seq $START_IDX $END_IDX | parallel -j $N_JOBS \
    python scripts/omni_math/run_omni_math.py \
    --experiment_name $EXPERIMENT_NAME \
    --idx {} \
    --judge_model KbsdJames/Omni-Judge \
    --answer_models $ANSWER_MODELS \
    --answer_model_probs $MODEL_PROBS \
    --answer_temperatures $TEMPERATURE \
    --reward_model_name $REWARD_MODEL_NAME \
    --only_reward_model \
    --is_sigmoid \
    --num_simulations $NUM_SIMULATIONS \
    --num_expand_samples $NUM_EXPAND_SAMPLES \
    --initial_expand_samples $INITIAL_EXPAND_SAMPLES \
    --initial_prompt_type $PROMPT_TYPE \
    --mcts_algo $MCTS_ALGO
end_time=$(date +%s)
elapsed_time=$((end_time - start_time))
elapsed_time=$(echo "scale=2; $elapsed_time / 60" | bc)
echo "Elapsed time (run): $elapsed_time minutes"

start_time=$(date +%s)
seq $START_IDX $END_IDX | parallel -j $N_JOBS \
    python scripts/omni_math/evaluate_omni_math.py \
    --experiment_name $EXPERIMENT_NAME \
    --idx {} \
    --judge_model KbsdJames/Omni-Judge \
    --reward_model_name $REWARD_MODEL_NAME \
    --only_reward_model \
    --is_sigmoid
end_time=$(date +%s)
elapsed_time=$((end_time - start_time))
elapsed_time=$(echo "scale=2; $elapsed_time / 60" | bc)
echo "Elapsed time (evaluate): $elapsed_time minutes"

start_time=$(date +%s)
seq $START_IDX $END_IDX | parallel -j $N_JOBS \
    python scripts/omni_math/make_submission_omni_math.py \
    --experiment_name $EXPERIMENT_NAME \
    --idx {} \
    --judge_model KbsdJames/Omni-Judge \
    --reward_model_name $REWARD_MODEL_NAME \
    --only_reward_model \
    --is_sigmoid
end_time=$(date +%s)
elapsed_time=$((end_time - start_time))
elapsed_time=$(echo "scale=2; $elapsed_time / 60" | bc)
echo "Elapsed time (make submission): $elapsed_time minutes"

python scripts/omni_math/gather_results_omni_math.py \
    --experiment_name $EXPERIMENT_NAME

TOTAL_PROBLEMS=$((END_IDX - START_IDX + 1))
SOLVED_PROBLEMS=$(jq -s '[.[] | select(.test_score==1)] | length ' logging/omni_math/$EXPERIMENT_NAME/prediction.jsonl)
echo "Solved $SOLVED_PROBLEMS out of $TOTAL_PROBLEMS problems ($(echo "scale=2; 100 * $SOLVED_PROBLEMS / $TOTAL_PROBLEMS" | bc)%)"

