{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, re, json\n",
    "import pandas as pd\n",
    "\n",
    "import torch, numpy as np\n",
    "from transformers import AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load Model & Tokenizer\n",
    "low_cpu_mem_usage = True\n",
    "torch.set_grad_enabled(False)\n",
    "device=\"cuda\"\n",
    "\n",
    "model_name = r\"EleutherAI/gpt-j-6b\"\n",
    "\n",
    "# Load Tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset Construction Code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_single_dataset(word_set, label_rule = lambda x,n: x[n], world_size=5, train_size=500, test_size=100, choose_n=1, allow_duplicates=False):\n",
    "    \"\"\"Simple dataset where targets are created based on defined label rule. \n",
    "    The default label rule is \"choose position n\" out of the possible slots in world_size\"\"\"\n",
    "    sample_N = lambda x, N: list(np.random.choice(x, N, replace=allow_duplicates))    \n",
    "    dataset = []\n",
    "    label_indices = []\n",
    "    assert (choose_n < world_size)\n",
    "\n",
    "    for i in range(train_size + test_size):\n",
    "        sample = sample_N(word_set, world_size)\n",
    "        label = label_rule(sample, choose_n)\n",
    "        label_index = sample.index(label)\n",
    "    \n",
    "        dataset.append({'input':sample, 'output':label})\n",
    "        label_indices.append(label_index)\n",
    "\n",
    "    return dataset[:train_size], dataset[train_size:], label_indices\n",
    "\n",
    "def construct_mixed_dataset(target_word_set, distractor_word_set, world_size=5, train_size=500, test_size=100, choose_n=1, allow_duplicates=False):\n",
    "    \"\"\"Simple dataset where choose_n targets are chosen from one dataset, and distractors chosen from another\"\"\"\n",
    "    sample_N = lambda x, N: list(np.random.choice(x, N, replace=allow_duplicates))\n",
    "    n_distractors = world_size - choose_n\n",
    "\n",
    "    dataset = []\n",
    "    label_indices = []\n",
    "    assert (choose_n < world_size)\n",
    "\n",
    "    for i in range(train_size + test_size):\n",
    "        distractors = sample_N(distractor_word_set, n_distractors)\n",
    "        target = sample_N(target_word_set, choose_n)\n",
    "        sample = list(np.random.permutation(distractors + target))\n",
    "        label = target[0]\n",
    "        label_index = sample.index(label)\n",
    "    \n",
    "        dataset.append({'input':sample, 'output':label})\n",
    "        label_indices.append(label_index)\n",
    "\n",
    "    return dataset[:train_size], dataset[train_size:], label_indices\n",
    "\n",
    "def construct_rule_dataset(word_set, label_rule):        \n",
    "    dataset = []\n",
    "    size = len(word_set)\n",
    "    for i in range(size):\n",
    "        sample = word_set[i]\n",
    "        label = label_rule(sample, choose_n)    \n",
    "        dataset.append({'input':sample, 'output':label})\n",
    "\n",
    "    return dataset\n",
    "\n",
    "\n",
    "# Label Rules:\n",
    "def choose_n(x, n):\n",
    "    \"\"\"Returns the word from x at position n\"\"\"\n",
    "    return x[n]\n",
    "\n",
    "def alphabetically_first(x,n):\n",
    "    \"\"\"Returns the word from x that appears alphabetically first, ties are broken by choosing lower index\"\"\"\n",
    "    return min(x)\n",
    "\n",
    "def alphabetically_last(x,n):\n",
    "    \"\"\"Returns the word from x that appears alphabetically last, ties are broken by choosing lower index\"\"\"\n",
    "    return max(x)\n",
    "\n",
    "def most_vowels(x,n):\n",
    "    \"\"\"Returns word from list x with the most vowels, ties are broken by choosing lower index\"\"\"\n",
    "    return max(x, key=lambda y: len(re.findall(r'[aeiouAEIOU]', y)))\n",
    "\n",
    "def longest_word(x,n):\n",
    "    \"\"\"Returns the longest word from list x, ties are broken by choosing lower index\"\"\"\n",
    "    return sorted(x, key=len)[-1]\n",
    "\n",
    "def shortest_word(x,n):\n",
    "    \"\"\"Returns the shortest word from list x, ties are broken by choosing lower index\"\"\"\n",
    "    return sorted(x, key=len)[0]\n",
    "\n",
    "def capitalize_first_letter(x, n):\n",
    "    \"\"\"Returns the first letter of the word, capitalized\"\"\"\n",
    "    if isinstance(x,list):\n",
    "        x = x[0]\n",
    "    return x.title().strip()[0]\n",
    "\n",
    "def capitalize(x, n):\n",
    "    \"\"\"Returns the word capitalized\"\"\"\n",
    "    if isinstance(x,list):\n",
    "        x = x[0]\n",
    "    return x.strip().title()\n",
    "\n",
    "def len_word(x,n):\n",
    "    return str(len(x.strip()))\n",
    "\n",
    "def next_capital_letter(x,n):\n",
    "    \"\"\"Returns the letter after the first letter of the word, capitalized\"\"\"\n",
    "    def next_alpha(s):\n",
    "        return chr((ord(s.upper())+1 - 65) % 26 + 65)\n",
    "    \n",
    "    if isinstance(x, list):\n",
    "        x = x[0]\n",
    "    return next_alpha(x.title().strip()[0])\n",
    "\n",
    "def next_letter(x,n):\n",
    "    \"\"\"Returns the letter after the first letter of the word, capitalized\"\"\"\n",
    "    def next_alpha(s):\n",
    "        return chr((ord(s.upper())+1 - 65) % 26 + 65)\n",
    "    \n",
    "    if isinstance(x, list):\n",
    "        x = x[0]\n",
    "    return next_alpha(x.strip()[0])\n",
    "\n",
    "def capitalize_last_letter(x,n):\n",
    "    if isinstance(x,list):\n",
    "        x = x[0]\n",
    "    return x.strip()[-1].title()\n",
    "\n",
    "def capitalize_second_letter(x,n):\n",
    "    if isinstance(x,list):\n",
    "        x = x[0]\n",
    "    return x.strip()[1].title()\n",
    "\n",
    "def lowercase_first_letter(x, n):\n",
    "    if isinstance(x,list):\n",
    "        x = x[0]\n",
    "    return x.lower().strip()[0]\n",
    "\n",
    "def lowercase_last_letter(x, n):\n",
    "    if isinstance(x,list):\n",
    "        x = x[0]\n",
    "    return x.lower().strip()[-1]\n",
    "\n",
    "def parens(x, n):\n",
    "    \"\"\"Returns the word capitalized\"\"\"\n",
    "    if isinstance(x,list):\n",
    "        x = x[0]\n",
    "    return '('+x.strip()+')'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create New Choose Item from List Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "categories = json.load(open('categories.json', 'r'))\n",
    "print(categories.keys())\n",
    "big_list = []\n",
    "for x in list(categories.keys()):\n",
    "    big_list.extend(categories[x])\n",
    "\n",
    "big_list = list(set(big_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for world_size in [3,5]:\n",
    "    # Indexed Datasets\n",
    "    train_dataset, _,_ = construct_single_dataset(big_list, choose_n=0, train_size=1000, test_size=0, world_size=world_size)\n",
    "    train_dataset = [{'input':\", \".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]\n",
    "    ##json.dump(train_dataset, open(f'dataset_files/extractive/choose_first_of_{world_size}.json','w'))\n",
    "\n",
    "    train_dataset, _,_ = construct_single_dataset(big_list, choose_n=world_size//2, train_size=1000, test_size=0, world_size=world_size)\n",
    "    train_dataset = [{'input':\", \".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]\n",
    "    #json.dump(train_dataset, open(f'dataset_files/extractive/choose_middle_of_{world_size}.json','w'))\n",
    "\n",
    "    train_dataset, _,_ = construct_single_dataset(big_list, choose_n=-1, train_size=1000, test_size=0, world_size=world_size)\n",
    "    train_dataset = [{'input':\", \".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]\n",
    "    #json.dump(train_dataset, open(f'dataset_files/extractive/choose_last_of_{world_size}.json','w'))\n",
    "\n",
    "    train_dataset, _,_ = construct_single_dataset(big_list, label_rule=alphabetically_first, train_size=1000, test_size=0, world_size=world_size)\n",
    "    train_dataset = [{'input':\", \".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]\n",
    "    #json.dump(train_dataset, open(f'dataset_files/extractive/alphabetically_first_{world_size}.json','w'))\n",
    "\n",
    "    train_dataset, _,_ = construct_single_dataset(big_list, label_rule=alphabetically_last, train_size=1000, test_size=0, world_size=world_size)\n",
    "    train_dataset = [{'input':\", \".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]\n",
    "    #json.dump(train_dataset, open(f'dataset_files/extractive/alphabetically_last_{world_size}.json', 'w'))\n",
    "\n",
    "\n",
    "    # Mixed/Distractor Datasets\n",
    "    train_dataset, _,_ = construct_mixed_dataset(categories['object'], categories['verb'] + categories['adjective'] + categories['preposition'], train_size=1000, test_size=0, world_size=world_size)\n",
    "    train_dataset = [{'input':\", \".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]\n",
    "    #json.dump(train_dataset, open(f'dataset_files/extractive/object_v_concept_{world_size}.json', 'w'))\n",
    "\n",
    "    train_dataset, _,_ = construct_mixed_dataset(categories['verb'] + categories['adjective'] + categories['preposition'],categories['object'], train_size=1000, test_size=0, world_size=world_size)\n",
    "    train_dataset = [{'input':\", \".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]\n",
    "    #json.dump(train_dataset, open(f'dataset_files/extractive/concept_v_object_{world_size}.json', 'w'))\n",
    "\n",
    "    train_dataset, _,_ = construct_mixed_dataset(categories['fruit'], categories['animal'], train_size=1000, test_size=0, world_size=world_size)\n",
    "    train_dataset = [{'input':\", \".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]\n",
    "    #json.dump(train_dataset, open(f'dataset_files/extractive/fruit_v_animal_{world_size}.json', 'w'))\n",
    "\n",
    "    train_dataset, _,_ = construct_mixed_dataset(categories['color'], categories['animal'], train_size=1000, test_size=0, world_size=world_size)\n",
    "    train_dataset = [{'input':\", \".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]\n",
    "    #json.dump(train_dataset, open(f'dataset_files/extractive/color_v_animal_{world_size}.json', 'w'))\n",
    "\n",
    "    train_dataset, _,_ = construct_mixed_dataset(categories['animal'], categories['object'], train_size=1000, test_size=0, world_size=world_size)\n",
    "    train_dataset = [{'input':\", \".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]\n",
    "    #json.dump(train_dataset, open(f'dataset_files/extractive/animal_v_object_{world_size}.json', 'w'))\n",
    "\n",
    "    train_dataset, _,_ = construct_mixed_dataset(categories['verb'], categories['adjective'], train_size=1000, test_size=0, world_size=world_size)\n",
    "    train_dataset = [{'input':\", \".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]\n",
    "    #json.dump(train_dataset, open(f'dataset_files/extractive/verb_v_adjective_{world_size}.json', 'w'))\n",
    "\n",
    "    train_dataset, _,_ = construct_mixed_dataset(categories['adjective'], categories['verb'], train_size=1000, test_size=0, world_size=world_size)\n",
    "    train_dataset = [{'input':\", \".join(list(w['input'])),'output':str(w['output'])} for w in train_dataset]\n",
    "    #json.dump(train_dataset, open(f'dataset_files/extractive/adjective_v_verb_{world_size}.json', 'w'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# DATASETS:\n",
    "# word_length, next_capital_letter, capitalize_last_letter, capitalize_second_letter, lowercase_first_letter, lowercase_last_letter\n",
    "\n",
    "train_dataset, _ = construct_rule_dataset([x.upper() for x in big_list], label_rule=lowercase_first_letter)\n",
    "#json.dump(train_dataset, open(f'dataset_files/unused/lowercase_last_letter.json','w'))\n",
    "\n",
    "train_dataset, _ = construct_rule_dataset([x.upper() for x in big_list], label_rule=lowercase_last_letter)\n",
    "#json.dump(train_dataset, open(f'dataset_files/unused/lowercase_last_letter.json','w'))\n",
    "\n",
    "train_dataset, _ = construct_rule_dataset(big_list, label_rule=capitalize_last_letter)\n",
    "#json.dump(train_dataset, open(f'dataset_files/unused/capitalize_last_letter.json','w'))\n",
    "\n",
    "train_dataset, _ = construct_rule_dataset(big_list, label_rule=next_capital_letter)\n",
    "#json.dump(train_dataset, open(f'dataset_files/unused/next_capital_letter.json','w'))\n",
    "\n",
    "train_dataset, _ = construct_rule_dataset(big_list, label_rule=len_word)\n",
    "#json.dump(train_dataset, open(f'dataset_files/unused/word_length.json','w'))\n",
    "\n",
    "train_dataset = construct_rule_dataset(big_list, label_rule=capitalize_first_letter)\n",
    "#json.dump(train_dataset, open(f'dataset_files/abstractive/capitalize_first_letter.json', 'w'))\n",
    "\n",
    "train_dataset = construct_rule_dataset(big_list, label_rule=capitalize)\n",
    "#json.dump(train_dataset, open(f'dataset_files/abstractive/capitalize.json', 'w'))\n",
    "\n",
    "train_dataset = construct_rule_dataset(big_list, label_rule=parens)\n",
    "#json.dump(train_dataset, open(f'dataset_files/abstractive/parens.json', 'w'))\n",
    "\n",
    "\n",
    "big_list2 = []\n",
    "big_list2 = list(filter(lambda x: len(x) > 1, big_list))\n",
    "train_dataset, _ = construct_rule_dataset(big_list2, label_rule=capitalize_second_letter)\n",
    "#json.dump(train_dataset, open(f'dataset_files/unused/capitalize_second_letter.json','w'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NLP Datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CONLL2003"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset as ld\n",
    "import re\n",
    "import json\n",
    "import numpy as np\n",
    "\n",
    "dataset = ld(\"conll2003\")\n",
    "\n",
    "conll_label_dict = {'person':{\"B_ind\": 1, \"I_ind\":2}, \n",
    "                    'organization':{\"B_ind\": 3, \"I_ind\":4}, \n",
    "                    'location':{\"B_ind\": 5, \"I_ind\":6}}\n",
    "\n",
    "re_test = re.compile(r\"[\\s]([,\\.])\")\n",
    "\n",
    "for category in conll_label_dict.keys():\n",
    "    B_ind = conll_label_dict[category]['B_ind']\n",
    "    I_ind = conll_label_dict[category]['I_ind']\n",
    "\n",
    "    data_filtered = []\n",
    "    n_train, n_val = len(dataset['train']), len(dataset['validation'])\n",
    "\n",
    "    for i in range(n_train + n_val):\n",
    "        if i < n_train:\n",
    "            data_point_i = dataset['train'][i]\n",
    "        else:\n",
    "            data_point_i = dataset['validation'][i - n_train]\n",
    "        # print(data_point_i)\n",
    "\n",
    "        if B_ind in data_point_i['ner_tags']:\n",
    "            tag_counts = np.unique(data_point_i['ner_tags'], return_counts=True)\n",
    "            tag_dict = {k:v for k,v in zip(tag_counts[0], tag_counts[1])}\n",
    "            \n",
    "            if tag_dict[B_ind] == 1: # Filter to sentences with only 1 appearance of B_ind\n",
    "                input_text = \" \".join(data_point_i['tokens'])\n",
    "                input_text = re_test.sub(r\"\\1\", input_text)\n",
    "                output = data_point_i['tokens'][data_point_i['ner_tags'].index(B_ind)]\n",
    "                if I_ind in data_point_i['ner_tags']:\n",
    "                    name_cont_ind = list(np.where(np.array(data_point_i['ner_tags']) == I_ind)[0])\n",
    "                    output_2 = \" \".join([data_point_i['tokens'][x] for x in name_cont_ind])\n",
    "                    output += \" \" + output_2\n",
    "                \n",
    "                data_filtered.append({\"input\":input_text, \"output\":output})\n",
    "    # json.dump(data_filtered, open(f'dataset_files/conll2003/conll2003_{category}.json', 'w'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CommonsenseQA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset as ld\n",
    "# import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "dataset = ld(\"commonsense_qa\", 'plain_text')\n",
    "df = pd.DataFrame(dataset['train'])\n",
    "df2 = pd.DataFrame(dataset['validation'])\n",
    "df = pd.concat([df, df2], axis=0)\n",
    "df['choices_label'] = df.choices.apply(lambda x: [y.lower() for y in  x['label']])\n",
    "df['choices_text'] = df.choices.apply(lambda x: x['text'])\n",
    "df['input'] = df.apply(lambda x: x['question'] + '\\n' + \"\\n\".join([(y[0] + ': ' + y[1]) for y in zip(x.choices_label, x.choices_text)]), axis=1)\n",
    "df['output'] = df['answerKey'].str.lower()\n",
    "df = df.drop(columns=['choices', 'id', 'question_concept', 'choices_label', 'choices_text', 'question', 'answerKey'])\n",
    "df = df.reset_index(drop=True)\n",
    "\n",
    "combined_data = [{'input':df.iloc[i].input, 'output':df.iloc[i].output} for i in range(len(df))]\n",
    "# json.dump(combined_data, open('dataset_files/abstractive/commonsense_qa.json', 'w'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### AG News"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = ld('ag_news')\n",
    "df = pd.DataFrame(dataset['test'])\n",
    "df = df.rename(columns={'text':'input', 'label':'output'})\n",
    "combined_data = [{'input':df.iloc[i].input, 'output':int(df.iloc[i].output)} for i in range(len(df))]\n",
    "# json.dump(combined_data, open('dataset_files/abstractive/ag_news.json', 'w'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Additional Datasets Considered"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SQUAD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset as ld\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "dataset = ld('squad')\n",
    "df = pd.DataFrame(dataset['validation'])\n",
    "df = df[df.apply(lambda x: len(np.unique(x['answers']['text']))==1, axis=1)] # filter to cases where answer is unique\n",
    "df['input'] = df.apply(lambda x: x['context'] + '\\n' + x['question'], axis=1)\n",
    "df['output'] = df.apply(lambda x: x['answers']['text'][0], axis=1)\n",
    "df = df.drop(columns=['id', 'title', 'answers', 'context', 'question']).reset_index(drop=True)\n",
    "\n",
    "combined_data = [{'input':df.iloc[i].input, 'output':df.iloc[i].output} for i in range(len(df))]\n",
    "# len(combined_data)\n",
    "# json.dump(combined_data, open('dataset_files/abstractive/squad_val.json', 'w'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MultiNERD Simple NER Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"tner/multinerd\", 'en')\n",
    "\n",
    "mn_label_dict = {'person':{\"B_ind\": 1, \"I_ind\":2}, \n",
    "                 'location':{\"B_ind\": 3, \"I_ind\":4}, \n",
    "                 'organization':{\"B_ind\": 5, \"I_ind\":6}}\n",
    "                #  'animal':{\"B_ind\": 7, \"I_ind\":8},\n",
    "                #  'celestial':{\"B_ind\": 11, \"I_ind\":12},\n",
    "                #  'disease':{\"B_ind\": 13, \"I_ind\":14},\n",
    "                #  'event':{\"B_ind\": 15, \"I_ind\":16},\n",
    "                #  'food':{\"B_ind\": 17, \"I_ind\":18}}\n",
    "\n",
    "re_test = re.compile(r\"[\\s]([,\\.])\")\n",
    "\n",
    "for category in mn_label_dict.keys():\n",
    "    B_ind = mn_label_dict[category]['B_ind']\n",
    "    I_ind = mn_label_dict[category]['I_ind']\n",
    "\n",
    "    data_filtered = []\n",
    "\n",
    "    for i in range(len(dataset['test'])):\n",
    "        data_point_i = dataset['test'][i]\n",
    "\n",
    "        if B_ind in data_point_i['tags']:\n",
    "            tag_counts = np.unique(data_point_i['tags'], return_counts=True)\n",
    "            tag_dict = {k:v for k,v in zip(tag_counts[0], tag_counts[1])}\n",
    "            \n",
    "            if tag_dict[B_ind] == 1: # Filter to sentences with only 1 appearance of B_ind\n",
    "                input_text = \" \".join(data_point_i['tokens'])\n",
    "                input_text = re_test.sub(r\"\\1\", input_text)\n",
    "                output = data_point_i['tokens'][data_point_i['tags'].index(B_ind)]\n",
    "                if I_ind in data_point_i['tags']:\n",
    "                    name_cont_ind = list(np.where(np.array(data_point_i['tags']) == I_ind)[0])\n",
    "                    output_2 = \" \".join([data_point_i['tokens'][x] for x in name_cont_ind])\n",
    "                    output += \" \" + output_2\n",
    "                \n",
    "                data_filtered.append({\"input\":input_text, \"output\":output})\n",
    "\n",
    "    # json.dump(data_filtered, open(f'dataset_files/multinerd/multinerd_{category}.json', 'w'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Next & Prev"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lower = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n",
    "double_lower = [x*2 for x in lower]\n",
    "\n",
    "upper = [x.upper() for x in lower]\n",
    "double_upper = [x.upper()*2 for x in lower]\n",
    "\n",
    "number_strings = [\"zero\",\"one\", \"two\", \"three\", \"four\", \"five\",\n",
    "           \"six\", \"seven\", \"eight\", \"nine\", \"ten\",\n",
    "           \"eleven\", \"twelve\", \"thirteen\", \"fourteen\", \"fifteen\",\n",
    "           \"sixteen\", \"seventeen\", \"eighteen\", \"nineteen\", \"twenty\"]\n",
    "\n",
    "numbers = [str(i) for i in range(30)]\n",
    "\n",
    "roman_numerals_upper = ['I', 'II', 'III', 'IV', 'V', 'VI', 'VII', 'VIII', 'IX', 'X',\n",
    "                        'XI', 'XII', 'XIII', 'XIV', 'XV', 'XVI', 'XVII', 'XVIII', 'XIX', 'XX']\n",
    "\n",
    "roman_numerals_lower = [x.lower() for x in roman_numerals_upper]\n",
    "\n",
    "days = ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']\n",
    "days_upper = [x.title() for x in days]\n",
    "\n",
    "months = ['january', 'february', 'march', 'april', 'may', 'june', 'july', 'august', 'september', 'october', 'november', 'december']\n",
    "months_upper = [x.title() for x in months]\n",
    "\n",
    "\n",
    "pair_up_next = lambda l : [{\"input\":l[i], \"output\":l[i+1]} for i in range(len(l)-1)]\n",
    "pair_up_prev = lambda l : [{\"input\":l[i+1], \"output\":l[i]} for i in range(len(l)-1)]\n",
    "\n",
    "PREV_dataset = pair_up_prev(number_strings) + pair_up_prev(numbers) + pair_up_prev(lower) + pair_up_prev(upper) + pair_up_prev(double_upper) + pair_up_prev(double_lower) +  pair_up_prev(roman_numerals_upper) + pair_up_prev(roman_numerals_lower) + pair_up_prev(days) +  pair_up_prev(months) +  pair_up_prev(days_upper)  + pair_up_prev(months_upper)  + [{'input':'monday', 'output':'sunday'}] +  [{'input':'january', 'output':'december'}] +  [{'input':'Monday', 'output':'Sunday'}] +  [{'input':'January', 'output':'December'}]\n",
    "NEXT_dataset = pair_up_next(number_strings) + pair_up_next(numbers) + pair_up_next(lower) + pair_up_next(upper) + pair_up_next(double_upper) + pair_up_next(double_lower) + pair_up_next(roman_numerals_upper) + pair_up_next(roman_numerals_lower) + pair_up_next(days) + pair_up_next(months)+ pair_up_next(days_upper) + pair_up_next(months_upper) + [{'input':'sunday', 'output':'monday'}] +  [{'input':'december', 'output':'january'}] +  [{'input':'Sunday', 'output':'Monday'}] +  [{'input':'December', 'output':'January'}]\n",
    "\n",
    "# json.dump(NEXT_dataset, open('dataset_files/abstractive/next_item.json','w'))\n",
    "# json.dump(PREV_dataset, open('dataset_files/abstractive/prev_item.json','w'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "b00494373a840e182ed4f756382d605e0c3662c3065e6fd11aabe1def62f89ca"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
