{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6873dc86",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TRAIN_DATASET = \"../datasets/train_deepcoder.parquet\"\n",
    "TRAIN_DATASET = \"../datasets/train_DAPO-Math-17k.parquet\"\n",
    "\n",
    "WARMUP_GENERATIONS = 2\n",
    "MODEL = \"Qwen/Qwen3-1.7B\"\n",
    "SAVE_FILE = \"../icl_corpus/train_deepcoder_icl_corpus_qwen3_1.7b.json\"\n",
    "SAVE_FILE = \"../icl_corpus/train_DAPO-Math-17k_icl_corpus_qwen3_1.7b.json\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bc61a76",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import datasets\n",
    "import numpy as np\n",
    "from tqdm.auto import tqdm\n",
    "from transformers import AutoTokenizer\n",
    "from verl.utils.reward_score import default_compute_score\n",
    "from concurrent.futures import ProcessPoolExecutor, as_completed\n",
    "\n",
    "def filter_tests(tests):\n",
    "    indexed = [(i, t) for i, t in enumerate(tests)]\n",
    "    indexed.sort(key=lambda it: (-len(str(it[1].get(\"input\"))), it[0]))\n",
    "    selected_tests = [t for _, t in indexed[:15]]\n",
    "    return selected_tests\n",
    "\n",
    "def reward_model_processor(reward_models):\n",
    "    processed_reward_models = []\n",
    "    for reward_model in reward_models:\n",
    "        if type(reward_model) == str:\n",
    "            reward_model = json.loads(reward_model)\n",
    "            reward_model[\"ground_truth\"] = filter_tests(reward_model[\"ground_truth\"])\n",
    "            reward_model[\"ground_truth\"] = json.dumps(reward_model[\"ground_truth\"])\n",
    "        processed_reward_models.append(reward_model)\n",
    "    return {\"reward_model\": processed_reward_models}\n",
    "\n",
    "dataframe = datasets.load_dataset(\"parquet\", data_files=TRAIN_DATASET)[\"train\"]\n",
    "dataframe = dataframe.map(reward_model_processor,\n",
    "    input_columns=\"reward_model\",\n",
    "    batched=True, batch_size=16,\n",
    "    num_proc=8,\n",
    "    desc=\"Processing reward_model\"\n",
    ")\n",
    "df = dataframe.to_pandas()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47d9a854-eb22-4965-a426-3f87a0a7dfae",
   "metadata": {
    "collapsed": true,
    "execution": {
     "iopub.execute_input": "2025-07-16T04:01:35.812679Z",
     "iopub.status.busy": "2025-07-16T04:01:35.812215Z",
     "iopub.status.idle": "2025-07-16T04:03:49.790819Z",
     "shell.execute_reply": "2025-07-16T04:03:49.790300Z",
     "shell.execute_reply.started": "2025-07-16T04:01:35.812648Z"
    },
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "from vllm import LLM, SamplingParams\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
    "prompts = df[\"prompt\"].apply(lambda x: tokenizer.apply_chat_template(x, tokenize=False, add_generation_prompt=True, enable_thinking=True))\n",
    "\n",
    "llm = LLM(model=MODEL, tensor_parallel_size=4, )\n",
    "sampling_params = SamplingParams(\n",
    "    temperature=1.0, top_p=0.95, n=WARMUP_GENERATIONS, max_tokens=16384\n",
    ")\n",
    "outputs = llm.generate(prompts[:2], sampling_params)\n",
    "gen_outputs = [[output.outputs[idx].text for idx in range(WARMUP_GENERATIONS)] for output in outputs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb22ff24",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Remove <think> ... </think>\n",
    "if MODEL == \"Qwen/Qwen3-1.7B\":\n",
    "    for idx in range(len(gen_outputs)):\n",
    "        for gen_idx in range(WARMUP_GENERATIONS):\n",
    "            if \"</think>\" in gen_outputs[idx][gen_idx]:\n",
    "                gen_outputs[idx][gen_idx] = str(gen_outputs[idx][gen_idx].split(\"</think>\")[-1].strip())\n",
    "            else:\n",
    "                gen_outputs[idx][gen_idx] = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbfe6249-9fa1-4e08-bb81-b078db87c2e7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-07-16T04:37:42.276428Z",
     "iopub.status.busy": "2025-07-16T04:37:42.275894Z",
     "iopub.status.idle": "2025-07-16T04:37:42.282591Z",
     "shell.execute_reply": "2025-07-16T04:37:42.282152Z",
     "shell.execute_reply.started": "2025-07-16T04:37:42.276403Z"
    }
   },
   "outputs": [],
   "source": [
    "user_prompts = df[\"prompt\"].apply(lambda x: x[0][\"content\"].strip()).values\n",
    "ground_truths = df[\"reward_model\"].apply(lambda x: x[\"ground_truth\"]).values\n",
    "data_source = df[\"data_source\"].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fdd83fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _score_one_prompt(args):\n",
    "    idx, user_prompt, data_src, gt_raw, outs = args\n",
    "    user_prompt = user_prompt.strip()\n",
    "    passed = []\n",
    "    for o in outs:\n",
    "        try:\n",
    "            if default_compute_score(data_src, o, gt_raw)==1:\n",
    "                passed.append(str(o))\n",
    "        except Exception:\n",
    "            continue\n",
    "    return user_prompt, passed\n",
    "\n",
    "def build_icl_corpus_mp(user_prompts, ground_truths, gen_outputs, data_source,\n",
    "                        max_workers=None):\n",
    "    assert len(user_prompts) == len(ground_truths) == len(gen_outputs) == len(data_source)\n",
    "    n = len(user_prompts)\n",
    "    items = [(i, user_prompts[i], data_source[i], ground_truths[i], gen_outputs[i]) for i in range(n)]\n",
    "\n",
    "    if max_workers is None:\n",
    "        max_workers = os.cpu_count()-2 or 1\n",
    "\n",
    "    icl_corpus = {}\n",
    "    with ProcessPoolExecutor(max_workers=max_workers) as ex:\n",
    "        futures = [ex.submit(_score_one_prompt, it) for it in items]\n",
    "        with tqdm(total=n, desc=\"Scoring\", dynamic_ncols=True) as pbar:\n",
    "            for fut in as_completed(futures):\n",
    "                try:\n",
    "                    prompt, passed = fut.result()\n",
    "                    icl_corpus[prompt] = passed\n",
    "                except Exception as e:\n",
    "                    print(e)\n",
    "                    # keep going even if a task failed\n",
    "                    pass\n",
    "                pbar.update(1)\n",
    "    return icl_corpus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b24f4a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "icl_corpus = build_icl_corpus_mp(user_prompts, ground_truths, gen_outputs, data_source)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4d53496",
   "metadata": {},
   "outputs": [],
   "source": [
    "sum([len(v)>0 for k, v in icl_corpus.items()])/len(icl_corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c3efabc-7a7a-42be-8cbf-6536d6607c67",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-07-16T05:07:06.359733Z",
     "iopub.status.busy": "2025-07-16T05:07:06.359462Z",
     "iopub.status.idle": "2025-07-16T05:07:06.648225Z",
     "shell.execute_reply": "2025-07-16T05:07:06.647797Z",
     "shell.execute_reply.started": "2025-07-16T05:07:06.359716Z"
    }
   },
   "outputs": [],
   "source": [
    "import json\n",
    "with open(SAVE_FILE, 'w') as f:\n",
    "    json.dump(icl_corpus, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae3276fb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "verl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
