#!/usr/bin/env bash
# analyse_dlm_order.sh
# -*- coding: utf-8 -*-
#
# One-command launcher for DLM order analysis with SAE instrumentation.
# All customization is done here in bash.
#
# Key improvement for generation quality:
# - Build an n-shot GSM8K-style prompt (few-shot) and pass it via --prompt_file.
# - Enforce the final answer format: last line is "#### <integer>".
#
# Usage:
#   bash analyse_dlm_order.sh

set -euo pipefail

# ----------------------------
# 0) Paths and environment
# ----------------------------

PROJ_DIR="/home/dslabra5/sae4dlm/dlm"
OUTPUT_DIR="/home/dslabra5/sae4dlm/dlm/output"

cd "${PROJ_DIR}"

# ----------------------------
# 1) Model configuration (customizable)
# ----------------------------

MODEL_NAME="Dream-org/Dream-v0-Base-7B"
DEVICE="cuda:0"
DTYPE="bf16"
TRUST_REMOTE_CODE="--trust_remote_code"

# ----------------------------
# 2) Dream generation parameters (customizable)
# ----------------------------

# For Dream DLM, it's often helpful to keep steps and max_new_tokens reasonably large.
# If steps == max_new_tokens and "1 token per step", these roughly align.
STEPS=128
MAX_NEW_TOKENS=128

TEMPERATURE=0.7
TOP_P=0.95
ALG_TEMP=0
DO_SAMPLE="--do_sample"

ALGS="origin,entropy,topk_margin"

# Optional official kwargs for your Dream implementation (leave empty by default)
EXTRA_KWARGS=''

# ----------------------------
# 3) Prompt configuration (customizable)
# ----------------------------

# Customize n-shot here (0..8). Default 8.
N_SHOT=8

# Safe assignment for arbitrary text (contains $, ', ", etc.)
TARGET_PROBLEM="$(cat <<'EOF'
A candle melts by 2 centimeters every hour that it burns. How many centimeters shorter will a candle be after burning from 1:00 PM to 5:00 PM?
EOF
)"


PROMPT_FILE="${PROJ_DIR}/tmp_prompt_gsm8k_${N_SHOT}shot.txt"

USE_CHAT_TEMPLATE="--use_chat_template"
MASK_TOKEN_STR=""

export N_SHOT
export TARGET_PROBLEM
export PROMPT_FILE

python3 - <<'PY'
import os

n_shot = int(os.environ.get("N_SHOT", "8"))
target_problem = os.environ.get("TARGET_PROBLEM", "").strip()
prompt_file = os.environ.get("PROMPT_FILE", "").strip()
if not prompt_file:
    raise RuntimeError("PROMPT_FILE is empty.")

