{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 1-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: TOKENIZERS_PARALLELISM=true\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yh0068/.conda/envs/matheval/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "2025-05-23 14:57:40,840\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n"
     ]
    }
   ],
   "source": [
    "%env TOKENIZERS_PARALLELISM=true\n",
    "import sys\n",
    "import os\n",
    "\n",
    "current_dir = os.getcwd()\n",
    "evaluation_dir_path = os.path.join(current_dir, 'evaluation')\n",
    "\n",
    "if evaluation_dir_path not in sys.path:\n",
    "    sys.path.insert(0, evaluation_dir_path)\n",
    "\n",
    "import os\n",
    "import json\n",
    "import logging\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "from types import SimpleNamespace\n",
    "from vllm import LLM, SamplingParams\n",
    "from transformers import AutoTokenizer, AutoConfig\n",
    "\n",
    "from evaluation.math_eval import *\n",
    "\n",
    "logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',\n",
    "                    datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Args:\n",
    "    def __init__(self):\n",
    "        self.data_names = \"math\"  # Example: \"gsm8k,math\"\n",
    "        self.data_dir = \"./evaluation/data\"\n",
    "        self.data_path = None # Default: None\n",
    "        self.model_name_or_path = \"models/Qwen2.5-1.5B-Instruct\" # Replace with your model\n",
    "        self.output_dir = \"./output/stage1_inference\"\n",
    "        self.prompt_type = \"qwen25-math-cot\"\n",
    "        self.split = \"test\"\n",
    "        self.num_test_sample = 100 # -1 for full data, set to a small number for testing\n",
    "        self.seed = 0\n",
    "        self.start = 0\n",
    "        self.end = -1 # -1 for all samples from start\n",
    "        self.temperature = 0.0\n",
    "        self.n_sampling = 1\n",
    "        self.top_p = 1.0\n",
    "        self.max_tokens_per_call = 1024 # Reduced for faster testing, adjust as needed\n",
    "        self.shuffle = True\n",
    "        self.use_vllm = False # Set to False if not using vLLM or for models not supported well by vLLM\n",
    "        self.save_outputs = True\n",
    "\n",
    "        # Ours\n",
    "        self.LLM_judge = False\n",
    "        self.PRM_judge = False\n",
    "        self.random_shots = False\n",
    "        self.llm_sol = False\n",
    "        \n",
    "        self.overwrite = True # Set to True to overwrite existing output files\n",
    "        self.use_safetensors = True # Recommended\n",
    "        self.num_shots = 5\n",
    "        self.num_skill_shots = 0\n",
    "        self.apply_chat_template = False # Set to True if your model expects chatml or similar\n",
    "        self.pipeline_parallel_size = 1\n",
    "        self.adapt_few_shot = False\n",
    "\n",
    "        # Auto-set top_p based on temperature for greedy sampling\n",
    "        self.top_p = (\n",
    "            1 if self.temperature == 0 else self.top_p\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed_value: int):\n",
    "    \"\"\"Sets the seed for reproducibility.\"\"\"\n",
    "    random.seed(seed_value)\n",
    "    np.random.seed(seed_value)\n",
    "    torch.manual_seed(seed_value)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed_value)\n",
    "    logger.info(f\"Set seed to {seed_value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_eval(args_obj: Args): # Type hint Args_obj with our class\n",
    "    # Ensure CUDA_VISIBLE_DEVICES is set, or vLLM might default unexpectedly\n",
    "    if args_obj.use_vllm and not os.environ.get(\"CUDA_VISIBLE_DEVICES\"):\n",
    "        logger.warning(\"CUDA_VISIBLE_DEVICES is not set. vLLM might default to all available GPUs or the first one.\")\n",
    "        # os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" # Optionally set a default if none is provided\n",
    "\n",
    "    available_gpus = os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"0\").split(\",\")\n",
    "    \n",
    "    llm_instance = None\n",
    "    tokenizer_instance = None\n",
    "\n",
    "    if args_obj.use_vllm:\n",
    "        logger.info(f\"Attempting to load model {args_obj.model_name_or_path} with vLLM.\")\n",
    "        # The rope_scaling logic from the original script can be complex and model-specific.\n",
    "        # vLLM often handles this automatically or via its own config.\n",
    "        # For simplicity, direct LLM initialization is shown here.\n",
    "        # If specific rope_scaling is needed, it should be passed to LLM constructor.\n",
    "        try:\n",
    "            llm_instance = LLM(\n",
    "                model=args_obj.model_name_or_path,\n",
    "                tensor_parallel_size=max(1, len(available_gpus) // args_obj.pipeline_parallel_size), # Ensure at least 1\n",
    "                pipeline_parallel_size=args_obj.pipeline_parallel_size,\n",
    "                dtype=\"bfloat16\" if torch.cuda.is_bf16_supported() else \"float16\",\n",
    "                trust_remote_code=True,\n",
    "                # max_model_len=max_tokens_per_call + some_buffer # Consider setting max_model_len\n",
    "            )\n",
    "            logger.info(f\"vLLM loaded model: {args_obj.model_name_or_path}\")\n",
    "        except Exception as e:\n",
    "            logger.error(f\"Failed to load model with vLLM: {e}\")\n",
    "            raise\n",
    "\n",
    "        if args_obj.apply_chat_template:\n",
    "            try:\n",
    "                tokenizer_instance = AutoTokenizer.from_pretrained(\n",
    "                    args_obj.model_name_or_path, trust_remote_code=True\n",
    "                )\n",
    "                logger.info(f\"Tokenizer loaded for chat template: {args_obj.model_name_or_path}\")\n",
    "            except Exception as e:\n",
    "                logger.error(f\"Failed to load tokenizer for chat template: {e}\")\n",
    "                # Depending on strictness, you might want to raise e or allow proceeding without chat template\n",
    "    else:\n",
    "        logger.info(f\"Attempting to load model {args_obj.model_name_or_path} with Hugging Face Transformers.\")\n",
    "        try:\n",
    "            tokenizer_instance, llm_instance, _ = load_model(args_obj.model_name_or_path, args_obj)\n",
    "            logger.info(f\"Hugging Face model and tokenizer loaded: {args_obj.model_name_or_path}\")\n",
    "        except Exception as e:\n",
    "            logger.error(f\"Failed to load model with Hugging Face: {e}\")\n",
    "            raise\n",
    "\n",
    "    # Infer & eval\n",
    "    data_list = args_obj.data_names.split(\",\")\n",
    "    results = []\n",
    "    for data_name_str in data_list:\n",
    "        data_name = data_name_str.strip()\n",
    "        if not data_name:\n",
    "            continue\n",
    "        logger.info(f\"\\nProcessing dataset: {data_name}\")\n",
    "        \n",
    "        dataset_result = main(llm_instance, tokenizer_instance, data_name, args_obj)\n",
    "        results.append(dataset_result)\n",
    "\n",
    "    if results:\n",
    "        summary_data_list = [name.strip() for name in data_list if name.strip()]\n",
    "        \n",
    "        if len(summary_data_list) > 1:\n",
    "            # Calculate average accuracy if multiple datasets were processed\n",
    "            valid_results_for_avg = [res for res in results if res and \"acc\" in res]\n",
    "            if valid_results_for_avg:\n",
    "                avg_acc = sum(res[\"acc\"] for res in valid_results_for_avg) / len(valid_results_for_avg)\n",
    "                results.append({\"acc\": avg_acc, \"data_name\": \"avg\"}) # Add data_name for clarity\n",
    "                summary_data_list.append(\"avg\")\n",
    "            else:\n",
    "                logger.warning(\"No valid results with 'acc' key found to calculate average.\")\n",
    "\n",
    "        logger.info(\"\\n\" + \"=\"*20 + \" Overall Summary \" + \"=\"*20)\n",
    "        \n",
    "        pad_width = max(len(name) for name in summary_data_list) if summary_data_list else 10\n",
    "\n",
    "        header_parts = []\n",
    "        score_parts = []\n",
    "        \n",
    "        res_idx = 0\n",
    "        for name in summary_data_list:\n",
    "            header_parts.append(name.ljust(pad_width))\n",
    "            current_res = None\n",
    "            if name == \"avg\" and results[-1].get(\"data_name\") == \"avg\": # Check if last result is avg\n",
    "                 current_res = results[-1]\n",
    "            elif res_idx < len(results) and results[res_idx].get(\"data_name\", data_list[res_idx].strip()) == name: # Check by original name\n",
    "                 current_res = results[res_idx]\n",
    "                 res_idx +=1\n",
    "            elif res_idx < len(results): # Fallback if data_name not in result but order might match\n",
    "                 current_res = results[res_idx]\n",
    "                 logger.warning(f\"Result for {name} matched by order, not explicit data_name key in result dict.\")\n",
    "                 res_idx +=1\n",
    "\n",
    "\n",
    "            if current_res and \"acc\" in current_res:\n",
    "                score_parts.append(f\"{current_res['acc']:.1f}\".ljust(pad_width))\n",
    "            else:\n",
    "                score_parts.append(\"N/A\".ljust(pad_width))\n",
    "                logger.warning(f\"Accuracy not found for dataset: {name}\")\n",
    "        \n",
    "        if header_parts:\n",
    "            final_header = \"\\t\".join(header_parts)\n",
    "            final_scores = \"\\t\".join(score_parts)\n",
    "            print(\"\\nResults Summary:\") # Print to console for easy viewing\n",
    "            print(final_header)\n",
    "            print(final_scores)\n",
    "            logger.info(\"Final results summary (also printed above):\")\n",
    "            logger.info(f\"Datasets: {final_header}\")\n",
    "            logger.info(f\"Accuracy: {final_scores}\")\n",
    "        else:\n",
    "            logger.info(\"No results to display in summary table.\")\n",
    "    else:\n",
    "        logger.info(\"No datasets processed or no results returned.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:10<00:00,  5.45s/it]\n",
      "05/23/2025 14:18:04 - INFO - evaluation.math_eval - loaded model with 3085938688 parameters\n",
      "05/23/2025 14:18:04 - INFO - evaluation.math_eval - setting tokenizer.model_max_length to 1024\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "data: math  ,remain samples: 100\n",
      "{'idx': 2795, 'problem': 'Find all the integer roots of\\n\\\\[x^4 + 5x^3 + 9x^2 - x - 14 = 0.\\\\]Enter all the integer roots, separated by commas.', 'level': 'Level 1', 'solution': 'By the Integer Root Theorem, the possible integer roots are all the divisors of 14 (including negative divisors), which are $-14,$ $-7,$ $-2,$ $-1,$ $1,$ $2,$ $7,$ and $14.$  Checking, we find that the only integer roots are $\\\\boxed{-2,1}.$', 'subject': 'Intermediate Algebra', 'unique_id': 'test/intermediate_algebra/1102.json', 'answer': '-2,1'}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 292.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------- Epoch 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating Completions: 100%|██████████| 100/100 [10:49<00:00,  6.49s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------- Epoch 1\n",
      "Unsolved samples: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate: 100%|██████████| 100/100 [00:00<00:00, 100.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'num_samples': 100, 'num_scores': 100, 'timeout_samples': 0, 'empty_samples': 0, 'acc': 53.0}\n",
      "Saved to ./output/test_100_0+5shots.jsonl\n",
      "\n",
      "Results Summary:\n",
      "math\n",
      "53.0\n"
     ]
    }
   ],
   "source": [
    "# --- Initial Evaluation ---\n",
    "args = Args()\n",
    "set_seed(args.seed)\n",
    "run_eval(args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 1-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yh0068/.conda/envs/prm_dev/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# Cell 1: Setup and Imports\n",
    "import sys\n",
    "import os\n",
    "import json\n",
    "import time\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.distributed as dist\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from datasets import load_dataset\n",
    "from accelerate import Accelerator # For accelerator.device, accelerator.num_processes, etc.\n",
    "from collections import Counter\n",
    "from tqdm import tqdm # For notebook-level progress if any; worker has its own\n",
    "\n",
    "# Add math-rm to Python path to import custom modules\n",
    "# Assumes the notebook is in the parent directory of 'math-rm'\n",
    "sys.path.append('./math-rm')\n",
    "from rm_classify import worker # Only worker is directly called from the main logic\n",
    "\n",
    "# For a cleaner log, you might want to set Transformers logging level\n",
    "import logging\n",
    "logging.getLogger(\"transformers\").setLevel(logging.ERROR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output directory: ./output/stage1_classified\n",
      "Using model: pwork7/llama31_it_prm_2e6_bz32_1epoch_conversation\n",
      "Using dataset: ./output/stage1_inference/test_100_0+5shots.jsonl\n",
      "Number of N candidates per question: 128\n",
      "Number of test samples: 100\n"
     ]
    }
   ],
   "source": [
    "# Cell 2: Define Arguments (mimicking argparse)\n",
    "class Args2:\n",
    "    def __init__(self):\n",
    "        self.reward_name_or_path = 'pwork7/llama31_it_prm_2e6_bz32_1epoch_conversation'\n",
    "        self.dataset = './output/stage1_inference/test_100_0+5shots.jsonl'\n",
    "        self.output_dir = \"./output/stage1_classified\"\n",
    "        self.pred_thres1 = 0.9\n",
    "        self.pred_thres2 = 0.7\n",
    "        self.num_n = 128  # Reduced for faster notebook execution, original was 1024\n",
    "        self.num_test_sample = 100\n",
    "        self.model_type = \"Deepseek\"\n",
    "\n",
    "args = Args2()\n",
    "\n",
    "# Create output directory if it doesn't exist\n",
    "os.makedirs(args.output_dir, exist_ok=True)\n",
    "print(f\"Output directory: {args.output_dir}\")\n",
    "print(f\"Using model: {args.reward_name_or_path}\")\n",
    "print(f\"Using dataset: {args.dataset}\")\n",
    "print(f\"Number of N candidates per question: {args.num_n}\")\n",
    "print(f\"Number of test samples: {args.num_test_sample if args.num_test_sample != -1 else 'ALL'}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Process 0/1 using device: cuda\n",
      "Loading dataset ./output/stage1_inference/test_100_0+5shots.jsonl...\n",
      "Failed to load dataset directly. Trying as json: Couldn't find a dataset script at /scratch/gpfs/yh0068/slm-math/adaptmi/output/stage1_inference/test_100_0+5shots.jsonl/test_100_0+5shots.jsonl.py or any data file in the same directory.\n",
      "Selected 100 samples for processing.\n",
      "Loading reward model pwork7/llama31_it_prm_2e6_bz32_1epoch_conversation...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model and tokenizer loaded successfully.\n",
      "Process 0/1: processing 100 samples (indices 0-99).\n",
      "Starting worker on process 0...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:30<00:00,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Worker finished on process 0. Results: 100 classifications, 100 processed samples.\n",
      "Main process gathering results...\n",
      "Total gathered classification results: 100\n",
      "Total gathered samples for saving: 100\n",
      "\n",
      "Metrics:\n",
      "\n",
      "To stay consistent with the paper, positive means model failure, nagative means model success.\n",
      "\n",
      "TP: 45, FN: 2, FP: 21, TN: 32\n",
      "Total Predictions: 100\n",
      "Accuracy: 0.7700\n",
      "Precision: 0.6818\n",
      "Recall (Sensitivity): 0.9574\n",
      "F1 Score: 0.7965\n",
      "Specificity (captured failure case): 0.9574\n",
      "Metrics saved to ./output/stage1_classified/size100_thres1=0.9_thres2=0.7_metrics.json\n",
      "Processed data saved to ./output/stage1_classified/size100_thres1=0.9_thres2=0.7_save_data.jsonl\n",
      "Classification finished.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Cell 3: Main Logic (adapted from the original script's if __name__ == \"__main__\": block)\n",
    "\n",
    "# Initialize Accelerator\n",
    "accelerator = Accelerator()\n",
    "\n",
    "# Determine distributed training parameters from Accelerator, falling back to script's os.getenv approach for torch.dist\n",
    "# script_world_size for initializing torch.distributed backend\n",
    "# accelerator.num_processes for data sharding\n",
    "# accelerator.local_process_index for rank and device mapping\n",
    "ddp_world_size = int(os.getenv(\"WORLD_SIZE\", accelerator.num_processes))\n",
    "ddp_local_rank = int(os.getenv(\"LOCAL_RANK\", accelerator.local_process_index))\n",
    "\n",
    "# Use accelerator's properties for device to ensure consistency\n",
    "device = accelerator.device\n",
    "print(f\"Process {ddp_local_rank}/{ddp_world_size} using device: {device}\")\n",
    "\n",
    "# Load dataset\n",
    "print(f\"Loading dataset {args.dataset}...\")\n",
    "# Make sure the dataset has 'question', 'code', 'pred', 'score' fields as expected by select_sample\n",
    "# For 'RLHFlow/Deepseek-GSM8K-Test', it might have 'prompt' and 'label' or similar.\n",
    "# This example proceeds assuming the structure matches. You might need to preprocess/map fields.\n",
    "# ds = load_dataset(\"json\", data_files={\"test\": args.dataset}, split=\"test\") # Original way\n",
    "try:\n",
    "    ds = load_dataset(args.dataset, split=\"test\") # Simpler way if dataset is on Hugging Face Hub\n",
    "    # Example: rename columns if necessary for GSM8K\n",
    "    # ds = ds.rename_column(\"question\", \"prompt_text\") # Fictitious example\n",
    "    # ds = ds.map(lambda example: {'question': example['prompt_text'], ...})\n",
    "except Exception as e:\n",
    "    print(f\"Failed to load dataset directly. Trying as json: {e}\")\n",
    "    ds = load_dataset(\"json\", data_files={\"test\": args.dataset}, split=\"test\")\n",
    "\n",
    "\n",
    "if args.num_test_sample == -1:\n",
    "    num_sample = len(ds)\n",
    "else:\n",
    "    num_sample = min(args.num_test_sample, len(ds))\n",
    "ds = ds.select(range(num_sample))\n",
    "print(f\"Selected {len(ds)} samples for processing.\")\n",
    "\n",
    "# Load model and tokenizer\n",
    "print(f\"Loading reward model {args.reward_name_or_path}...\")\n",
    "downloaded = False\n",
    "while not downloaded:\n",
    "    try:\n",
    "        tokenizer = AutoTokenizer.from_pretrained(args.reward_name_or_path)\n",
    "        model = AutoModelForCausalLM.from_pretrained(\n",
    "            args.reward_name_or_path, \n",
    "            torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32\n",
    "        ).to(device).eval() # Model to the device determined by Accelerator\n",
    "        downloaded = True\n",
    "        print(\"Model and tokenizer loaded successfully.\")\n",
    "    except Exception as error:\n",
    "        print(f\"An error occurred during model loading: {error}\")\n",
    "        print(\"Retrying in 2 seconds...\")\n",
    "        time.sleep(2)\n",
    "\n",
    "tokenizer.padding_side = \"right\"\n",
    "if tokenizer.pad_token is None:\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "if model.config.pad_token_id is None:\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else model.config.eos_token_id\n",
    "\n",
    "\n",
    "# Prepare data for the current process (data sharding)\n",
    "data_size = len(ds)\n",
    "# Use accelerator.num_processes for sharding\n",
    "share = (data_size + accelerator.num_processes - 1) // accelerator.num_processes # ceiling division\n",
    "start_idx = accelerator.process_index * share\n",
    "end_idx = min((accelerator.process_index + 1) * share, data_size)\n",
    "\n",
    "# Select the portion of the dataset for this process\n",
    "current_ds_slice = ds.select(np.arange(start_idx, end_idx))\n",
    "print(f\"Process {accelerator.process_index}/{accelerator.num_processes}: processing {len(current_ds_slice)} samples (indices {start_idx}-{end_idx-1}).\")\n",
    "\n",
    "data_for_worker = [sample for sample in current_ds_slice]\n",
    "\n",
    "# Call worker function\n",
    "# The 'local_rank' argument for worker and select_sample is used for .to(local_rank)\n",
    "# We pass `device` (which could be 'cpu' or 'cuda:X') to ensure tensors are moved correctly.\n",
    "print(f\"Starting worker on process {accelerator.process_index}...\")\n",
    "selected_data, new_data = worker(args, model, tokenizer, data_for_worker, device)\n",
    "print(f\"Worker finished on process {accelerator.process_index}. Results: {len(selected_data)} classifications, {len(new_data)} processed samples.\")\n",
    "\n",
    "# Distributed data gathering\n",
    "# If running in a distributed environment (e.g. via accelerate launch)\n",
    "if accelerator.num_processes > 1:\n",
    "    # Ensure MASTER_ADDR and MASTER_PORT are set if not already by launch environment for torch.dist\n",
    "    os.environ.setdefault('MASTER_ADDR', 'localhost')\n",
    "    os.environ.setdefault('MASTER_PORT', '12355') # Ensure this port is free or configurable\n",
    "    \n",
    "    if not dist.is_initialized():\n",
    "        backend = 'nccl' if accelerator.use_cuda else 'gloo'\n",
    "        print(f\"Process {ddp_local_rank}: Initializing torch.distributed with backend {backend}, world_size {ddp_world_size}, rank {ddp_local_rank}\")\n",
    "        dist.init_process_group(\n",
    "            backend=backend,\n",
    "            rank=ddp_local_rank,\n",
    "            world_size=ddp_world_size\n",
    "        )\n",
    "\n",
    "# Prepare data for gathering (all_gather_object expects a list of objects to be populated)\n",
    "data_to_send_from_this_rank = {\n",
    "    \"selected_data_payload\": selected_data,\n",
    "    \"new_data_payload\": new_data\n",
    "}\n",
    "\n",
    "if accelerator.num_processes > 1:\n",
    "    gathered_dictionaries_list = [None] * accelerator.num_processes\n",
    "    dist.all_gather_object(gathered_dictionaries_list, data_to_send_from_this_rank)\n",
    "else:\n",
    "    gathered_dictionaries_list = [data_to_send_from_this_rank]\n",
    "\n",
    "gathered_classification_results = []\n",
    "gathered_full_samples = []\n",
    "\n",
    "# Process gathered data (only on the main process)\n",
    "if accelerator.is_main_process:\n",
    "    print(\"Main process gathering results...\")\n",
    "    for i in range(accelerator.num_processes):\n",
    "        data_from_rank_i = gathered_dictionaries_list[i]\n",
    "        if data_from_rank_i:\n",
    "            gathered_classification_results.extend(data_from_rank_i[\"selected_data_payload\"])\n",
    "            gathered_full_samples.extend(data_from_rank_i[\"new_data_payload\"])\n",
    "    \n",
    "    print(f\"Total gathered classification results: {len(gathered_classification_results)}\")\n",
    "    print(f\"Total gathered samples for saving: {len(gathered_full_samples)}\")\n",
    "\n",
    "    # Calculate metrics\n",
    "    counter = Counter(gathered_classification_results)\n",
    "    num_TN = counter[\"TP\"]\n",
    "    num_FP = counter[\"FN\"]\n",
    "    num_FN = counter[\"FP\"]\n",
    "    num_TP = counter[\"TN\"]\n",
    "\n",
    "    precision = 0\n",
    "    recall = 0\n",
    "    f1 = 0\n",
    "    \n",
    "    if num_TP + num_FP > 0:\n",
    "        precision = num_TP / (num_TP + num_FP)\n",
    "    if num_TP + num_FN > 0:\n",
    "        recall = num_TP / (num_TP + num_FN)\n",
    "    if precision + recall > 0:\n",
    "        f1 = 2 * precision * recall / (precision + recall)\n",
    "\n",
    "    accuracy = 0\n",
    "    total_predictions = num_TN + num_TP + num_FN + num_FP\n",
    "    if total_predictions > 0:\n",
    "        accuracy = (num_TN + num_TP) / total_predictions\n",
    "    \n",
    "    specificity = 0 # True Negative Rate\n",
    "    if num_TN + num_FP > 0:\n",
    "        specificity = num_TP / (num_TP + num_FN)\n",
    "\n",
    "    print(f\"\\nMetrics:\\n\")\n",
    "    print(\"To stay consistent with the paper, positive means model failure, nagative means model success.\\n\")\n",
    "    print(f\"TP: {num_TP}, FN: {num_FN}, FP: {num_FP}, TN: {num_TN}\")\n",
    "    print(f\"Total Predictions: {total_predictions}\")\n",
    "    print(f\"Accuracy: {accuracy:.4f}\")\n",
    "    print(f\"Precision: {precision:.4f}\")\n",
    "    print(f\"Recall (Sensitivity): {recall:.4f}\")\n",
    "    print(f\"F1 Score: {f1:.4f}\")\n",
    "    print(f\"Specificity (captured failure case): {specificity:.4f}\")\n",
    "\n",
    "    metrics_summary = {\n",
    "        \"TP\": num_TP, \"FN\": num_FN, \"FP\": num_FP, \"TN\": num_TN,\n",
    "        \"total_predictions\": total_predictions,\n",
    "        \"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1_score\": f1,\n",
    "        \"specificity\": specificity,\n",
    "        \"num_test_samples_processed\": len(gathered_full_samples) # Should match num_sample if all processed\n",
    "    }\n",
    "    \n",
    "    output_metrics_file = os.path.join(args.output_dir, f\"size{args.num_test_sample}_thres1={args.pred_thres1}_thres2={args.pred_thres2}_metrics.json\")\n",
    "    with open(output_metrics_file, 'w') as f:\n",
    "        json.dump(metrics_summary, f, indent=4, ensure_ascii=False)\n",
    "    print(f\"Metrics saved to {output_metrics_file}\")\n",
    "\n",
    "    output_data_file = os.path.join(args.output_dir, f\"size{args.num_test_sample}_thres1={args.pred_thres1}_thres2={args.pred_thres2}_save_data.jsonl\")\n",
    "    with open(output_data_file, 'w') as f:\n",
    "        for entry in gathered_full_samples:\n",
    "            f.write(json.dumps(entry) + \"\\n\")\n",
    "    print(f\"Processed data saved to {output_data_file}\")\n",
    "\n",
    "if accelerator.num_processes > 1 and dist.is_initialized():\n",
    "    dist.destroy_process_group()\n",
    "    print(f\"Process {ddp_local_rank}: Destroyed DDP process group.\")\n",
    "\n",
    "print(\"Classification finished.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]\u001b[A\n",
      "Loading checkpoint shards:  50%|█████     | 1/2 [00:11<00:10, 10.99s/it]\u001b[A\n",
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:18<00:00,  9.18s/it]\u001b[A\n",
      "05/23/2025 15:09:54 - INFO - evaluation.math_eval - loaded model with 3085938688 parameters\n",
      "05/23/2025 15:09:54 - INFO - evaluation.math_eval - setting tokenizer.model_max_length to 1024\n",
      "\n",
      "Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 116.17ba/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "data: math-skill  ,remain samples: 100\n",
      "{'idx': 1139, 'question': 'The area of triangle $ABC$ is equal to $a^2 - (b - c)^2,$ where $a,$ $b,$ and $c$ are the sides of triangle $ABC,$ as usual.  Compute $\\\\tan A.$', 'gt_cot': 'The area of triangle $ABC$ is given by\\n\\\\[\\\\frac{1}{2} bc \\\\sin A.\\\\]Hence,\\n\\\\[\\\\frac{1}{2} bc \\\\sin A = a^2 - (b - c)^2 = a^2 - b^2 + 2bc - c^2.\\\\]By the Law of Cosines, $b^2 + c^2 - 2bc \\\\cos A = a^2,$ so\\n\\\\[\\\\frac{1}{2} bc \\\\sin A = 2bc - 2bc \\\\cos A.\\\\]This simplifies to $\\\\sin A = 4 - 4 \\\\cos A.$  Squaring both sides, we get\\n\\\\[\\\\sin^2 A = 16 - 32 \\\\cos A + 16 \\\\cos^2 A,\\\\]so $1 - \\\\cos^2 A = 16 - 32 \\\\cos A + 16 \\\\cos^2 A.$  This simplifies to\\n\\\\[17 \\\\cos^2 A - 32 \\\\cos A + 15 = 0.\\\\]This factors as $(\\\\cos A - 1)(17 \\\\cos A - 15) = 0.$  Since $\\\\cos A$ cannot be equal to 1, $\\\\cos A = \\\\frac{15}{17}.$\\n\\nThen $\\\\sin A = 4 - 4 \\\\cos A = \\\\frac{8}{17},$ so\\n\\\\[\\\\tan A = \\\\frac{\\\\sin A}{\\\\cos A} = \\\\boxed{\\\\frac{8}{15}}.\\\\]', 'gt': '\\\\frac{8}{15}', 'prompt': \"<|im_start|>system\\nPlease reason step by step, and put your final answer within \\\\boxed{}.<|im_end|>\\n<|im_start|>user\\nQuestion: Kevin Kangaroo begins hopping on a number line at 0. He wants to get to 1, but he can hop only $\\\\frac{1}{3}$ of the distance. Each hop tires him out so that he continues to hop $\\\\frac{1}{3}$ of the remaining distance. How far has he hopped after five hops? Express your answer as a common fraction.\\n\\nResponse: Let's think step by step\\n\\nKevin hops $1/3$ of the remaining distance with every hop.\\n\\nHis first hop takes $1/3$ closer.\\n\\nFor his second hop, he has $2/3$ left to travel, so he hops forward $(2/3)(1/3)$.\\n\\nFor his third hop, he has $(2/3)^2$ left to travel, so he hops forward $(2/3)^2(1/3)$.\\n\\nIn general, Kevin hops forward $(2/3)^{k-1}(1/3)$ on his $k$th hop.\\n\\nWe want to find how far he has hopped after five hops.\\n\\nThis is a finite geometric series with first term $1/3$, common ratio $2/3$, and five terms.\\n\\nThus, Kevin has hopped $\\\\frac{\\\\frac{1}{3}\\\\left(1-\\\\left(\\\\frac{2}{3}\\\\right)^5\\\\right)}{1-\\\\frac{2}{3}} = \\\\boxed{\\\\frac{211}{243}}$.\\n\\nThe answer is \\\\frac{211}{243}}\\n\\nQuestion: What is the area of the region defined by the equation $x^2+y^2 - 7 = 4y-14x+3$?\\n\\nResponse: Let's think step by step\\n\\nWe rewrite the equation as $x^2 + 14x + y^2 - 4y = 10$ and then complete the square,\\n\\nresulting in  $(x+7)^2-49 + (y-2)^2-4=10$,\\n\\nor $(x+7)^2+(y-2)^2=63$.\\n\\nThis is the equation of a circle with center $(-7, 2)$ and radius $\\\\sqrt{63},$\\n\\nso the area of this region is $\\\\pi r^2 = \\\\boxed{63\\\\pi}$.\\n\\nThe answer is 63\\\\pi\\n\\nQuestion: If $x^2+y^2=1$, what is the largest possible value of $|x|+|y|$?\\n\\nResponse: Let's think step by step\\n\\nIf $(x,y)$ lies on the circle,\\n\\nso does $(x,-y),$ $(-x,-y),$ and $(-x,-y),$ (which all give the same value of $|x| + |y|$),\\n\\nso we can assume that $x \\\\ge 0$ and $y \\\\ge 0.$\\n\\nThen $|x| + |y| = x + y.$  Squaring, we get\\n\\n\\\\[(x + y)^2 = x^2 + 2xy + y^2 = 1 + 2xy.\\\\]\\n\\nNote that $(x - y)^2 \\\\ge 0.$\\n\\nExpanding, we get $x^2 - 2xy + y^2 \\\\ge 0,$ so $2xy \\\\le x^2 + y^2 = 1.$\\n\\nHence,\\\\[1 + 2xy \\\\le 2,\\\\]which means $x + y \\\\le \\\\sqrt{2}.$\\n\\nEquality occurs when $x = y = \\\\frac{1}{\\\\sqrt{2}},$\\n\\nso the maximum value of $|x| + |y|$ is $\\\\boxed{\\\\sqrt{2}}.$\\n\\nThe answer is \\\\sqrt{2}\\n\\nQuestion: If $f(x)=\\\\frac{ax+b}{cx+d}, abcd\\\\n\\not=0$ and $f(f(x))=x$ for all $x$ in the domain of $f$, what is the value of $a+d$?\\n\\nResponse: Let's think step by step\\n\\nThe condition $f(f(x))$ means that $f$ is the inverse of itself,\\n\\nso its graph is symmetrical about the line $y = x$.\\n\\nWith a rational function of this form, we will have two asymptotes:\\n\\na vertical one at $x=-d/c$ if $cx+d$ does not divide $ax+b$,\\n\\nand a horizontal one at $y=a/c$,\\n\\nif we take the limit of $f(x)$ as $x$ goes to $\\\\pm\\\\infty$.\\n\\nIn order for $f$ to be its own inverse, the intersection of the asymptotes must lie on the line $y=x$\\n\\nso that it and its asymptotes reflect onto themselves.\\n\\nThis means that $-d/c=a/c$,\\n\\nand therefore $-d=a$ and $a+d=\\\\boxed{0}$.\\n\\nThe answer is 0\\n\\nQuestion: Expand $(2z^2 + 5z - 6)(3z^3 - 2z + 1)$.\\n\\nResponse: Let's think step by step\\n\\n$$\\\\begin{array}{crrrrrrr}\\n\\n& & & 3z^3 & & -2z & + 1 & \\\\\\\\\\n\\n\\\\times & & & & 2z^2 & +5z & -6 \\\\\\\\\\n\\n\\\\cline{1-7}\\\\rule{0pt}{0.17in}\\n\\n& & & -18z^3 & & +12z & -6 & \\\\\\\\\\n\\n& & +15z^4 & & -10z^2 & +5z & & \\\\\\\\\\n\\n+ & 6z^5 & & -4z^3 & +2z^2 & & & \\\\\\\\\\n\\n\\\\cline{1-7}\\\\rule{0pt}{0.17in}\\n\\n& 6z^5 & +15z^4 & -22z^3 & - 8z^2 &+17z & -6 &\\n\\n\\\\end{array}$$\\n\\nThe answer is 6z^5+15z^4-22z^3-8z^2+17z-6\\n\\nQuestion: The area of triangle $ABC$ is equal to $a^2 - (b - c)^2,$ where $a,$ $b,$ and $c$ are the sides of triangle $ABC,$ as usual.  Compute $\\\\tan A.$<|im_end|>\\n<|im_start|>assistant\\n\", 'level': 'Level 3', 'solution': 'The area of triangle $ABC$ is given by\\n\\\\[\\\\frac{1}{2} bc \\\\sin A.\\\\]Hence,\\n\\\\[\\\\frac{1}{2} bc \\\\sin A = a^2 - (b - c)^2 = a^2 - b^2 + 2bc - c^2.\\\\]By the Law of Cosines, $b^2 + c^2 - 2bc \\\\cos A = a^2,$ so\\n\\\\[\\\\frac{1}{2} bc \\\\sin A = 2bc - 2bc \\\\cos A.\\\\]This simplifies to $\\\\sin A = 4 - 4 \\\\cos A.$  Squaring both sides, we get\\n\\\\[\\\\sin^2 A = 16 - 32 \\\\cos A + 16 \\\\cos^2 A,\\\\]so $1 - \\\\cos^2 A = 16 - 32 \\\\cos A + 16 \\\\cos^2 A.$  This simplifies to\\n\\\\[17 \\\\cos^2 A - 32 \\\\cos A + 15 = 0.\\\\]This factors as $(\\\\cos A - 1)(17 \\\\cos A - 15) = 0.$  Since $\\\\cos A$ cannot be equal to 1, $\\\\cos A = \\\\frac{15}{17}.$\\n\\nThen $\\\\sin A = 4 - 4 \\\\cos A = \\\\frac{8}{17},$ so\\n\\\\[\\\\tan A = \\\\frac{\\\\sin A}{\\\\cos A} = \\\\boxed{\\\\frac{8}{15}}.\\\\]', 'answer': '\\\\frac{8}{15}', 'subject': 'Precalculus', 'unique_id': 'test/precalculus/1082.json', 'correct_flagged': False, 'code': ['To solve for \\\\(\\\\tan A\\\\) given that the area of triangle \\\\(ABC\\\\) is \\\\(a^2 - (b - c)^2\\\\), we start by expressing the area using the standard formula for the area of a triangle with sides \\\\(a\\\\), \\\\(b\\\\), and \\\\(c\\\\):\\n\\n\\\\[\\n\\\\text{Area} = \\\\frac{1}{2}bc \\\\sin A\\n\\\\]\\n\\nWe are given that the area is also equal to \\\\(a^2 - (b - c)^2\\\\). Expanding the right-hand side, we get:\\n\\n\\\\[\\na^2 - (b - c)^2 = a^2 - (b^2 - 2bc + c^2) = a^2 - b^2 + 2bc - c^2\\n\\\\]\\n\\nSo, we have:\\n\\n\\\\[\\n\\\\frac{1}{2} bc \\\\sin A = a^2 - b^2 + 2bc - c^2\\n\\\\]\\n\\nNext, we use the Law of Cosines, which states:\\n\\n\\\\[\\na^2 = b^2 + c^2 - 2bc \\\\cos A\\n\\\\]\\n\\nRearranging this, we get:\\n\\n\\\\[\\na^2 - b^2 - c^2 = -2bc \\\\cos A\\n\\\\]\\n\\nSubstituting \\\\(a^2 - b^2 - c^2\\\\) into the area expression, we have:\\n\\n\\\\[\\na^2 - b^2 + 2bc - c^2 = -2bc \\\\cos A + 2bc = 2bc (1 - \\\\cos A)\\n\\\\]\\n\\nThus, we now have:\\n\\n\\\\[\\n\\\\frac{1}{2} bc \\\\sin A = 2bc (1 - \\\\cos A)\\n\\\\]\\n\\nDividing both sides by \\\\(bc\\\\) (assuming \\\\(b \\\\neq 0\\\\) and \\\\(c \\\\neq 0\\\\)):\\n\\n\\\\[\\n\\\\frac{1}{2} \\\\sin A = 2 (1 - \\\\cos A)\\n\\\\]\\n\\nMultiplying both sides by 2:\\n\\n\\\\[\\n\\\\sin A = 4 (1 - \\\\cos A)\\n\\\\]\\n\\nUsing the Pythagorean identity \\\\(\\\\sin^2 A + \\\\cos^2 A = 1\\\\), we substitute \\\\(\\\\sin A = 4(1 - \\\\cos A)\\\\):\\n\\n\\\\[\\n[4(1 - \\\\cos A)]^2 + \\\\cos^2 A = 1\\n\\\\]\\n\\nSimplifying the left-hand side:\\n\\n\\\\[\\n16(1 - \\\\cos A)^2 + \\\\cos^2 A = 1\\n\\\\]\\n\\nExpanding \\\\((1 - \\\\cos A)^2\\\\):\\n\\n\\\\[\\n16(1 - 2\\\\cos A + \\\\cos^2 A) + \\\\cos^2 A = 1\\n\\\\]\\n\\nDistributing the 16:\\n\\n\\\\[\\n16 - 32\\\\cos A + 16\\\\cos^2 A + \\\\cos^2 A = 1\\n\\\\]\\n\\nCombining like terms:\\n\\n\\\\[\\n16 - 32\\\\cos A + 17\\\\cos^2 A = 1\\n\\\\]\\n\\nSubtracting 1 from both sides:\\n\\n\\\\[\\n17\\\\cos^2 A - 32\\\\cos A + 15 = 0\\n\\\\]\\n\\nThis is a quadratic equation in \\\\(\\\\cos A\\\\). We solve it using the quadratic formula \\\\(\\\\cos A = \\\\frac{-b \\\\pm \\\\sqrt{b^2 - 4ac}}{2a}\\\\), where \\\\(a = 17\\\\), \\\\(b = -32\\\\), and \\\\(c = 15\\\\):\\n\\n\\\\[\\n\\\\cos A = \\\\frac{32 \\\\pm \\\\sqrt{(-32)^2 - 4 \\\\cdot 17 \\\\cdot 15}}{2 \\\\cdot 17}\\n\\\\]\\n\\nCalculating the discriminant:\\n\\n\\\\[\\n(-32)^2 - 4 \\\\cdot 17 \\\\cdot 15 = 1024 - 1020 = 4\\n\\\\]\\n\\nThus:\\n\\n\\\\[\\n\\\\cos A = \\\\frac{32 \\\\pm \\\\sqrt{4}}{34} = \\\\frac{32 \\\\pm 2}{34}\\n\\\\]\\n\\nThis gives us two solutions:\\n\\n\\\\[\\n\\\\cos A = \\\\frac{34}{34} = 1 \\\\quad \\\\text{(not possible since } \\\\cos A \\\\neq 1 \\\\text{ for an acute angle)}\\n\\\\]\\n\\\\[\\n\\\\cos A = \\\\frac{30}{34} = \\\\frac{15}{17}\\n\\\\]\\n\\nNow, we find \\\\(\\\\sin A\\\\) using \\\\(\\\\sin^2 A + \\\\cos^2 A = 1\\\\):\\n\\n\\\\[\\n\\\\sin^2 A = 1 - \\\\cos^2 A = 1 - \\\\left(\\\\frac{'], 'pred': ['1'], 'report': [None], 'score': [False], 'prm_pred': False, 'step_scores': [[0.3775406777858734, 0.7057850360870361, 0.7772998809814453, 0.8354835510253906, 0.8933094143867493, 0.8807970285415649, 0.8933094143867493, 0.9149009585380554, 0.9324532747268677, 0.9241418242454529, 0.9324532747268677, 0.9465966820716858, 0.9669140577316284, 0.9525741338729858, 0.9525741338729858, 0.9465966820716858, 0.970687747001648, 0.957912266254425, 0.9149009585380554, 0.9399133324623108, 0.957912266254425, 0.957912266254425, 0.9626730680465698, 0.957912266254425, 0.9626730680465698, 0.9669140577316284, 0.970687747001648, 0.9399133324623108, 0.957912266254425, 0.957912266254425, 0.9149009585380554, 0.9241418242454529, 0.9399133324623108, 0.977022647857666, 0.9840936064720154, 0.977022647857666, 0.9796676635742188, 0.9875683188438416, 0.9914224743843079, 0.8933094143867493, 0.08509904146194458]], 'best_ans_idx': 0, 'best_ans': 'To solve for \\\\(\\\\tan A\\\\) given that the area of triangle \\\\(ABC\\\\) is \\\\(a^2 - (b - c)^2\\\\), we start by expressing the area using the standard formula for the area of a triangle with sides \\\\(a\\\\), \\\\(b\\\\), and \\\\(c\\\\):\\n\\n\\\\[\\n\\\\text{Area} = \\\\frac{1}{2}bc \\\\sin A\\n\\\\]\\n\\nWe are given that the area is also equal to \\\\(a^2 - (b - c)^2\\\\). Expanding the right-hand side, we get:\\n\\n\\\\[\\na^2 - (b - c)^2 = a^2 - (b^2 - 2bc + c^2) = a^2 - b^2 + 2bc - c^2\\n\\\\]\\n\\nSo, we have:\\n\\n\\\\[\\n\\\\frac{1}{2} bc \\\\sin A = a^2 - b^2 + 2bc - c^2\\n\\\\]\\n\\nNext, we use the Law of Cosines, which states:\\n\\n\\\\[\\na^2 = b^2 + c^2 - 2bc \\\\cos A\\n\\\\]\\n\\nRearranging this, we get:\\n\\n\\\\[\\na^2 - b^2 - c^2 = -2bc \\\\cos A\\n\\\\]\\n\\nSubstituting \\\\(a^2 - b^2 - c^2\\\\) into the area expression, we have:\\n\\n\\\\[\\na^2 - b^2 + 2bc - c^2 = -2bc \\\\cos A + 2bc = 2bc (1 - \\\\cos A)\\n\\\\]\\n\\nThus, we now have:\\n\\n\\\\[\\n\\\\frac{1}{2} bc \\\\sin A = 2bc (1 - \\\\cos A)\\n\\\\]\\n\\nDividing both sides by \\\\(bc\\\\) (assuming \\\\(b \\\\neq 0\\\\) and \\\\(c \\\\neq 0\\\\)):\\n\\n\\\\[\\n\\\\frac{1}{2} \\\\sin A = 2 (1 - \\\\cos A)\\n\\\\]\\n\\nMultiplying both sides by 2:\\n\\n\\\\[\\n\\\\sin A = 4 (1 - \\\\cos A)\\n\\\\]\\n\\nUsing the Pythagorean identity \\\\(\\\\sin^2 A + \\\\cos^2 A = 1\\\\), we substitute \\\\(\\\\sin A = 4(1 - \\\\cos A)\\\\):\\n\\n\\\\[\\n[4(1 - \\\\cos A)]^2 + \\\\cos^2 A = 1\\n\\\\]\\n\\nSimplifying the left-hand side:\\n\\n\\\\[\\n16(1 - \\\\cos A)^2 + \\\\cos^2 A = 1\\n\\\\]\\n\\nExpanding \\\\((1 - \\\\cos A)^2\\\\):\\n\\n\\\\[\\n16(1 - 2\\\\cos A + \\\\cos^2 A) + \\\\cos^2 A = 1\\n\\\\]\\n\\nDistributing the 16:\\n\\n\\\\[\\n16 - 32\\\\cos A + 16\\\\cos^2 A + \\\\cos^2 A = 1\\n\\\\]\\n\\nCombining like terms:\\n\\n\\\\[\\n16 - 32\\\\cos A + 17\\\\cos^2 A = 1\\n\\\\]\\n\\nSubtracting 1 from both sides:\\n\\n\\\\[\\n17\\\\cos^2 A - 32\\\\cos A + 15 = 0\\n\\\\]\\n\\nThis is a quadratic equation in \\\\(\\\\cos A\\\\). We solve it using the quadratic formula \\\\(\\\\cos A = \\\\frac{-b \\\\pm \\\\sqrt{b^2 - 4ac}}{2a}\\\\), where \\\\(a = 17\\\\), \\\\(b = -32\\\\), and \\\\(c = 15\\\\):\\n\\n\\\\[\\n\\\\cos A = \\\\frac{32 \\\\pm \\\\sqrt{(-32)^2 - 4 \\\\cdot 17 \\\\cdot 15}}{2 \\\\cdot 17}\\n\\\\]\\n\\nCalculating the discriminant:\\n\\n\\\\[\\n(-32)^2 - 4 \\\\cdot 17 \\\\cdot 15 = 1024 - 1020 = 4\\n\\\\]\\n\\nThus:\\n\\n\\\\[\\n\\\\cos A = \\\\frac{32 \\\\pm \\\\sqrt{4}}{34} = \\\\frac{32 \\\\pm 2}{34}\\n\\\\]\\n\\nThis gives us two solutions:\\n\\n\\\\[\\n\\\\cos A = \\\\frac{34}{34} = 1 \\\\quad \\\\text{(not possible since } \\\\cos A \\\\neq 1 \\\\text{ for an acute angle)}\\n\\\\]\\n\\\\[\\n\\\\cos A = \\\\frac{30}{34} = \\\\frac{15}{17}\\n\\\\]\\n\\nNow, we find \\\\(\\\\sin A\\\\) using \\\\(\\\\sin^2 A + \\\\cos^2 A = 1\\\\):\\n\\n\\\\[\\n\\\\sin^2 A = 1 - \\\\cos^2 A = 1 - \\\\left(\\\\frac{', 'worst_step_idx': 0, 'worst_step': 'To solve for \\\\(\\\\tan A\\\\) given that the area of triangle \\\\(ABC\\\\) is \\\\(a^2 - (b - c)^2\\\\), we start by expressing the area using the standard formula for the area of a triangle with sides \\\\(a\\\\), \\\\(b\\\\), and \\\\(c\\\\):'}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "  0%|          | 0/100 [00:00<?, ?it/s]\u001b[A\n",
      "  8%|▊         | 8/100 [00:00<00:01, 73.42it/s]\u001b[A\n",
      " 16%|█▌        | 16/100 [00:00<00:01, 59.57it/s]\u001b[A\n",
      " 23%|██▎       | 23/100 [00:00<00:01, 59.63it/s]\u001b[A\n",
      " 30%|███       | 30/100 [00:00<00:01, 60.06it/s]\u001b[A\n",
      " 37%|███▋      | 37/100 [00:00<00:01, 58.84it/s]\u001b[A\n",
      " 44%|████▍     | 44/100 [00:00<00:00, 61.81it/s]\u001b[A\n",
      " 51%|█████     | 51/100 [00:00<00:00, 61.55it/s]\u001b[A\n",
      " 58%|█████▊    | 58/100 [00:00<00:00, 61.73it/s]\u001b[A\n",
      " 66%|██████▌   | 66/100 [00:01<00:00, 64.92it/s]\u001b[A\n",
      " 73%|███████▎  | 73/100 [00:01<00:00, 62.65it/s]\u001b[A\n",
      " 80%|████████  | 80/100 [00:01<00:00, 60.53it/s]\u001b[A\n",
      " 87%|████████▋ | 87/100 [00:01<00:00, 59.70it/s]\u001b[A\n",
      " 93%|█████████▎| 93/100 [00:01<00:00, 57.33it/s]\u001b[A\n",
      "100%|██████████| 100/100 [00:01<00:00, 60.43it/s]\u001b[A\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------- Epoch 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Generating Completions:   0%|          | 0/66 [00:00<?, ?it/s]\u001b[A"
     ]
    },
    {
     "ename": "OutOfMemoryError",
     "evalue": "CUDA out of memory. Tried to allocate 28.28 GiB. GPU ",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mOutOfMemoryError\u001b[0m                          Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 10\u001b[0m\n\u001b[1;32m      7\u001b[0m args\u001b[38;5;241m.\u001b[39mPRM_judge \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m      9\u001b[0m set_seed(args\u001b[38;5;241m.\u001b[39mseed)\n\u001b[0;32m---> 10\u001b[0m \u001b[43mrun_eval\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[4], line 59\u001b[0m, in \u001b[0;36mrun_eval\u001b[0;34m(args_obj)\u001b[0m\n\u001b[1;32m     56\u001b[0m         \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[1;32m     57\u001b[0m     logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mProcessing dataset: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdata_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 59\u001b[0m     dataset_result \u001b[38;5;241m=\u001b[39m \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mllm_instance\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokenizer_instance\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs_obj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     60\u001b[0m     results\u001b[38;5;241m.\u001b[39mappend(dataset_result)\n\u001b[1;32m     62\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m results:\n",
      "File \u001b[0;32m/scratch/gpfs/yh0068/slm-math/adaptmi/evaluation/math_eval.py:395\u001b[0m, in \u001b[0;36mmain\u001b[0;34m(llm, tokenizer, data_name, args)\u001b[0m\n\u001b[1;32m    393\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m [output\u001b[38;5;241m.\u001b[39moutputs[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mtext \u001b[38;5;28;01mfor\u001b[39;00m output \u001b[38;5;129;01min\u001b[39;00m outputs]\n\u001b[1;32m    394\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 395\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_completions\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    396\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mllm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    397\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    398\u001b[0m \u001b[43m        \u001b[49m\u001b[43mprompts\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprompts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    399\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmax_new_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_tokens_per_call\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    400\u001b[0m \u001b[43m        \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m16\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    401\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstop_id_sequences\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstop_words\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    402\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    404\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(outputs) \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mlen\u001b[39m(current_prompts)\n\u001b[1;32m    406\u001b[0m \u001b[38;5;66;03m# process all outputs\u001b[39;00m\n",
      "File \u001b[0;32m~/.conda/envs/matheval/lib/python3.10/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    114\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/scratch/gpfs/yh0068/slm-math/adaptmi/evaluation/model_utils.py:96\u001b[0m, in \u001b[0;36mgenerate_completions\u001b[0;34m(model, tokenizer, prompts, batch_size, stop_id_sequences, add_special_tokens, disable_tqdm, **generation_kwargs)\u001b[0m\n\u001b[1;32m     94\u001b[0m \u001b[38;5;66;03m# try:\u001b[39;00m\n\u001b[1;32m     95\u001b[0m stop_criteria \u001b[38;5;241m=\u001b[39m KeywordsStoppingCriteria(stop_id_sequences, tokenizer)\n\u001b[0;32m---> 96\u001b[0m batch_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     97\u001b[0m \u001b[43m    \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_input_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     98\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     99\u001b[0m \u001b[43m    \u001b[49m\u001b[43mstopping_criteria\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mStoppingCriteriaList\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mstop_criteria\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    100\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;66;43;03m# stopping_criteria=[KeyWordsCriteria(stop_id_sequences)] if stop_id_sequences else None,\u001b[39;49;00m\n\u001b[1;32m    101\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;66;43;03m# stopping_criteria=[KeyWordsCriteriaTrunc(stop_id_sequences, batch_input_ids.size(1))] if stop_id_sequences else None,\u001b[39;49;00m\n\u001b[1;32m    102\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mgeneration_kwargs\u001b[49m\n\u001b[1;32m    103\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    105\u001b[0m \u001b[38;5;66;03m# the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate.\u001b[39;00m\n\u001b[1;32m    106\u001b[0m \u001b[38;5;66;03m# so some outputs still have the stop sequence, which we need to remove.\u001b[39;00m\n\u001b[1;32m    107\u001b[0m \u001b[38;5;66;03m# if stop_id_sequences:\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    116\u001b[0m \u001b[38;5;66;03m# we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.\u001b[39;00m\n\u001b[1;32m    117\u001b[0m \u001b[38;5;66;03m# space is important for some tasks (e.g., code completion).\u001b[39;00m\n\u001b[1;32m    118\u001b[0m batch_outputs \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mbatch_decode(batch_outputs, skip_special_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
      "File \u001b[0;32m~/.conda/envs/matheval/lib/python3.10/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    114\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.conda/envs/matheval/lib/python3.10/site-packages/transformers/generation/utils.py:1914\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m   1906\u001b[0m     input_ids, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_expand_inputs_for_generation(\n\u001b[1;32m   1907\u001b[0m         input_ids\u001b[38;5;241m=\u001b[39minput_ids,\n\u001b[1;32m   1908\u001b[0m         expand_size\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mnum_return_sequences,\n\u001b[1;32m   1909\u001b[0m         is_encoder_decoder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder,\n\u001b[1;32m   1910\u001b[0m         \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs,\n\u001b[1;32m   1911\u001b[0m     )\n\u001b[1;32m   1913\u001b[0m     \u001b[38;5;66;03m# 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)\u001b[39;00m\n\u001b[0;32m-> 1914\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1915\u001b[0m \u001b[43m        \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1916\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlogits_processor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_logits_processor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1917\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlogits_warper\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_logits_warper\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1918\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstopping_criteria\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_stopping_criteria\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1919\u001b[0m \u001b[43m        \u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1920\u001b[0m \u001b[43m        \u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msynced_gpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1921\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstreamer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstreamer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1922\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1923\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1925\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m generation_mode \u001b[38;5;129;01min\u001b[39;00m (GenerationMode\u001b[38;5;241m.\u001b[39mBEAM_SAMPLE, GenerationMode\u001b[38;5;241m.\u001b[39mBEAM_SEARCH):\n\u001b[1;32m   1926\u001b[0m     \u001b[38;5;66;03m# 11. prepare logits warper\u001b[39;00m\n\u001b[1;32m   1927\u001b[0m     prepared_logits_warper \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m   1928\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_logits_warper(generation_config, device\u001b[38;5;241m=\u001b[39minput_ids\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m   1929\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m generation_config\u001b[38;5;241m.\u001b[39mdo_sample\n\u001b[1;32m   1930\u001b[0m         \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1931\u001b[0m     )\n",
      "File \u001b[0;32m~/.conda/envs/matheval/lib/python3.10/site-packages/transformers/generation/utils.py:2651\u001b[0m, in \u001b[0;36mGenerationMixin._sample\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)\u001b[0m\n\u001b[1;32m   2648\u001b[0m model_inputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprepare_inputs_for_generation(input_ids, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs)\n\u001b[1;32m   2650\u001b[0m \u001b[38;5;66;03m# forward pass to get next token\u001b[39;00m\n\u001b[0;32m-> 2651\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2652\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2653\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m   2654\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2655\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2656\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2658\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m synced_gpus \u001b[38;5;129;01mand\u001b[39;00m this_peer_finished:\n\u001b[1;32m   2659\u001b[0m     \u001b[38;5;28;01mcontinue\u001b[39;00m  \u001b[38;5;66;03m# don't waste resources running the code we don't need\u001b[39;00m\n",
      "File \u001b[0;32m~/.conda/envs/matheval/lib/python3.10/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1530\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.conda/envs/matheval/lib/python3.10/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1539\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1540\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1544\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.conda/envs/matheval/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:1236\u001b[0m, in \u001b[0;36mQwen2ForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)\u001b[0m\n\u001b[1;32m   1234\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m   1235\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n\u001b[0;32m-> 1236\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[43mlogits\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1238\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1239\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m labels \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   1240\u001b[0m     \u001b[38;5;66;03m# Shift so that tokens < n predict n\u001b[39;00m\n",
      "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 28.28 GiB. GPU "
     ]
    }
   ],
   "source": [
    "# --- AdaptMI Evaluation ---\n",
    "args = Args()\n",
    "args.data_names += \"-skill\"\n",
    "args.num_skill_shots = 5\n",
    "args.data_path = \"./output/stage1_classified/size100_thres1=0.9_thres2=0.7_save_data.jsonl\"\n",
    "args.output_dir = \"./output/stage2_inference\"\n",
    "args.PRM_judge = True\n",
    "\n",
    "set_seed(args.seed)\n",
    "run_eval(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "matheval [~/.conda/envs/matheval/]",
   "language": "python",
   "name": "conda_matheval"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  },
  "vscode": {
   "interpreter": {
    "hash": "0240d59f15d8bfdbf895e85c2da540f18e732f8198eeec118c9bc623318a322c"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
