{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XIx6KXZo1G2w"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import openai\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from time import sleep\n",
    "from tqdm.notebook import tqdm\n",
    "import nltk\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ChatGPT dataset generation example\n",
    "\n",
    "This notebook demostrates data generation process with OpenAi API for WikiM dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nltk.download('punkt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "languages = {\n",
    "    'en':\"English\",\n",
    "    'ar':\"Arabic\",\n",
    "    'zh-cn':\"Chinese\",\n",
    "    'nl':\"Dutch\",\n",
    "    'fr':'French',\n",
    "    'de':\"German\",\n",
    "    'it':\"Italian\",\n",
    "    'ja':\"Japanese\",\n",
    "    'ko':\"Korean\",\n",
    "    'pl':\"Polish\",\n",
    "    'pt':\"Portuguese\",\n",
    "    'ru':\"Russian\",\n",
    "    'es':\"Spanish\",\n",
    "    'th':\"Thai\",\n",
    "    'tr':\"Turkish\",\n",
    "    'bg':\"Bulgarian\",\n",
    "    'ca':\"Catalan\",\n",
    "    'cs':\"Czech\",\n",
    "    'da':\"Danish\",\n",
    "    'el':\"Greek\",\n",
    "    'et':\"Estonian\",\n",
    "    'fa':\"Persian\",\n",
    "    'fi':\"Finnish\",\n",
    "    'he':\"Hebrew\",\n",
    "    'hi':\"Hindi\",\n",
    "    'hr':\"Croatian\",\n",
    "    'hu':\"Hungarian\",\n",
    "    'id':\"Indonesian\",\n",
    "    'lt':\"Lithuanian\",\n",
    "    'lv':\"Latvian\",\n",
    "    'ms':\"Malay\",\n",
    "    'no':\"Norwegian\",\n",
    "    'ro':\"Romanian\",\n",
    "    'sk':\"Slovak\",\n",
    "    'sl':\"Slovenian\",\n",
    "    'sr':\"Serbian\",\n",
    "    'sw':\"Swedish\",\n",
    "    'tl':\"\",\n",
    "    'uk':\"Ukranian\",\n",
    "    'vi':\"Vietnamese\"\n",
    "}\n",
    "\n",
    "language_id = 'en'\n",
    "language_name = languages[language_id]\n",
    "base_promt = f'Continue the following lines in Wikipedia style with paragraph titles in the {language_name} language:\\n'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create Wiki40M dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Take 5000 samples from train, and 300 for test / validation from Wiki40b;\n",
    "\n",
    "2. Choose longer samples; clean them.\n",
    "\n",
    "3. Generate new samples from ChatGPT (3.5 Turbo)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_TRAIN = 5000\n",
    "N_TEST = 300"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def clear_text(text):\n",
    "    pure_text = re.sub(\"_NEWLINE_\", \"\\n\", text)\n",
    "    pure_text = re.sub(\"_START_ARTICLE_\", \"\", pure_text)\n",
    "    pure_text = re.sub(\"_START_SECTION_\", \"\", pure_text)\n",
    "    pure_text = re.sub(\"_START_PARAGRAPH_\", \"\", pure_text)\n",
    "    pure_text = re.sub(\"  \", \" \", pure_text)\n",
    "    return pure_text.strip()\n",
    "\n",
    "def preprocess_text(text):\n",
    "    \"\"\"\n",
    "    Delete special tokens and output pure text with paragraph titles and promt body.\n",
    "    \n",
    "    \"\"\"\n",
    "    lines = text.split('\\n')\n",
    "    title, section, paragraph = None, None, None\n",
    "    for i, line in enumerate(lines):\n",
    "        if line.strip() == \"_START_PARAGRAPH_\":\n",
    "            break\n",
    "    paragraph = nltk.tokenize.sent_tokenize(lines[i+1], language=language_name.lower())[0]\n",
    "    promt_body = clear_text(\"\\n\".join(lines[:i+1] + [paragraph]))\n",
    "    pure_text = clear_text(text)\n",
    "    return promt_body, pure_text\n",
    "   \n",
    "def preprocess_sample(sample):\n",
    "    promt_body, pure_text = preprocess_text(sample['text'].numpy().decode('utf-8'))\n",
    "    sample['text'] = pure_text\n",
    "    sample['promt_body'] = promt_body\n",
    "    sample['version_id'] = sample['version_id'].numpy().decode('utf-8')\n",
    "    sample['wikidata_id'] = sample['wikidata_id'].numpy().decode('utf-8')\n",
    "    return sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "train = tfds.load(f\"wiki40b/{language_id}\", split=\"train\", try_gcs=True)\n",
    "test = tfds.load(f\"wiki40b/{language_id}\", split=\"test\", try_gcs=True)\n",
    "validation = tfds.load(f\"wiki40b/{language_id}\", split=\"validation\", try_gcs=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get length distribution and calculate 30 percentile threshold\n",
    "lengths = []\n",
    "for i, sample in enumerate(tqdm(train)):\n",
    "    lengths.append(len(sample['text'].numpy().decode('utf-8')))\n",
    "length_threshold = np.quantile(lengths, 0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_samples = []\n",
    "test_samples = []\n",
    "validation_samples = []\n",
    "for sample in tqdm(train):\n",
    "    if len(train_samples) == N_TRAIN:\n",
    "        break\n",
    "    if len(sample['text'].numpy().decode('utf-8')) < length_threshold:\n",
    "        continue\n",
    "    train_samples.append(preprocess_sample(sample))\n",
    "\n",
    "for sample in tqdm(validation):\n",
    "    if len(validation_samples) == N_TEST:\n",
    "        break\n",
    "    if len(sample['text'].numpy().decode('utf-8')) < length_threshold:\n",
    "        continue\n",
    "    validation_samples.append(preprocess_sample(sample))\n",
    "    \n",
    "for sample in tqdm(test):\n",
    "    if len(test_samples) == N_TEST:\n",
    "        break\n",
    "    if len(sample['text'].numpy().decode('utf-8')) < length_threshold:\n",
    "        continue\n",
    "    test_samples.append(preprocess_sample(sample))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Take 5000 samples of data (train / test  / validation) for each language, length threshold: 80% percentile\n",
    "# For test / validation, sample 5.5% of the train size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(f'wiki40b/{language_id}'):\n",
    "    os.mkdir(f'wiki40b/{language_id}')\n",
    "pd.DataFrame(train_samples).to_csv(f'wiki40b/{language_id}/train.csv')\n",
    "pd.DataFrame(test_samples).to_csv(f'wiki40b/{language_id}/test.csv')\n",
    "pd.DataFrame(validation_samples).to_csv(f'wiki40b/{language_id}/validation.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subset = 'train'\n",
    "data = pd.read_csv(f'wiki40b/{language_id}/{subset}.csv').drop(columns='Unnamed: 0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data['model'] = 'gpt-3.5-turbo'\n",
    "data['dataset'] = 'wiki40b'\n",
    "data['language'] = language_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fcT79Ska1WPu"
   },
   "outputs": [],
   "source": [
    "openai.api_key = \"Here is your private API key\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_promt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "JkvvPdWEWTim",
    "outputId": "2386c9ee-114d-4efb-b37e-0e85652c57e4"
   },
   "outputs": [],
   "source": [
    "generated_samples = []\n",
    "for i in tqdm(range(len(data))):#len(full_df[\"prompt_body\"]))):\n",
    "#     length = len(data.iloc[i]['text'].split())\n",
    "#     base_promt = f'Continue the following lines in Wikipedia style with paragraph titles in the {language_name} language.'\n",
    "#     base_promt = base_promt + f'The total length should be around {length} words:\\n'\n",
    "    full_prompt = base_promt + '\\n' + data.iloc[i]['promt_body']\n",
    "    finished = False\n",
    "    while not finished:\n",
    "        try:\n",
    "            query_result = openai.ChatCompletion.create(\n",
    "                model=\"gpt-3.5-turbo\",\n",
    "                messages=[\n",
    "                  {\"role\": \"user\", \"content\": full_prompt}\n",
    "                ]\n",
    "            )\n",
    "            finished = True\n",
    "        except Exception:\n",
    "            sleep(60)\n",
    "    generated_samples.append(query_result[\"choices\"][0][\"message\"][\"content\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 363
    },
    "id": "nne5bRANaW_M",
    "outputId": "091e588f-179e-44ec-81c4-c976c3c2d425"
   },
   "outputs": [],
   "source": [
    "data['gen_body'] = generated_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "M-0IKgSt2oDL"
   },
   "outputs": [],
   "source": [
    "data.to_csv(f\"wiki40b/{language_id}/{subset}_generation.csv\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python3.8 (ATD)",
   "language": "python",
   "name": ".env_atd"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
