{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "adopted-junior",
   "metadata": {},
   "source": [
    "# Train GReaT for Adult Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bottom-anniversary",
   "metadata": {},
   "source": [
    "## Import packages and set config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abstract-giving",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "blind-gathering",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer\n",
    "from datasets import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "infrared-sending",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "exterior-trail",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Name is used to save training checkpoints\n",
    "experiment_name = \"trainer_adult\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "humanitarian-penguin",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Smaller than the original GPT-2, only 6 layers (instead of 12), 12 heads, 768 dimensions (like GPT-2)\n",
    "checkpoint = \"distilgpt2\" # \"gpt2-medium\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dense-sewing",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Shuffle the features or not\n",
    "shuffle = True "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "falling-holocaust",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "included-yacht",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"adult_train.csv\", index_col=0)\n",
    "\n",
    "columns = df.columns.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "active-reviewer",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "upset-warehouse",
   "metadata": {},
   "source": [
    "## Train GREAT Transformer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "applied-external",
   "metadata": {},
   "source": [
    "### Prepare the tabular data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "simplified-rugby",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use the HuggingFace dataset object\n",
    "ds = Dataset.from_pandas(df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "verified-chemical",
   "metadata": {},
   "source": [
    "Tabular data can be ordered or shuffled"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "juvenile-surgery",
   "metadata": {},
   "outputs": [],
   "source": [
    "def combine_data_ordered(sample):\n",
    "    concat = \"\"\n",
    "    for col in columns:\n",
    "        concat += \"%s is %s, \" % (col, str(sample[col]).strip())\n",
    "\n",
    "    return {\"concat\": concat}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "assigned-berkeley",
   "metadata": {},
   "outputs": [],
   "source": [
    "def combine_data_shuffled(sample):\n",
    "    concat = \"\"\n",
    "    for col in random.sample(columns, k=len(columns)):\n",
    "        concat += \"%s is %s, \" % (col, str(sample[col]).strip())\n",
    "\n",
    "    return {\"concat\": concat}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "moving-pearl",
   "metadata": {},
   "outputs": [],
   "source": [
    "if shuffle:\n",
    "    combined_ds = ds.map(combine_data_shuffled)\n",
    "else:\n",
    "    combined_ds = ds.map(combine_data_ordered)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "casual-venezuela",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Remove all original features from the dataset \n",
    "combined_ds = combined_ds.remove_columns(ds.column_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "needed-charm",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Take a look at the dataset\n",
    "combined_ds[\"concat\"][:3]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "minimal-elephant",
   "metadata": {},
   "source": [
    "### Tokenize the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "liquid-practice",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "meaningful-quest",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 125"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "retired-value",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenizer_function(sample):\n",
    "    result = tokenizer(sample[\"concat\"], truncation=True, padding=\"max_length\", max_length=max_length)\n",
    "    result[\"labels\"] = result[\"input_ids\"].copy()\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "supreme-newman",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Tokenize dataset and create pytorch tensors\n",
    "tokenizer_ds = combined_ds.map(tokenizer_function, batched=True)\n",
    "tokenizer_ds.set_format(\"torch\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "allied-prefix",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer_ds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "satisfied-stable",
   "metadata": {},
   "source": [
    "### Create Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "accepting-leonard",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "metric-second",
   "metadata": {},
   "source": [
    "### Train Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eastern-table",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 200\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "harmful-address",
   "metadata": {},
   "outputs": [],
   "source": [
    "# All hyperparameters:\n",
    "# https://huggingface.co/docs/transformers/v4.19.2/en/main_classes/trainer#transformers.TrainingArguments\n",
    "\n",
    "training_args = TrainingArguments(experiment_name, \n",
    "                                  num_train_epochs=epochs, \n",
    "                                  per_device_train_batch_size=batch_size,\n",
    "                                  save_steps=5000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "casual-chocolate",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(model, training_args, train_dataset=tokenizer_ds, tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "organic-spice",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trainer.train() # resume_from_checkpoint = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "flush-boston",
   "metadata": {},
   "source": [
    "### Save Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ordinary-freight",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_name = \"adult_\" + checkpoint + \"_\" + str(epochs) + \".pt\"\n",
    "\n",
    "model_path = \"models/\" + model_name\n",
    "torch.save(model.state_dict(), model_path)\n",
    "model_name"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "young-banner",
   "metadata": {},
   "source": [
    "### Look at loss curve"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "quality-monthly",
   "metadata": {},
   "source": [
    "The trainer.state.log_history contains some information about the training.\n",
    "Just saving the loss curve, but one could also save the whole state of the trainer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "mature-robertson",
   "metadata": {},
   "outputs": [],
   "source": [
    "rounds = len(trainer.state.log_history) - 1\n",
    "loss = [x[\"loss\"] if i < rounds else None for i, x in enumerate(trainer.state.log_history)]\n",
    "plt.plot(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "changing-valuation",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
