{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/data/anon/miniconda3/envs/greatt/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import openml\n",
    "import pandas as pd\n",
    "import datetime\n",
    "import os\n",
    "import json\n",
    "import transformers\n",
    "from src.tabby import MHTabbyGPT2Config, MHTabbyGPT2\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load real data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_328052/2725583496.py:1: FutureWarning: Starting from Version 0.15 `download_data`, `download_qualities`, and `download_features_meta_data` will all be ``False`` instead of ``True`` by default to enable lazy loading. To disable this message until version 0.15 explicitly set `download_data`, `download_qualities`, and `download_features_meta_data` to a bool while calling `get_dataset`.\n",
      "  dataset = openml.datasets.get_dataset('diabetes')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train (576, 9) val (57, 9) test (135, 9)\n"
     ]
    }
   ],
   "source": [
    "dataset = openml.datasets.get_dataset('diabetes')\n",
    "df, _, _, _ = dataset.get_data(dataset_format=\"dataframe\")\n",
    "\n",
    "# rename and clean columns\n",
    "cols = ['pregnancies', 'glucose-plasma', 'blood-pressure', 'skin-thickness', 'insulin', 'BMI', 'pedigree', 'age', 'diagnosis']\n",
    "df.columns = cols \n",
    "df['diagnosis'] = df['diagnosis'].map(lambda x: 'positive' if x=='tested_positive' else 'negative')\n",
    "\n",
    "# shuffle data\n",
    "df = df.sample(frac=1, random_state=42, ignore_index=True)\n",
    "\n",
    "# split into train/val/test sets used by paper\n",
    "n = len(df)\n",
    "train_size = int(0.75 * n)\n",
    "val_size = int(0.075 * n)\n",
    "train = df.iloc[:train_size, :]\n",
    "val = df.iloc[train_size:train_size+val_size, :]\n",
    "test = df.iloc[train_size+val_size:, :]\n",
    "print('train', train.shape, 'val', val.shape, 'test', test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = MHTabbyGPT2Config.from_pretrained(\"sonicc/tabby-distilgpt2-diabetes\")\n",
    "model = MHTabbyGPT2.from_pretrained(\"sonicc/tabby-distilgpt2-diabetes\")\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(\"sonicc/tabby-distilgpt2-diabetes\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "column_names_tokens = tokenizer(list(train.columns), add_special_tokens=False).input_ids\n",
    "token_heads = list(range( len(train.columns) ))\n",
    "model.set_generation_mode(token_heads=token_heads, column_names_tokens=column_names_tokens)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Perform Synthesis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "**Synthetic dataset:**\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pregnancies</th>\n",
       "      <th>glucose-plasma</th>\n",
       "      <th>blood-pressure</th>\n",
       "      <th>skin-thickness</th>\n",
       "      <th>insulin</th>\n",
       "      <th>BMI</th>\n",
       "      <th>pedigree</th>\n",
       "      <th>age</th>\n",
       "      <th>diagnosis</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2.0</td>\n",
       "      <td>74.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.102</td>\n",
       "      <td>22.0</td>\n",
       "      <td>negative</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.0</td>\n",
       "      <td>90.0</td>\n",
       "      <td>62.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>59.0</td>\n",
       "      <td>25.1</td>\n",
       "      <td>1.268</td>\n",
       "      <td>25.0</td>\n",
       "      <td>negative</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.0</td>\n",
       "      <td>93.0</td>\n",
       "      <td>70.0</td>\n",
       "      <td>31.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30.4</td>\n",
       "      <td>0.315</td>\n",
       "      <td>23.0</td>\n",
       "      <td>negative</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>7.0</td>\n",
       "      <td>184.0</td>\n",
       "      <td>84.0</td>\n",
       "      <td>33.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>35.5</td>\n",
       "      <td>0.355</td>\n",
       "      <td>41.0</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2.0</td>\n",
       "      <td>88.0</td>\n",
       "      <td>58.0</td>\n",
       "      <td>26.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>28.4</td>\n",
       "      <td>0.766</td>\n",
       "      <td>22.0</td>\n",
       "      <td>negative</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  pregnancies glucose-plasma blood-pressure skin-thickness insulin   BMI  \\\n",
       "0         2.0           74.0            0.0            0.0     0.0   0.0   \n",
       "1         1.0           90.0           62.0           18.0    59.0  25.1   \n",
       "2         1.0           93.0           70.0           31.0     0.0  30.4   \n",
       "3         7.0          184.0           84.0           33.0     0.0  35.5   \n",
       "4         2.0           88.0           58.0           26.0    16.0  28.4   \n",
       "\n",
       "  pedigree   age diagnosis  \n",
       "0    0.102  22.0  negative  \n",
       "1    1.268  25.0  negative  \n",
       "2    0.315  23.0  negative  \n",
       "3    0.355  41.0  positive  \n",
       "4    0.766  22.0  negative  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_samples = 5\n",
    "\n",
    "outputs = []\n",
    "for _ in range(n_samples):\n",
    "    inputs = torch.full((1, 1), tokenizer.bos_token_id).to(model.device)\n",
    "    toks = model.generate(inputs, do_sample=True, num_beams=1, max_length=10, pad_token_id=tokenizer.pad_token_id)[...,1:]\n",
    "    outputs.append(tokenizer.batch_decode(toks)[0]) \n",
    "    \n",
    "# parse the lines output by model\n",
    "def parse_line(l):\n",
    "    entries = l.split('<EOC>')[:-1] # remove newline at end\n",
    "    words = [c.split(' ') for c in entries] #'name', 'is', 'value'\n",
    "    d = dict()\n",
    "    for c in words:\n",
    "        if c[0] in cols and len(c) == 3 and c[0] not in d: # keep only first occurence\n",
    "            d[c[0]] = c[2]\n",
    "            \n",
    "    if set(d.keys()) == set(cols):\n",
    "        return d \n",
    "    else:\n",
    "        return None\n",
    "\n",
    "print('\\n\\n**Synthetic dataset:**')\n",
    "dicts = [parse_line(out) for out in outputs]\n",
    "pd.DataFrame(dicts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "greatt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
