{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "34a67c0e-7be1-4249-9a89-ecb9d6c1a8f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "parent_dir = os.path.abspath('../..')\n",
    "sys.path.append(parent_dir)\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"4\"\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"  # or \"true\" depending on your requirement\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from eval_utils import compare_MLE, data_profiling, cate_info, col_check,  transform_label\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "import os\n",
    "import json\n",
    "import pickle\n",
    "import torch\n",
    "from sdv.metadata import SingleTableMetadata\n",
    "from sdv.single_table import CTGANSynthesizer, TVAESynthesizer\n",
    "from sdmetrics.reports.single_table import QualityReport\n",
    "from be_great import GReaT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6197d76a-685e-4000-8c9d-0e4117a90ad2",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"../sample/Insurance/\"\n",
    "\n",
    "test_data = pd.read_csv(path + \"data_test.csv\")\n",
    "data100 = pd.read_csv(path + \"data100.csv\")\n",
    "data200 = pd.read_csv(path + \"data200.csv\")\n",
    "data400 = pd.read_csv(path + \"data400.csv\")\n",
    "data800 = pd.read_csv(path + \"data800.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "71fe024a-85ce-46b3-ac50-9f34d272b9f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_lst = [data100, data200, data400, data800]\n",
    "num_lst = [100, 200, 400, 800]\n",
    "\n",
    "seed_lst = list(range(5))\n",
    "\n",
    "x_cols = ['age', 'sex', 'bmi', 'children', 'smoker', 'region']\n",
    "y_col = 'charges'\n",
    "bool_var = []\n",
    "cols = x_cols + [y_col]\n",
    "cate_var = ['sex', 'smoker', 'region']\n",
    "num_var = list(set(cols) - set(cate_var))\n",
    "\n",
    "res_shape = {}\n",
    "res = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "42e27ea1-69d7-4fb6-816f-2c91f0c0cfef",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sdv/single_table/base.py:92: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/torch/autograd/graph.py:768: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at ../aten/src/ATen/cuda/CublasHandlePool.cpp:135.)\n",
      "  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n",
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1200' max='1200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1200/1200 01:14, Epoch 400/400]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>1.018800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.627600</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                                                                                                                             | 0/100 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "183it [00:00, 415.48it/s]                                                                                                                                                                                                                 \n",
      "179it [00:00, 442.79it/s]                                                                                                                                                                                                                 \n",
      "180it [00:00, 508.01it/s]                                                                                                                                                                                                                 \n",
      "176it [00:00, 507.26it/s]                                                                                                                                                                                                                 \n",
      "181it [00:00, 545.80it/s]                                                                                                                                                                                                                 \n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sdv/single_table/base.py:92: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n",
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='3000' max='3000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [3000/3000 03:30, Epoch 600/600]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>1.129900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.684600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.559600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.491200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.462000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.450000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "283it [00:00, 555.98it/s]                                                                                                                                                                                                                 \n",
      "273it [00:00, 550.42it/s]                                                                                                                                                                                                                 \n",
      "281it [00:00, 554.71it/s]                                                                                                                                                                                                                 \n",
      "278it [00:00, 566.30it/s]                                                                                                                                                                                                                 \n",
      "273it [00:00, 530.60it/s]                                                                                                                                                                                                                 \n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sdv/single_table/base.py:92: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n",
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='4000' max='4000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [4000/4000 04:39, Epoch 400/400]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>1.268200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.841100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.698100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.623100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.565800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.531500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.512000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.503300</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "474it [00:00, 544.94it/s]                                                                                                                                                                                                                 \n",
      "461it [00:00, 531.66it/s]                                                                                                                                                                                                                 \n",
      "465it [00:00, 576.42it/s]                                                                                                                                                                                                                 \n",
      "481it [00:00, 598.13it/s]                                                                                                                                                                                                                 \n",
      "482it [00:00, 596.46it/s]                                                                                                                                                                                                                 \n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sdv/single_table/base.py:92: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n",
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='4000' max='4000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [4000/4000 04:39, Epoch 200/200]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>1.338400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>1.037800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.848100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.758300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.714500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.685800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.665700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.654400</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "877it [00:01, 562.52it/s]                                                                                                                                                                                                                 \n",
      "886it [00:01, 599.43it/s]                                                                                                                                                                                                                 \n",
      "869it [00:01, 581.67it/s]                                                                                                                                                                                                                 \n",
      "869it [00:01, 563.67it/s]                                                                                                                                                                                                                 \n",
      "870it [00:01, 574.21it/s]                                                                                                                                                                                                                 \n"
     ]
    }
   ],
   "source": [
    "num_lst = [100, 200, 400, 800]\n",
    "epoch_lst = [400, 600, 400, 200]\n",
    "for i in range(4):\n",
    "    num = num_lst[i]\n",
    "    sampled_data = data_lst[i]\n",
    "    sampled_data = sampled_data[cols].copy()\n",
    "    metadata = SingleTableMetadata()\n",
    "    metadata.detect_from_dataframe(sampled_data[cols])\n",
    "    metadata_dict = metadata.to_dict()\n",
    "    \n",
    "    ctgan_model = CTGANSynthesizer(metadata, cuda=True)\n",
    "    ctgan_model.fit(sampled_data[cols])\n",
    "    tvae_model = TVAESynthesizer(metadata, cuda=True)\n",
    "    tvae_model.fit(sampled_data[cols])\n",
    "    great_model = GReaT(llm='distilgpt2', batch_size = 40, epochs=epoch_lst[i])\n",
    "    great_model.fit(sampled_data[cols])\n",
    "\n",
    "\n",
    "    for j in range(len(seed_lst)):\n",
    "        seed = seed_lst[j]\n",
    "        # CTGAN\n",
    "        #ctgan_model.set_random_state(seed)\n",
    "        ctgan_syn = ctgan_model.sample(len(sampled_data))\n",
    "        ctgan_syn.reset_index(inplace=True, drop=True)\n",
    "        ctgan_syn.to_csv(f'gen/Insurance/{num}/CTGAN/df_{seed}.csv', index=False)\n",
    "        \n",
    "        # TVAE\n",
    "        #tvae_model.set_random_state(seed)\n",
    "        tvae_syn = tvae_model.sample(len(sampled_data))\n",
    "        tvae_syn.reset_index(inplace=True, drop=True)\n",
    "        tvae_syn.to_csv(f'gen/Insurance/{num}/TVAE/df_{seed}.csv', index=False)\n",
    "        \n",
    "\n",
    "        # Be-great\n",
    "        great_syn_df = great_model.sample(len(sampled_data), max_length=4000)\n",
    "        great_syn_df.reset_index(inplace=True, drop=True)\n",
    "        great_syn_df.to_csv(f'gen/Insurance/{num}/BeGReaT/df_{seed}.csv', index=False)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cc0c89f-a979-4857-8eb6-fede30c51eb8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cda22057-33a0-4eb6-9f58-afc58856e5e7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10",
   "language": "python",
   "name": "py3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
