{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import json\n",
    "import numpy as np\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch\n",
    "\n",
    "from utils import score_fast, append_sol_and_remove_eos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_samples = 50\n",
    "log_interval = 10\n",
    "\n",
    "rng_seed = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(rng_seed)\n",
    "random.seed(rng_seed)\n",
    "torch.manual_seed(rng_seed)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_to_use = 'instruct-gpt-j-fp16' # 'gpt2'\n",
    "\n",
    "if model_to_use == 'instruct-gpt-j-fp16':\n",
    "    tokenizer = AutoTokenizer.from_pretrained('nlpcloud/instruct-gpt-j-fp16')\n",
    "    model = AutoModelForCausalLM.from_pretrained('nlpcloud/instruct-gpt-j-fp16',\n",
    "                                                torch_dtype=torch.bfloat16)\n",
    "elif model_to_use == 'gpt2':\n",
    "    tokenizer = AutoTokenizer.from_pretrained('gpt2')\n",
    "    model = AutoModelForCausalLM.from_pretrained('gpt2')\n",
    "\n",
    "model.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "answers = [ 'objective', 'subjective' ]\n",
    "\n",
    "obj_id = tokenizer.vocab['Ġobjective']\n",
    "subj_id = tokenizer.vocab['Ġsubjective']\n",
    "\n",
    "data_train = [ json.loads(l) for l in open(f'data/subj/train.{train_samples}.jsonl', 'r') ]\n",
    "data_test = [ json.loads(l) for l in open('data/subj/test.jsonl', 'r') ]\n",
    "\n",
    "data_train = [sample for sample in data_train]\n",
    "data_test = [sample for sample in data_test]\n",
    "\n",
    "train_queries = []\n",
    "train_sols = []\n",
    "\n",
    "test_queries = []\n",
    "test_sols = []\n",
    "\n",
    "intro_prompt = 'Classify this movie review as objective or subjective: \"'\n",
    "cot_prompt = '\" It is'\n",
    "\n",
    "for sample in data_train:\n",
    "    train_queries.append(intro_prompt + sample['text'] + cot_prompt)\n",
    "    train_sols.append(' ' + sample['label_text'])\n",
    "\n",
    "few_show_examples = [train_queries[i] + train_sols[i] + '.\\n' for i in range(train_samples)]\n",
    "random.shuffle(few_show_examples)\n",
    "few_shot_prompt = ''.join(few_show_examples)\n",
    "    \n",
    "for sample in data_test:\n",
    "    test_queries.append(few_shot_prompt+intro_prompt + sample['text'] + cot_prompt)\n",
    "    test_sols.append(' ' + sample['label_text'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_train_queries = [tokenizer(query, return_tensors='pt')['input_ids'].cuda() for query in train_queries]\n",
    "encoded_train_sols = [tokenizer(answer, return_tensors='pt')['input_ids'].cuda() for answer in train_sols]\n",
    "encoded_train_all_sols = [tokenizer(' objective.', return_tensors='pt')['input_ids'].cuda(),\n",
    "                          tokenizer(' subjective.', return_tensors='pt')['input_ids'].cuda()]\n",
    "encoded_test_queries = [tokenizer(query, return_tensors='pt')['input_ids'].cuda() for query in test_queries]\n",
    "\n",
    "eos_token_id = tokenizer.eos_token_id\n",
    "pad_token_id = tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_preds(model, encoded_queries, top_n = 999999, bsz = 1):\n",
    "    preds = []\n",
    "    encoded_obj = tokenizer(' objective',\n",
    "                                return_tensors='pt').to('cuda')['input_ids'][0]\n",
    "    encoded_sub = tokenizer(' subjective',\n",
    "                                return_tensors='pt').to('cuda')['input_ids'][0]\n",
    "    encoded_results = torch.nn.utils.rnn.pad_sequence([encoded_obj, encoded_sub], batch_first=True, padding_value=eos_token_id)\n",
    "    encoded_queries_to_use = encoded_queries[:top_n]\n",
    "    for i in range(len(encoded_queries_to_use) // bsz):\n",
    "        batch_input = torch.nn.utils.rnn.pad_sequence([x[0] for x in encoded_queries_to_use[i*bsz:(i+1)*bsz]],\n",
    "                                                      batch_first=True,\n",
    "                                                      padding_value=eos_token_id)\n",
    "        with torch.no_grad():\n",
    "            mean_reward = score_fast(model,\n",
    "                            append_sol_and_remove_eos(batch_input.repeat_interleave(2, dim=0),\n",
    "                                                      encoded_results.repeat(bsz, 1), eos_token_id, pad_token_id),\n",
    "                            eos_token_id=eos_token_id)\n",
    "        pred = mean_reward.reshape(bsz, 2)\n",
    "        preds += (pred[:, 0] > pred[:, 1]).tolist()\n",
    "    return preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "true_preds_train = torch.tensor([True if 'objective' in sol else False for sol in train_sols])\n",
    "true_preds = torch.tensor([True if 'objective' in sol else False for sol in test_sols])\n",
    "\n",
    "model.eval()\n",
    "train_preds = get_preds(model, encoded_train_queries, bsz = 10)\n",
    "print(f'Train Acc : {(torch.tensor(train_preds) == true_preds_train).sum() / len(true_preds_train)}')\n",
    "test_preds = get_preds(model, encoded_test_queries, bsz = 10)\n",
    "print(f'Test Acc : {(torch.tensor(test_preds) == true_preds).sum() / len(true_preds)}')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
