{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x38eeb7510>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../..\")\n",
    "\n",
    "from copy import deepcopy\n",
    "import omegaconf\n",
    "import torch\n",
    "import pandas as pd\n",
    "from next_sentence.utils import FrozenModelSentenceGivenPrompt, ReplayBuffer\n",
    "from next_sentence.lightning_module import NextSentenceGFNTask\n",
    "from train_next_sentence import get_model\n",
    "\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = omegaconf.DictConfig(\n",
    "    {\n",
    "        \"task\": {\n",
    "            \"model\": {\n",
    "                \"name\": \"distilgpt2\",\n",
    "                \"lora_config\": {\n",
    "                    \"_target_\": \"peft.LoraConfig\",\n",
    "                    \"target_modules\": [\"c_attn\", \"c_proj\", \"c_fc\"],\n",
    "                    \"r\": 64,\n",
    "                    \"lora_alpha\": 16,\n",
    "                    \"lora_dropout\": 0.1,\n",
    "                    \"bias\": \"none\",\n",
    "                    \"fan_in_fan_out\": True,\n",
    "                },\n",
    "            }\n",
    "        }\n",
    "    }\n",
    ")\n",
    "\n",
    "model, tokenizer = get_model(config)\n",
    "model_orig_ = deepcopy(model).to(\"mps\")\n",
    "model = NextSentenceGFNTask.load_from_checkpoint(\n",
    "    \"~/Downloads/best.ckpt\", model=model, tokenizer=tokenizer, map_location=\"mps\"\n",
    ")\n",
    "model_orig = deepcopy(model)\n",
    "model_orig.model = model_orig_\n",
    "del model_orig_\n",
    "\n",
    "model.eval()\n",
    "model_orig.eval()\n",
    "model.reward.temperature = 1.0\n",
    "model_orig.reward.temperature = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompts = [\n",
    "    \"My dog is black. \",\n",
    "    \"I've been working on this far too long. \",\n",
    "    \"There's something I need to get off my chest. \",\n",
    "]\n",
    "prompts_tensor = [\n",
    "    tokenizer(prompt, return_tensors=\"pt\").input_ids[0].to(model.device)\n",
    "    for prompt in prompts\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = model.sample_probes(prompts_tensor, n_samples=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples_orig = model_orig.sample_probes(prompts_tensor, n_samples=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Prompt</th>\n",
       "      <th>Sampled sentence</th>\n",
       "      <th>logP(s)</th>\n",
       "      <th>Sampled sentence (beam)</th>\n",
       "      <th>logP(s) (beam)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>There's something I need to get off my chest.</td>\n",
       "      <td>????</td>\n",
       "      <td>-8.783124</td>\n",
       "      <td>ㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠ�</td>\n",
       "      <td>-24.739201</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>My dog is black.</td>\n",
       "      <td>________</td>\n",
       "      <td>-9.519690</td>\n",
       "      <td>__________________\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n...</td>\n",
       "      <td>-20.008408</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>I've been working on this far too long.</td>\n",
       "      <td>\\nhttp://www</td>\n",
       "      <td>-11.228891</td>\n",
       "      <td>_________________\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\...</td>\n",
       "      <td>-21.900307</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           Prompt Sampled sentence    logP(s)  \\\n",
       "2  There's something I need to get off my chest.              ????  -8.783124   \n",
       "0                               My dog is black.          ________  -9.519690   \n",
       "1        I've been working on this far too long.      \\nhttp://www -11.228891   \n",
       "\n",
       "                             Sampled sentence (beam)  logP(s) (beam)  \n",
       "2                                  ㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠㅠ�      -24.739201  \n",
       "0  __________________\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n...      -20.008408  \n",
       "1  _________________\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\n\\...      -21.900307  "
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples_orig"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gfn-lm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
