{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d926e0ab",
   "metadata": {},
   "source": [
    "## Early-Stopping criteria"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0eeceb20",
   "metadata": {},
   "source": [
    "In our paper, we introduce two early-stopping criteria:\n",
    "- **Ideal Early Stopping ($\\mathcal{IES}$):** We iteratively adds new steps into the final answer until the answer is contained in one of the steps (if any). If the answer is parsed, the algorithm stops and we return the pruned traces.\n",
    "- **Step-Tagging Early-Stopping (ST-ES):** Given the tags of a steps, our algorithm prune the remaining steps based on the count a specific step-type frequency."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46b27501",
   "metadata": {},
   "source": [
    "### A. Ideal Early-Stopping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76efe653",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from utils import process_answers_single\n",
    "\n",
    "def ideal_early_stoppping(steps, ground_truth):\n",
    "\n",
    "    \"\"\"\n",
    "    Ideal Early Stopping algorithm for step-by-step answer generation.\n",
    "\n",
    "    Inputs:\n",
    "        steps (list of str): A list of answer fragments (steps) that are concatenated \n",
    "            one by one to form the complete answer.\n",
    "        ground_truth (str): The correct final answer to compare against.\n",
    "\n",
    "    Returns:\n",
    "        tuple:\n",
    "            acc (bool): Whether the final accumulated answer is correct.\n",
    "            early_stop (bool): Whether early stopping occurred \n",
    "            es_step (int): The step index at which early stopping happened\n",
    "            answer (str): The ideal early-stopped pruned trace\n",
    "    \"\"\"\n",
    "\n",
    "    answer = \"\"\n",
    "    num_steps = len(steps)\n",
    "    early_stop = False\n",
    "    es_step = 0\n",
    "\n",
    "    for j in range(num_steps):\n",
    "\n",
    "        # accumulate the current answer with the current step\n",
    "        answer += steps[j]\n",
    "\n",
    "        try:\n",
    "\n",
    "            result = process_answers_single(gold=ground_truth, answer=answer, gold_is_latex=False)\n",
    "            result_acc = result[0]['is_correct']\n",
    "\n",
    "            if result_acc:\n",
    "                es_step = j\n",
    "                early_stop = True\n",
    "                break\n",
    "\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "    if not early_stop:\n",
    "        es_step = num_steps\n",
    "\n",
    "    result = process_answers_single(gold=ground_truth, answer=answer, gold_is_latex=False)\n",
    "    acc = result[0]['is_correct']\n",
    "\n",
    "    return (acc, early_stop, es_step, answer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecbede02",
   "metadata": {},
   "source": [
    "### Step-Tagging Early-Stopping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31f42df7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def step_tagging_early_stopping(trace, tags, constraint):\n",
    "    \"\"\"\n",
    "    Apply a simple tag-based constraint to a single reasoning trace.\n",
    "\n",
    "    Args:\n",
    "        trace (list of str): List of step texts.\n",
    "        tags (list of str): Corresponding list of tags for each step.\n",
    "        constraint (tuple): (tag, threshold) specifying how many times a tag can appear before stopping.\n",
    "\n",
    "    Returns:\n",
    "        str: the pruned reasoning traces\n",
    "    \"\"\"\n",
    "    tag, threshold = constraint\n",
    "    tag_counts = {tag: 0}\n",
    "    answer_steps = [trace[0]]  # initialize with the first step\n",
    "\n",
    "    for step, step_tag in zip(trace[1:], tags[1:]):\n",
    "        if step_tag == tag:\n",
    "            if tag_counts[tag] < threshold:\n",
    "                answer_steps.append(step)\n",
    "                tag_counts[tag] += 1\n",
    "            else:\n",
    "                # break the algorithm if the constraint is broken\n",
    "                break\n",
    "        else:\n",
    "            # tag not constrained, so keep it\n",
    "            answer_steps.append(step)\n",
    "\n",
    "    return {\n",
    "        \"answer_steps\": answer_steps,\n",
    "        \"used_tags\": tag_counts\n",
    "    }\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
