{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xq77AgKWJM-N"
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import json\n",
    "import numpy as np\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch\n",
    "from peft import LoraConfig, get_peft_model\n",
    "\n",
    "import replay_buffer\n",
    "import utils\n",
    "from utils import generate, append_sol_and_remove_eos\n",
    "from rewards import get_reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## HPs\n",
    "bsz = 64\n",
    "grad_acc = 8\n",
    "\n",
    "lr = 0.0005\n",
    "warmup_steps = 100\n",
    "total_steps = 1000\n",
    "pf_temp_high = 2\n",
    "pf_temp_low = 0.5\n",
    "\n",
    "\n",
    "subtb_lambda = 1.\n",
    "reward_temp = 1.\n",
    "reward_sched_start = 1.2\n",
    "reward_sched_end = 1.0\n",
    "reward_sched_horizon = 150\n",
    "\n",
    "max_len = 5\n",
    "min_len = 1\n",
    "\n",
    "eval_interval = 100\n",
    "log_interval = 10\n",
    "\n",
    "n_rationales = 20\n",
    "train_samples = 50\n",
    "preseed_buffer = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_to_use = 'gpt-j' # 'gpt2'\n",
    "\n",
    "if model_to_use == 'gpt-j':\n",
    "    tokenizer = AutoTokenizer.from_pretrained('nlpcloud/instruct-gpt-j-fp16')\n",
    "    model = AutoModelForCausalLM.from_pretrained('nlpcloud/instruct-gpt-j-fp16',\n",
    "                                                torch_dtype=torch.bfloat16)\n",
    "elif model_to_use == 'gpt2':\n",
    "    tokenizer = AutoTokenizer.from_pretrained('gpt2')\n",
    "    model = AutoModelForCausalLM.from_pretrained('gpt2')\n",
    "\n",
    "model.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(0)\n",
    "random.seed(0)\n",
    "\n",
    "answers = [ 'objective', 'subjective' ]\n",
    "\n",
    "obj_id = tokenizer.vocab['Ġobjective']\n",
    "subj_id = tokenizer.vocab['Ġsubjective']\n",
    "\n",
    "data_train = [ json.loads(l) for l in open('data/subj/train.{train_samples}.jsonl', 'r') ]\n",
    "data_test = [ json.loads(l) for l in open('data/subj/test.jsonl', 'r') ]\n",
    "\n",
    "data_train = [sample for sample in data_train]\n",
    "data_test = [sample for sample in data_test]\n",
    "\n",
    "train_queries = []\n",
    "train_sols = []\n",
    "\n",
    "test_queries = []\n",
    "test_sols = []\n",
    "\n",
    "intro_prompt = 'Classify this movie review as objective or subjective: \"'\n",
    "cot_prompt = '\" This review is'\n",
    "sol_prompt = ', so it is'\n",
    "\n",
    "for sample in data_train:\n",
    "    train_queries.append(intro_prompt + sample['text'] + cot_prompt)\n",
    "    train_sols.append(sol_prompt + ' ' + sample['label_text'] + '.')\n",
    "\n",
    "for sample in data_test:\n",
    "    test_queries.append(intro_prompt + sample['text'] + cot_prompt)\n",
    "    test_sols.append(sol_prompt + ' ' + sample['label_text'] + '.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_train_queries = [tokenizer(query, return_tensors='pt')['input_ids'].cuda() for query in train_queries]\n",
    "encoded_train_sols = [tokenizer(answer, return_tensors='pt')['input_ids'].cuda() for answer in train_sols]\n",
    "encoded_train_all_sols = [tokenizer(sol_prompt+' objective.', return_tensors='pt')['input_ids'].cuda(),\n",
    "                          tokenizer(sol_prompt+' subjective.', return_tensors='pt')['input_ids'].cuda()]\n",
    "encoded_test_queries = [tokenizer(query, return_tensors='pt')['input_ids'].cuda() for query in test_queries]\n",
    "encoded_sol_prompt = tokenizer(sol_prompt, return_tensors='pt')['input_ids'].cuda()\n",
    "\n",
    "eos_token_id = tokenizer.eos_token_id\n",
    "pad_token_id = tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_sols[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lora_config = LoraConfig(\n",
    "    r=256,\n",
    "    lora_alpha=16,\n",
    "    target_modules=[\"k_proj\", \"v_proj\"] if model_to_use == 'gpt-j' else [\"c_attn\"],\n",
    "    lora_dropout=0.,\n",
    "    bias=\"none\",\n",
    "    modules_to_save=[\"classifier\"],\n",
    ")\n",
    "inference_model = get_peft_model(model, lora_config)\n",
    "\n",
    "loss_type = 'modified_subtb' # 'tb' 'tb_no_z' 'hvi' 'hvi_bl' 'pg'\n",
    "\n",
    "opt = torch.optim.AdamW([{'params': inference_model.parameters(), 'lr': lr}], betas=(0.9, 0.99))\n",
    "\n",
    "# learning rate schedule\n",
    "def get_lr_mult_at_step(step):\n",
    "    if step <= warmup_steps:\n",
    "        return min(step/warmup_steps, 1.)\n",
    "    return max((total_steps - step) / (total_steps - warmup_steps), 0)\n",
    "sched = torch.optim.lr_scheduler.LambdaLR(opt, get_lr_mult_at_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rew = get_reward([\"FrozenModel\"], [{\"model\": inference_model,\n",
    "                                    \"eos_token_id\": eos_token_id,\n",
    "                                    \"temperature\": reward_temp,\n",
    "                                    \"solution_beta\": 1.,\n",
    "                                    \"cot_beta\": 1.0,\n",
    "                                    \"len_beta\": 0,\n",
    "                                    \"min_len\": 0}])\n",
    "get_reward_temp = lambda x : reward_sched_start + (reward_sched_end - reward_sched_start) * min(1, x / reward_sched_horizon)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "rbuffer = replay_buffer.ReplayBuffer(100, eos_token_id=eos_token_id, sim_tolerance=0.1)\n",
    "\n",
    "list_of_symbols = [v for k, v in tokenizer.vocab.items() if not k.strip('Ġ').strip().isalnum()]\n",
    "\n",
    "# add plausible rationales to the replay buffer\n",
    "if preseed_buffer:\n",
    "    for query_ind in range(len(train_sols)):\n",
    "        encoded_input = encoded_train_queries[query_ind]\n",
    "        encoded_result = encoded_train_sols[query_ind]\n",
    "        if 'objective' in train_sols[query_ind]:\n",
    "            wishful_prompt = train_queries[query_ind] + ' objective because it is'\n",
    "        else:\n",
    "            wishful_prompt = train_queries[query_ind] + ' subjective because it is'\n",
    "        encoded_wishful_input = tokenizer(wishful_prompt, return_tensors='pt')['input_ids'].cuda()\n",
    "        encoded_rationale = generate(inference_model,\n",
    "                                    encoded_wishful_input.repeat(n_rationales, 1),\n",
    "                                    eos_token_id=eos_token_id,\n",
    "                                    max_len=max_len,\n",
    "                                    temperature=1)[0][:, encoded_wishful_input.size(-1):]\n",
    "        # find the first non-letter symbol and replace it with EOS\n",
    "        for j in range(encoded_rationale.size(0)):\n",
    "            for k in range(encoded_rationale.size(1)):\n",
    "                if encoded_rationale[j, k] in list_of_symbols:\n",
    "                    encoded_rationale[j, k:] = eos_token_id\n",
    "                    break\n",
    "        print(tokenizer.batch_decode(encoded_rationale)[0])\n",
    "\n",
    "        def reward_fn(x):\n",
    "            results = rew.score(append_sol_and_remove_eos(x.repeat(3, 1),\n",
    "                                                          torch.cat([encoded_result.repeat(x.size(0), 1),\n",
    "                                                                     encoded_train_all_sols[0].repeat(x.size(0), 1),\n",
    "                                                                     encoded_train_all_sols[1].repeat(x.size(0), 1)]),\n",
    "                                                          eos_token_id,\n",
    "                                                          pad_token_id),\n",
    "                                skip_first=encoded_input.size(-1),\n",
    "                                solution_len=0)\n",
    "            base_reward = results[:x.size(0)]\n",
    "            obj_score = results[x.size(0):2*x.size(0)]\n",
    "            sub_score = results[2*x.size(0):]\n",
    "            pred_obj = obj_score > sub_score\n",
    "            if 'obj' in train_sols[query_ind]:\n",
    "                return torch.where(pred_obj, base_reward, base_reward - 50)\n",
    "            if 'sub' in train_sols[query_ind]:\n",
    "                return torch.where(pred_obj, base_reward - 50, base_reward)\n",
    "            raise NotImplementedError\n",
    "\n",
    "        with torch.no_grad():\n",
    "            logrewards = utils.generate_and_return_eos_logprob(inference_model, \n",
    "                                                            encoded_input.repeat(n_rationales, 1),\n",
    "                                                            eos_token_id=eos_token_id,\n",
    "                                                            reward_fn=reward_fn,\n",
    "                                                            max_len=max_len,\n",
    "                                                            min_len=min_len,\n",
    "                                                            temperature=1,\n",
    "                                                            action_seq=encoded_rationale)[3]\n",
    "        rbuffer.add_batch(query=encoded_input,\n",
    "                          answer=encoded_result,\n",
    "                          rationales=encoded_rationale,\n",
    "                          logrewards=logrewards,\n",
    "                          tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_to_date = []\n",
    "for k, v in rbuffer._buffer.items():\n",
    "    rewards_to_date.append(np.mean([rat[0] for rat in v['rationales']]))\n",
    "print(np.mean(rewards_to_date))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "dph6LMlH0ooD",
    "outputId": "cb17a778-4221-421b-f6b2-f163cad72797",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for step in range(1, total_steps+1):\n",
    "    opt.zero_grad()\n",
    "    loss = 0.\n",
    "    # change reward temperature\n",
    "    rew.temperature = get_reward_temp(step)\n",
    "    for _ in range(grad_acc):\n",
    "        # select an example\n",
    "        query_ind = np.random.choice(np.arange(len(encoded_train_queries)))\n",
    "        encoded_input = encoded_train_queries[query_ind]\n",
    "        encoded_result = encoded_train_sols[query_ind]\n",
    "        if loss_type.startswith('modified'):\n",
    "            def reward_fn(x):\n",
    "                results = rew.score(append_sol_and_remove_eos(x.repeat(3, 1),\n",
    "                                                              torch.cat([encoded_result.repeat(x.size(0), 1),\n",
    "                                                                         encoded_train_all_sols[0].repeat(x.size(0), 1),\n",
    "                                                                         encoded_train_all_sols[1].repeat(x.size(0), 1)]),\n",
    "                                                              eos_token_id,\n",
    "                                                              pad_token_id),\n",
    "                                    skip_first=encoded_input.size(-1),\n",
    "                                    solution_len=0)\n",
    "                base_reward = results[:x.size(0)]\n",
    "                obj_score = results[x.size(0):2*x.size(0)]\n",
    "                sub_score = results[2*x.size(0):]\n",
    "                pred_obj = obj_score > sub_score\n",
    "                if 'obj' in train_sols[query_ind]:\n",
    "                    return torch.where(pred_obj, base_reward, base_reward - 50)\n",
    "                if 'sub' in train_sols[query_ind]:\n",
    "                    return torch.where(pred_obj, base_reward - 50, base_reward)\n",
    "                raise NotImplementedError\n",
    "                \n",
    "            # choose a behavior policy\n",
    "            b_policy_choice = random.randint(0, 3)\n",
    "            if b_policy_choice in [0, 1]:\n",
    "                # using the action policy without tempering\n",
    "                generated_text, logPF, eos_logprob, logrewards = \\\n",
    "                    utils.generate_and_return_eos_logprob(inference_model, \n",
    "                                                    encoded_input.repeat(bsz, 1),\n",
    "                                                    eos_token_id=eos_token_id,\n",
    "                                                    reward_fn=reward_fn,\n",
    "                                                    max_len=max_len,\n",
    "                                                    min_len=min_len,\n",
    "                                                    use_tools=False,\n",
    "                                                    temperature=1 if b_policy_choice == 0 else random.random()*(pf_temp_high-pf_temp_low)+pf_temp_low)\n",
    "                rbuffer.add_batch(query=encoded_input,\n",
    "                                answer=encoded_result,\n",
    "                                rationales=generated_text[:, encoded_input.size(-1):],\n",
    "                                logrewards=logrewards * rew.temperature, # undo the effect of reward tempering\n",
    "                                tokenizer=tokenizer)\n",
    "            else:\n",
    "                # using samples from the replay buffer\n",
    "                action_seq, logrewards = rbuffer.sample(bsz, query=encoded_input, answer=encoded_result)\n",
    "                if action_seq is None:\n",
    "                    continue\n",
    "                logrewards *= (1/rew.temperature) # redo the effect of reward tempering\n",
    "                generated_text, logPF, eos_logprob, logrewards_2 = \\\n",
    "                    utils.generate_and_return_eos_logprob(inference_model, \n",
    "                                                    encoded_input.repeat(action_seq.size(0), 1),\n",
    "                                                    eos_token_id=eos_token_id,\n",
    "                                                    reward_fn=reward_fn,\n",
    "                                                    max_len=max_len,\n",
    "                                                    min_len=min_len,\n",
    "                                                    use_tools=False,\n",
    "                                                    action_seq=action_seq,\n",
    "                                                    skip_rewards=True)\n",
    "            if loss_type == 'modified_db':\n",
    "                # modified db loss with logpb=0\n",
    "                db_loss = (logrewards[:, :-1] + logPF[:, :-1] + eos_logprob[:, 1:] - logrewards[:, 1:] - eos_logprob[:, :-1])**2\n",
    "                # get a mask for newly generated tokens after the first eos in generated_text\n",
    "                mask = (generated_text[:, encoded_input.size(-1):] == eos_token_id).cumsum(dim=-1) >= 1\n",
    "                # if mask is too short, pad it\n",
    "                if mask.size(-1) < max_len:\n",
    "                    mask = torch.cat([mask, torch.ones(mask.size(0), max_len-1-mask.size(-1), dtype=torch.bool, device='cuda')], dim=-1)\n",
    "                mask = mask[:, :max_len]\n",
    "                # get trajectory lengths by summing the mask\n",
    "                traj_len = (~mask).sum(dim=-1)\n",
    "                # get rid of the loss for the terminating step\n",
    "                db_loss[mask] = 0\n",
    "                batch_loss = db_loss.sum(-1) / traj_len\n",
    "            elif loss_type == 'modified_subtb':\n",
    "                # modified subTB loss with logpb=0\n",
    "                delta = (logrewards[:, :-1] - eos_logprob[:, :-1] + logPF[:, :-1] - (logrewards[:, 1:] - eos_logprob[:, 1:]))\n",
    "                delta_cumsum = torch.cat( [ torch.zeros_like(delta[:, :1]), delta ], 1).cumsum(1)\n",
    "                # get a mask for tokens after the first eos in generated_text\n",
    "                mask = (generated_text == eos_token_id).cumsum(dim=-1) >= 1\n",
    "                mask = mask[:, encoded_input.size(-1):]\n",
    "                mask = mask[:, :max_len]\n",
    "                # if mask is too short, pad it\n",
    "                if mask.size(-1) < max_len:\n",
    "                    mask = torch.cat([mask, torch.ones(mask.size(0), max_len-mask.size(-1), dtype=torch.bool, device='cuda')], dim=-1)\n",
    "                # get trajectory lengths by summing the mask\n",
    "                batch_loss = 0.\n",
    "                total_lambda = 0.\n",
    "                for subtraj_len in range(1, max_len+1):\n",
    "                    subtb_term = (delta_cumsum[:, subtraj_len:] - delta_cumsum[:, :-subtraj_len])**2\n",
    "                    subtb_term[mask[:, subtraj_len - 1:]] = 0\n",
    "                    batch_loss += subtb_lambda ** (subtraj_len - 1) * subtb_term.sum()\n",
    "                    total_lambda += subtb_lambda ** (subtraj_len - 1) * (~mask[:, subtraj_len - 1:]).sum()\n",
    "                batch_loss /= total_lambda\n",
    "        else:\n",
    "            raise NotImplementedError\n",
    "        loss += batch_loss.mean()\n",
    "        batch_loss.mean().backward()\n",
    "    opt.step()\n",
    "    sched.step()\n",
    "    if step % log_interval == 0:\n",
    "        print(f'loss: {loss.item()}')\n",
    "    if step % eval_interval == 0:\n",
    "        print(f'Step: {step}')\n",
    "        # pick a random example from the test set\n",
    "        query_ind = random.randint(0, len(encoded_test_queries)-1)\n",
    "        encoded_input = encoded_test_queries[query_ind]\n",
    "        generated_text = generate(inference_model,\n",
    "                                 encoded_input.repeat(3, 1),\n",
    "                                 eos_token_id=eos_token_id,\n",
    "                                 max_len=max_len,\n",
    "                                 temperature=1)[0]\n",
    "        print(\"Test example:\")\n",
    "        print('\\n'.join(tokenizer.batch_decode(append_sol_and_remove_eos(generated_text, [None,] * generated_text.size(0), eos_token_id, pad_token_id))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_acc(encoded_test_queries, test_sols, top_n = 2000):\n",
    "    correct, total = 0, 0\n",
    "    encoded_obj = tokenizer(', so it is objective.',\n",
    "                                return_tensors='pt').to('cuda')['input_ids']\n",
    "    encoded_sub = tokenizer(', so it is subjective.',\n",
    "                                return_tensors='pt').to('cuda')['input_ids']\n",
    "    for encoded_input, sol in zip(encoded_test_queries[:top_n], test_sols[:top_n]):        \n",
    "        lls = []\n",
    "        for encoded_result in [encoded_obj, encoded_sub]:\n",
    "            generated_text = generate(inference_model,\n",
    "                                    encoded_input,\n",
    "                                    eos_token_id=eos_token_id,\n",
    "                                    max_len=max_len,\n",
    "                                    temperature=.1)[0]\n",
    "            mean_reward = rew.score(\n",
    "                            append_sol_and_remove_eos(generated_text, encoded_result, eos_token_id, pad_token_id),\n",
    "                            skip_first=encoded_input.size(-1),\n",
    "                            solution_len=0).mean().item()\n",
    "            lls.append(mean_reward)\n",
    "        pred = lls[0] > lls[1]\n",
    "        if (pred is True and 'objective' in sol) or (pred is False and 'subjective' in sol):\n",
    "            correct += 1\n",
    "        total += 1\n",
    "    return correct/total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Train Acc:', eval_acc(encoded_train_queries, train_sols))\n",
    "print('Test Acc @ 100:', eval_acc(encoded_test_queries, test_sols, 100))\n",
    "#print('Test Acc @ 2000:', eval_acc(encoded_test_queries, test_sols, 2000))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_name = f'subj_obj_{model_to_use}_{train_samples}samples_len{max_len}_{total_steps}steps_rewtemp{reward_temp}_seed_{preseed_buffer}'\n",
    "inference_model.save_pretrained(f'ckpts/{ckpt_name}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample CoTs for the entire training set\n",
    "n_samples = 100\n",
    "encoded_train_queries_w_cot_sample = []\n",
    "for encoded_input in encoded_train_queries:        \n",
    "    generated_text = generate(inference_model,\n",
    "                                encoded_input.repeat(n_samples, 1),\n",
    "                                eos_token_id=eos_token_id,\n",
    "                                max_len=max_len,\n",
    "                                temperature=1)[0]\n",
    "    encoded_train_queries_w_cot_sample.append(generated_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# greedily generate CoTs for the entire test set\n",
    "encoded_test_queries_w_cot_greedy = []\n",
    "for encoded_input in encoded_test_queries:\n",
    "    generated_text = generate(inference_model,\n",
    "                                encoded_input,\n",
    "                                eos_token_id=eos_token_id,\n",
    "                                max_len=max_len,\n",
    "                                temperature=.01)[0]\n",
    "    encoded_test_queries_w_cot_greedy.append(generated_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample CoTs for the entire test set\n",
    "n_samples = 10\n",
    "encoded_test_queries_w_cot_sample = []\n",
    "for encoded_input in encoded_test_queries:        \n",
    "    generated_text = generate(inference_model,\n",
    "                                encoded_input.repeat(n_samples, 1),\n",
    "                                eos_token_id=eos_token_id,\n",
    "                                max_len=max_len,\n",
    "                                temperature=1)[0]\n",
    "    encoded_test_queries_w_cot_sample.append(generated_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(encoded_train_queries_w_cot_sample, f'ckpts/{ckpt_name}/encoded_train_queries_w_cot_sample.pt')\n",
    "torch.save(encoded_test_queries_w_cot_greedy, f'ckpts/{ckpt_name}/encoded_test_queries_w_cot_greedy.pt')\n",
    "torch.save(encoded_test_queries_w_cot_sample, f'ckpts/{ckpt_name}/encoded_test_queries_w_cot_sample.pt')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
