{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6fd9797a-66c7-4516-9b62-ad7de09167aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "parent_dir = os.path.abspath('../..')\n",
    "sys.path.append(parent_dir)\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"  # or \"true\" depending on your requirement\n",
    "\n",
    "from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, r2_score\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from eval_utils import compare_MLE, data_profiling, cate_info, col_check,  transform_label\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "import os\n",
    "import json\n",
    "import pickle\n",
    "from sdv.metadata import SingleTableMetadata\n",
    "from sdv.single_table import CTGANSynthesizer, TVAESynthesizer\n",
    "from sdmetrics.reports.single_table import QualityReport\n",
    "from be_great import GReaT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b41bd22f-3e9a-430b-8cf1-559eb3c567ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"../sample/Adult/\"\n",
    "\n",
    "test_data = pd.read_csv(path + \"data_test.csv\")\n",
    "data100 = pd.read_csv(path + \"data100.csv\")\n",
    "data200 = pd.read_csv(path + \"data200.csv\")\n",
    "data400 = pd.read_csv(path + \"data400.csv\")\n",
    "data800 = pd.read_csv(path + \"data800.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "159e4f3f-0941-4af1-b8f0-d82e38cdc4d3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['age', 'workclass', 'education', 'education-num', 'marital-status',\n",
       "       'occupation', 'relationship', 'race', 'sex', 'capital-gain',\n",
       "       'capital-loss', 'hours-per-week', 'native-country', 'Income'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data800.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c3e27a22-9b48-40b4-b228-d3977c6d6cac",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_lst = [data100, data200, data400, data800]\n",
    "seed_lst = list(range(5))\n",
    "x_cols = ['age', 'workclass', 'education', 'education-num',\n",
    "          'marital-status', 'occupation', 'relationship', 'race', 'sex',\n",
    "          'capital-gain', 'capital-loss', 'hours-per-week', 'native-country']\n",
    "y_col = \"Income\"\n",
    "cate_var = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country', 'Income']\n",
    "bool_var = []\n",
    "cols = x_cols + [y_col]\n",
    "num_var = list((set(cols) - set(cate_var)) - set(bool_var))\n",
    "\n",
    "\n",
    "res = {}\n",
    "res_shape = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fab2cbb-3e7f-4fb4-88f3-6f4912b8c7fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sdv/single_table/base.py:92: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/torch/autograd/graph.py:768: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at ../aten/src/ATen/cuda/CublasHandlePool.cpp:135.)\n",
      "  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n",
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1200' max='1200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1200/1200 03:01, Epoch 400/400]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.561500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.402500</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1209.01it/s]|\n",
      "Column Shapes Score: 84.93%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 93.69it/s]|\n",
      "Column Pair Trends Score: 71.49%\n",
      "\n",
      "Overall Score (Average): 78.21%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1343.96it/s]|\n",
      "Column Shapes Score: 81.46%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:01<00:00, 87.22it/s]|\n",
      "Column Pair Trends Score: 66.18%\n",
      "\n",
      "Overall Score (Average): 73.82%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                                                                      | 0/100 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "166it [00:01, 151.43it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 2007.26it/s]|\n",
      "Column Shapes Score: 91.07%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 95.83it/s]|\n",
      "Column Pair Trends Score: 73.34%\n",
      "\n",
      "Overall Score (Average): 82.21%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1152.76it/s]|\n",
      "Column Shapes Score: 82.93%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 93.39it/s]|\n",
      "Column Pair Trends Score: 67.62%\n",
      "\n",
      "Overall Score (Average): 75.28%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1364.70it/s]|\n",
      "Column Shapes Score: 83.32%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:01<00:00, 87.96it/s]|\n",
      "Column Pair Trends Score: 66.74%\n",
      "\n",
      "Overall Score (Average): 75.03%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "175it [00:01, 161.13it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 2011.38it/s]|\n",
      "Column Shapes Score: 90.5%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 96.13it/s]|\n",
      "Column Pair Trends Score: 78.19%\n",
      "\n",
      "Overall Score (Average): 84.34%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1950.45it/s]|\n",
      "Column Shapes Score: 84.79%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 95.98it/s]|\n",
      "Column Pair Trends Score: 70.07%\n",
      "\n",
      "Overall Score (Average): 77.43%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1399.10it/s]|\n",
      "Column Shapes Score: 81.32%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:01<00:00, 90.35it/s]|\n",
      "Column Pair Trends Score: 66.29%\n",
      "\n",
      "Overall Score (Average): 73.8%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "173it [00:01, 166.73it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 2005.75it/s]|\n",
      "Column Shapes Score: 88.46%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 94.10it/s]|\n",
      "Column Pair Trends Score: 69.45%\n",
      "\n",
      "Overall Score (Average): 78.96%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 2021.42it/s]|\n",
      "Column Shapes Score: 83.25%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 97.71it/s]|\n",
      "Column Pair Trends Score: 67.44%\n",
      "\n",
      "Overall Score (Average): 75.34%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1334.92it/s]|\n",
      "Column Shapes Score: 81.43%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:01<00:00, 88.58it/s]|\n",
      "Column Pair Trends Score: 65.91%\n",
      "\n",
      "Overall Score (Average): 73.67%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "172it [00:01, 159.02it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 2017.81it/s]|\n",
      "Column Shapes Score: 89.39%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 97.16it/s]|\n",
      "Column Pair Trends Score: 69.94%\n",
      "\n",
      "Overall Score (Average): 79.67%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1346.33it/s]|\n",
      "Column Shapes Score: 84.07%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 96.52it/s]|\n",
      "Column Pair Trends Score: 69.83%\n",
      "\n",
      "Overall Score (Average): 76.95%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1338.20it/s]|\n",
      "Column Shapes Score: 82.82%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:01<00:00, 88.17it/s]|\n",
      "Column Pair Trends Score: 66.84%\n",
      "\n",
      "Overall Score (Average): 74.83%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "175it [00:01, 165.54it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1923.43it/s]|\n",
      "Column Shapes Score: 87.93%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 94.83it/s]|\n",
      "Column Pair Trends Score: 76.31%\n",
      "\n",
      "Overall Score (Average): 82.12%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sdv/single_table/base.py:92: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n",
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='3000' max='3000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [3000/3000 08:50, Epoch 600/600]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.552000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.398000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.375300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.360700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.350800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.345600</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1701.49it/s]|\n",
      "Column Shapes Score: 88.04%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 101.86it/s]|\n",
      "Column Pair Trends Score: 75.59%\n",
      "\n",
      "Overall Score (Average): 81.82%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1949.48it/s]|\n",
      "Column Shapes Score: 77.11%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:01<00:00, 89.81it/s]|\n",
      "Column Pair Trends Score: 61.34%\n",
      "\n",
      "Overall Score (Average): 69.22%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "292it [00:01, 182.63it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1883.93it/s]|\n",
      "Column Shapes Score: 89.71%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 100.83it/s]|\n",
      "Column Pair Trends Score: 77.89%\n",
      "\n",
      "Overall Score (Average): 83.8%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1920.41it/s]|\n",
      "Column Shapes Score: 86.46%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 99.78it/s]|\n",
      "Column Pair Trends Score: 73.48%\n",
      "\n",
      "Overall Score (Average): 79.97%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1937.83it/s]|\n",
      "Column Shapes Score: 77.57%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:01<00:00, 88.75it/s]|\n",
      "Column Pair Trends Score: 63.79%\n",
      "\n",
      "Overall Score (Average): 70.68%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "285it [00:01, 181.60it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 2124.62it/s]|\n",
      "Column Shapes Score: 91.32%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 102.22it/s]|\n",
      "Column Pair Trends Score: 74.64%\n",
      "\n",
      "Overall Score (Average): 82.98%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 2158.36it/s]|\n",
      "Column Shapes Score: 87.39%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 103.63it/s]|\n",
      "Column Pair Trends Score: 73.7%\n",
      "\n",
      "Overall Score (Average): 80.54%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1597.13it/s]|\n",
      "Column Shapes Score: 78.0%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:01<00:00, 87.64it/s]|\n",
      "Column Pair Trends Score: 61.79%\n",
      "\n",
      "Overall Score (Average): 69.9%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "291it [00:01, 187.37it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 2078.81it/s]|\n",
      "Column Shapes Score: 91.18%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 103.67it/s]|\n",
      "Column Pair Trends Score: 76.02%\n",
      "\n",
      "Overall Score (Average): 83.6%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1415.63it/s]|\n",
      "Column Shapes Score: 87.36%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 100.31it/s]|\n",
      "Column Pair Trends Score: 74.98%\n",
      "\n",
      "Overall Score (Average): 81.17%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1376.63it/s]|\n",
      "Column Shapes Score: 77.43%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:01<00:00, 87.83it/s]|\n",
      "Column Pair Trends Score: 56.79%\n",
      "\n",
      "Overall Score (Average): 67.11%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "284it [00:01, 185.52it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1984.53it/s]|\n",
      "Column Shapes Score: 86.04%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 101.94it/s]|\n",
      "Column Pair Trends Score: 69.85%\n",
      "\n",
      "Overall Score (Average): 77.94%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1937.20it/s]|\n",
      "Column Shapes Score: 86.79%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 101.72it/s]|\n",
      "Column Pair Trends Score: 75.16%\n",
      "\n",
      "Overall Score (Average): 80.97%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1947.28it/s]|\n",
      "Column Shapes Score: 77.75%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:01<00:00, 88.45it/s]|\n",
      "Column Pair Trends Score: 61.2%\n",
      "\n",
      "Overall Score (Average): 69.47%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "284it [00:01, 183.98it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 1955.97it/s]|\n",
      "Column Shapes Score: 89.18%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 99.72it/s]|\n",
      "Column Pair Trends Score: 79.18%\n",
      "\n",
      "Overall Score (Average): 84.18%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sdv/single_table/base.py:92: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n",
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='790' max='4000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 790/4000 02:17 < 09:21, 5.72 it/s, Epoch 78.90/400]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.559100</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "num_lst = [100, 200, 400, 800]\n",
    "epoch_lst = [400, 600, 400, 200]\n",
    "for i in range(4):\n",
    "    num = num_lst[i]\n",
    "    sampled_data = data_lst[i]\n",
    "    sampled_data = sampled_data[cols].copy()\n",
    "    df_test = test_data.copy()\n",
    "    df_comb = pd.concat([sampled_data, df_test])\n",
    "    dummy_cols = list(pd.get_dummies(df_comb[x_cols]).columns)\n",
    "    df_test_dummy = pd.get_dummies(df_test[cols])\n",
    "    df_test_dummy = col_check(df_test_dummy, dummy_cols)\n",
    "    X_test = df_test_dummy.to_numpy()\n",
    "    y_test = df_test[y_col].to_numpy()\n",
    "    metadata = SingleTableMetadata()\n",
    "    metadata.detect_from_dataframe(sampled_data[cols])\n",
    "    metadata_dict = metadata.to_dict()\n",
    "    ctgan_model = CTGANSynthesizer(metadata, cuda=True)\n",
    "    ctgan_model.fit(sampled_data[cols])\n",
    "    tvae_model = TVAESynthesizer(metadata, cuda=True)\n",
    "    tvae_model.fit(sampled_data[cols])\n",
    "\n",
    "    great_model = GReaT(llm='distilgpt2', batch_size = 40, epochs=epoch_lst[i])\n",
    "    great_model.fit(sampled_data[cols])\n",
    "\n",
    "    df_train_dummy = pd.get_dummies(sampled_data[x_cols])\n",
    "    df_train_dummy = col_check(df_train_dummy, dummy_cols)\n",
    "    X_train = df_train_dummy.to_numpy()\n",
    "    y_train = sampled_data[y_col].to_numpy()\n",
    "    baseline = compare_MLE(X_test, y_test, X_train, y_train, task='classification')\n",
    "    res[str(num)] = {}\n",
    "    res[str(num)]['baseline'] = baseline\n",
    "    res_shape[str(num)] = {'ctgan':[], 'tvae':[], 'great':[]}\n",
    "\n",
    "    for j in range(len(seed_lst)):\n",
    "        seed = seed_lst[j]\n",
    "        res[str(num)][str(seed)] = {}\n",
    "        # CTGAN\n",
    "        #ctgan_model.set_random_state(seed)\n",
    "        ctgan_syn = ctgan_model.sample(len(sampled_data))\n",
    "        ctgan_syn.reset_index(inplace=True, drop=True)\n",
    "        ctgan_syn.to_csv(f'gen/Adult/{num}/CTGAN/df_{seed}.csv', index=False)\n",
    "        ctgan_report = QualityReport()\n",
    "        ctgan_report.generate(df_test[cols], ctgan_syn[cols], metadata_dict)\n",
    "        res_shape[str(num)]['ctgan'].append(ctgan_report._overall_score)\n",
    "        \n",
    "        ctgan_df_dummy = pd.get_dummies(ctgan_syn[x_cols])\n",
    "        ctgan_df_dummy = col_check(ctgan_df_dummy, dummy_cols)\n",
    "        X_train = ctgan_df_dummy.to_numpy()\n",
    "        y_train = ctgan_syn[y_col].to_numpy()\n",
    "\n",
    "        ctgan_performance = compare_MLE(X_test, y_test, X_train, y_train, task='classification')\n",
    "        res[str(num)][str(seed)]['ctgan'] = ctgan_performance\n",
    "        \n",
    "        # TVAE\n",
    "        #tvae_model.set_random_state(seed)\n",
    "        tvae_syn = tvae_model.sample(len(sampled_data))\n",
    "        tvae_syn.reset_index(inplace=True, drop=True)\n",
    "        tvae_syn.to_csv(f'gen/Adult/{num}/TVAE/df_{seed}.csv', index=False)\n",
    "        tvae_report = QualityReport()\n",
    "        tvae_report.generate(df_test[cols], tvae_syn[cols], metadata_dict)\n",
    "        res_shape[str(num)]['tvae'].append(tvae_report._overall_score)\n",
    "        tvae_df_dummy = pd.get_dummies(tvae_syn[x_cols])\n",
    "        tvae_df_dummy = col_check(tvae_df_dummy, dummy_cols)\n",
    "\n",
    "        X_train = tvae_df_dummy.to_numpy()\n",
    "        y_train = tvae_syn[y_col].to_numpy()\n",
    "\n",
    "        tvae_performance = compare_MLE(X_test, y_test, X_train, y_train, task='classification')\n",
    "        res[str(num)][str(seed)]['tave'] = tvae_performance\n",
    "\n",
    "        # Be-great\n",
    "        great_syn_df = great_model.sample(len(sampled_data), max_length=4000)\n",
    "        great_syn_df.reset_index(inplace=True, drop=True)\n",
    "        great_syn_df.to_csv(f'gen/Adult/{num}/BeGReaT/df_{seed}.csv', index=False)\n",
    "        great_report = QualityReport()\n",
    "        great_report.generate(df_test[cols], great_syn_df[cols], metadata_dict)\n",
    "        res_shape[str(num)]['great'].append(great_report._overall_score)\n",
    "        great_df_dummy = pd.get_dummies(great_syn_df[x_cols])\n",
    "        great_df_dummy = col_check(great_df_dummy, dummy_cols)\n",
    "\n",
    "        X_train = great_df_dummy.to_numpy()\n",
    "        y_train = great_syn_df[y_col].to_numpy()\n",
    "        great_performance = compare_MLE(X_test, y_test, X_train, y_train, task='classification')\n",
    "        res[str(num)][str(seed)]['great'] = great_performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e33f8637-09d3-4e38-a87e-2f55f020f0c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mle_summary_f1(res):\n",
    "    eval_sum_counts = {}\n",
    "    highest_avg_values = {}\n",
    "\n",
    "    for num, results in res.items():\n",
    "        eval_sum_counts[num] = {}\n",
    "        highest_avg_values[num] = {}\n",
    "        for seed, results0 in results.items():\n",
    "            if seed != 'baseline':\n",
    "                for model, results00 in results0.items():\n",
    "                    if model not in eval_sum_counts[num]:\n",
    "                        eval_sum_counts[num][model] = {}\n",
    "                    for eval_model, results000 in results00.items():\n",
    "                        accuracy = results000.get('accuracy', None)\n",
    "                        f1_score = results000.get('f1 score', None)\n",
    "                        if accuracy is not None and f1_score is not None:\n",
    "                            avg_value = (f1_score + f1_score) / 2\n",
    "                            if model not in highest_avg_values[num]:\n",
    "                                highest_avg_values[num][model] = []\n",
    "                            highest_avg_values[num][model].append((avg_value, seed, eval_model))\n",
    "            else:\n",
    "                eval_sum_counts[num]['baseline'] = results0\n",
    "\n",
    "    res_average = {}\n",
    "    for num, model_data in highest_avg_values.items():\n",
    "        res_average[num] = {}\n",
    "        for model, avg_data in model_data.items():\n",
    "            # Select the highest average value for each seed\n",
    "            highest_avg_by_seed = {}\n",
    "            for avg_value, seed, eval_model in avg_data:\n",
    "                if seed not in highest_avg_by_seed or avg_value > highest_avg_by_seed[seed][0]:\n",
    "                    highest_avg_by_seed[seed] = (avg_value, eval_model)\n",
    "\n",
    "            # Calculate mean and std of the highest average values across seeds\n",
    "            highest_avg_values_list = [value[0] for value in highest_avg_by_seed.values()]\n",
    "            res_average[num][model] = {\n",
    "                'highest_avg_mean': np.mean(highest_avg_values_list),\n",
    "                'highest_avg_std': np.std(highest_avg_values_list),\n",
    "                'eval_model_selected': {seed: eval_model for seed, (avg_value, eval_model) in highest_avg_by_seed.items()}\n",
    "            }\n",
    "\n",
    "    return res_average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "41fee85f-e0a9-4a50-b9a5-e5fdd7203c1a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'ctgan': {'highest_avg_mean': 0.6633027947763328,\n",
       "   'highest_avg_std': 0.05976949773728505,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Random Forest'}},\n",
       "  'tave': {'highest_avg_mean': 0.6642512925400983,\n",
       "   'highest_avg_std': 0.0579953818099571,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Random Forest'}},\n",
       "  'great': {'highest_avg_mean': 0.6724756105046781,\n",
       "   'highest_avg_std': 0.020385463117044176,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Logistic Regression'}}},\n",
       " '200': {'ctgan': {'highest_avg_mean': 0.6115476734331524,\n",
       "   'highest_avg_std': 0.01899163459889648,\n",
       "   'eval_model_selected': {'0': 'Random Forest',\n",
       "    '1': 'Random Forest',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Random Forest',\n",
       "    '4': 'Random Forest'}},\n",
       "  'tave': {'highest_avg_mean': 0.6666204714525936,\n",
       "   'highest_avg_std': 0.046509774596710046,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Random Forest',\n",
       "    '4': 'Logistic Regression'}},\n",
       "  'great': {'highest_avg_mean': 0.6871301600655922,\n",
       "   'highest_avg_std': 0.047648608315218895,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Logistic Regression'}}},\n",
       " '400': {'ctgan': {'highest_avg_mean': 0.6115593089907368,\n",
       "   'highest_avg_std': 0.01699211874599802,\n",
       "   'eval_model_selected': {'0': 'Random Forest',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Random Forest'}},\n",
       "  'tave': {'highest_avg_mean': 0.706297012542217,\n",
       "   'highest_avg_std': 0.06550919835724806,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Random Forest'}},\n",
       "  'great': {'highest_avg_mean': 0.775739414240481,\n",
       "   'highest_avg_std': 0.05777041669713688,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Logistic Regression'}}},\n",
       " '800': {'ctgan': {'highest_avg_mean': 0.629143254709608,\n",
       "   'highest_avg_std': 0.036808962537931304,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Random Forest'}},\n",
       "  'tave': {'highest_avg_mean': 0.7742213012426568,\n",
       "   'highest_avg_std': 0.020768348334283045,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Logistic Regression'}},\n",
       "  'great': {'highest_avg_mean': 0.7760030319636344,\n",
       "   'highest_avg_std': 0.023454937001709766,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Logistic Regression'}}}}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary_f1(res)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb805934-4e62-411c-8a9f-953be9e70444",
   "metadata": {},
   "source": [
    "# Adult dataset exploration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "764cbd0b-b366-4809-b90f-5d165fab4ba6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10",
   "language": "python",
   "name": "py3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