examples = [
    {
        "q": "To make pizza, together with other ingredients, Kimber needs 10 cups of water, 16 cups of flour, and 1/2 times as many teaspoons of salt as the number of cups of flour. Calculate the combined total number of cups of water, flour, and teaspoons of salt that she needs to make the pizza.",
        "a": "To make the pizza, Kimber half as many teaspoons of salt as the number of cups of flour, meaning she needs 1/2*16 = <<16*1/2=8>>8 teaspoons of salt.\nThe total number of cups of flour and teaspoons of salt she needs is 8+16 = <<8+16=24>>24\nShe also needs 10 cups of water, which means the total number of cups of water and flour and teaspoons of salt she needs is 24+10 = <<24+10=34>>34\n#### 34",
    },
    {
        "q": "Mr. Sam shared a certain amount of money between his two sons, Ken and Tony. If Ken got $1750, and Tony got twice as much as Ken, how much was the money shared?",
        "a": "Tony got twice $1750 which is 2*$1750 = $<<2*1750=3500>>3500\nThe total amount shared was $1750+$3500 = $<<1750+3500=5250>>5250\n#### 5250",
    },
    {
        "q": "Leah earned $28 working odd jobs around the neighborhood. She spent a seventh of it on a milkshake and put half of the rest in her savings account. She left the remaining money in her wallet. Her dog got ahold of her wallet and shredded all the money inside but $1. How many dollars did Leah lose?",
        "a": "Leah spent 28 / 7 = $<<28/7=4>>4 on a milkshake.\nShe had 28 - 4 = $<<28-4=24>>24 left.\nShe put half in her savings account and half in her wallet, so she had 24 / 2 = $<<24/2=12>>12 in her wallet.\nHer dog shredded all the money in her wallet but $1, so Leah lost 12 - 1 = $<<12-1=11>>11.\n#### 11",
    },
    {
        "q": "Leo's assignment was divided into three parts. He finished the first part of his assignment in 25 minutes. It took him twice as long to finish the second part. If he was able to finish his assignment in 2 hours, how many minutes did Leo finish the third part of the assignment?",
        "a": "It took Leo 25 x 2 = <<25*2=50>>50 minutes to finish the second part of the assignment.\nLeo finished the first and second parts of the assignment in 25 + 50 = <<25+50=75>>75 minutes.\nHe finished the entire assignment in 60 x 2 = <<60*2=120>>120 minutes.\nTherefore, it took Leo 120 - 75 = <<120-75=45>>45 minutes to finish the third part of the assignment.\n#### 45",
    },
    {
        "q": "Nancy is filling an aquarium for her fish. She fills it halfway and goes to answer the door. While she's gone, her cat knocks the aquarium over and spills half the water in it. Then Nancy comes back and triples the amount of water in the aquarium. If the aquarium is 4 feet long, 6 feet wide, and 3 feet high, how many cubic feet of water are in the aquarium?",
        "a": "First calculate the volume of the aquarium by multiplying its length, width and height: 4 ft * 6 ft * 3 ft = <<4*6*3=72>>72 cubic ft\nThen figure out what proportion of the aquarium is full after the cat knocks it over: 1/2 * 1/2 = 1/4\nThen figure out what proportion of the aquarium is full after Nancy refills it: 3 * 1/4 = 3/4\nNow multiply the proportion of the aquarium that's full by the aquarium's volume to find out how much water is in it: 72 cubic ft * 3/4 = <<72*3/4=54>>54 cubic ft\n#### 54",
    },
    {
        "q": "Manny had 3 birthday cookie pies to share with his 24 classmates and his teacher, Mr. Keith. If each of the cookie pies were cut into 10 slices and Manny, his classmates, and Mr. Keith all had 1 piece, how many slices are left?",
        "a": "There is a total of 3 x 10 = <<3*10=30>>30 cookie slices.\nThere are 24 + 1 + 1 = <<24+1+1=26>>26 people who ate the cookie pieces.\nThere is 30 - 26 = <<30-26=4>>4 cookie slices left.\n#### 4",
    },
    {
        "q": "Peter goes to the store to buy a soda. The soda costs $.25 an ounce. He brought $2 with him and leaves with $.50. How many ounces of soda did he buy?",
        "a": "He spend $1.5 on soda because 2 - .5 = <<2-.5=1.5>>1.5\nHe bought 6 ounces of soda because 1.5 / .25 = <<1.5/0.25=6>>6\n#### 6",
    },
    {
        "q": "Jerry’s two daughters play softball on different teams. They each have 8 games this season. Each team practices 4 hours for every game they play. If each game lasts for 2 hours, how many hours will Jerry spend at the field watching his daughters play and practice altogether?",
        "a": "Jerry will spend 8 games x 2 hours per game = <<8*2=16>>16 hours watching one daughter play her games.\nHe will spend 16 x 2 = <<16*2=32>>32 hours watching both daughters play their games.\nHe will spend 8 games x 4 hours of practice = <<8*4=32>>32 hours watching one daughter practice.\nHe will spend 32 x 2 = <<32*2=64>>64 hours watching both daughters practice.\nHe will spend a total of 32 hours watching games + 64 hours watching practice = <<32+64=96>>96 hours.\n#### 96",
    },
]


if n_shot < 0:
    n_shot = 0
if n_shot > len(examples):
    n_shot = len(examples)

header = (
    "You are a helpful math tutor.\n"
    "Solve the problem step by step.\n"
    "IMPORTANT:\n"
    "1) Once the answer is calculated, immediately stop outputting to avoid getting stuck in a repeat loop.\n"
    "2) The last sentence output MUST be: #### <Answer>\n"
    "3) Do not add anything after the final #### line.\n"
    "\n"
    "Here are some examples:\n"
)

parts = [header]
for ex in examples[:n_shot]:
    parts.append("Problem: " + ex["q"])
    parts.append("Answer:\n" + ex["a"])
    parts.append("")  # blank line

parts.append("Now solve the next problem.\n")
parts.append("Problem: " + target_problem)
parts.append("Answer:")

text = "\n".join(parts).strip() + "\n"

os.makedirs(os.path.dirname(prompt_file), exist_ok=True)
with open(prompt_file, "w", encoding="utf-8") as f:
    f.write(text)

print(f"[INFO] Wrote prompt_file: {prompt_file}")
print(f"[INFO] n_shot used: {n_shot}")
PY

