{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xq77AgKWJM-N"
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import json\n",
    "import numpy as np\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch\n",
    "from peft import LoraConfig, get_peft_model\n",
    "\n",
    "from utils import score_fast, remove_eos_and_pad_left, \\\n",
    "                append_sol_and_remove_eos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bsz = 32\n",
    "grad_acc = 8\n",
    "log_interval = 10\n",
    "\n",
    "total_steps = 1000\n",
    "warmup_m_steps = 20\n",
    "total_m_steps = 100\n",
    "\n",
    "lr = 0.0001\n",
    "\n",
    "max_len = 5\n",
    "reward_temp = 1\n",
    "\n",
    "train_samples = 20\n",
    "preseed_buffer = True\n",
    "\n",
    "rngseed = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_to_use = 'instruct-gpt-j-fp16' # 'gpt2'\n",
    "\n",
    "if model_to_use == 'instruct-gpt-j-fp16':\n",
    "    tokenizer = AutoTokenizer.from_pretrained('nlpcloud/instruct-gpt-j-fp16')\n",
    "    model = AutoModelForCausalLM.from_pretrained('nlpcloud/instruct-gpt-j-fp16',\n",
    "                                                torch_dtype=torch.bfloat16)\n",
    "elif model_to_use == 'gpt2':\n",
    "    tokenizer = AutoTokenizer.from_pretrained('gpt2')\n",
    "    model = AutoModelForCausalLM.from_pretrained('gpt2')  \n",
    "\n",
    "model.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(0)\n",
    "random.seed(0)\n",
    "\n",
    "answers = [ 'objective', 'subjective' ]\n",
    "\n",
    "obj_id = tokenizer.vocab['Ġobjective']\n",
    "subj_id = tokenizer.vocab['Ġsubjective']\n",
    "\n",
    "data_train = [ json.loads(l) for l in open('data/subj/train.jsonl', 'r') ]\n",
    "data_test = [ json.loads(l) for l in open('data/subj/test.jsonl', 'r') ]\n",
    "\n",
    "data_train = [sample for sample in data_train if len(sample['text'].split()) < 25]\n",
    "data_test = [sample for sample in data_test]\n",
    "\n",
    "random.shuffle(data_train)\n",
    "data_train = data_train[:train_samples]\n",
    "\n",
    "train_queries = []\n",
    "train_sols = []\n",
    "\n",
    "test_queries = []\n",
    "test_sols = []\n",
    "\n",
    "intro_prompt = 'Classify this movie review as objective or subjective: \"'\n",
    "cot_prompt = '\" This review is'\n",
    "sol_prompt = ', so it is'\n",
    "\n",
    "for sample in data_train:\n",
    "    train_queries.append(intro_prompt + sample['text'] + cot_prompt)\n",
    "    train_sols.append(sol_prompt + ' ' + sample['label_text'] + '.')\n",
    "\n",
    "for sample in data_test:\n",
    "    test_queries.append(intro_prompt + sample['text'] + cot_prompt)\n",
    "    test_sols.append(sol_prompt + ' ' + sample['label_text'] + '.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 200\n",
    "train_jsons = [json.dumps(x) for x in data_train][:n]\n",
    "with open(f'data/subj/train.{n}.jsonl', 'w') as f:\n",
    "    f.write('\\n'.join(train_jsons))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_train_queries = [tokenizer(query, return_tensors='pt')['input_ids'].cuda() for query in train_queries]\n",
    "encoded_train_sols = [tokenizer(answer, return_tensors='pt')['input_ids'].cuda() for answer in train_sols]\n",
    "encoded_train_all_sols = [tokenizer(sol_prompt+' objective.', return_tensors='pt')['input_ids'].cuda(),\n",
    "                          tokenizer(sol_prompt+' subjective.', return_tensors='pt')['input_ids'].cuda()]\n",
    "encoded_test_queries = [tokenizer(query, return_tensors='pt')['input_ids'].cuda() for query in test_queries]\n",
    "encoded_sol_prompt = tokenizer(sol_prompt, return_tensors='pt')['input_ids'].cuda()\n",
    "\n",
    "eos_token_id = tokenizer.eos_token_id\n",
    "pad_token_id = tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_sols[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lora_config = LoraConfig(\n",
    "    r=256,\n",
    "    lora_alpha=16,\n",
    "    target_modules=[\"k_proj\", \"v_proj\"] if model_to_use == 'instruct-gpt-j-fp16' else [\"c_attn\"],\n",
    "    lora_dropout=0.,\n",
    "    bias=\"none\",\n",
    "    modules_to_save=[\"classifier\"],\n",
    ")\n",
    "knowledge_model = get_peft_model(model, lora_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_preds(model, encoded_queries, top_n = 999999, bsz = 1):\n",
    "    preds = []\n",
    "    encoded_obj = tokenizer(', so it is objective',\n",
    "                                return_tensors='pt').to('cuda')['input_ids'][0]\n",
    "    encoded_sub = tokenizer(', so it is subjective',\n",
    "                                return_tensors='pt').to('cuda')['input_ids'][0]\n",
    "    encoded_results = torch.nn.utils.rnn.pad_sequence([encoded_obj, encoded_sub], batch_first=True, padding_value=eos_token_id)\n",
    "    encoded_queries_to_use = encoded_queries[:top_n]\n",
    "    for i in range(len(encoded_queries_to_use) // bsz):\n",
    "        batch_input = torch.nn.utils.rnn.pad_sequence([x[0] for x in encoded_queries_to_use[i*bsz:(i+1)*bsz]],\n",
    "                                                      batch_first=True,\n",
    "                                                      padding_value=eos_token_id)\n",
    "        with torch.no_grad():\n",
    "            mean_reward = score_fast(model,\n",
    "                            append_sol_and_remove_eos(batch_input.repeat_interleave(2, dim=0),\n",
    "                                                      encoded_results.repeat(bsz, 1), eos_token_id, pad_token_id),\n",
    "                            eos_token_id=eos_token_id)\n",
    "        pred = mean_reward.reshape(bsz, 2)\n",
    "        preds += (pred[:, 0] > pred[:, 1]).tolist()\n",
    "    return preds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## M step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = '<YOUR SAVE DIR>'\n",
    "ckpt_name = f'subj_obj_{model_to_use}_{train_samples}samples_len{max_len}_{total_steps}steps_rewtemp{reward_temp}_seed_{preseed_buffer}_rngseed_{rngseed}'\n",
    "\n",
    "encoded_train_queries_w_cot_sample = torch.load(f'{save_dir}/{ckpt_name}/encoded_train_queries_w_cot_sample.pt')\n",
    "encoded_test_queries_w_cot_greedy = torch.load(f'{save_dir}/{ckpt_name}/encoded_test_queries_w_cot_greedy.pt')\n",
    "encoded_test_queries_w_cot_sample = torch.load(f'{save_dir}/{ckpt_name}/encoded_test_queries_w_cot_sample.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_preds = torch.tensor([True if 'objective' in sol else False for sol in test_sols])\n",
    "\n",
    "knowledge_model.eval()\n",
    "\n",
    "# E-step-only greedy\n",
    "test_preds_greedy = get_preds(knowledge_model, encoded_test_queries_w_cot_greedy, bsz = 100)\n",
    "greedy_preds = torch.tensor(test_preds_greedy)\n",
    "print(f'Test Acc (greedy) : {(greedy_preds == true_preds).sum() / len(true_preds)}')\n",
    "# E-step-only sample\n",
    "test_preds_sample = get_preds(knowledge_model, [s.unsqueeze(0) for x in encoded_test_queries_w_cot_sample for s in x], bsz = 100)\n",
    "shaped_preds = torch.tensor(test_preds_sample).reshape(len(encoded_test_queries_w_cot_sample), -1)\n",
    "agg_preds = shaped_preds.sum(-1) / shaped_preds.size(1) > 0.5\n",
    "print(f'Test Acc (sample) : {(agg_preds == true_preds).sum() / len(true_preds)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_knowledge = torch.optim.AdamW([{'params': knowledge_model.parameters(), 'lr': lr}], betas=(0.9, 0.99))\n",
    "\n",
    "# learning rate schedule\n",
    "def get_lr_mult_at_step(step):\n",
    "    if step <= warmup_m_steps:\n",
    "        return min(step/warmup_m_steps, 1.)\n",
    "    return max((total_m_steps - step) / (total_m_steps - warmup_m_steps), 0)\n",
    "sched = torch.optim.lr_scheduler.LambdaLR(opt_knowledge, get_lr_mult_at_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "knowledge_model.train()\n",
    "for step in range(total_m_steps):\n",
    "    opt_knowledge.zero_grad()\n",
    "    loss = 0.\n",
    "    for _ in range(grad_acc):\n",
    "        # build a batch\n",
    "        batch_input = []\n",
    "        batch_labels = []\n",
    "        for _ in range(bsz):\n",
    "            query_ind = np.random.choice(np.arange(len(encoded_train_queries_w_cot_sample)))\n",
    "            rationale_ind = np.random.choice(np.arange(encoded_train_queries_w_cot_sample[query_ind].size(0)))\n",
    "            encoded_input = encoded_train_queries_w_cot_sample[query_ind][rationale_ind]\n",
    "            encoded_input = encoded_input[encoded_input != eos_token_id]\n",
    "            encoded_input = append_sol_and_remove_eos(encoded_input.unsqueeze(0),\n",
    "                                                      encoded_sol_prompt,\n",
    "                                                      eos_token_id=eos_token_id,\n",
    "                                                      pad_token_id=eos_token_id)[0]\n",
    "            batch_input.append(encoded_input) # reverse to prepare for left-padding\n",
    "            if 'objective' in train_sols[query_ind]:\n",
    "                batch_labels.append(True)\n",
    "            elif 'subjective' in train_sols[query_ind]:\n",
    "                batch_labels.append(False)\n",
    "        batch_input, position_ids, _ = \\\n",
    "            remove_eos_and_pad_left(batch_input, eos_token_id=eos_token_id, pad_token_id=eos_token_id)\n",
    "        position_ids = position_ids.cuda()\n",
    "        batch_labels = torch.tensor(batch_labels, device='cuda', dtype=torch.bool)\n",
    "\n",
    "        last_logprob = knowledge_model(batch_input,\n",
    "                                       attention_mask=batch_input!=eos_token_id,\n",
    "                                       position_ids=position_ids)['logits'][:, -1].log_softmax(dim=-1)\n",
    "        obj_logprob = last_logprob[:, obj_id]\n",
    "        subj_logprob = last_logprob[:, subj_id]\n",
    "        partition_fn = torch.logsumexp(torch.stack([obj_logprob, subj_logprob], dim=-1), dim=-1)\n",
    "        loss = torch.where(batch_labels, -(obj_logprob - partition_fn), -(subj_logprob - partition_fn))\n",
    "        loss.mean().backward()\n",
    "\n",
    "    opt_knowledge.step()\n",
    "    sched.step()\n",
    "    if step % log_interval == 0:\n",
    "        print(f'loss: {loss.mean().item()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "knowledge_model.eval()\n",
    "# one-step EM greedy\n",
    "test_preds_greedy = get_preds(knowledge_model, encoded_test_queries_w_cot_greedy, bsz = 100)\n",
    "greedy_preds = torch.tensor(test_preds_greedy)\n",
    "print(f'Test Acc (greedy) : {(greedy_preds == true_preds).sum() / len(true_preds)}')\n",
    "# one-step EM sample\n",
    "test_preds_sample = get_preds(knowledge_model, [s.unsqueeze(0) for x in encoded_test_queries_w_cot_sample for s in x], bsz = 100)\n",
    "shaped_preds = torch.tensor(test_preds_sample).reshape(len(encoded_test_queries_w_cot_sample), -1)\n",
    "agg_preds = shaped_preds.sum(-1) / shaped_preds.size(1) > 0.5\n",
    "print(f'Test Acc (sample) : {(agg_preds == true_preds).sum() / len(true_preds)}')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
