#!/usr/bin/env bash
# Reproduce Table 4 — Llama-3.1-8B-Instruct-Q4_K_M via llama-cpp-python
# Two conditions run sequentially:
#   (A) Standard encoding  — vocab_size=6  (matches paper Table 4 settings exactly)
#   (B) Vocab-partitioned  — vocab_size=16 (block_size=5 ≥ max_val=5; novel ablation)
#
# Settings: N_dim=3, V_min=2, V_max=5, (O,S) ∈ {(1,1),(1,2),(4,1),(4,2)}
# Protocol: 5 seeds × 32 episodes each condition  (paper used 5 seeds × 64)
# Listener: discussion_cot (DSPy chain-of-thought over a multi-turn conversation)
#
# Requirements:
#   pip install llama-cpp-python (ROCm/HIP or CUDA build — see backend docstring)
#   Model downloaded automatically to ~/.cache/huggingface/hub/ on first run (~5 GB).
#
# Optional Weave tracing (uncomment WEAVE_PROJECT line below):
#   pip install 'meta-rg-s2b[weave]'
set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(dirname "$SCRIPT_DIR")"
cd "$REPO_ROOT"

source "$REPO_ROOT/../p311TheRockLM_venv/bin/activate"

CONFIG="configs/eval/llama3_8b_llamacpp.yaml"
N_SEEDS=5
N_EPISODES=32
MAX_NEW_TOKENS=256   # CoT needs more room than the default 64-token backend limit
BASE_SEED=0
WANDB_PROJECT="meta-rg-s2b"
# WEAVE_PROJECT="meta-rg-s2b"  # uncomment to enable Weave tracing

OUT_ROOT="outputs/eval/table4_llama8b_llamacpp"

# ── (A) Standard encoding ─────────────────────────────────────────────────────
# vocab_size=6 as in the paper; tokens 1-5 encode all latent values (shared range).
echo ""
echo "========================================================================"
echo "  (A) Standard encoding  |  vocab_size=6  |  Llama-3.1-8B discussion_cot"
echo "========================================================================"

python run_eval.py \
    --config "$CONFIG" \
    --prompt_strategy discussion_cot \
    --table4 \
    --n_seeds    "$N_SEEDS" \
    --n_episodes "$N_EPISODES" \
    --max_new_tokens "$MAX_NEW_TOKENS" \
    --base_seed  "$BASE_SEED" \
    --wandb_project "$WANDB_PROJECT" \
    --output_dir "$OUT_ROOT/standard"
    # --weave_project "$WEAVE_PROJECT"

# ── (B) Vocab-partition encoding ──────────────────────────────────────────────
# vocab_size=16: block_size = (16-1)//3 = 5 ≥ max_val=5.
# Each latent gets a disjoint token range:
#   latent 0 → tokens  1- 5
#   latent 1 → tokens  6-10
#   latent 2 → tokens 11-15
# This removes the ambiguity between same-valued tokens across dimensions.
echo ""
echo "========================================================================"
echo "  (B) Vocab-partition    |  vocab_size=16  |  Llama-3.1-8B discussion_cot"
echo "========================================================================"

python run_eval.py \
    --config "$CONFIG" \
    --prompt_strategy discussion_cot \
    --vocab_partition \
    --vocab_size 16 \
    --table4 \
    --n_seeds    "$N_SEEDS" \
    --n_episodes "$N_EPISODES" \
    --max_new_tokens "$MAX_NEW_TOKENS" \
    --base_seed  "$BASE_SEED" \
    --wandb_project "$WANDB_PROJECT" \
    --output_dir "$OUT_ROOT/vocab_partition"
    # --weave_project "$WEAVE_PROJECT"

echo ""
echo "========================================================================"
echo "  Done.  Results written to $OUT_ROOT/{standard,vocab_partition}/"
echo "========================================================================"
