{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "318733a0-eb89-4df2-b278-e80afc144cca",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install --upgrade pip\n",
    "%pip install openicl\n",
    "# Restart the kernel after the installation is completed"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1be4e6a4-dff9-4f39-a86f-62736ccd82c4",
   "metadata": {},
   "source": [
    "# 2. Using Different Language Models with OpenICL"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b92d14f4-2e73-4b1c-ba9e-0f9112a764b1",
   "metadata": {},
   "source": [
    "In this chapter, we will show you how to use OpenICL to do in-context learning (ICL) with different language models. Mainly including [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf), [FLAN-T5](https://arxiv.org/abs/2109.01652), [XGLM](https://arxiv.org/abs/2112.10668), OpenAI's [GPT-3](https://arxiv.org/abs/2005.14165) API and [OPT-175B](https://arxiv.org/abs/2205.01068) API."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06ba89e7-2790-4097-9e5d-c75e356e152f",
   "metadata": {},
   "source": [
    "## 2-1 Huggingface Library's Models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca01fccb-6188-4960-85ee-662af04718ce",
   "metadata": {},
   "source": [
    "In this section, we will take GPT2, FLAN-T5, and XGLM as examples to show you how to use the models in the [huggingface library](https://huggingface.co/models) with OpenICL. Generally speaking, you only need to assign the corresponding name to the `model_name` parameter when declaring `Inferencer`, but we will still provide you with some specific examples."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2aa253f5-7d3e-4810-a5fe-ddebc9699551",
   "metadata": {},
   "source": [
    "### 2-1-1 GPT-2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "2a7f087a",
   "metadata": {},
   "source": [
    "This example can be found in [tutorial1](https://github.com/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial1_getting_started.ipynb). But this time, we set `batch_size` for `TopkRetriever` and `PPLInference` to speed up. It can be noticed that the values ​​of the two `batch_size`(s) could be set to be different (`8` and `6`). That is because, at the beginning of retrieval and inference, the corresponding components will receive the complete dataset or the retrieval results for the entire test set, instead of processing the data in batches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faff9428",
   "metadata": {},
   "outputs": [],
   "source": [
    "from openicl import DatasetReader, PromptTemplate, TopkRetriever, PPLInferencer\n",
    "\n",
    "# Define a DatasetReader, loading dataset from huggingface.\n",
    "data = DatasetReader('gpt3mix/sst2', input_columns=['text'], output_column='label')\n",
    "\n",
    "# SST-2 Template Example\n",
    "tp_dict = {\n",
    "     0: '</E>Positive Movie Review: </text>',\n",
    "     1: '</E>Negative Movie Review: </text>' \n",
    "}\n",
    "template = PromptTemplate(tp_dict, {'text' : '</text>'}, ice_token='</E>')\n",
    "\n",
    "# TopK Retriever\n",
    "retriever = TopkRetriever(data, ice_num=2, batch_size=8)\n",
    "\n",
    "# Define a Inferencer\n",
    "inferencer = PPLInferencer(model_name='gpt2', batch_size=6)\n",
    "\n",
    "# Inference\n",
    "predictions = inferencer.inference(retriever, ice_template=template, output_json_filename='gpt2_sst2')\n",
    "print(predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0a7b09a-2c1a-468e-bf1a-26f69ad51b21",
   "metadata": {},
   "source": [
    "### 2-1-2 XGLM"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b56bb0c2",
   "metadata": {},
   "source": [
    "When it comes to machine translation, it is a good choice to use XGLM. But when using XGLM, we **don't suggest** to set `batch_size` in `GenInferencer`. (When calling the `model.generate` method of [huggingface transformers library](https://huggingface.co/docs/transformers/index), padding is needed if you want to input multiple pieces of data at a time. But we found in the test that if padding exists, the generation of XGLM will be affected). The code for evaluating the ICL performance of XGLM (7.5B) on WMT16 (de-en) dataset\n",
    "with direct inference strategy is as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6951df6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from openicl import DatasetReader, PromptTemplate, BM25Retriever, GenInferencer\n",
    "from datasets import load_dataset\n",
    "\n",
    "# Loading dataset from huggingface \n",
    "dataset = load_dataset('wmt16', name='de-en')\n",
    "\n",
    "# Data Preprocessing\n",
    "dataset = dataset.map(lambda example: example['translation']).remove_columns('translation')\n",
    "\n",
    "# Define a DatasetReader, selecting 5 pieces of data randomly.\n",
    "data = DatasetReader(dataset, input_columns='de', output_column='en', ds_size=5)\n",
    "\n",
    "# WMT16 en->de Template Example\n",
    "template = PromptTemplate('</E></de> = </en>', {'en' : '</en>', 'de' : '</de>'}, ice_token='</E>')\n",
    "\n",
    "# BM25 Retriever\n",
    "retriever = BM25Retriever(data, ice_num=1, index_split='validation', test_split='test', batch_size=5)\n",
    "\n",
    "# Define a Inferencer\n",
    "inferencer = GenInferencer(model_name='facebook/xglm-7.5B')\n",
    "\n",
    "# Inference\n",
    "predictions = inferencer.inference(retriever, ice_template=template, output_json_filename='xglm_wmt')\n",
    "print(predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2518e5ee-b58e-4378-a30f-ac0535da5559",
   "metadata": {},
   "source": [
    "### 2-1-3 FLAN-T5"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8778eead",
   "metadata": {},
   "source": [
    "In this section, we will use FLAN-T5 with OpenICL to reproduce the results in the figure below:"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "be985353",
   "metadata": {},
   "source": [
    "<div align=\"center\">\n",
    "<img src=\"https://s1.ax1x.com/2023/03/10/ppuZnQP.png\"  width=300px />\n",
    "<p>(figure in <a href=\"https://arxiv.org/abs/2109.01652\">Finetuned Language Models Are Zero-Shot Learners</a>)</p>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1f592cb3-fae4-4295-baa9-1a3998c3ea19",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset snli (/home/zhangyudejia/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)\n",
      "100%|██████████| 3/3 [00:00<00:00, 323.35it/s]\n",
      "[2023-03-10 15:40:19,712] [openicl.icl_inferencer.icl_gen_inferencer] [INFO] Starting inference process...\n",
      "100%|██████████| 10/10 [00:00<00:00, 15.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['yes', 'yes', 'yes', 'no', 'yes', 'no', 'it is not possible to tell', 'yes', 'yes', 'it is not possible to tell']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from openicl import DatasetReader, PromptTemplate, ZeroRetriever, GenInferencer\n",
    "\n",
    "# Define a DatasetReader, loading dataset from huggingface and selecting 10 pieces of data randomly.\n",
    "data = DatasetReader('snli', input_columns=['premise', 'hypothesis'], output_column='label', ds_size=10)\n",
    "\n",
    "# SNLI Template\n",
    "tp_str = '</E>Premise:</premise>\\nHypothesis:</hypothesis>\\nDoes the premise entail the hypothesis?\\nOPTIONS:\\n-yes -It is not possible to tell -no'\n",
    "template = PromptTemplate(tp_str, column_token_map={'premise' : '</premise>', 'hypothesis' : '</hypothesis>'}, ice_token='</E>')\n",
    "\n",
    "# ZeroShot Retriever (do nothing)\n",
    "retriever = ZeroRetriever(data, index_split='train', test_split='test')\n",
    "\n",
    "# Define a Inferencer\n",
    "inferencer = GenInferencer(model_name='google/flan-t5-small', max_model_token_num=1000)\n",
    "\n",
    "# Inference\n",
    "predictions = inferencer.inference(retriever, ice_template=template, output_json_filename='flan-t5-small')\n",
    "print(predictions)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "39e65e67",
   "metadata": {},
   "source": [
    "## 2-2 Using API-based model"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "17a9d790",
   "metadata": {},
   "source": [
    "OpenICL also currently supports OpenAI's GPT-3 API and OPT-175B API. But before using them, users need to do some configuration."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "deab3b33",
   "metadata": {},
   "source": [
    "### 2-2-1 OpenAI's GPT-3 API"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5b784102",
   "metadata": {},
   "source": [
    "OpenAI provides its own open-source library -- [openai](https://github.com/openai/openai-python), for users to call their API services. To use this library in OpenICL, you need to set environment variable `OPEN_API_KEY` in advance. Here is a simple way (for detailed information, see openai's documentation [here](https://platform.openai.com/docs/api-reference/introduction)):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84f229ea",
   "metadata": {
    "vscode": {
     "languageId": "powershell"
    }
   },
   "outputs": [],
   "source": [
    "# Replace 'your_api_key' with your key, and run this command in bash\n",
    "export OPENAI_API_KEY=\"your_api_key\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f7a7dab9",
   "metadata": {},
   "source": [
    "After the setting is complete, set `api_name='gpt3'` in `Inferencer` to use it normally. Below is a code snippet:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9510bd1e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zhangyudejia/.local/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Found cached dataset mtop (/home/zhangyudejia/.cache/huggingface/datasets/iohadrubin___mtop/mtop/1.0.0/4ba6d9db9efaebd4f6504db7e36925632e959f456071b9d7f1b86a85cce52448)\n",
      "100%|██████████| 3/3 [00:00<00:00, 814.96it/s]\n",
      "[2023-03-10 19:03:00,481] [openicl.icl_retriever.icl_bm25_retriever] [INFO] Retrieving data for test set...\n",
      "100%|██████████| 1/1 [00:00<00:00, 1504.95it/s]\n",
      "[2023-03-10 19:03:00,486] [openicl.icl_inferencer.icl_gen_inferencer] [INFO] Starting inference process...\n",
      "100%|██████████| 1/1 [00:06<00:00,  6.38s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['[IN:GET_EVENTS [SL:TYPE music festivals ] [SL:DATE in 2018 ] ]']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from openicl import DatasetReader, PromptTemplate, BM25Retriever, GenInferencer\n",
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"iohadrubin/mtop\")\n",
    "dataset['train'] = dataset['train'].select([0, 1, 2])\n",
    "dataset['test'] = dataset['test'].select([0])\n",
    "\n",
    "dr = DatasetReader(dataset, input_columns=['question'], output_column='logical_form')  \n",
    "\n",
    "tp_str = \"</E></Q>\\t</L>\"      \n",
    "tp = PromptTemplate(tp_str, column_token_map={'question' : '</Q>', 'logical_form' : '</L>'}, ice_token='</E>')\n",
    "\n",
    "rtr = BM25Retriever(dr, ice_num=1)\n",
    "\n",
    "infr = GenInferencer(api_name='gpt3', engine='text-davinci-003', sleep_time=3)\n",
    "\n",
    "print(infr.inference(rtr, ice_template=tp))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "9d974057",
   "metadata": {},
   "source": [
    "Some models of OpenAI are charged and have a rate limit. So we set `sleep_time`(3 seconds) here to control the frequency of data requests. In order to prevent data loss caused by throwing exceptions, we also recommend using this API on a small-scale test set every time. For more information about API parameter configuration in OpenICL, please view [api_service.py](https://github.com/Shark-NLP/OpenICL/blob/main/openicl/utils/api_service.py)."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "48635ed8",
   "metadata": {},
   "source": [
    "### 2-2-2 OPT-175B API "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f3a240a2",
   "metadata": {},
   "source": [
    "For OPT-175B, you need to deploy the model yourself (or get a URL of a deployed model from your friends :\\) ). \n",
    "Visit the [metaseq](https://github.com/facebookresearch/metaseq) repository for more information on deployment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bf3c44c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from openicl import GenInferencer\n",
    "\n",
    "URL = \"xxx\"\n",
    "inferencer = GenInferencer(api_name='opt-175b', URL=URL)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
