{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec91d2ba-7ca4-4bd4-a6f5-c78a5b321179",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install --upgrade pip\n",
    "%pip install openicl"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c9acf526-7e21-4f95-9a94-fb3b00973e8d",
   "metadata": {},
   "source": [
    "# Self-adaptive In-context Learning\n",
    "---\n",
    "Code for paper [Self-adaptive In-context Learning](https://arxiv.org/abs/2212.10375)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd2e9c3a-c8eb-4a03-827f-253ed84b0c7e",
   "metadata": {},
   "source": [
    "## Templates "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "73a948c3-aa16-4b1e-9bc3-c6a6c0fdb7e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from openicl import PromptTemplate\n",
    "\n",
    "# SST-2\n",
    "sst2_tp_dict = {\n",
    "    0: '</E>Positive Movie Review: \\\"<X>\\\"', \n",
    "    1: '</E>Negative Movie Review: \\\"<X>\\\"',\n",
    "}\n",
    "sst2_template = PromptTemplate(sst2_tp_dict, column_token_map={'text' : '<X>'}, ice_token='</E>')\n",
    "\n",
    "# SST-5\n",
    "sst5_tp_dict = {\n",
    "    0: \"</E>Review: <X>\\nSentiment: terrible\",\n",
    "    1: \"</E>Review: <X>\\nSentiment: bad\",\n",
    "    2: \"</E>Review: <X>\\nSentiment: okay\",\n",
    "    3: \"</E>Review: <X>\\nSentiment: good\",\n",
    "    4: \"</E>Review: <X>\\nSentiment: great\",\n",
    "}\n",
    "sst5_template = PromptTemplate(sst5_tp_dict, column_token_map={'text' : '<X>'}, ice_token='</E>')\n",
    "\n",
    "# AG_NEWS\n",
    "ag_news_tp_dict = {\n",
    "    0: \"</E>\\\"<X>\\\" It is about world.\",\n",
    "    1: \"</E>\\\"<X>\\\" It is about sports.\",\n",
    "    2: \"</E>\\\"<X>\\\" It is about business.\",\n",
    "    3: \"</E>\\\"<X>\\\" It is about science and technology.\",\n",
    "}\n",
    "ag_news_template = PromptTemplate(ag_news_tp_dict, column_token_map={'text' : '<X>'}, ice_token='</E>')\n",
    "\n",
    "# TREC\n",
    "trec_tp_dict = {\n",
    "    0: \"</E>\\\"<X>\\\" It is about abbreviation.\",\n",
    "    1: \"</E>\\\"<X>\\\" It is about entity.\",\n",
    "    2: \"</E>\\\"<X>\\\" It is about description and abstract concept.\",\n",
    "    3: \"</E>\\\"<X>\\\" It is about human being.\",\n",
    "    4: \"</E>\\\"<X>\\\" It is about location.\",\n",
    "    5: \"</E>\\\"<X>\\\" It is about numeric value.\"\n",
    "}\n",
    "trec_template = PromptTemplate(trec_tp_dict, column_token_map={'text' : '<X>'}, ice_token='</E>')\n",
    "\n",
    "# SNLI & MNLI\n",
    "xnli_tp_dict = {\n",
    "    0: '</E><X1>? Yes, <X2>',\n",
    "    1: '</E><X1>? Maybe, <X2>',\n",
    "    2: '</E><X1>? No, <X2>'\n",
    "}\n",
    "xnli_template = PromptTemplate(xnli_tp_dict, column_token_map={'premise' : '<X1>', 'hypothesis' : '<X2>'}, ice_token='</E>')\n",
    "\n",
    "# QNLI \n",
    "qnli_tp_dict = {\n",
    "    0: \"</E><X1> Can we know <X2>? Yes.\",\n",
    "    1: \"</E><X1> Can we know <X2>? No.\",\n",
    "}\n",
    "qnli_template = PromptTemplate(qnli_tp_dict, column_token_map={'sentence' : '<X1>', 'question' : '<X2>'}, ice_token='</E>')\n",
    "\n",
    "# Commonsense QA\n",
    "cmsqa_template=PromptTemplate(\n",
    "    {\n",
    "        'A': \"</E>Answer the following question:\\n</Q>\\nAnswer: </Ans1>\",\n",
    "        'B': \"</E>Answer the following question:\\n</Q>\\nAnswer: </Ans2>\",\n",
    "        'C': \"</E>Answer the following question:\\n</Q>\\nAnswer: </Ans3>\",\n",
    "        'D': \"</E>Answer the following question:\\n</Q>\\nAnswer: </Ans4>\",\n",
    "        'E': \"</E>Answer the following question:\\n</Q>\\nAnswer: </Ans5>\",\n",
    "    },\n",
    "    {'question':'</Q>', 'A': '</Ans1>', 'B': '</Ans2>', 'C': '</Ans3>', 'D': '</Ans4>', 'E': '</Ans5>'},\n",
    "    ice_token='</E>' \n",
    ")\n",
    "\n",
    "templates = {'sst2': sst2_template,\n",
    "             'snli': xnli_template,\n",
    "             'mnli': xnli_template,\n",
    "             \"qnli\": qnli_template,\n",
    "             \"sst5\": sst5_template,\n",
    "             \"ag_news\": ag_news_template,\n",
    "             \"trec\": trec_template,\n",
    "             \"commonsense_qa\": cmsqa_template\n",
    "            }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cde382ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Datasets "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb05fc8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from openicl import DatasetReader\n",
    "\n",
    "data_path = {'sst2': [\"gpt3mix/sst2\", None],\n",
    "             'snli': ['snli', None],\n",
    "             'mnli': ['LysandreJik/glue-mnli-train', None],\n",
    "             \"qnli\": [\"glue\", \"qnli\"],\n",
    "             \"sst5\": [\"SetFit/sst5\", None],\n",
    "             \"ag_news\": [\"ag_news\", None],\n",
    "             \"trec\": [\"trec\", None],\n",
    "             \"commonsense_qa\": [\"commonsense_qa\", None]\n",
    "            }\n",
    "\n",
    "input_columns={'sst2': [\"text\"],\n",
    "             'snli': ['premise', 'hypothesis'],\n",
    "             'mnli': ['premise', 'hypothesis'],\n",
    "             \"qnli\": [\"sentence\", \"question\"],\n",
    "             \"sst5\": [\"text\"],\n",
    "             \"ag_news\": [\"text\"],\n",
    "             \"trec\": [\"text\"],\n",
    "             \"commonsense_qa\": ['question', 'A', 'B', 'C', 'D', 'E']\n",
    "            }\n",
    "\n",
    "output_column={'sst2': 'label',\n",
    "             'snli': 'label',\n",
    "             'mnli': 'label',\n",
    "             \"qnli\": 'label',\n",
    "             \"sst5\": 'label',\n",
    "             \"ag_news\": 'label',\n",
    "             \"trec\": 'label-coarse',\n",
    "             \"commonsense_qa\": \"answerKey\"\n",
    "            }\n",
    "\n",
    "# Change it for other tasks\n",
    "task_name='snli'\n",
    "\n",
    "path,name=data_path[task_name]\n",
    "dataset = load_dataset(path=path,name=name)\n",
    "\n",
    "# Preprocess for commonsense_qa\n",
    "def pre_process(example):\n",
    "    for i in range(5):\n",
    "        example[chr(ord('A') + i)] = example['choices']['text'][i]\n",
    "    return example\n",
    "\n",
    "if task_name=='commonsense_qa':\n",
    "    dataset=dataset.map(pre_process).remove_columns(['question_concept', 'id', 'choices'])\n",
    "\n",
    "\n",
    "data=DatasetReader(dataset, input_columns=input_columns[task_name], output_column=output_column[task_name])\n",
    "\n",
    "\n",
    "test_split={\n",
    "    'sst2': 'test',\n",
    "    'snli': 'test',\n",
    "    \"sst5\": 'test',\n",
    "    \"ag_news\": 'test',\n",
    "    \"trec\": 'test',\n",
    "    'mnli': 'validation', # cannot get gold labels for the test split\n",
    "    \"qnli\": 'validation',\n",
    "    \"commonsense_qa\": \"validation\"\n",
    "}\n",
    "# If you only want to test part of the test set for faster running, you can use the following codes\n",
    "# dataset['test'] = dataset['test'].select(list(range(100)))\n",
    "# dataset['validation'] = dataset['validation'].select(list(range(100))) # trec,agnews don't have validation\n",
    "# dataset['train'] = dataset['train'].select(list(range(100)))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5f28ecda-13f8-430c-a0c9-c85e52c806cf",
   "metadata": {},
   "source": [
    "### TopK-MDL Experiments (method in the paper)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b3bc424-fe92-4b2a-9d5f-ac942ebc675f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from openicl import MDLRetriever, PPLInferencer, AccEvaluator\n",
    "\n",
    "retriever = MDLRetriever(data, ice_num=8, candidate_num=30, select_time=10, seed=1, batch_size=12, test_split=test_split[task_name])\n",
    "\n",
    "inferencer = PPLInferencer(model_name='gpt2-xl', batch_size=8)\n",
    "\n",
    "predictions = inferencer.inference(retriever, ice_template=templates[task_name], output_json_filename=f'mdl_{task_name}')\n",
    "\n",
    "scores = AccEvaluator().score(predictions=predictions, references=data.references)\n",
    "print(scores)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}