{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56a013a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import sys\n",
    "import json, torch\n",
    "import torch.nn.functional as F\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "\n",
    "from math500.math_utils import * \n",
    "\n",
    "MODEL_ID   = \"sentence-transformers/all-mpnet-base-v2\"\n",
    "BATCH_SIZE = 1\n",
    "TOP_K      = 1                                      \n",
    "device     = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(device)\n",
    "\n",
    "def mean_pooling(model_output, attention_mask):\n",
    "    token_emb = model_output[0]                       # (B, L, H)\n",
    "    mask       = attention_mask.unsqueeze(-1).expand(token_emb.shape).float()\n",
    "    return (token_emb * mask).sum(1) / torch.clamp(mask.sum(1), min=1e-9)\n",
    "\n",
    "@torch.inference_mode()\n",
    "def encode_texts(texts, tokenizer, model):\n",
    "    all_embs = []\n",
    "    for i in range(0, len(texts), BATCH_SIZE):\n",
    "        batch = texts[i : i + BATCH_SIZE]\n",
    "        enc   = tokenizer(batch, padding=True, truncation=True, return_tensors=\"pt\").to(device)\n",
    "        outputs = model(**enc)\n",
    "        emb = mean_pooling(outputs, enc[\"attention_mask\"])\n",
    "        emb = F.normalize(emb, p=2, dim=1)            # dot = cosine\n",
    "        all_embs.append(emb.cpu())                    \n",
    "    return torch.cat(all_embs, dim=0)                 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcb328b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n",
    "model     = AutoModel.from_pretrained(MODEL_ID).to(device).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "22dc2f5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"data/math500/test.jsonl\", \"r\", encoding=\"utf-8\") as f:\n",
    "    test_data = [json.loads(ln) for ln in f]\n",
    "\n",
    "few_shot = load_prompt(num_shots=5) \n",
    "few_questions   = [ex[0] for ex in few_shot]          \n",
    "test_questions  = [ex[\"problem\"] for ex in test_data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "038ab7ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "few_embs  = encode_texts(few_questions,  tokenizer, model)   # (M, H)\n",
    "test_embs = encode_texts(test_questions, tokenizer, model)   # (N, H)\n",
    "\n",
    "similarity = torch.matmul(test_embs, few_embs.T)             # (N, M)\n",
    "topk_vals, topk_idxs = torch.topk(similarity, k=TOP_K, dim=1)  # (N, k)\n",
    "\n",
    "jsonl_path = \"embedding/math500/math500_k.jsonl\" \n",
    "os.makedirs(os.path.dirname(jsonl_path), exist_ok=True)\n",
    "\n",
    "records = [\n",
    "    {\"test_idx\": int(i), \"fewshot_topk\": topk_idxs[i].tolist()}\n",
    "    for i in range(topk_idxs.size(0))\n",
    "]\n",
    "\n",
    "with open(jsonl_path, \"w\", encoding=\"utf-8\") as f:\n",
    "    f.writelines(json.dumps(r, ensure_ascii=False) + \"\\n\" for r in records)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "proj2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
