{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We provide our model checkpoint via huggingface library. We provide you a 11-billion model, but you may also try with 3-billion model by changing \"seonghyeonye/flipped_11B\" into \"seonghyeonye/flipped_3B\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import torch\n",
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "model = T5ForConditionalGeneration.from_pretrained(\"seonghyeonye/flipped_11B\")\n",
    "tokenizer = T5Tokenizer.from_pretrained(\"seonghyeonye/flipped_11B\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to make input and instruction, we first need a whole string as an input in a conventional encoder-decoder model(such as T0). Then, we verify chunks that may describe what kind of task that the string is requiring the model. We mask each chunks as increasing order of \"<extra_id_{index}>\" and list the masked string as an instruct. Please refer to the image below.\n",
    "![title](instruction_input.PNG)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here are few examples that you may test as an inference.\n",
    "```\n",
    "input = \"<extra_id_0>On a shelf, there are five books: a yellow book, a brown book, a gray book, a black book, and a purple book. The yellow book is to the right of the purple book. The black book is to the left of the purple book. The black book is the third from the left. The brown book is the leftmost.<extra_id_1>\"\n",
    "instruct = \"<extra_id_0> The following paragraphs each describe a set of five objects arranged in a fixed order. The statements are logically consistent within each paragraph.\\n\\n <extra_id_1>  \"\n",
    "options = [\"The yellow book is the leftmost.\", \"The brown book is the leftmost.\", \"The gray book is the leftmost.\", \"The black book is the leftmost.\", \"The purple book is the leftmost.\"]\n",
    "```\n",
    "```\n",
    "input = \"Dana Reeve, the widow of the actor Christopher Reeve, has died of lung cancer at age 44, according to the Christopher Reeve Foundation.<extra_id_0>Christopher Reeve had an accident.<extra_id_1>\"\n",
    "instruct = \"<extra_id_0> Using only the above description and what you know about the world, is \\\"<extra_id_1> \\\" definitely correct? Yes or no?\"\n",
    "options = [\"Yes\", \"No\"]\n",
    "```\n",
    "```\n",
    "input = \"<extra_id_0>Almost Sunrise is a 2016 American documentary film directed by Michael Collins. It recounts the story of two Iraq veterans, Tom Voss and Anthony Anderson, who, in an attempt to put their combat experience behind them, embark on a 2,700-mile trek on foot across America. It made its world premiere on the opening night of the Telluride Mountainfilm Festival on 27 May, 2016.<extra_id_1>Tom and Anthony have both killed someone.<extra_id_2>\"\n",
    "instruct = \"<extra_id_0> Suppose <extra_id_1>  Can we infer that \\\"<extra_id_2> \\\"? Yes, no, or maybe?\"\n",
    "options = [\"Yes\", \"No\", \"Maybe\"]\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input = \"<extra_id_0> this is the best cast iron skillet you will ever buy<extra_id_1>\"\n",
    "instruct = \"<extra_id_0> Title: <extra_id_1> Review:\"\n",
    "options = [\"Positive\", \"Negative\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We make a batch of batch size 1 for inference. The batch is consisted of 4 items. \"source_ids\"/\"source_mask\" refers to a tokenized tensor/mask of input+target which are expected to have a shape of [1, {length of the \"options\"}, 512], and \"target_ids\"/\"target_mask\" refers to a tokenized tensor/mask of instruction, with shape [1, {length of the \"options\"}, 64].\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_ids_batch=[]\n",
    "target_ids_batch=[]\n",
    "src_mask_batch=[]\n",
    "target_mask_batch=[]\n",
    "for target in options:\n",
    "    input_ = f'input: {input}\\n output: {target}'\n",
    "    target_= instruct\n",
    "    print(\"input is:\\n\", input_)\n",
    "    print(\"target is:\\n\", target_)\n",
    "    source = tokenizer.batch_encode_plus([str(input_)], max_length=512, \n",
    "                                                            padding='max_length', truncation=True, return_tensors=\"pt\")\n",
    "    targets = tokenizer.batch_encode_plus([str(target_)], max_length=64, \n",
    "                                                    padding='max_length', truncation=True, return_tensors=\"pt\")\n",
    "    new_options= options\n",
    "    source_ids = source[\"input_ids\"].squeeze()\n",
    "    target_ids = targets[\"input_ids\"].squeeze()\n",
    "    src_mask    = source[\"attention_mask\"].squeeze()\n",
    "    target_mask = targets[\"attention_mask\"].squeeze()\n",
    "    source_ids_batch.append(source_ids)\n",
    "    target_ids_batch.append(target_ids)\n",
    "    src_mask_batch.append(src_mask)\n",
    "    target_mask_batch.append(target_mask)\n",
    "batch = {\"source_ids\":torch.stack(source_ids_batch).unsqueeze(dim=0), \"source_mask\": torch.stack(src_mask_batch).unsqueeze(dim=0), \"target_ids\": torch.stack(target_ids_batch).unsqueeze(dim=0), \"target_mask\": torch.stack(target_mask_batch).unsqueeze(dim=0)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For each option, we compare the logits corresponding the given input+option generates the instruction. We predict the option with the highest logit as the answer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    prob_list = []\n",
    "    for index in range(int(batch[\"source_ids\"].shape[1])):\n",
    "        lm_channel_labels = batch[\"target_ids\"][:,index,:]\n",
    "        lm_channel_labels[lm_channel_labels[:, :] ==tokenizer.pad_token_id] = -100\n",
    "        output_channel = model(\n",
    "            input_ids=batch[\"source_ids\"][:,index,:].contiguous(),\n",
    "            attention_mask=batch[\"source_mask\"][:,index,:].contiguous(),\n",
    "            labels=lm_channel_labels.contiguous(),\n",
    "            decoder_attention_mask=batch[\"target_mask\"][:,index,:].contiguous()\n",
    "        )\n",
    "        logits_channel = batch[\"target_mask\"][:,index,:].unsqueeze(-1) * torch.log_softmax(output_channel.logits, dim=-1)\n",
    "\n",
    "        lm_channel_labels=lm_channel_labels.unsqueeze(-1)\n",
    "        seq_token_log_prob=torch.zeros(lm_channel_labels.shape)\n",
    "        for i in range(lm_channel_labels.shape[0]):\n",
    "            for j in range(lm_channel_labels.shape[1]):\n",
    "                seq_token_log_prob[i][j][0] = logits_channel[i][j][lm_channel_labels[i][j][0]]\n",
    "        seq_channel_log_prob = seq_token_log_prob.squeeze(dim=-1).sum(dim=-1)\n",
    "        print(f'logit for label {options[index]} is {seq_channel_log_prob[0]}')\n",
    "        prob_list.append(seq_channel_log_prob.unsqueeze(-1))\n",
    "\n",
    "    concat = torch.cat(prob_list, 1).view(-1,int(batch[\"source_ids\"].shape[1]))\n",
    "    predictions = concat.argmax(dim=1)\n",
    "    print(f'The answer is {options[predictions]}')"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
