#!/bin/bash

# Experiment configuration for inference
DATASET_TYPE="random-80-10-10"       # Options: canonical-symmetry-grouping, random-80-10-10
REPRESENTATION_MODE="nl"
INSTRUCT_MODEL="True"

# Set the model checkpoint path (update as needed)
# MODEL_CHECKPOINT="/mnt/data/data/stlm-logic/checkpoints/MBZUAI-LaMini-GPT-124M_random-80-10-10_legal_move_nl_grpo-nl-expt-final/checkpoint-5000"
# MODEL_CHECKPOINT="/mnt/data/data/stlm-logic/checkpoints/meta-llama-Llama-3.2-1B-Instruct_random-80-10-10_legal_move_nl_grpo-nl-expt-final/checkpoint-400"
# MODEL_CHECKPOINT="/mnt/data/data/stlm-logic/checkpoints/MBZUAI-LaMini-GPT-774M_random-80-10-10_legal_move_nl_grpo-nl-expt-final/checkpoint-600"

# MODEL_CHECKPOINT="/mnt/data/data/stlm-logic/checkpoints/meta-llama-Llama-3.2-1B-Instruct_canconical-symmetry-grouping_legal_move_nl_grpo-nl-expt-final/checkpoint-1200"

# Fixed prompts
# MODEL_CHECKPOINT="/mnt/data/data/stlm-logic/updated-checkpoints/meta-llama-Llama-3.2-1B-Instruct_canconical-symmetry-grouping_legal_move_nl_grpo-nl-expt-final-fixed-prompt-and-data-loading/checkpoint-1400"
# MODEL_CHECKPOINT="/mnt/data/data/stlm-logic/updated-checkpoints/meta-llama-Llama-3.2-1B-Instruct_random-80-10-10_legal_move_nl_grpo-nl-expt-final-fixed-prompt-and-data-loading/checkpoint-2400"


# MODEL_CHECKPOINT="/mnt/data/data/stlm-logic/updated-checkpoints/llamafy-Qwen-Qwen2.5-0.5B-Instruct-llamafied_random-80-10-10_legal_move_nl_grpo-nl-expt-final-fixed-prompt-and-data-loading/checkpoint-2800"
# MODEL_CHECKPOINT="/mnt/data/data/stlm-logic/updated-checkpoints/llamafy-Qwen-Qwen2.5-1.5B-Instruct-llamafied_random-80-10-10_legal_move_nl_grpo-nl-expt-final-fixed-prompt-and-data-loading/checkpoint-2200"

MODEL_CHECKPOINT="/mnt/data/data/stlm-logic/updated-checkpoints/MBZUAI-LaMini-GPT-774M_random-80-10-10_legal_move_nl_grpo-nl-expt-final-fixed-prompt-and-data-loading/checkpoint-1400"

# Extract the model mark from the parent directory of the checkpoint and append the checkpoint name.
# This will be used as the name of the log file 
MODEL_MARK="$(basename "$(dirname "$MODEL_CHECKPOINT")")-$(basename "$MODEL_CHECKPOINT")"
echo "Model mark: $MODEL_MARK"



# Specify the test dataset path based on dataset type
if [[ "$DATASET_TYPE" == "canonical-symmetry-grouping" ]]; then
    TEST_PATH="/mnt/data/data/stlm-logic/datasets/tictactoe_test.json"
elif [[ "$DATASET_TYPE" == "random-80-10-10" ]]; then
    TEST_PATH="/mnt/data/data/stlm-logic/datasets/random_test_dataset_0.8_0.1_0.1.json"
else
    echo "Invalid DATASET_TYPE selected!"
    exit 1
fi

# Run the inference script using the provided parameters.
python scripts/python/grpo_inference.py \
  --model_checkpoint "$MODEL_CHECKPOINT" \
  --test_dataset_path "$TEST_PATH" \
  --representation_mode "$REPRESENTATION_MODE" \
  --instruct_model "$INSTRUCT_MODEL" \
  --dataset_splits "test" \
  --batch_size 8 \
  --model_mark "$MODEL_MARK"
