{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import json\n",
    "import numpy as np\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch\n",
    "from peft import LoraConfig, get_peft_model\n",
    "\n",
    "from utils import score_fast, append_sol_and_remove_eos, remove_eos_and_pad_left"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bsz = 32\n",
    "grad_acc = 8\n",
    "\n",
    "lr = 0.001\n",
    "warmup_steps = 20\n",
    "total_steps = 100\n",
    "\n",
    "train_samples = 20\n",
    "log_interval = 10\n",
    "\n",
    "rng_seed = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(rng_seed)\n",
    "random.seed(rng_seed)\n",
    "torch.manual_seed(rng_seed)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_to_use = 'instruct-gpt-j-fp16' # 'gpt2'\n",
    "\n",
    "if model_to_use == 'instruct-gpt-j-fp16':\n",
    "    tokenizer = AutoTokenizer.from_pretrained('nlpcloud/instruct-gpt-j-fp16')\n",
    "    model = AutoModelForCausalLM.from_pretrained('nlpcloud/instruct-gpt-j-fp16',\n",
    "                                                torch_dtype=torch.bfloat16)\n",
    "elif model_to_use == 'gpt2':\n",
    "    tokenizer = AutoTokenizer.from_pretrained('gpt2')\n",
    "    model = AutoModelForCausalLM.from_pretrained('gpt2')\n",
    "\n",
    "model.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "answers = [ 'objective', 'subjective' ]\n",
    "\n",
    "obj_id = tokenizer.vocab['Ġobjective']\n",
    "subj_id = tokenizer.vocab['Ġsubjective']\n",
    "\n",
    "data_train = [ json.loads(l) for l in open(f'data/subj/train.{train_samples}.jsonl', 'r') ]\n",
    "data_test = [ json.loads(l) for l in open('data/subj/test.jsonl', 'r') ]\n",
    "\n",
    "data_train = [sample for sample in data_train if len(sample['text'].split()) < 25]\n",
    "data_test = [sample for sample in data_test]\n",
    "\n",
    "train_queries = []\n",
    "train_sols = []\n",
    "\n",
    "test_queries = []\n",
    "test_sols = []\n",
    "\n",
    "intro_prompt = 'Classify this movie review as objective or subjective: \"'\n",
    "cot_prompt = '\" It is'\n",
    "\n",
    "for sample in data_train:\n",
    "    train_queries.append(intro_prompt + sample['text'] + cot_prompt)\n",
    "    train_sols.append(' ' + sample['label_text'])\n",
    "\n",
    "for sample in data_test:\n",
    "    test_queries.append(intro_prompt + sample['text'] + cot_prompt)\n",
    "    test_sols.append(' ' + sample['label_text'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_train_queries = [tokenizer(query, return_tensors='pt')['input_ids'].cuda() for query in train_queries]\n",
    "encoded_train_sols = [tokenizer(answer, return_tensors='pt')['input_ids'].cuda() for answer in train_sols]\n",
    "encoded_train_all_sols = [tokenizer(' objective.', return_tensors='pt')['input_ids'].cuda(),\n",
    "                          tokenizer(' subjective.', return_tensors='pt')['input_ids'].cuda()]\n",
    "encoded_test_queries = [tokenizer(query, return_tensors='pt')['input_ids'].cuda() for query in test_queries]\n",
    "\n",
    "eos_token_id = tokenizer.eos_token_id\n",
    "pad_token_id = tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_sols[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ccJtHB9IJX2o"
   },
   "outputs": [],
   "source": [
    "lora_config = LoraConfig(\n",
    "    r=256,\n",
    "    lora_alpha=16,\n",
    "    target_modules=[\"k_proj\", \"v_proj\"] if model_to_use == 'instruct-gpt-j-fp16' else [\"c_attn\"],\n",
    "    lora_dropout=0.,\n",
    "    bias=\"none\",\n",
    "    modules_to_save=[\"classifier\"],\n",
    ")\n",
    "knowledge_model = get_peft_model(model, lora_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IMLxprPGJp7_"
   },
   "outputs": [],
   "source": [
    "opt = torch.optim.AdamW([{'params': knowledge_model.parameters(), 'lr': lr}], betas=(0.9, 0.99))\n",
    "\n",
    "# learning rate schedule\n",
    "def get_lr_mult_at_step(step):\n",
    "    if step <= warmup_steps:\n",
    "        return min(step/warmup_steps, 1.)\n",
    "    return max((total_steps - step) / (total_steps - warmup_steps), 0)\n",
    "sched = torch.optim.lr_scheduler.LambdaLR(opt, get_lr_mult_at_step)\n",
    "\n",
    "get_lr_at_step = lambda x : min(x/warmup_steps*lr, lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "dph6LMlH0ooD",
    "outputId": "cb17a778-4221-421b-f6b2-f163cad72797",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for step in range(total_steps):\n",
    "    opt.zero_grad()\n",
    "    loss = 0.\n",
    "    for _ in range(grad_acc):\n",
    "        # build a batch\n",
    "        batch_input = []\n",
    "        batch_labels = []\n",
    "        for _ in range(bsz):\n",
    "            # select an example\n",
    "            query_ind = np.random.choice(np.arange(len(encoded_train_queries)))\n",
    "            encoded_input = encoded_train_queries[query_ind]\n",
    "            batch_input.append(encoded_input[0]) # reverse to prepare for left-padding\n",
    "            if 'objective' in train_sols[query_ind]:\n",
    "                batch_labels.append(True)\n",
    "            elif 'subjective' in train_sols[query_ind]:\n",
    "                batch_labels.append(False)\n",
    "        #batch_input = torch.nn.utils.rnn.pad_sequence(batch_input, batch_first=True, padding_value=eos_token_id).flip(-1)\n",
    "        batch_input, position_ids, _ = \\\n",
    "            remove_eos_and_pad_left(batch_input, eos_token_id=eos_token_id, pad_token_id=eos_token_id)\n",
    "        position_ids = position_ids.cuda()\n",
    "        batch_labels = torch.tensor(batch_labels, device='cuda', dtype=torch.bool)\n",
    "\n",
    "        last_logprob = knowledge_model(batch_input,\n",
    "                                       attention_mask=batch_input!=eos_token_id,\n",
    "                                       position_ids=position_ids)['logits'][:, -1].log_softmax(dim=-1)\n",
    "        obj_logprob = last_logprob[:, obj_id]\n",
    "        subj_logprob = last_logprob[:, subj_id]\n",
    "        partition_fn = torch.logsumexp(torch.stack([obj_logprob, subj_logprob], dim=-1), dim=-1)\n",
    "        loss = torch.where(batch_labels, -(obj_logprob - partition_fn), -(subj_logprob - partition_fn))\n",
    "        loss.mean().backward()\n",
    "        \n",
    "    opt.step()\n",
    "    sched.step()\n",
    "    if step % log_interval == 0:\n",
    "        print(f'loss: {loss.mean().item()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_preds(model, encoded_queries, top_n = 999999, bsz = 1):\n",
    "    preds = []\n",
    "    encoded_obj = tokenizer(' objective',\n",
    "                                return_tensors='pt').to('cuda')['input_ids'][0]\n",
    "    encoded_sub = tokenizer(' subjective',\n",
    "                                return_tensors='pt').to('cuda')['input_ids'][0]\n",
    "    encoded_results = torch.nn.utils.rnn.pad_sequence([encoded_obj, encoded_sub], batch_first=True, padding_value=eos_token_id)\n",
    "    encoded_queries_to_use = encoded_queries[:top_n]\n",
    "    for i in range(len(encoded_queries_to_use) // bsz):\n",
    "        batch_input = torch.nn.utils.rnn.pad_sequence([x[0] for x in encoded_queries_to_use[i*bsz:(i+1)*bsz]],\n",
    "                                                      batch_first=True,\n",
    "                                                      padding_value=eos_token_id)\n",
    "        with torch.no_grad():\n",
    "            mean_reward = score_fast(model,\n",
    "                            append_sol_and_remove_eos(batch_input.repeat_interleave(2, dim=0),\n",
    "                                                      encoded_results.repeat(bsz, 1), eos_token_id, pad_token_id),\n",
    "                            eos_token_id=eos_token_id)\n",
    "        pred = mean_reward.reshape(bsz, 2)\n",
    "        preds += (pred[:, 0] > pred[:, 1]).tolist()\n",
    "    return preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "true_preds_train = torch.tensor([True if 'objective' in sol else False for sol in train_sols])\n",
    "true_preds = torch.tensor([True if 'objective' in sol else False for sol in test_sols])\n",
    "\n",
    "knowledge_model.eval()\n",
    "train_preds = get_preds(knowledge_model, encoded_train_queries, bsz = 10)\n",
    "print(f'Train Acc : {(torch.tensor(train_preds) == true_preds_train).sum() / len(true_preds_train)}')\n",
    "test_preds = get_preds(knowledge_model, encoded_test_queries, bsz = 100)\n",
    "print(f'Test Acc : {(torch.tensor(test_preds) == true_preds).sum() / len(true_preds)}')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
