{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "torch.cuda.get_device_name(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import BertTokenizer, BertForMaskedLM, BertModel\n",
    "from tokenizer import *\n",
    "import pickle\n",
    "from torch.utils.data import DataLoader\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from data_vcpkg import help_tokenize, load_paired_data, FunctionDataset_CL, FunctionDataset_CL_Load\n",
    "from transformers import AdamW\n",
    "import torch.nn.functional as F\n",
    "import argparse\n",
    "import logging\n",
    "import sys\n",
    "import time\n",
    "import data_vcpkg as data\n",
    "import pickle\n",
    "import sys\n",
    "from datautils_windows.playdata_vcpkg import DatasetBase as DatasetBase\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import argparse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def finetune_eval(net, data_loader):\n",
    "    net.eval()\n",
    "    with torch.no_grad():\n",
    "        avg=[]\n",
    "        gt=[]\n",
    "        cons=[]\n",
    "        eval_iterator = tqdm(data_loader)\n",
    "        for i, (seq1,seq2,_,mask1,mask2,_) in enumerate(eval_iterator):\n",
    "            input_ids1, attention_mask1= seq1.cuda(),mask1.cuda()\n",
    "            input_ids2, attention_mask2= seq2.cuda(),mask2.cuda()\n",
    "\n",
    "            anchor,pos=0,0\n",
    "\n",
    "            output1 = model(input_ids=input_ids1, attention_mask=attention_mask1)\n",
    "            anchor = output1.pooler_output\n",
    "\n",
    "            output2 = model(input_ids=input_ids2, attention_mask=attention_mask2)\n",
    "            pos = output2.pooler_output\n",
    "\n",
    "            ans=0\n",
    "            for i in range(len(anchor)):    # check every vector of (vA,vB)\n",
    "                vA=anchor[i:i+1].cuda()  #pos[i]\n",
    "                sim=[]\n",
    "                for j in range(len(pos)):\n",
    "                    vB=pos[j:j+1].cuda()   # pos[j]\n",
    "                    AB_sim=F.cosine_similarity(vA, vB).item()\n",
    "                    sim.append(AB_sim)\n",
    "                    if j!=i:\n",
    "                        cons.append(AB_sim)\n",
    "                sim=np.array(sim)\n",
    "                y=np.argsort(-sim)\n",
    "                posi=0\n",
    "                for j in range(len(pos)):\n",
    "                    if y[j]==i:\n",
    "                        posi=j+1\n",
    "\n",
    "                gt.append(sim[i])\n",
    "                ans+=1/posi\n",
    "            ans=ans/len(anchor)\n",
    "            avg.append(ans)\n",
    "        print(avg)\n",
    "        return np.mean(np.array(avg))\n",
    "\n",
    "class BinBertModel(BertModel):\n",
    "    def __init__(self, config, add_pooling_layer=True):\n",
    "        super().__init__(config)\n",
    "        self.config = config\n",
    "        self.embeddings.position_embeddings=self.embeddings.word_embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"models/jTrans-finetune\"\n",
    "model = BinBertModel.from_pretrained(model_path)\n",
    "\n",
    "eval_path = \"/data/jTrans/some_extract\"\n",
    "model = nn.DataParallel(model)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained(\"./jtrans_tokenizer\")\n",
    "valid_set = FunctionDataset_CL_Load(tokenizer, eval_path, convert_jump_addr=True, opt=None)\n",
    "valid_dataloader = DataLoader(valid_set, batch_size=16, num_workers=128, shuffle=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mrr = finetune_eval(model, valid_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mrr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jtrans",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