# ----------------------------
# 4) SAE configuration (customizable)
# ----------------------------

SAE_ROOT_DIR="/home/dslabra5/sae4dlm/dictionary_learning_demo/saes_mask_Dream-org_Dream-v0-Base-7B_top_k"
SAE_LAYERS="5,14,23"
SAE_TRAINER=""
SAE_K="50"
SAE_TOPK=50

# ----------------------------
# 5) Position selection policy (customizable)
# ----------------------------

# For analysis completeness, full_gen captures all generated positions every step.
# It's expensive but makes Step1/2/3 far easier to interpret.
POSITIONS_MODE="full_gen" #full_gen, update_only, update_plus_anchors, mask_only

INCLUDE_ANSWER_POS="--include_answer_pos"
ANCHOR_HEAD_K=32
ANCHOR_TAIL_K=32
EXTRA_POSITIONS=""

ATTN_MASK_DTYPE="bool"

# ----------------------------
# 6) Similarity metric and analysis controls (customizable)
# ----------------------------

SIM_METRIC="jaccard"
MAX_STEPS_ANALYZE=0
SAVE_TOP1_TRACES=""
COMPUTE_STEP3="--compute_step3"
STEP3_TOPN=20

STRICT_DETERMINISM="--strict_determinism"
SEED=42

# ----------------------------
# 7) Experiment name (customizable)
# ----------------------------

EXP_NAME=""

# ----------------------------
# 8) Assemble optional args
# ----------------------------

PROMPT_ARGS=(--prompt_file "${PROMPT_FILE}")

SAE_ARGS=(
  --sae_root_dir "${SAE_ROOT_DIR}"
  --sae_layers "${SAE_LAYERS}"
  --sae_topk "${SAE_TOPK}"
)

if [[ -n "${SAE_TRAINER}" ]]; then
  SAE_ARGS+=(--sae_trainer "${SAE_TRAINER}")
fi

if [[ -n "${SAE_K}" ]]; then
  SAE_ARGS+=(--sae_k "${SAE_K}")
fi

MASK_ARGS=()
if [[ -n "${MASK_TOKEN_STR}" ]]; then
  MASK_ARGS+=(--mask_token_str "${MASK_TOKEN_STR}")
fi

EXTRA_KWARGS_ARGS=()
if [[ -n "${EXTRA_KWARGS}" ]]; then
  EXTRA_KWARGS_ARGS+=(--extra_generate_kwargs "${EXTRA_KWARGS}")
fi

EXP_ARGS=(--output_dir "${OUTPUT_DIR}")
if [[ -n "${EXP_NAME}" ]]; then
  EXP_ARGS+=(--exp_name "${EXP_NAME}")
fi

# ----------------------------
# 9) Run
# ----------------------------

# Export vars for the inline python snippet
export N_SHOT
export TARGET_PROBLEM
export PROMPT_FILE

python3 -u "${PROJ_DIR}/src/dlm_order.py" \
  --model_name "${MODEL_NAME}" \
  --device "${DEVICE}" \
  --dtype "${DTYPE}" \
  ${TRUST_REMOTE_CODE} \
  "${PROMPT_ARGS[@]}" \
  ${USE_CHAT_TEMPLATE} \
  --algs "${ALGS}" \
  --steps "${STEPS}" \
  --max_new_tokens "${MAX_NEW_TOKENS}" \
  --temperature "${TEMPERATURE}" \
  --top_p "${TOP_P}" \
  --alg_temp "${ALG_TEMP}" \
  ${DO_SAMPLE} \
  --seed "${SEED}" \
  ${STRICT_DETERMINISM} \
  "${SAE_ARGS[@]}" \
  --positions_mode "${POSITIONS_MODE}" \
  ${INCLUDE_ANSWER_POS} \
  --anchor_head_k "${ANCHOR_HEAD_K}" \
  --anchor_tail_k "${ANCHOR_TAIL_K}" \
  --extra_positions "${EXTRA_POSITIONS}" \
  --attn_mask_dtype "${ATTN_MASK_DTYPE}" \
  --sim_metric "${SIM_METRIC}" \
  --max_steps_analyze "${MAX_STEPS_ANALYZE}" \
  ${SAVE_TOP1_TRACES} \
  ${COMPUTE_STEP3} \
  --step3_topn "${STEP3_TOPN}" \
  "${MASK_ARGS[@]}" \
  "${EXTRA_KWARGS_ARGS[@]}"

echo ""
echo "[OK] Done. Check outputs under: ${OUTPUT_DIR}"
echo "[OK] Prompt file used: ${PROMPT_FILE}"
