{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c85ff65",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TRAIN_DATASET = \"../datasets/train_deepcoder.parquet\"\n",
    "# EMBEDDING_MODEL = \"Qwen/Qwen3-Embedding-8B\"\n",
    "# SAVE_FILE = \"../icl_corpus/train_deepcoder_similar_questions.json\"\n",
    "# TOPK = 15\n",
    "\n",
    "TRAIN_DATASET = \"../datasets/train_DAPO-Math-17k.parquet\"\n",
    "EMBEDDING_MODEL = \"Qwen/Qwen3-Embedding-8B\"\n",
    "SAVE_FILE = \"../icl_corpus/train_DAPO-Math-17k_similar_questions.json\"\n",
    "TOPK = 15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3eb59624-68fb-4898-b8df-d91b54dd70a6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-07-15T19:26:35.014501Z",
     "iopub.status.busy": "2025-07-15T19:26:35.014269Z",
     "iopub.status.idle": "2025-07-15T19:26:35.377178Z",
     "shell.execute_reply": "2025-07-15T19:26:35.376733Z",
     "shell.execute_reply.started": "2025-07-15T19:26:35.014486Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "df = pd.read_parquet(TRAIN_DATASET)\n",
    "prompts = df[\"prompt\"].apply(lambda x: x[0][\"content\"].strip()).values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "068bde1d-0658-456c-aa47-93a3b141b244",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-07-15T19:27:10.176503Z",
     "iopub.status.busy": "2025-07-15T19:27:10.176190Z",
     "iopub.status.idle": "2025-07-15T19:27:11.901778Z",
     "shell.execute_reply": "2025-07-15T19:27:11.901284Z",
     "shell.execute_reply.started": "2025-07-15T19:27:10.176484Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2754d8eebe864875b2a9c7ebe6a7d24e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = SentenceTransformer(EMBEDDING_MODEL, model_kwargs={\"torch_dtype\": torch.bfloat16})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b11d1bcb-ae54-4e16-b821-1aafdfce189a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-07-15T19:27:15.029237Z",
     "iopub.status.busy": "2025-07-15T19:27:15.028598Z",
     "iopub.status.idle": "2025-07-15T19:27:47.597171Z",
     "shell.execute_reply": "2025-07-15T19:27:47.596338Z",
     "shell.execute_reply.started": "2025-07-15T19:27:15.029221Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f685ef9e67c849cd9ebc4118e14ca1ff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Chunks:   0%|          | 0/80 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1a4b785af1a74f5a924c2d0246997723",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Chunks:   0%|          | 0/80 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "query_embeddings = model.encode(prompts, \n",
    "                                batch_size=16, \n",
    "                                prompt_name=\"query\", \n",
    "                                device=[\"cuda:0\", \"cuda:1\", \"cuda:2\", \"cuda:3\", \"cuda:4\", \"cuda:5\", \"cuda:6\", \"cuda:7\"], \n",
    "                                show_progress_bar=True\n",
    "                            )\n",
    "\n",
    "document_embeddings = model.encode(prompts,\n",
    "                                   batch_size=16,\n",
    "                                   device=[\"cuda:0\", \"cuda:1\", \"cuda:2\", \"cuda:3\", \"cuda:4\", \"cuda:5\", \"cuda:6\", \"cuda:7\"],\n",
    "                                   show_progress_bar=True\n",
    "                                )\n",
    "\n",
    "similarity = model.similarity(query_embeddings, document_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6edbcd6f-e85e-44cb-be69-4b277d7d21a4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-07-15T19:29:23.535338Z",
     "iopub.status.busy": "2025-07-15T19:29:23.535010Z",
     "iopub.status.idle": "2025-07-15T19:29:28.463226Z",
     "shell.execute_reply": "2025-07-15T19:29:28.462366Z",
     "shell.execute_reply.started": "2025-07-15T19:29:23.535322Z"
    }
   },
   "outputs": [],
   "source": [
    "similar_questions = {}\n",
    "for idx, prompt in enumerate(prompts):\n",
    "    sim = similarity[idx]\n",
    "    similar_questions[prompt] = [prompts[k] for k in torch.argsort(sim)[-TOPK:-1].numpy()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "12445b1e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Q1\n",
      "Trapezoid \\(ABCD\\) has \\(AD \\parallel BC\\), \\(BD = 1\\), \\(\\angle DBA = 23^\\circ\\), and \\(\\angle BDC = 46^\\circ\\). The ratio \\(BC: AD\\) is \\(9: 5\\). What is the value of \\(CD\\)? The original answer is in the form \\(\\frac{k}{m}\\), where \\(\\frac{k}{m}\\) is a fully simplified fraction. Please provide the value of \\(k + m\\).\n",
      "\n",
      "Let's think step by step and output the final answer within \\boxed{}.\n",
      "Q2\n",
      "In triangle \\(ABC\\), \\(\\measuredangle CBA=72^\\circ\\), \\(E\\) is the midpoint of side \\(AC\\), and \\(D\\) is a point on side \\(BC\\) such that \\(2BD=DC\\); \\(AD\\) and \\(BE\\) intersect at \\(F\\). Find the ratio of the area of triangle \\(BDF\\) to the area of quadrilateral \\(FDCE\\). The original answer is in \\(\\frac{k}{m}\\) format, please give the value of k + m.\n",
      "\n",
      "Let's think step by step and output the final answer within \\boxed{}.\n"
     ]
    }
   ],
   "source": [
    "print(\"Q1\")\n",
    "print(prompt)\n",
    "print(\"Q2\")\n",
    "print(similar_questions[prompt][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1400910f-9947-4324-b9a2-ce407f8824b5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-07-15T19:36:07.357113Z",
     "iopub.status.busy": "2025-07-15T19:36:07.356732Z",
     "iopub.status.idle": "2025-07-15T19:36:07.563248Z",
     "shell.execute_reply": "2025-07-15T19:36:07.562703Z",
     "shell.execute_reply.started": "2025-07-15T19:36:07.357093Z"
    }
   },
   "outputs": [],
   "source": [
    "import json\n",
    "with open(SAVE_FILE, 'w') as f:\n",
    "    json.dump(similar_questions, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de8131e9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "verl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
