#!/bin/bash

ANSWER_MODELS=$1

MCTS_ALGO=$2  # "pymc-thompson", "thompson-beta", "thompson-gaussian-tau-square-01"

NUM_SIMULATIONS=$(($3 - 1))  # Need to subtract 1 because of the initial node

TEMPERATURE=$4
MODEL_PROBS=$5

SUFFIX=$6
if [ "$SUFFIX" ]; then
    SUFFIX="_${SUFFIX}"
fi

IS_FAIL=$7

PROMPT_TYPE="live_code_bench_code_generation_v10_single_turn"

EXPERIMENT_NAME="live_code_bench_${MCTS_ALGO}_${ANSWER_MODELS}_temp${TEMPERATURE}_prob${MODEL_PROBS}_nsim${NUM_SIMULATIONS}${SUFFIX}"


RELEASE_VERSION="release_v4"
N_JOBS=16

if [ "$IS_FAIL" ]; then
    INDICES_FILE="llm_mcts/tasks/live_code_bench_code_generation/${EXPERIMENT_NAME}.txt"
    if [ ! -f $INDICES_FILE ]; then
        echo "IS_FAIL is true, but INDICES_FILE: $INDICES_FILE does not exist"
        exit 1
    fi
    echo "IS_FAIL is true, resume the experiment with INDICES_FILE: $INDICES_FILE"
else
    INDICES_FILE="llm_mcts/tasks/live_code_bench_code_generation/release_v4_202408_202411_indeces.txt"
fi


start_time=$(date +%s)
cat $INDICES_FILE | PYTHONPATH=".:$PYTHONPATH" parallel -j $N_JOBS \
    python scripts/live_code_bench_code_generation/run_live_code_bench_code_generation.py \
    --experiment_name $EXPERIMENT_NAME \
    --idx {} \
    --release_version $RELEASE_VERSION \
    --answer_models $ANSWER_MODELS \
    --answer_model_probs $MODEL_PROBS \
    --temperature $TEMPERATURE \
    --mcts_algo $MCTS_ALGO \
    --num_simulations $NUM_SIMULATIONS \
    --initial_prompt_type $PROMPT_TYPE
end_time=$(date +%s)
elapsed_time=$((end_time - start_time))
elapsed_time=$(echo "scale=2; $elapsed_time / 60" | bc)
echo "Elapsed time (run): $elapsed_time minutes" >> "logging/live_code_bench_code_generation/$EXPERIMENT_NAME/log.txt"

start_time=$(date +%s)
cat $INDICES_FILE | PYTHONPATH=".:$PYTHONPATH" parallel -j $N_JOBS \
    python scripts/live_code_bench_code_generation/evaluate_live_code_bench_code_generation.py \
    --experiment_name $EXPERIMENT_NAME \
    --idx {} \
    --release_version $RELEASE_VERSION
end_time=$(date +%s)
elapsed_time=$((end_time - start_time))
elapsed_time=$(echo "scale=2; $elapsed_time / 60" | bc)
echo "Elapsed time (evaluate): $elapsed_time minutes" >> "logging/live_code_bench_code_generation/$EXPERIMENT_NAME/log.txt"

start_time=$(date +%s)
cat $INDICES_FILE | PYTHONPATH=".:$PYTHONPATH" parallel -j $N_JOBS \
    python scripts/live_code_bench_code_generation/make_submission_live_code_bench_code_generation.py \
    --experiment_name $EXPERIMENT_NAME \
    --idx {} \
    --release_version $RELEASE_VERSION
end_time=$(date +%s)
elapsed_time=$((end_time - start_time))
elapsed_time=$(echo "scale=2; $elapsed_time / 60" | bc)
echo "Elapsed time (make submission): $elapsed_time minutes" >> "logging/live_code_bench_code_generation/$EXPERIMENT_NAME/log.txt"

python scripts/live_code_bench_code_generation/gather_results_live_code_bench_code_generation.py \
    --experiment_name $EXPERIMENT_NAME

INDICES_FILE="llm_mcts/tasks/live_code_bench_code_generation/release_v4_202408_202411_indeces.txt"
TOTAL_PROBLEMS=$(($(wc -l < $INDICES_FILE) + 1))
SOLVED_PROBLEMS=$(jq -s '[.[] | select(.private_tests.score==1)] | length ' logging/live_code_bench_code_generation/$EXPERIMENT_NAME/prediction.jsonl)
echo "Solved $SOLVED_PROBLEMS out of $TOTAL_PROBLEMS problems ($(echo "scale=2; 100 * $SOLVED_PROBLEMS / $TOTAL_PROBLEMS" | bc)%)" >> "logging/live_code_bench_code_generation/$EXPERIMENT_NAME/log.txt"
