{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "\n",
    "cache_dir = \"/work/hdd/bdkj/audreyh/.cache\"\n",
    "os.environ['XDG_CACHE_HOME'] = cache_dir\n",
    "# os.environ['HF_HOME'] = cache_dir\n",
    "# os.environ['HF_TOKEN_PATH'] = f\"{cache_dir}/token\"\n",
    "# os.environ['HF_HUB_CACHE'] = f\"{cache_dir}/hub\"\n",
    "\n",
    "\n",
    "import numpy as np \n",
    "import torch \n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline\n",
    "\n",
    "# print(os.environ.get('HF_HOME'))\n",
    "# print(os.environ.get('HF_HUB_CACHE'))\n",
    "# print(os.environ.get('HF_TOKEN_PATH'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1957075e40ba40d48b70782d91ef5424",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mFailed to connect to the remote Jupyter Server 'http://gpub048.delta.ncsa.illinois.edu:52434/'. Verify the server is running and reachable."
     ]
    }
   ],
   "source": [
    "model_name = 'weqweasdas/RM-Gemma-2B'\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "chat = [\n",
    "   {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n",
    "   {\"role\": \"assistant\", \"content\": \"I'm doing great. How can I help you today?\"},\n",
    "   {\"role\": \"user\", \"content\": \"I'd like to show off how chat templating works!\"},\n",
    "  ]\n",
    "\n",
    "test_texts = [tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False).replace(tokenizer.bos_token, \"\")]\n",
    "_tok_texts = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=False, return_tensors='pt')\n",
    "\n",
    "mask = (_tok_texts != tokenizer.bos_token_id)\n",
    "tok_texts = _tok_texts[mask].unsqueeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad(): \n",
    "    out = model(tok_texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-1.6332]])\n"
     ]
    }
   ],
   "source": [
    "logits = out['logits']\n",
    "print(logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "76d671e6c7294f5b9f1b7892aaa858cc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "rm_pipe = pipeline(\n",
    "    # \"sentiment-analysis\",\n",
    "    model=\"weqweasdas/RM-Gemma-2B\",\n",
    "    device=\"cuda\",\n",
    "    tokenizer=tokenizer,\n",
    "    model_kwargs={\"torch_dtype\": torch.bfloat16}\n",
    ")\n",
    "pipe_kwargs = {\n",
    "      \"return_all_scores\": True,\n",
    "      \"function_to_apply\": \"none\",\n",
    "      \"batch_size\": 1\n",
    "  }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/audreyh/anaconda3/envs/rlhf/lib/python3.10/site-packages/transformers/pipelines/text_classification.py:106: UserWarning: `return_all_scores` is now deprecated,  if want a similar functionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "pipe_outputs = rm_pipe(test_texts, **pipe_kwargs)\n",
    "rewards = [output[0][\"score\"] for output in pipe_outputs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'label': 'LABEL_0', 'score': -2.734375}]]"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pipe_outputs"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
