#!/bin/bash

ANSWER_MODELS=$1

NUM_SIMULATIONS=$(($2 - 1))  # Need to subtract 1 because of the initial expansion
NUM_EXPAND_SAMPLES=1
INITIAL_EXPAND_SAMPLES=1

TEMPERATURE=$3
MODEL_PROBS=$4

SUFFIX=$5
if [ "$SUFFIX" ]; then
    SUFFIX="_${SUFFIX}"
fi

IS_FAIL=$6

MCTS_ALGO="standard"
PROMPT_TYPE="code_contest_lcb_single_turn"

EXPERIMENT_NAME="sequential-refinement_${ANSWER_MODELS}_init${INITIAL_EXPAND_SAMPLES}_nsim${NUM_SIMULATIONS}_nexp${NUM_EXPAND_SAMPLES}_prob${MODEL_PROBS}_temp${TEMPERATURE}_${PROMPT_TYPE}${SUFFIX}"


SPLIT="test"
START_IDX=0
END_IDX=164
N_JOBS=16
N_WORKERS=16

if [ "$IS_FAIL" ]; then
    INDICES_FILE="llm_mcts/tasks/code_contest/${EXPERIMENT_NAME}.txt"
    if [ ! -f $INDICES_FILE ]; then
        echo "IS_FAIL is true, but INDICES_FILE: $INDICES_FILE does not exist"
        exit 1
    fi
    echo "IS_FAIL is true, resume the experiment with INDICES_FILE: $INDICES_FILE"
else
    echo "IS_FAIL is false, run the experiment with all problems"
fi

start_time=$(date +%s)
if [ "$IS_FAIL" ]; then
    cat $INDICES_FILE | parallel -j $N_JOBS \
        python scripts/code_contest/run_code_contest.py \
        --experiment_name $EXPERIMENT_NAME \
        --idx {} \
        --split $SPLIT \
        --answer_models $ANSWER_MODELS \
        --answer_model_probs $MODEL_PROBS \
        --temperature $TEMPERATURE \
        --mcts_algo $MCTS_ALGO \
        --num_simulations $NUM_SIMULATIONS \
        --num_expand_samples $NUM_EXPAND_SAMPLES \
        --initial_expand_samples $INITIAL_EXPAND_SAMPLES \
        --initial_prompt_type $PROMPT_TYPE
else
    seq $START_IDX $END_IDX | parallel -j $N_JOBS \
        python scripts/code_contest/run_code_contest.py \
        --experiment_name $EXPERIMENT_NAME \
        --idx {} \
        --split $SPLIT \
        --answer_models $ANSWER_MODELS \
        --answer_model_probs $MODEL_PROBS \
        --temperature $TEMPERATURE \
        --mcts_algo $MCTS_ALGO \
        --num_simulations $NUM_SIMULATIONS \
        --num_expand_samples $NUM_EXPAND_SAMPLES \
        --initial_expand_samples $INITIAL_EXPAND_SAMPLES \
        --initial_prompt_type $PROMPT_TYPE
fi
end_time=$(date +%s)
elapsed_time=$((end_time - start_time))
elapsed_time=$(echo "scale=2; $elapsed_time / 60" | bc)
echo "Elapsed time (run): $elapsed_time minutes" >> "logging/code_contest/$EXPERIMENT_NAME/log.txt"

start_time=$(date +%s)
if [ "$IS_FAIL" ]; then
    for IDX in $(cat $INDICES_FILE); do
        python scripts/code_contest/evaluate_code_contest.py \
        --experiment_name $EXPERIMENT_NAME \
        --idx $IDX \
        --split $SPLIT \
        --num_workers $N_WORKERS
    done
else
    for IDX in $(seq $START_IDX $END_IDX); do
        python scripts/code_contest/evaluate_code_contest.py \
        --experiment_name $EXPERIMENT_NAME \
        --idx $IDX \
        --split $SPLIT \
        --num_workers $N_WORKERS
    done
fi
end_time=$(date +%s)
elapsed_time=$((end_time - start_time))
elapsed_time=$(echo "scale=2; $elapsed_time / 60" | bc)
echo "Elapsed time (evaluate): $elapsed_time minutes" >> "logging/code_contest/$EXPERIMENT_NAME/log.txt"

start_time=$(date +%s)
seq $START_IDX $END_IDX | parallel -j $N_JOBS \
    python scripts/code_contest/make_submission_code_contest.py \
    --experiment_name $EXPERIMENT_NAME \
    --idx {} \
    --split $SPLIT
end_time=$(date +%s)
elapsed_time=$((end_time - start_time))
elapsed_time=$(echo "scale=2; $elapsed_time / 60" | bc)
echo "Elapsed time (make submission): $elapsed_time minutes" >> "logging/code_contest/$EXPERIMENT_NAME/log.txt"

python scripts/code_contest/gather_results_code_contest.py \
    --experiment_name $EXPERIMENT_NAME

TOTAL_PROBLEMS=$((END_IDX - START_IDX + 1))
SOLVED_PROBLEMS=$(jq -s '[.[] | select(.private_tests.score==1)] | length ' logging/code_contest/$EXPERIMENT_NAME/prediction.jsonl)
echo "Solved $SOLVED_PROBLEMS out of $TOTAL_PROBLEMS problems ($(echo "scale=2; 100 * $SOLVED_PROBLEMS / $TOTAL_PROBLEMS" | bc)%)" >> "logging/code_contest/$EXPERIMENT_NAME/log.txt"
