{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "educated-niger",
   "metadata": {},
   "source": [
    "# Use GReaT to generate data for the Adult Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "returning-diversity",
   "metadata": {},
   "source": [
    "### Import packages and set config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "false-magnitude",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import random\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "severe-viking",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "centered-jamaica",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.asyncio import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "medieval-workshop",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"adult\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "excellent-homeless",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Smaller than the original GPT-2, only 6 layers (instead of 12), 12 heads, 768 dimensions (like GPT-2)\n",
    "checkpoint = \"distilgpt2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "certified-newport",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of epochs the model was trained.\n",
    "epochs = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "narrow-developer",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Name of the model that will be loaded\n",
    "model_name = dataset + \"_\" + checkpoint + \"_\" + str(epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cheap-prerequisite",
   "metadata": {},
   "outputs": [],
   "source": [
    "columns = ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', \n",
    "           'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'income']\n",
    "label_col = \"income\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ethical-viewer",
   "metadata": {},
   "source": [
    "### Load tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceramic-desktop",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "chicken-culture",
   "metadata": {},
   "source": [
    "### Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "inclusive-symphony",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "parliamentary-regard",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(checkpoint)\n",
    "model.load_state_dict(torch.load(\"models/\" + model_name + \".pt\"))\n",
    "model.to(device)\n",
    "print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "mineral-armor",
   "metadata": {},
   "source": [
    "### Generate Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ignored-adelaide",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samples = 1000\n",
    "n_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "answering-executive",
   "metadata": {},
   "outputs": [],
   "source": [
    "temperature = 0.7"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "rubber-offset",
   "metadata": {},
   "source": [
    "Conditioning on the target column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acting-graduation",
   "metadata": {},
   "outputs": [],
   "source": [
    "population = [\"<=50K\", \">50K\"]\n",
    "weights = [0.76, 0.24]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "floating-kansas",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.asyncio import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "active-consultancy",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "k = 100\n",
    "\n",
    "generated_data = []\n",
    "\n",
    "for i in tqdm(range(int(n_samples/k)+1)):\n",
    "    start_words = random.choices(population, weights, k=k)\n",
    "    start = [label_col + \" is \" + str(s) + \",\" for s in start_words]\n",
    "    start_token = torch.tensor(tokenizer(start)[\"input_ids\"]).to(device)\n",
    "\n",
    "    gens = model.generate(input_ids=start_token, min_length=80, max_length=130,\n",
    "                         do_sample=True, temperature=temperature, pad_token_id=50256) # input_ids=test_ids\n",
    "    \n",
    "    for g in gens:\n",
    "        generated_data.append(g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "sharp-measurement",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "decoded_data = [tokenizer.decode(gen) for gen in generated_data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "educated-observer",
   "metadata": {},
   "outputs": [],
   "source": [
    "decoded_data = [d.replace(\"<|endoftext|>\", \"\") for d in decoded_data]\n",
    "decoded_data = [d.replace(\"\\n\", \" \") for d in decoded_data]\n",
    "decoded_data = [d.replace(\"\\r\", \"\") for d in decoded_data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "purple-airport",
   "metadata": {},
   "outputs": [],
   "source": [
    "decoded_data[:3]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "portuguese-letters",
   "metadata": {},
   "source": [
    "### Convert Text back to Tabular Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "wicked-kazakhstan",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "df_gen = pd.DataFrame(columns=columns)\n",
    "\n",
    "for g in tqdm(decoded_data):\n",
    "    features = g.split(\",\")\n",
    "    \n",
    "    td = dict.fromkeys(columns)\n",
    "\n",
    "    # Transform all features back to tabular data\n",
    "    for f in features:\n",
    "        values = f.strip().split(\" \")\n",
    "\n",
    "        if values[0] in columns:\n",
    "            td[values[0]] = [values[-1]]\n",
    "            \n",
    "    df_gen = pd.concat([df_gen, pd.DataFrame(td)], ignore_index=True, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "honey-minister",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Needed for adult dataset. Some categorical values there can be \"?\" and the GREAT Transformer\n",
    "# does not set a space before it\n",
    "\n",
    "df_gen[\"workclass\"] = df_gen[\"workclass\"].replace(\"is?\", \"?\")\n",
    "df_gen[\"occupation\"] = df_gen[\"occupation\"].replace(\"is?\", \"?\")\n",
    "df_gen[\"native-country\"] = df_gen[\"native-country\"].replace(\"is?\", \"?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "innovative-physiology",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_gen.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a8c8d67-e504-4b8a-9db3-956114ff1529",
   "metadata": {},
   "source": [
    "Thats a small sanity check to verify all numerical columns are really numbers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "889c92fe-1d4d-4dcc-b075-4a9a5d158a05",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_cols = config[\"num_cols\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "stuck-canberra",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i_num_cols in num_cols:\n",
    "    df_gen = df_gen[pd.to_numeric(df_gen[i_num_cols], errors='coerce').notnull()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "utility-bahrain",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_gen[num_cols] = df_gen[num_cols].astype(\"int\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "thermal-thing",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_gen.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bda0ecf-c37f-4200-b24b-cd524a0ac70d",
   "metadata": {},
   "source": [
    "Remove all columns with missing values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "secret-float",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_gen = df_gen.drop(df_gen[df_gen.isna().any(axis=1)].index)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e78f3d81-30e7-410a-ba18-3faa5af72bdc",
   "metadata": {},
   "source": [
    "Only take the desired number of samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3eaaaa24-5b15-4d90-90bd-a59ad126f4c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_gen = df_gen.head(n_samples)\n",
    "df_gen = df_gen.reset_index(drop=True) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bc824f6-950c-478d-a7e4-fd9813b678fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_gen.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "treated-assault",
   "metadata": {},
   "source": [
    "### Save synthetic dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "recovered-cocktail",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "df_gen_.to_csv(\"gen_data/adult_samples.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "metropolitan-electron",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
