{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **AdaptMI : Adaptive Skill-based In-context Math Instructions for Small Language Models**\n",
    "\n",
    "This is the jupyter notebook of the paper **AdaptMI : Adaptive Skill-based In-context Math Instructions for Small Language Models.**\n",
    "\n",
    "Our work compares the in-context learning of SLM (Small Language Models) with human learning from teachers in a classroom. Instead of feeding in a fixed set of in-context examples, we propose an AdaptMI, an **Adapt**ive approach to selecting skill-based in-context **M**ath **I**nstructions for SLMs.\n",
    "\n",
    "Inspired by cognitive load theory from human pedagogy, our method only introduces skill-based examples when the model performs poorly. Our method effectively boost the math reasoning accuracy of SLM by up to 6% over naive skill-based prompting strategies.\n",
    "\n",
    "![Method Overview](figures/Picture1.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**‼️ Caveats:** Due to resource limit, the notebook only tests AdaptMI on 50 MATH examples, which is way less than the test set (5k examples) in the paper. Therefore, the **exact** accuracy numbers (as well as accuracy gain) may deviate from Table 1 in the paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 🔔 Stage1: **Detection of _easy_ and _difficult_ questions**\n",
    "\n",
    "In this stage, we will label a question as _easy_ or _difficult_ for a Small Language Model.\n",
    "\n",
    "### 👉 Stage1-1: Initial evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Environmental Setup\n",
    "\n",
    "```bash\n",
    "conda create -n matheval python=3.10\n",
    "conda activate matheval\n",
    "\n",
    "cd evaluation/latex2sympy\n",
    "pip install -e .\n",
    "cd ..\n",
    "pip install torch\n",
    "pip install -r requirements.txt\n",
    "pip install vllm==0.5.1 --no-build-isolation\n",
    "pip install transformers==4.42.3\n",
    "conda install ipykernel\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please activate the environment `matheval`, and run the following cells:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: TOKENIZERS_PARALLELISM=true\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yh0068/.conda/envs/evaltest/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "2025-05-25 20:29:03,278\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n"
     ]
    }
   ],
   "source": [
    "%env TOKENIZERS_PARALLELISM=true\n",
    "import sys\n",
    "import os\n",
    "\n",
    "current_dir = os.getcwd()\n",
    "evaluation_dir_path = os.path.join(current_dir, 'evaluation')\n",
    "\n",
    "if evaluation_dir_path not in sys.path:\n",
    "    sys.path.insert(0, evaluation_dir_path)\n",
    "\n",
    "import os\n",
    "import json\n",
    "import logging\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "from types import SimpleNamespace\n",
    "from vllm import LLM, SamplingParams\n",
    "from transformers import AutoTokenizer, AutoConfig\n",
    "\n",
    "from evaluation.math_eval import *\n",
    "\n",
    "logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',\n",
    "                    datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Args:\n",
    "    def __init__(self):\n",
    "        self.data_names = \"math\"  # Example: \"gsm8k,math\"\n",
    "        self.data_dir = \"./evaluation/data\"\n",
    "        self.data_path = None # Default: None\n",
    "        self.model_name_or_path = \"models/Qwen2.5-1.5B-Instruct\" # Replace with your model\n",
    "        self.output_dir = \"./output/stage1_inference\"\n",
    "        self.prompt_type = \"qwen25-math-cot\"\n",
    "        self.split = \"test\"\n",
    "        self.num_test_sample = 50 # -1 for full data, set to a small number for testing\n",
    "        self.seed = 0\n",
    "        self.start = 0\n",
    "        self.end = -1 # -1 for all samples from start\n",
    "        self.temperature = 0.0\n",
    "        self.n_sampling = 1\n",
    "        self.top_p = 1.0\n",
    "        self.max_tokens_per_call = 1024 # Reduced for faster testing, adjust as needed\n",
    "        self.shuffle = True\n",
    "        self.use_vllm = False # Set to False if not using vLLM or for models not supported well by vLLM\n",
    "        self.save_outputs = True\n",
    "\n",
    "        # Ours\n",
    "        self.LLM_judge = False\n",
    "        self.PRM_judge = False\n",
    "        self.random_shots = False\n",
    "        self.llm_sol = False\n",
    "        \n",
    "        self.overwrite = True # Set to True to overwrite existing output files\n",
    "        self.use_safetensors = True # Recommended\n",
    "        self.num_shots = 5\n",
    "        self.num_skill_shots = 0\n",
    "        self.apply_chat_template = False # Set to True if your model expects chatml or similar\n",
    "        self.pipeline_parallel_size = 1\n",
    "        self.adapt_few_shot = False\n",
    "\n",
    "        # Auto-set top_p based on temperature for greedy sampling\n",
    "        self.top_p = (\n",
    "            1 if self.temperature == 0 else self.top_p\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed_value: int):\n",
    "    \"\"\"Sets the seed for reproducibility.\"\"\"\n",
    "    random.seed(seed_value)\n",
    "    np.random.seed(seed_value)\n",
    "    torch.manual_seed(seed_value)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed_value)\n",
    "    logger.info(f\"Set seed to {seed_value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_eval(args_obj: Args): # Type hint Args_obj with our class\n",
    "    # Ensure CUDA_VISIBLE_DEVICES is set, or vLLM might default unexpectedly\n",
    "    if args_obj.use_vllm and not os.environ.get(\"CUDA_VISIBLE_DEVICES\"):\n",
    "        logger.warning(\"CUDA_VISIBLE_DEVICES is not set. vLLM might default to all available GPUs or the first one.\")\n",
    "        # os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" # Optionally set a default if none is provided\n",
    "\n",
    "    available_gpus = os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"0\").split(\",\")\n",
    "    \n",
    "    llm_instance = None\n",
    "    tokenizer_instance = None\n",
    "\n",
    "    if args_obj.use_vllm:\n",
    "        logger.info(f\"Attempting to load model {args_obj.model_name_or_path} with vLLM.\")\n",
    "        # The rope_scaling logic from the original script can be complex and model-specific.\n",
    "        # vLLM often handles this automatically or via its own config.\n",
    "        # For simplicity, direct LLM initialization is shown here.\n",
    "        # If specific rope_scaling is needed, it should be passed to LLM constructor.\n",
    "        try:\n",
    "            llm_instance = LLM(\n",
    "                model=args_obj.model_name_or_path,\n",
    "                tensor_parallel_size=max(1, len(available_gpus) // args_obj.pipeline_parallel_size), # Ensure at least 1\n",
    "                pipeline_parallel_size=args_obj.pipeline_parallel_size,\n",
    "                dtype=\"bfloat16\" if torch.cuda.is_bf16_supported() else \"float16\",\n",
    "                trust_remote_code=True,\n",
    "                # max_model_len=max_tokens_per_call + some_buffer # Consider setting max_model_len\n",
    "            )\n",
    "            logger.info(f\"vLLM loaded model: {args_obj.model_name_or_path}\")\n",
    "        except Exception as e:\n",
    "            logger.error(f\"Failed to load model with vLLM: {e}\")\n",
    "            raise\n",
    "\n",
    "        if args_obj.apply_chat_template:\n",
    "            try:\n",
    "                tokenizer_instance = AutoTokenizer.from_pretrained(\n",
    "                    args_obj.model_name_or_path, trust_remote_code=True\n",
    "                )\n",
    "                logger.info(f\"Tokenizer loaded for chat template: {args_obj.model_name_or_path}\")\n",
    "            except Exception as e:\n",
    "                logger.error(f\"Failed to load tokenizer for chat template: {e}\")\n",
    "                # Depending on strictness, you might want to raise e or allow proceeding without chat template\n",
    "    else:\n",
    "        logger.info(f\"Attempting to load model {args_obj.model_name_or_path} with Hugging Face Transformers.\")\n",
    "        try:\n",
    "            tokenizer_instance, llm_instance, _ = load_model(args_obj.model_name_or_path, args_obj)\n",
    "            logger.info(f\"Hugging Face model and tokenizer loaded: {args_obj.model_name_or_path}\")\n",
    "        except Exception as e:\n",
    "            logger.error(f\"Failed to load model with Hugging Face: {e}\")\n",
    "            raise\n",
    "\n",
    "    # Infer & eval\n",
    "    data_list = args_obj.data_names.split(\",\")\n",
    "    results = []\n",
    "    for data_name_str in data_list:\n",
    "        data_name = data_name_str.strip()\n",
    "        if not data_name:\n",
    "            continue\n",
    "        logger.info(f\"\\nProcessing dataset: {data_name}\")\n",
    "        \n",
    "        dataset_result = main(llm_instance, tokenizer_instance, data_name, args_obj)\n",
    "        results.append(dataset_result)\n",
    "\n",
    "    if results:\n",
    "        summary_data_list = [name.strip() for name in data_list if name.strip()]\n",
    "        \n",
    "        if len(summary_data_list) > 1:\n",
    "            # Calculate average accuracy if multiple datasets were processed\n",
    "            valid_results_for_avg = [res for res in results if res and \"acc\" in res]\n",
    "            if valid_results_for_avg:\n",
    "                avg_acc = sum(res[\"acc\"] for res in valid_results_for_avg) / len(valid_results_for_avg)\n",
    "                results.append({\"acc\": avg_acc, \"data_name\": \"avg\"}) # Add data_name for clarity\n",
    "                summary_data_list.append(\"avg\")\n",
    "            else:\n",
    "                logger.warning(\"No valid results with 'acc' key found to calculate average.\")\n",
    "\n",
    "        logger.info(\"\\n\" + \"=\"*20 + \" Overall Summary \" + \"=\"*20)\n",
    "        \n",
    "        pad_width = max(len(name) for name in summary_data_list) if summary_data_list else 10\n",
    "\n",
    "        header_parts = []\n",
    "        score_parts = []\n",
    "        \n",
    "        res_idx = 0\n",
    "        for name in summary_data_list:\n",
    "            header_parts.append(name.ljust(pad_width))\n",
    "            current_res = None\n",
    "            if name == \"avg\" and results[-1].get(\"data_name\") == \"avg\": # Check if last result is avg\n",
    "                 current_res = results[-1]\n",
    "            elif res_idx < len(results) and results[res_idx].get(\"data_name\", data_list[res_idx].strip()) == name: # Check by original name\n",
    "                 current_res = results[res_idx]\n",
    "                 res_idx +=1\n",
    "            elif res_idx < len(results): # Fallback if data_name not in result but order might match\n",
    "                 current_res = results[res_idx]\n",
    "                 logger.warning(f\"Result for {name} matched by order, not explicit data_name key in result dict.\")\n",
    "                 res_idx +=1\n",
    "\n",
    "\n",
    "            if current_res and \"acc\" in current_res:\n",
    "                score_parts.append(f\"{current_res['acc']:.1f}\".ljust(pad_width))\n",
    "            else:\n",
    "                score_parts.append(\"N/A\".ljust(pad_width))\n",
    "                logger.warning(f\"Accuracy not found for dataset: {name}\")\n",
    "        \n",
    "        if header_parts:\n",
    "            final_header = \"\\t\".join(header_parts)\n",
    "            final_scores = \"\\t\".join(score_parts)\n",
    "            print(\"\\nResults Summary:\") # Print to console for easy viewing\n",
    "            print(final_header)\n",
    "            print(final_scores)\n",
    "            logger.info(\"Final results summary (also printed above):\")\n",
    "            logger.info(f\"Datasets: {final_header}\")\n",
    "            logger.info(f\"Accuracy: {final_scores}\")\n",
    "        else:\n",
    "            logger.info(\"No results to display in summary table.\")\n",
    "    else:\n",
    "        logger.info(\"No datasets processed or no results returned.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "05/25/2025 20:29:40 - INFO - evaluation.math_eval - loaded model with 1543714304 parameters\n",
      "05/25/2025 20:29:41 - INFO - evaluation.math_eval - setting tokenizer.model_max_length to 1024\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "data: math  ,remain samples: 100\n",
      "{'idx': 2795, 'problem': 'Find all the integer roots of\\n\\\\[x^4 + 5x^3 + 9x^2 - x - 14 = 0.\\\\]Enter all the integer roots, separated by commas.', 'level': 'Level 1', 'solution': 'By the Integer Root Theorem, the possible integer roots are all the divisors of 14 (including negative divisors), which are $-14,$ $-7,$ $-2,$ $-1,$ $1,$ $2,$ $7,$ and $14.$  Checking, we find that the only integer roots are $\\\\boxed{-2,1}.$', 'subject': 'Intermediate Algebra', 'unique_id': 'test/intermediate_algebra/1102.json', 'answer': '-2,1'}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 285.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------- Epoch 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating Completions: 100%|██████████| 100/100 [08:16<00:00,  4.96s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------- Epoch 1\n",
      "Unsolved samples: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate: 100%|██████████| 100/100 [00:01<00:00, 78.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'num_samples': 100, 'num_scores': 100, 'timeout_samples': 0, 'empty_samples': 0, 'acc': 45.0}\n",
      "Saved to ./output/stage1_inference/test_100_0+5shots.jsonl\n",
      "\n",
      "Results Summary:\n",
      "math\n",
      "45.0\n"
     ]
    }
   ],
   "source": [
    "# --- Initial evaluation ---\n",
    "args = Args()\n",
    "set_seed(args.seed)\n",
    "run_eval(args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 👉 Stage 1-2\n",
    "This stage classifies questions into _easy_ and _difficult_ according to the model's performance. `math-rm/rm_classify.py` employs a process reward model to assign scores for each step in the SLM response. We then use thresholds τ1, τ2 (`pred_thres1` and `pred_thres2` in the code) to classify whether a question q is easy or difficult."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Environmental Setup\n",
    "\n",
    "```bash\n",
    "conda create -n classify python=3.10.9\n",
    "conda activate classify\n",
    "\n",
    "git clone https://github.com/OpenAccess-AI-Collective/axolotl\n",
    "cd axolotl\n",
    "git checkout 55cc214c767741e83ee7b346e5e13e6c03b7b9fa\n",
    "pip install -e .\n",
    "\n",
    "pip3 install torch==2.1.2 torchvision torchaudio\n",
    "pip install flash-attn\n",
    "\n",
    "git clone https://github.com/lm-sys/FastChat.git\n",
    "cd FastChat\n",
    "pip install -e .\n",
    "\n",
    "git clone https://github.com/WeiXiongUST/RLHF-Reward-Modeling.git\n",
    "pip install deepspeed\n",
    "\n",
    "pip install -r math-rm/requirements.txt\n",
    "conda install ipykernel\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please activate the environment `classify`, and run the following cells:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yh0068/.conda/envs/prm_dev/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# Cell 1: Setup and Imports\n",
    "import sys\n",
    "import os\n",
    "import json\n",
    "import time\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.distributed as dist\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from datasets import load_dataset\n",
    "from accelerate import Accelerator # For accelerator.device, accelerator.num_processes, etc.\n",
    "from collections import Counter\n",
    "from tqdm import tqdm # For notebook-level progress if any; worker has its own\n",
    "\n",
    "# Add math-rm to Python path to import custom modules\n",
    "# Assumes the notebook is in the parent directory of 'math-rm'\n",
    "sys.path.append('./math-rm')\n",
    "from rm_classify import worker # Only worker is directly called from the main logic\n",
    "\n",
    "# For a cleaner log, you might want to set Transformers logging level\n",
    "import logging\n",
    "logging.getLogger(\"transformers\").setLevel(logging.ERROR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output directory: ./output/stage1_classified\n",
      "Using model: pwork7/llama31_it_prm_2e6_bz32_1epoch_conversation\n",
      "Using dataset: ./output/stage1_inference/test_100_0+5shots.jsonl\n",
      "Number of N candidates per question: 128\n",
      "Number of test samples: 100\n"
     ]
    }
   ],
   "source": [
    "# Cell 2: Define Arguments (mimicking argparse)\n",
    "class Args2:\n",
    "    def __init__(self):\n",
    "        self.reward_name_or_path = 'pwork7/llama31_it_prm_2e6_bz32_1epoch_conversation'\n",
    "        self.dataset = './output/stage1_inference/test_50_0+5shots.jsonl'\n",
    "        self.output_dir = \"./output/stage1_classified\"\n",
    "        self.pred_thres1 = 0.9\n",
    "        self.pred_thres2 = 0.7\n",
    "        self.num_n = 128  # Reduced for faster notebook execution, original was 1024\n",
    "        self.num_test_sample = 50\n",
    "        self.model_type = \"Deepseek\"\n",
    "\n",
    "args = Args2()\n",
    "\n",
    "# Create output directory if it doesn't exist\n",
    "os.makedirs(args.output_dir, exist_ok=True)\n",
    "print(f\"Output directory: {args.output_dir}\")\n",
    "print(f\"Using model: {args.reward_name_or_path}\")\n",
    "print(f\"Using dataset: {args.dataset}\")\n",
    "print(f\"Number of N candidates per question: {args.num_n}\")\n",
    "print(f\"Number of test samples: {args.num_test_sample if args.num_test_sample != -1 else 'ALL'}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Process 0/1 using device: cuda\n",
      "Loading dataset ./output/stage1_inference/test_100_0+5shots.jsonl...\n",
      "Failed to load dataset directly. Trying as json: Couldn't find a dataset script at /scratch/gpfs/yh0068/slm-math/adaptmi/output/stage1_inference/test_100_0+5shots.jsonl/test_100_0+5shots.jsonl.py or any data file in the same directory.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating test split: 100 examples [00:00, 6195.52 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Selected 100 samples for processing.\n",
      "Loading reward model pwork7/llama31_it_prm_2e6_bz32_1epoch_conversation...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 11.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model and tokenizer loaded successfully.\n",
      "Process 0/1: processing 100 samples (indices 0-99).\n",
      "Starting worker on process 0...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:30<00:00,  3.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Worker finished on process 0. Results: 100 classifications, 100 processed samples.\n",
      "Main process gathering results...\n",
      "Total gathered classification results: 100\n",
      "Total gathered samples for saving: 100\n",
      "\n",
      "Metrics:\n",
      "\n",
      "To stay consistent with the paper, positive means model failure, nagative means model success.\n",
      "\n",
      "TP: 54, FN: 1, FP: 20, TN: 25\n",
      "Total Predictions: 100\n",
      "Accuracy: 0.7900\n",
      "Precision: 0.7297\n",
      "Recall (Sensitivity): 0.9818\n",
      "F1 Score: 0.8372\n",
      "Specificity (captured failure case): 0.9818\n",
      "Metrics saved to ./output/stage1_classified/size100_thres1=0.9_thres2=0.7_metrics.json\n",
      "Processed data saved to ./output/stage1_classified/size100_thres1=0.9_thres2=0.7_save_data.jsonl\n",
      "Classification finished.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Cell 3: Main Logic (adapted from the original script's if __name__ == \"__main__\": block)\n",
    "\n",
    "# Initialize Accelerator\n",
    "accelerator = Accelerator()\n",
    "\n",
    "# Determine distributed training parameters from Accelerator, falling back to script's os.getenv approach for torch.dist\n",
    "# script_world_size for initializing torch.distributed backend\n",
    "# accelerator.num_processes for data sharding\n",
    "# accelerator.local_process_index for rank and device mapping\n",
    "ddp_world_size = int(os.getenv(\"WORLD_SIZE\", accelerator.num_processes))\n",
    "ddp_local_rank = int(os.getenv(\"LOCAL_RANK\", accelerator.local_process_index))\n",
    "\n",
    "# Use accelerator's properties for device to ensure consistency\n",
    "device = accelerator.device\n",
    "print(f\"Process {ddp_local_rank}/{ddp_world_size} using device: {device}\")\n",
    "\n",
    "# Load dataset\n",
    "print(f\"Loading dataset {args.dataset}...\")\n",
    "# Make sure the dataset has 'question', 'code', 'pred', 'score' fields as expected by select_sample\n",
    "# For 'RLHFlow/Deepseek-GSM8K-Test', it might have 'prompt' and 'label' or similar.\n",
    "# This example proceeds assuming the structure matches. You might need to preprocess/map fields.\n",
    "# ds = load_dataset(\"json\", data_files={\"test\": args.dataset}, split=\"test\") # Original way\n",
    "try:\n",
    "    ds = load_dataset(args.dataset, split=\"test\") # Simpler way if dataset is on Hugging Face Hub\n",
    "    # Example: rename columns if necessary for GSM8K\n",
    "    # ds = ds.rename_column(\"question\", \"prompt_text\") # Fictitious example\n",
    "    # ds = ds.map(lambda example: {'question': example['prompt_text'], ...})\n",
    "except Exception as e:\n",
    "    print(f\"Failed to load dataset directly. Trying as json: {e}\")\n",
    "    ds = load_dataset(\"json\", data_files={\"test\": args.dataset}, split=\"test\")\n",
    "\n",
    "\n",
    "if args.num_test_sample == -1:\n",
    "    num_sample = len(ds)\n",
    "else:\n",
    "    num_sample = min(args.num_test_sample, len(ds))\n",
    "ds = ds.select(range(num_sample))\n",
    "print(f\"Selected {len(ds)} samples for processing.\")\n",
    "\n",
    "# Load model and tokenizer\n",
    "print(f\"Loading reward model {args.reward_name_or_path}...\")\n",
    "downloaded = False\n",
    "while not downloaded:\n",
    "    try:\n",
    "        tokenizer = AutoTokenizer.from_pretrained(args.reward_name_or_path)\n",
    "        model = AutoModelForCausalLM.from_pretrained(\n",
    "            args.reward_name_or_path, \n",
    "            torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32\n",
    "        ).to(device).eval() # Model to the device determined by Accelerator\n",
    "        downloaded = True\n",
    "        print(\"Model and tokenizer loaded successfully.\")\n",
    "    except Exception as error:\n",
    "        print(f\"An error occurred during model loading: {error}\")\n",
    "        print(\"Retrying in 2 seconds...\")\n",
    "        time.sleep(2)\n",
    "\n",
    "tokenizer.padding_side = \"right\"\n",
    "if tokenizer.pad_token is None:\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "if model.config.pad_token_id is None:\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else model.config.eos_token_id\n",
    "\n",
    "\n",
    "# Prepare data for the current process (data sharding)\n",
    "data_size = len(ds)\n",
    "# Use accelerator.num_processes for sharding\n",
    "share = (data_size + accelerator.num_processes - 1) // accelerator.num_processes # ceiling division\n",
    "start_idx = accelerator.process_index * share\n",
    "end_idx = min((accelerator.process_index + 1) * share, data_size)\n",
    "\n",
    "# Select the portion of the dataset for this process\n",
    "current_ds_slice = ds.select(np.arange(start_idx, end_idx))\n",
    "print(f\"Process {accelerator.process_index}/{accelerator.num_processes}: processing {len(current_ds_slice)} samples (indices {start_idx}-{end_idx-1}).\")\n",
    "\n",
    "data_for_worker = [sample for sample in current_ds_slice]\n",
    "\n",
    "# Call worker function\n",
    "# The 'local_rank' argument for worker and select_sample is used for .to(local_rank)\n",
    "# We pass `device` (which could be 'cpu' or 'cuda:X') to ensure tensors are moved correctly.\n",
    "print(f\"Starting worker on process {accelerator.process_index}...\")\n",
    "selected_data, new_data = worker(args, model, tokenizer, data_for_worker, device)\n",
    "print(f\"Worker finished on process {accelerator.process_index}. Results: {len(selected_data)} classifications, {len(new_data)} processed samples.\")\n",
    "\n",
    "# Distributed data gathering\n",
    "# If running in a distributed environment (e.g. via accelerate launch)\n",
    "if accelerator.num_processes > 1:\n",
    "    # Ensure MASTER_ADDR and MASTER_PORT are set if not already by launch environment for torch.dist\n",
    "    os.environ.setdefault('MASTER_ADDR', 'localhost')\n",
    "    os.environ.setdefault('MASTER_PORT', '12355') # Ensure this port is free or configurable\n",
    "    \n",
    "    if not dist.is_initialized():\n",
    "        backend = 'nccl' if accelerator.use_cuda else 'gloo'\n",
    "        print(f\"Process {ddp_local_rank}: Initializing torch.distributed with backend {backend}, world_size {ddp_world_size}, rank {ddp_local_rank}\")\n",
    "        dist.init_process_group(\n",
    "            backend=backend,\n",
    "            rank=ddp_local_rank,\n",
    "            world_size=ddp_world_size\n",
    "        )\n",
    "\n",
    "# Prepare data for gathering (all_gather_object expects a list of objects to be populated)\n",
    "data_to_send_from_this_rank = {\n",
    "    \"selected_data_payload\": selected_data,\n",
    "    \"new_data_payload\": new_data\n",
    "}\n",
    "\n",
    "if accelerator.num_processes > 1:\n",
    "    gathered_dictionaries_list = [None] * accelerator.num_processes\n",
    "    dist.all_gather_object(gathered_dictionaries_list, data_to_send_from_this_rank)\n",
    "else:\n",
    "    gathered_dictionaries_list = [data_to_send_from_this_rank]\n",
    "\n",
    "gathered_classification_results = []\n",
    "gathered_full_samples = []\n",
    "\n",
    "# Process gathered data (only on the main process)\n",
    "if accelerator.is_main_process:\n",
    "    print(\"Main process gathering results...\")\n",
    "    for i in range(accelerator.num_processes):\n",
    "        data_from_rank_i = gathered_dictionaries_list[i]\n",
    "        if data_from_rank_i:\n",
    "            gathered_classification_results.extend(data_from_rank_i[\"selected_data_payload\"])\n",
    "            gathered_full_samples.extend(data_from_rank_i[\"new_data_payload\"])\n",
    "    \n",
    "    print(f\"Total gathered classification results: {len(gathered_classification_results)}\")\n",
    "    print(f\"Total gathered samples for saving: {len(gathered_full_samples)}\")\n",
    "\n",
    "    # Calculate metrics\n",
    "    counter = Counter(gathered_classification_results)\n",
    "    num_TN = counter[\"TP\"]\n",
    "    num_FP = counter[\"FN\"]\n",
    "    num_FN = counter[\"FP\"]\n",
    "    num_TP = counter[\"TN\"]\n",
    "\n",
    "    precision = 0\n",
    "    recall = 0\n",
    "    f1 = 0\n",
    "    \n",
    "    if num_TP + num_FP > 0:\n",
    "        precision = num_TP / (num_TP + num_FP)\n",
    "    if num_TP + num_FN > 0:\n",
    "        recall = num_TP / (num_TP + num_FN)\n",
    "    if precision + recall > 0:\n",
    "        f1 = 2 * precision * recall / (precision + recall)\n",
    "\n",
    "    accuracy = 0\n",
    "    total_predictions = num_TN + num_TP + num_FN + num_FP\n",
    "    if total_predictions > 0:\n",
    "        accuracy = (num_TN + num_TP) / total_predictions\n",
    "    \n",
    "    specificity = 0 # True Negative Rate\n",
    "    if num_TN + num_FP > 0:\n",
    "        specificity = num_TP / (num_TP + num_FN)\n",
    "\n",
    "    print(f\"\\nMetrics:\\n\")\n",
    "    print(\"To stay consistent with the paper, positive means model failure, nagative means model success.\\n\")\n",
    "    print(f\"TP: {num_TP}, FN: {num_FN}, FP: {num_FP}, TN: {num_TN}\")\n",
    "    print(f\"Total Predictions: {total_predictions}\")\n",
    "    print(f\"Accuracy: {accuracy:.4f}\")\n",
    "    print(f\"Precision: {precision:.4f}\")\n",
    "    print(f\"Recall (Sensitivity): {recall:.4f}\")\n",
    "    print(f\"F1 Score: {f1:.4f}\")\n",
    "    print(f\"Specificity (captured failure case): {specificity:.4f}\")\n",
    "\n",
    "    metrics_summary = {\n",
    "        \"TP\": num_TP, \"FN\": num_FN, \"FP\": num_FP, \"TN\": num_TN,\n",
    "        \"total_predictions\": total_predictions,\n",
    "        \"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1_score\": f1,\n",
    "        \"specificity\": specificity,\n",
    "        \"num_test_samples_processed\": len(gathered_full_samples) # Should match num_sample if all processed\n",
    "    }\n",
    "    \n",
    "    output_metrics_file = os.path.join(args.output_dir, f\"size{args.num_test_sample}_thres1={args.pred_thres1}_thres2={args.pred_thres2}_metrics.json\")\n",
    "    with open(output_metrics_file, 'w') as f:\n",
    "        json.dump(metrics_summary, f, indent=4, ensure_ascii=False)\n",
    "    print(f\"Metrics saved to {output_metrics_file}\")\n",
    "\n",
    "    output_data_file = os.path.join(args.output_dir, f\"size{args.num_test_sample}_thres1={args.pred_thres1}_thres2={args.pred_thres2}_save_data.jsonl\")\n",
    "    with open(output_data_file, 'w') as f:\n",
    "        for entry in gathered_full_samples:\n",
    "            f.write(json.dumps(entry) + \"\\n\")\n",
    "    print(f\"Processed data saved to {output_data_file}\")\n",
    "\n",
    "if accelerator.num_processes > 1 and dist.is_initialized():\n",
    "    dist.destroy_process_group()\n",
    "    print(f\"Process {ddp_local_rank}: Destroyed DDP process group.\")\n",
    "\n",
    "print(\"Classification finished.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 🔔 **Stage 2: Skill-based selection of in-context examples**\n",
    "\n",
    "- AdaptMI uses skill-based _k_-shot examples for _difficult_ questions and fixed _k_-shot examples for _easy_ questions.\n",
    "- AdaptMI+ focuses only on the skills that the model’s initial response lacks."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please activate the environment `matheval`, and run the following cells:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: TOKENIZERS_PARALLELISM=true\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yh0068/.conda/envs/evaltest/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "2025-05-25 20:48:52,832\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n"
     ]
    }
   ],
   "source": [
    "%env TOKENIZERS_PARALLELISM=true\n",
    "import sys\n",
    "import os\n",
    "\n",
    "current_dir = os.getcwd()\n",
    "evaluation_dir_path = os.path.join(current_dir, 'evaluation')\n",
    "\n",
    "if evaluation_dir_path not in sys.path:\n",
    "    sys.path.insert(0, evaluation_dir_path)\n",
    "\n",
    "import os\n",
    "import json\n",
    "import logging\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "from types import SimpleNamespace\n",
    "from vllm import LLM, SamplingParams\n",
    "from transformers import AutoTokenizer, AutoConfig\n",
    "\n",
    "from evaluation.math_eval import *\n",
    "\n",
    "logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',\n",
    "                    datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Args:\n",
    "    def __init__(self):\n",
    "        self.data_names = \"math-skill\"  # Example: \"gsm8k,math\"\n",
    "        self.data_dir = \"./evaluation/data\"\n",
    "        self.data_path = \"./output/stage1_classified/size50_thres1=0.9_thres2=0.7_save_data.jsonl\" # Default: None\n",
    "        self.model_name_or_path = \"models/Qwen2.5-1.5B-Instruct\" # Replace with your model\n",
    "        self.output_dir = \"./output/stage2_inference\"\n",
    "        self.prompt_type = \"qwen25-math-cot\"\n",
    "        self.split = \"test\"\n",
    "        self.num_test_sample = 50 # -1 for full data, set to a small number for testing\n",
    "        self.seed = 0\n",
    "        self.start = 0\n",
    "        self.end = -1 # -1 for all samples from start\n",
    "        self.temperature = 0.0\n",
    "        self.n_sampling = 1\n",
    "        self.top_p = 1.0\n",
    "        self.max_tokens_per_call = 1024 # Reduced for faster testing, adjust as needed\n",
    "        self.shuffle = True\n",
    "        self.use_vllm = False # Set to False if not using vLLM or for models not supported well by vLLM\n",
    "        self.save_outputs = True\n",
    "\n",
    "        # Ours\n",
    "        self.LLM_judge = False\n",
    "        self.PRM_judge = True\n",
    "        self.random_shots = False\n",
    "        self.llm_sol = False\n",
    "        \n",
    "        self.overwrite = True # Set to True to overwrite existing output files\n",
    "        self.use_safetensors = True # Recommended\n",
    "        self.num_shots = 5\n",
    "        self.num_skill_shots = 5\n",
    "        self.apply_chat_template = False # Set to True if your model expects chatml or similar\n",
    "        self.pipeline_parallel_size = 1\n",
    "        self.adapt_few_shot = False\n",
    "\n",
    "        # Auto-set top_p based on temperature for greedy sampling\n",
    "        self.top_p = (\n",
    "            1 if self.temperature == 0 else self.top_p\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed_value: int):\n",
    "    \"\"\"Sets the seed for reproducibility.\"\"\"\n",
    "    random.seed(seed_value)\n",
    "    np.random.seed(seed_value)\n",
    "    torch.manual_seed(seed_value)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed_value)\n",
    "    logger.info(f\"Set seed to {seed_value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_eval(args_obj: Args): # Type hint Args_obj with our class\n",
    "    # Ensure CUDA_VISIBLE_DEVICES is set, or vLLM might default unexpectedly\n",
    "    if args_obj.use_vllm and not os.environ.get(\"CUDA_VISIBLE_DEVICES\"):\n",
    "        logger.warning(\"CUDA_VISIBLE_DEVICES is not set. vLLM might default to all available GPUs or the first one.\")\n",
    "        # os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" # Optionally set a default if none is provided\n",
    "\n",
    "    available_gpus = os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"0\").split(\",\")\n",
    "    \n",
    "    llm_instance = None\n",
    "    tokenizer_instance = None\n",
    "\n",
    "    if args_obj.use_vllm:\n",
    "        logger.info(f\"Attempting to load model {args_obj.model_name_or_path} with vLLM.\")\n",
    "        # The rope_scaling logic from the original script can be complex and model-specific.\n",
    "        # vLLM often handles this automatically or via its own config.\n",
    "        # For simplicity, direct LLM initialization is shown here.\n",
    "        # If specific rope_scaling is needed, it should be passed to LLM constructor.\n",
    "        try:\n",
    "            llm_instance = LLM(\n",
    "                model=args_obj.model_name_or_path,\n",
    "                tensor_parallel_size=max(1, len(available_gpus) // args_obj.pipeline_parallel_size), # Ensure at least 1\n",
    "                pipeline_parallel_size=args_obj.pipeline_parallel_size,\n",
    "                dtype=\"bfloat16\" if torch.cuda.is_bf16_supported() else \"float16\",\n",
    "                trust_remote_code=True,\n",
    "                # max_model_len=max_tokens_per_call + some_buffer # Consider setting max_model_len\n",
    "            )\n",
    "            logger.info(f\"vLLM loaded model: {args_obj.model_name_or_path}\")\n",
    "        except Exception as e:\n",
    "            logger.error(f\"Failed to load model with vLLM: {e}\")\n",
    "            raise\n",
    "\n",
    "        if args_obj.apply_chat_template:\n",
    "            try:\n",
    "                tokenizer_instance = AutoTokenizer.from_pretrained(\n",
    "                    args_obj.model_name_or_path, trust_remote_code=True\n",
    "                )\n",
    "                logger.info(f\"Tokenizer loaded for chat template: {args_obj.model_name_or_path}\")\n",
    "            except Exception as e:\n",
    "                logger.error(f\"Failed to load tokenizer for chat template: {e}\")\n",
    "                # Depending on strictness, you might want to raise e or allow proceeding without chat template\n",
    "    else:\n",
    "        logger.info(f\"Attempting to load model {args_obj.model_name_or_path} with Hugging Face Transformers.\")\n",
    "        try:\n",
    "            tokenizer_instance, llm_instance, _ = load_model(args_obj.model_name_or_path, args_obj)\n",
    "            logger.info(f\"Hugging Face model and tokenizer loaded: {args_obj.model_name_or_path}\")\n",
    "        except Exception as e:\n",
    "            logger.error(f\"Failed to load model with Hugging Face: {e}\")\n",
    "            raise\n",
    "\n",
    "    # Infer & eval\n",
    "    data_list = args_obj.data_names.split(\",\")\n",
    "    results = []\n",
    "    for data_name_str in data_list:\n",
    "        data_name = data_name_str.strip()\n",
    "        if not data_name:\n",
    "            continue\n",
    "        logger.info(f\"\\nProcessing dataset: {data_name}\")\n",
    "        \n",
    "        dataset_result = main(llm_instance, tokenizer_instance, data_name, args_obj)\n",
    "        results.append(dataset_result)\n",
    "\n",
    "    if results:\n",
    "        summary_data_list = [name.strip() for name in data_list if name.strip()]\n",
    "        \n",
    "        if len(summary_data_list) > 1:\n",
    "            # Calculate average accuracy if multiple datasets were processed\n",
    "            valid_results_for_avg = [res for res in results if res and \"acc\" in res]\n",
    "            if valid_results_for_avg:\n",
    "                avg_acc = sum(res[\"acc\"] for res in valid_results_for_avg) / len(valid_results_for_avg)\n",
    "                results.append({\"acc\": avg_acc, \"data_name\": \"avg\"}) # Add data_name for clarity\n",
    "                summary_data_list.append(\"avg\")\n",
    "            else:\n",
    "                logger.warning(\"No valid results with 'acc' key found to calculate average.\")\n",
    "\n",
    "        logger.info(\"\\n\" + \"=\"*20 + \" Overall Summary \" + \"=\"*20)\n",
    "        \n",
    "        pad_width = max(len(name) for name in summary_data_list) if summary_data_list else 10\n",
    "\n",
    "        header_parts = []\n",
    "        score_parts = []\n",
    "        \n",
    "        res_idx = 0\n",
    "        for name in summary_data_list:\n",
    "            header_parts.append(name.ljust(pad_width))\n",
    "            current_res = None\n",
    "            if name == \"avg\" and results[-1].get(\"data_name\") == \"avg\": # Check if last result is avg\n",
    "                 current_res = results[-1]\n",
    "            elif res_idx < len(results) and results[res_idx].get(\"data_name\", data_list[res_idx].strip()) == name: # Check by original name\n",
    "                 current_res = results[res_idx]\n",
    "                 res_idx +=1\n",
    "            elif res_idx < len(results): # Fallback if data_name not in result but order might match\n",
    "                 current_res = results[res_idx]\n",
    "                 logger.warning(f\"Result for {name} matched by order, not explicit data_name key in result dict.\")\n",
    "                 res_idx +=1\n",
    "\n",
    "\n",
    "            if current_res and \"acc\" in current_res:\n",
    "                score_parts.append(f\"{current_res['acc']:.1f}\".ljust(pad_width))\n",
    "            else:\n",
    "                score_parts.append(\"N/A\".ljust(pad_width))\n",
    "                logger.warning(f\"Accuracy not found for dataset: {name}\")\n",
    "        \n",
    "        if header_parts:\n",
    "            final_header = \"\\t\".join(header_parts)\n",
    "            final_scores = \"\\t\".join(score_parts)\n",
    "            print(\"\\nResults Summary:\") # Print to console for easy viewing\n",
    "            print(final_header)\n",
    "            print(final_scores)\n",
    "            logger.info(\"Final results summary (also printed above):\")\n",
    "            logger.info(f\"Datasets: {final_header}\")\n",
    "            logger.info(f\"Accuracy: {final_scores}\")\n",
    "        else:\n",
    "            logger.info(\"No results to display in summary table.\")\n",
    "    else:\n",
    "        logger.info(\"No datasets processed or no results returned.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "05/25/2025 20:49:14 - INFO - evaluation.math_eval - loaded model with 1543714304 parameters\n",
      "05/25/2025 20:49:15 - INFO - evaluation.math_eval - setting tokenizer.model_max_length to 1024\n",
      "Generating train split: 100 examples [00:00, 4729.39 examples/s]\n",
      "Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 126.03ba/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "data: math-skill  ,remain samples: 100\n",
      "{'idx': 1139, 'question': 'The area of triangle $ABC$ is equal to $a^2 - (b - c)^2,$ where $a,$ $b,$ and $c$ are the sides of triangle $ABC,$ as usual.  Compute $\\\\tan A.$', 'gt_cot': 'The area of triangle $ABC$ is given by\\n\\\\[\\\\frac{1}{2} bc \\\\sin A.\\\\]Hence,\\n\\\\[\\\\frac{1}{2} bc \\\\sin A = a^2 - (b - c)^2 = a^2 - b^2 + 2bc - c^2.\\\\]By the Law of Cosines, $b^2 + c^2 - 2bc \\\\cos A = a^2,$ so\\n\\\\[\\\\frac{1}{2} bc \\\\sin A = 2bc - 2bc \\\\cos A.\\\\]This simplifies to $\\\\sin A = 4 - 4 \\\\cos A.$  Squaring both sides, we get\\n\\\\[\\\\sin^2 A = 16 - 32 \\\\cos A + 16 \\\\cos^2 A,\\\\]so $1 - \\\\cos^2 A = 16 - 32 \\\\cos A + 16 \\\\cos^2 A.$  This simplifies to\\n\\\\[17 \\\\cos^2 A - 32 \\\\cos A + 15 = 0.\\\\]This factors as $(\\\\cos A - 1)(17 \\\\cos A - 15) = 0.$  Since $\\\\cos A$ cannot be equal to 1, $\\\\cos A = \\\\frac{15}{17}.$\\n\\nThen $\\\\sin A = 4 - 4 \\\\cos A = \\\\frac{8}{17},$ so\\n\\\\[\\\\tan A = \\\\frac{\\\\sin A}{\\\\cos A} = \\\\boxed{\\\\frac{8}{15}}.\\\\]', 'gt': '\\\\frac{8}{15}', 'prompt': \"<|im_start|>system\\nPlease reason step by step, and put your final answer within \\\\boxed{}.<|im_end|>\\n<|im_start|>user\\nQuestion: Kevin Kangaroo begins hopping on a number line at 0. He wants to get to 1, but he can hop only $\\\\frac{1}{3}$ of the distance. Each hop tires him out so that he continues to hop $\\\\frac{1}{3}$ of the remaining distance. How far has he hopped after five hops? Express your answer as a common fraction.\\n\\nResponse: Let's think step by step\\n\\nKevin hops $1/3$ of the remaining distance with every hop.\\n\\nHis first hop takes $1/3$ closer.\\n\\nFor his second hop, he has $2/3$ left to travel, so he hops forward $(2/3)(1/3)$.\\n\\nFor his third hop, he has $(2/3)^2$ left to travel, so he hops forward $(2/3)^2(1/3)$.\\n\\nIn general, Kevin hops forward $(2/3)^{k-1}(1/3)$ on his $k$th hop.\\n\\nWe want to find how far he has hopped after five hops.\\n\\nThis is a finite geometric series with first term $1/3$, common ratio $2/3$, and five terms.\\n\\nThus, Kevin has hopped $\\\\frac{\\\\frac{1}{3}\\\\left(1-\\\\left(\\\\frac{2}{3}\\\\right)^5\\\\right)}{1-\\\\frac{2}{3}} = \\\\boxed{\\\\frac{211}{243}}$.\\n\\nThe answer is \\\\frac{211}{243}}\\n\\nQuestion: What is the area of the region defined by the equation $x^2+y^2 - 7 = 4y-14x+3$?\\n\\nResponse: Let's think step by step\\n\\nWe rewrite the equation as $x^2 + 14x + y^2 - 4y = 10$ and then complete the square,\\n\\nresulting in  $(x+7)^2-49 + (y-2)^2-4=10$,\\n\\nor $(x+7)^2+(y-2)^2=63$.\\n\\nThis is the equation of a circle with center $(-7, 2)$ and radius $\\\\sqrt{63},$\\n\\nso the area of this region is $\\\\pi r^2 = \\\\boxed{63\\\\pi}$.\\n\\nThe answer is 63\\\\pi\\n\\nQuestion: If $x^2+y^2=1$, what is the largest possible value of $|x|+|y|$?\\n\\nResponse: Let's think step by step\\n\\nIf $(x,y)$ lies on the circle,\\n\\nso does $(x,-y),$ $(-x,-y),$ and $(-x,-y),$ (which all give the same value of $|x| + |y|$),\\n\\nso we can assume that $x \\\\ge 0$ and $y \\\\ge 0.$\\n\\nThen $|x| + |y| = x + y.$  Squaring, we get\\n\\n\\\\[(x + y)^2 = x^2 + 2xy + y^2 = 1 + 2xy.\\\\]\\n\\nNote that $(x - y)^2 \\\\ge 0.$\\n\\nExpanding, we get $x^2 - 2xy + y^2 \\\\ge 0,$ so $2xy \\\\le x^2 + y^2 = 1.$\\n\\nHence,\\\\[1 + 2xy \\\\le 2,\\\\]which means $x + y \\\\le \\\\sqrt{2}.$\\n\\nEquality occurs when $x = y = \\\\frac{1}{\\\\sqrt{2}},$\\n\\nso the maximum value of $|x| + |y|$ is $\\\\boxed{\\\\sqrt{2}}.$\\n\\nThe answer is \\\\sqrt{2}\\n\\nQuestion: If $f(x)=\\\\frac{ax+b}{cx+d}, abcd\\\\n\\not=0$ and $f(f(x))=x$ for all $x$ in the domain of $f$, what is the value of $a+d$?\\n\\nResponse: Let's think step by step\\n\\nThe condition $f(f(x))$ means that $f$ is the inverse of itself,\\n\\nso its graph is symmetrical about the line $y = x$.\\n\\nWith a rational function of this form, we will have two asymptotes:\\n\\na vertical one at $x=-d/c$ if $cx+d$ does not divide $ax+b$,\\n\\nand a horizontal one at $y=a/c$,\\n\\nif we take the limit of $f(x)$ as $x$ goes to $\\\\pm\\\\infty$.\\n\\nIn order for $f$ to be its own inverse, the intersection of the asymptotes must lie on the line $y=x$\\n\\nso that it and its asymptotes reflect onto themselves.\\n\\nThis means that $-d/c=a/c$,\\n\\nand therefore $-d=a$ and $a+d=\\\\boxed{0}$.\\n\\nThe answer is 0\\n\\nQuestion: Expand $(2z^2 + 5z - 6)(3z^3 - 2z + 1)$.\\n\\nResponse: Let's think step by step\\n\\n$$\\\\begin{array}{crrrrrrr}\\n\\n& & & 3z^3 & & -2z & + 1 & \\\\\\\\\\n\\n\\\\times & & & & 2z^2 & +5z & -6 \\\\\\\\\\n\\n\\\\cline{1-7}\\\\rule{0pt}{0.17in}\\n\\n& & & -18z^3 & & +12z & -6 & \\\\\\\\\\n\\n& & +15z^4 & & -10z^2 & +5z & & \\\\\\\\\\n\\n+ & 6z^5 & & -4z^3 & +2z^2 & & & \\\\\\\\\\n\\n\\\\cline{1-7}\\\\rule{0pt}{0.17in}\\n\\n& 6z^5 & +15z^4 & -22z^3 & - 8z^2 &+17z & -6 &\\n\\n\\\\end{array}$$\\n\\nThe answer is 6z^5+15z^4-22z^3-8z^2+17z-6\\n\\nQuestion: The area of triangle $ABC$ is equal to $a^2 - (b - c)^2,$ where $a,$ $b,$ and $c$ are the sides of triangle $ABC,$ as usual.  Compute $\\\\tan A.$<|im_end|>\\n<|im_start|>assistant\\n\", 'level': 'Level 3', 'solution': 'The area of triangle $ABC$ is given by\\n\\\\[\\\\frac{1}{2} bc \\\\sin A.\\\\]Hence,\\n\\\\[\\\\frac{1}{2} bc \\\\sin A = a^2 - (b - c)^2 = a^2 - b^2 + 2bc - c^2.\\\\]By the Law of Cosines, $b^2 + c^2 - 2bc \\\\cos A = a^2,$ so\\n\\\\[\\\\frac{1}{2} bc \\\\sin A = 2bc - 2bc \\\\cos A.\\\\]This simplifies to $\\\\sin A = 4 - 4 \\\\cos A.$  Squaring both sides, we get\\n\\\\[\\\\sin^2 A = 16 - 32 \\\\cos A + 16 \\\\cos^2 A,\\\\]so $1 - \\\\cos^2 A = 16 - 32 \\\\cos A + 16 \\\\cos^2 A.$  This simplifies to\\n\\\\[17 \\\\cos^2 A - 32 \\\\cos A + 15 = 0.\\\\]This factors as $(\\\\cos A - 1)(17 \\\\cos A - 15) = 0.$  Since $\\\\cos A$ cannot be equal to 1, $\\\\cos A = \\\\frac{15}{17}.$\\n\\nThen $\\\\sin A = 4 - 4 \\\\cos A = \\\\frac{8}{17},$ so\\n\\\\[\\\\tan A = \\\\frac{\\\\sin A}{\\\\cos A} = \\\\boxed{\\\\frac{8}{15}}.\\\\]', 'answer': '\\\\frac{8}{15}', 'subject': 'Precalculus', 'unique_id': 'test/precalculus/1082.json', 'correct_flagged': False, 'code': ['To solve for \\\\(\\\\tan A\\\\) given that the area of triangle \\\\(ABC\\\\) is \\\\(a^2 - (b - c)^2\\\\), we start by using the standard formula for the area of a triangle involving sine:\\n\\n\\\\[\\n\\\\text{Area} = \\\\frac{1}{2}bc \\\\sin A\\n\\\\]\\n\\nGiven:\\n\\\\[\\n\\\\text{Area} = a^2 - (b - c)^2\\n\\\\]\\n\\nFirst, simplify the expression inside the parentheses:\\n\\\\[\\n(b - c)^2 = b^2 - 2bc + c^2\\n\\\\]\\nThus,\\n\\\\[\\na^2 - (b - c)^2 = a^2 - (b^2 - 2bc + c^2) = a^2 - b^2 + 2bc - c^2\\n\\\\]\\n\\nSo, the area becomes:\\n\\\\[\\n\\\\text{Area} = a^2 - b^2 + 2bc - c^2\\n\\\\]\\n\\nUsing the standard area formula again:\\n\\\\[\\n\\\\text{Area} = \\\\frac{1}{2} bc \\\\sin A\\n\\\\]\\n\\nEquating the two expressions for the area:\\n\\\\[\\n\\\\frac{1}{2} bc \\\\sin A = a^2 - b^2 + 2bc - c^2\\n\\\\]\\n\\nMultiply both sides by 2 to clear the fraction:\\n\\\\[\\nbc \\\\sin A = 2(a^2 - b^2 + 2bc - c^2)\\n\\\\]\\n\\nNext, we need to express \\\\(a^2 - b^2\\\\) in a different form. Recall the identity:\\n\\\\[\\na^2 - b^2 = (a + b)(a - b)\\n\\\\]\\n\\nSubstitute this into our equation:\\n\\\\[\\nbc \\\\sin A = 2((a + b)(a - b) + 2bc - c^2)\\n\\\\]\\n\\nDistribute the 2:\\n\\\\[\\nbc \\\\sin A = 2a^2 - 2b^2 + 2ab + 4bc - 2c^2\\n\\\\]\\n\\nRearrange terms:\\n\\\\[\\nbc \\\\sin A = 2a^2 + 2ab + 4bc - 2b^2 - 2c^2\\n\\\\]\\n\\nFactor out common terms:\\n\\\\[\\nbc \\\\sin A = 2(a^2 + ab + 2c^2 - b^2 - c^2)\\n\\\\]\\n\\nSimplify further:\\n\\\\[\\nbc \\\\sin A = 2(a^2 + ab + c^2 - b^2 - c^2)\\n\\\\]\\n\\nNotice that:\\n\\\\[\\na^2 + c^2 - b^2 = 2ac \\\\cos B\\n\\\\]\\n\\nThus:\\n\\\\[\\nbc \\\\sin A = 2(2ac \\\\cos B)\\n\\\\]\\n\\nDivide both sides by \\\\(2bc\\\\):\\n\\\\[\\n\\\\frac{\\\\sin A}{2} = \\\\cos B\\n\\\\]\\n\\nRecall the trigonometric identity for complementary angles:\\n\\\\[\\n\\\\sin(A + B) = \\\\sin A \\\\cos B + \\\\cos A \\\\sin B\\n\\\\]\\n\\nSince \\\\(A + B = 90^\\\\circ\\\\):\\n\\\\[\\n\\\\sin(A + B) = \\\\sin 90^\\\\circ = 1\\n\\\\]\\n\\nTherefore:\\n\\\\[\\n\\\\sin A \\\\cos B + \\\\cos A \\\\sin B = 1\\n\\\\]\\n\\nGiven \\\\(\\\\cos B = \\\\frac{\\\\sin A}{2}\\\\):\\n\\\\[\\n\\\\sin A \\\\cdot \\\\frac{\\\\sin A}{2} + \\\\cos A \\\\sin B = 1\\n\\\\]\\n\\nLet \\\\(\\\\tan A = t\\\\). Then:\\n\\\\[\\nt \\\\cdot \\\\frac{t}{2} + \\\\cos A \\\\sin B = 1\\n\\\\]\\n\\nSince \\\\(\\\\cos A = \\\\frac{1}{\\\\sqrt{1+t^2}}\\\\) and \\\\(\\\\sin B = \\\\frac{b}{c}\\\\), substitute back:\\n\\\\[\\n\\\\frac{t^2}{2} + \\\\frac{b}{c} \\\\cdot \\\\frac{b}{c} = 1\\n\\\\]\\n\\nFinally:\\n\\\\[\\nt = \\\\boxed{1}\\n\\\\]'], 'pred': ['1'], 'report': [None], 'score': [False], 'prm_pred': False, 'step_scores': [[0.5, 0.7549149990081787, 0.7549149990081787, 0.7772998809814453, 0.8354835510253906, 0.851952850818634, 0.8807970285415649, 0.851952850818634, 0.7549149990081787, 0.7981867790222168, 0.8354835510253906, 0.851952850818634, 0.7981867790222168, 0.851952850818634, 0.8354835510253906, 0.8670357465744019, 0.851952850818634, 0.8354835510253906, 0.9046505093574524, 0.9149009585380554, 0.8807970285415649, 0.8354835510253906, 0.7772998809814453, 0.9149009585380554, 0.957912266254425]], 'best_ans_idx': 0, 'best_ans': 'To solve for \\\\(\\\\tan A\\\\) given that the area of triangle \\\\(ABC\\\\) is \\\\(a^2 - (b - c)^2\\\\), we start by using the standard formula for the area of a triangle involving sine:\\n\\n\\\\[\\n\\\\text{Area} = \\\\frac{1}{2}bc \\\\sin A\\n\\\\]\\n\\nGiven:\\n\\\\[\\n\\\\text{Area} = a^2 - (b - c)^2\\n\\\\]\\n\\nFirst, simplify the expression inside the parentheses:\\n\\\\[\\n(b - c)^2 = b^2 - 2bc + c^2\\n\\\\]\\nThus,\\n\\\\[\\na^2 - (b - c)^2 = a^2 - (b^2 - 2bc + c^2) = a^2 - b^2 + 2bc - c^2\\n\\\\]\\n\\nSo, the area becomes:\\n\\\\[\\n\\\\text{Area} = a^2 - b^2 + 2bc - c^2\\n\\\\]\\n\\nUsing the standard area formula again:\\n\\\\[\\n\\\\text{Area} = \\\\frac{1}{2} bc \\\\sin A\\n\\\\]\\n\\nEquating the two expressions for the area:\\n\\\\[\\n\\\\frac{1}{2} bc \\\\sin A = a^2 - b^2 + 2bc - c^2\\n\\\\]\\n\\nMultiply both sides by 2 to clear the fraction:\\n\\\\[\\nbc \\\\sin A = 2(a^2 - b^2 + 2bc - c^2)\\n\\\\]\\n\\nNext, we need to express \\\\(a^2 - b^2\\\\) in a different form. Recall the identity:\\n\\\\[\\na^2 - b^2 = (a + b)(a - b)\\n\\\\]\\n\\nSubstitute this into our equation:\\n\\\\[\\nbc \\\\sin A = 2((a + b)(a - b) + 2bc - c^2)\\n\\\\]\\n\\nDistribute the 2:\\n\\\\[\\nbc \\\\sin A = 2a^2 - 2b^2 + 2ab + 4bc - 2c^2\\n\\\\]\\n\\nRearrange terms:\\n\\\\[\\nbc \\\\sin A = 2a^2 + 2ab + 4bc - 2b^2 - 2c^2\\n\\\\]\\n\\nFactor out common terms:\\n\\\\[\\nbc \\\\sin A = 2(a^2 + ab + 2c^2 - b^2 - c^2)\\n\\\\]\\n\\nSimplify further:\\n\\\\[\\nbc \\\\sin A = 2(a^2 + ab + c^2 - b^2 - c^2)\\n\\\\]\\n\\nNotice that:\\n\\\\[\\na^2 + c^2 - b^2 = 2ac \\\\cos B\\n\\\\]\\n\\nThus:\\n\\\\[\\nbc \\\\sin A = 2(2ac \\\\cos B)\\n\\\\]\\n\\nDivide both sides by \\\\(2bc\\\\):\\n\\\\[\\n\\\\frac{\\\\sin A}{2} = \\\\cos B\\n\\\\]\\n\\nRecall the trigonometric identity for complementary angles:\\n\\\\[\\n\\\\sin(A + B) = \\\\sin A \\\\cos B + \\\\cos A \\\\sin B\\n\\\\]\\n\\nSince \\\\(A + B = 90^\\\\circ\\\\):\\n\\\\[\\n\\\\sin(A + B) = \\\\sin 90^\\\\circ = 1\\n\\\\]\\n\\nTherefore:\\n\\\\[\\n\\\\sin A \\\\cos B + \\\\cos A \\\\sin B = 1\\n\\\\]\\n\\nGiven \\\\(\\\\cos B = \\\\frac{\\\\sin A}{2}\\\\):\\n\\\\[\\n\\\\sin A \\\\cdot \\\\frac{\\\\sin A}{2} + \\\\cos A \\\\sin B = 1\\n\\\\]\\n\\nLet \\\\(\\\\tan A = t\\\\). Then:\\n\\\\[\\nt \\\\cdot \\\\frac{t}{2} + \\\\cos A \\\\sin B = 1\\n\\\\]\\n\\nSince \\\\(\\\\cos A = \\\\frac{1}{\\\\sqrt{1+t^2}}\\\\) and \\\\(\\\\sin B = \\\\frac{b}{c}\\\\), substitute back:\\n\\\\[\\n\\\\frac{t^2}{2} + \\\\frac{b}{c} \\\\cdot \\\\frac{b}{c} = 1\\n\\\\]\\n\\nFinally:\\n\\\\[\\nt = \\\\boxed{1}\\n\\\\]', 'worst_step_idx': 0, 'worst_step': 'To solve for \\\\(\\\\tan A\\\\) given that the area of triangle \\\\(ABC\\\\) is \\\\(a^2 - (b - c)^2\\\\), we start by using the standard formula for the area of a triangle involving sine:'}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 100/100 [00:01<00:00, 61.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------- Epoch 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating Completions: 100%|██████████| 74/74 [07:28<00:00,  6.06s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------- Epoch 1\n",
      "Unsolved samples: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate: 100%|██████████| 100/100 [00:01<00:00, 65.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'num_samples': 100, 'num_scores': 100, 'timeout_samples': 0, 'empty_samples': 0, 'acc': 44.0}\n",
      "Saved to ./output/stage2_inference/test_100_5+0shots.jsonl\n",
      "\n",
      "Results Summary:\n",
      "math-skill\n",
      "44.0      \n"
     ]
    }
   ],
   "source": [
    "# --- AdaptMI Evaluation ---\n",
    "args = Args()\n",
    "\n",
    "set_seed(args.seed)\n",
    "run_eval(args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ✨ Evaluation Summary\n",
    "Please run this cell to get a summary of AdaptMI performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "📎 Initial accuracy: 40.00\n",
      "✨ AdaptMI accuracy: 42.00\n",
      "Initial accuracy on difficult questions: 19.44\n",
      "Final accuracy on difficult questions: 22.22 🚀\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "\n",
    "def analyze(jsonl_file):\n",
    "    total, correct_before, correct_after = 0, 0, 0\n",
    "    easy_all, diff_all, easy_correct_before, easy_correct_after, diff_correct_before, diff_correct_after = 0, 0, 0, 0, 0, 0\n",
    "\n",
    "    final_data = {}\n",
    "    with open(jsonl_file, 'r', encoding='utf-8') as f:\n",
    "        for line in f:\n",
    "            data = json.loads(line)\n",
    "            final_data[data[\"idx\"]] = data\n",
    "            total += 1\n",
    "\n",
    "            if True in data[\"initial_score\"]:\n",
    "                correct_before += 1\n",
    "            if True in data[\"score\"]:\n",
    "                correct_after += 1\n",
    "            \n",
    "            if data[\"prm_pred\"]: # easy\n",
    "                easy_all += 1\n",
    "                if True in data[\"initial_score\"]:\n",
    "                    easy_correct_before += 1\n",
    "                if True in data[\"score\"]:\n",
    "                    easy_correct_after += 1\n",
    "            else: # difficult\n",
    "                diff_all += 1\n",
    "                if True in data[\"initial_score\"]:\n",
    "                    diff_correct_before += 1\n",
    "                if True in data[\"score\"]:\n",
    "                    diff_correct_after += 1\n",
    "\n",
    "    \n",
    "    initial_accuracy = correct_before / total\n",
    "    final_accuracy = correct_after / total\n",
    "    print(f\"📎 Initial accuracy: {(100*initial_accuracy):.2f}\")\n",
    "    print(f\"✨ AdaptMI accuracy: {(100*final_accuracy):.2f}\")\n",
    "    \n",
    "    acc_easy_before = easy_correct_before / easy_all\n",
    "    acc_easy_after = easy_correct_after / easy_all\n",
    "    acc_diff_before = diff_correct_before / diff_all\n",
    "    acc_diff_after = diff_correct_after / diff_all\n",
    "    print(f\"Initial accuracy on difficult questions: {(100*acc_diff_before):.2f}\\nFinal accuracy on difficult questions: {(100*acc_diff_after):.2f} 🚀\")\n",
    "\n",
    "analyze(\"output/stage2_inference/test_50_5+0shots.jsonl\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.17 ('matheval')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  },
  "vscode": {
   "interpreter": {
    "hash": "0993e674be3504149a46e24a80dacbf2567ac699f8ef2e2458766e310c002ab0"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
