{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6fd9797a-66c7-4516-9b62-ad7de09167aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "parent_dir = os.path.abspath('../..')\n",
    "sys.path.append(parent_dir)\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"  # or \"true\" depending on your requirement\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from eval_utils import compare_MLE, data_profiling, cate_info, col_check,  transform_label\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "import os\n",
    "import json\n",
    "import pickle\n",
    "from sdv.metadata import SingleTableMetadata\n",
    "from sdv.single_table import CTGANSynthesizer, TVAESynthesizer\n",
    "from sdmetrics.reports.single_table import QualityReport\n",
    "from be_great import GReaT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b41bd22f-3e9a-430b-8cf1-559eb3c567ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"../sample/Asia/\"\n",
    "\n",
    "test_data = pd.read_csv(path + \"data_test.csv\")\n",
    "data100 = pd.read_csv(path + \"data100.csv\")\n",
    "data200 = pd.read_csv(path + \"data200.csv\")\n",
    "data400 = pd.read_csv(path + \"data400.csv\")\n",
    "data800 = pd.read_csv(path + \"data800.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c3e27a22-9b48-40b4-b228-d3977c6d6cac",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_lst = [data100, data200, data400, data800]\n",
    "seed_lst = list(range(5))\n",
    "x_cols = [\"visit to Asia\", \"tuberculosis\", \"smoke\", \"lung cancer\", \"bronchitis\", \"either tuberculosis or lung cancer\", \"chest X-ray\"]\n",
    "y_col = \"Dyspnea\"\n",
    "cols = x_cols + [y_col]\n",
    "num_var = []\n",
    "cate_var = []\n",
    "bool_var = cols"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7fe489a0-5a7f-42dd-9202-53a2387ad514",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = {}\n",
    "res_shape = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3fab2cbb-3e7f-4fb4-88f3-6f4912b8c7fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sdv/single_table/base.py:92: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n",
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1200' max='1200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1200/1200 01:31, Epoch 400/400]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.355900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.247500</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 989.60it/s]|\n",
      "Column Shapes Score: 95.0%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 95.29it/s]|\n",
      "Column Pair Trends Score: 90.29%\n",
      "\n",
      "Overall Score (Average): 92.64%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1382.38it/s]|\n",
      "Column Shapes Score: 95.88%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 85.46it/s]|\n",
      "Column Pair Trends Score: 93.04%\n",
      "\n",
      "Overall Score (Average): 94.46%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                                                                      | 0/100 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "196it [00:00, 334.12it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2292.91it/s]|\n",
      "Column Shapes Score: 95.38%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 86.86it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 47.69%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2289.47it/s]|\n",
      "Column Shapes Score: 94.75%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 104.87it/s]|\n",
      "Column Pair Trends Score: 88.39%\n",
      "\n",
      "Overall Score (Average): 91.57%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2366.82it/s]|\n",
      "Column Shapes Score: 95.12%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 101.65it/s]|\n",
      "Column Pair Trends Score: 92.11%\n",
      "\n",
      "Overall Score (Average): 93.62%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "193it [00:00, 405.38it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2272.57it/s]|\n",
      "Column Shapes Score: 96.75%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 88.06it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 48.38%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2320.98it/s]|\n",
      "Column Shapes Score: 94.88%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 109.22it/s]|\n",
      "Column Pair Trends Score: 89.52%\n",
      "\n",
      "Overall Score (Average): 92.2%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2310.27it/s]|\n",
      "Column Shapes Score: 95.25%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 102.77it/s]|\n",
      "Column Pair Trends Score: 92.11%\n",
      "\n",
      "Overall Score (Average): 93.68%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "193it [00:00, 389.89it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2387.20it/s]|\n",
      "Column Shapes Score: 96.75%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 87.18it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 48.38%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2245.95it/s]|\n",
      "Column Shapes Score: 93.12%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 110.47it/s]|\n",
      "Column Pair Trends Score: 86.32%\n",
      "\n",
      "Overall Score (Average): 89.72%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2342.37it/s]|\n",
      "Column Shapes Score: 97.25%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 102.95it/s]|\n",
      "Column Pair Trends Score: 94.68%\n",
      "\n",
      "Overall Score (Average): 95.96%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "194it [00:00, 423.86it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2264.74it/s]|\n",
      "Column Shapes Score: 97.88%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 88.00it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 48.94%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2213.35it/s]|\n",
      "Column Shapes Score: 93.88%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 110.32it/s]|\n",
      "Column Pair Trends Score: 87.52%\n",
      "\n",
      "Overall Score (Average): 90.7%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2300.61it/s]|\n",
      "Column Shapes Score: 94.38%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 93.28it/s]|\n",
      "Column Pair Trends Score: 90.38%\n",
      "\n",
      "Overall Score (Average): 92.38%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "193it [00:00, 406.84it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2406.37it/s]|\n",
      "Column Shapes Score: 96.37%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 87.38it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 48.19%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sdv/single_table/base.py:92: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n",
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='3000' max='3000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [3000/3000 04:16, Epoch 600/600]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.349600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.246300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.241300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.239200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.237700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.236800</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2039.04it/s]|\n",
      "Column Shapes Score: 94.56%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 109.98it/s]|\n",
      "Column Pair Trends Score: 88.55%\n",
      "\n",
      "Overall Score (Average): 91.56%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2328.88it/s]|\n",
      "Column Shapes Score: 92.94%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 103.71it/s]|\n",
      "Column Pair Trends Score: 87.48%\n",
      "\n",
      "Overall Score (Average): 90.21%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 403.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2218.18it/s]|\n",
      "Column Shapes Score: 97.62%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 86.05it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 48.81%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2182.12it/s]|\n",
      "Column Shapes Score: 94.31%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 99.49it/s]|\n",
      "Column Pair Trends Score: 88.64%\n",
      "\n",
      "Overall Score (Average): 91.48%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2325.65it/s]|\n",
      "Column Shapes Score: 94.0%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 101.31it/s]|\n",
      "Column Pair Trends Score: 90.05%\n",
      "\n",
      "Overall Score (Average): 92.03%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "298it [00:00, 446.39it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2303.46it/s]|\n",
      "Column Shapes Score: 97.38%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 87.46it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 48.69%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2300.61it/s]|\n",
      "Column Shapes Score: 95.56%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 108.64it/s]|\n",
      "Column Pair Trends Score: 89.46%\n",
      "\n",
      "Overall Score (Average): 92.51%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2249.26it/s]|\n",
      "Column Shapes Score: 93.62%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 99.62it/s]|\n",
      "Column Pair Trends Score: 89.32%\n",
      "\n",
      "Overall Score (Average): 91.47%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "298it [00:00, 420.07it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2281.68it/s]|\n",
      "Column Shapes Score: 95.88%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 84.58it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 47.94%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1215.48it/s]|\n",
      "Column Shapes Score: 92.94%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 97.37it/s]|\n",
      "Column Pair Trends Score: 86.95%\n",
      "\n",
      "Overall Score (Average): 89.94%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2293.38it/s]|\n",
      "Column Shapes Score: 94.94%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 104.48it/s]|\n",
      "Column Pair Trends Score: 91.46%\n",
      "\n",
      "Overall Score (Average): 93.2%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 422.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2290.87it/s]|\n",
      "Column Shapes Score: 96.88%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 88.30it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 48.44%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2210.58it/s]|\n",
      "Column Shapes Score: 93.88%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 107.95it/s]|\n",
      "Column Pair Trends Score: 88.46%\n",
      "\n",
      "Overall Score (Average): 91.17%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2133.69it/s]|\n",
      "Column Shapes Score: 93.44%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 101.04it/s]|\n",
      "Column Pair Trends Score: 89.09%\n",
      "\n",
      "Overall Score (Average): 91.26%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "296it [00:00, 429.75it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2172.09it/s]|\n",
      "Column Shapes Score: 96.31%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 86.69it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 48.16%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sdv/single_table/base.py:92: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n",
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='4000' max='4000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [4000/4000 05:45, Epoch 400/400]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.352900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.249700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.244600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.242400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.240800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.239800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.239100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.238500</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2092.71it/s]|\n",
      "Column Shapes Score: 92.25%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 107.73it/s]|\n",
      "Column Pair Trends Score: 85.54%\n",
      "\n",
      "Overall Score (Average): 88.9%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2158.67it/s]|\n",
      "Column Shapes Score: 95.78%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 104.51it/s]|\n",
      "Column Pair Trends Score: 93.01%\n",
      "\n",
      "Overall Score (Average): 94.4%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "497it [00:01, 448.85it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2086.20it/s]|\n",
      "Column Shapes Score: 97.62%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 85.78it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 48.81%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1826.99it/s]|\n",
      "Column Shapes Score: 92.62%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 100.02it/s]|\n",
      "Column Pair Trends Score: 85.96%\n",
      "\n",
      "Overall Score (Average): 89.29%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2184.82it/s]|\n",
      "Column Shapes Score: 95.38%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 102.99it/s]|\n",
      "Column Pair Trends Score: 92.17%\n",
      "\n",
      "Overall Score (Average): 93.77%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "498it [00:01, 447.85it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2169.98it/s]|\n",
      "Column Shapes Score: 97.72%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 85.41it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 48.86%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2137.77it/s]|\n",
      "Column Shapes Score: 91.34%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 107.97it/s]|\n",
      "Column Pair Trends Score: 85.12%\n",
      "\n",
      "Overall Score (Average): 88.23%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2176.88it/s]|\n",
      "Column Shapes Score: 95.91%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 104.16it/s]|\n",
      "Column Pair Trends Score: 93.21%\n",
      "\n",
      "Overall Score (Average): 94.56%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "497it [00:01, 443.05it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2153.00it/s]|\n",
      "Column Shapes Score: 98.34%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 86.88it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 49.17%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2091.79it/s]|\n",
      "Column Shapes Score: 91.75%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 107.59it/s]|\n",
      "Column Pair Trends Score: 84.71%\n",
      "\n",
      "Overall Score (Average): 88.23%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2092.84it/s]|\n",
      "Column Shapes Score: 95.94%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 103.81it/s]|\n",
      "Column Pair Trends Score: 93.29%\n",
      "\n",
      "Overall Score (Average): 94.62%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "497it [00:01, 446.86it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2119.67it/s]|\n",
      "Column Shapes Score: 96.81%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 84.82it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 48.41%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2153.41it/s]|\n",
      "Column Shapes Score: 92.66%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 106.33it/s]|\n",
      "Column Pair Trends Score: 87.07%\n",
      "\n",
      "Overall Score (Average): 89.86%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2196.26it/s]|\n",
      "Column Shapes Score: 96.53%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 105.70it/s]|\n",
      "Column Pair Trends Score: 94.17%\n",
      "\n",
      "Overall Score (Average): 95.35%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "497it [00:01, 450.06it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2145.01it/s]|\n",
      "Column Shapes Score: 97.47%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 86.61it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 48.73%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sdv/single_table/base.py:92: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n",
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='4000' max='4000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [4000/4000 05:49, Epoch 200/200]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.353900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.249300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.244400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.242200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.240600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.239700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.238900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.238300</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1023.81it/s]|\n",
      "Column Shapes Score: 92.98%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 94.81it/s]|\n",
      "Column Pair Trends Score: 86.58%\n",
      "\n",
      "Overall Score (Average): 89.78%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 2071.90it/s]|\n",
      "Column Shapes Score: 96.72%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 106.96it/s]|\n",
      "Column Pair Trends Score: 94.7%\n",
      "\n",
      "Overall Score (Average): 95.71%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "892it [00:02, 443.63it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1541.10it/s]|\n",
      "Column Shapes Score: 98.42%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 82.95it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 49.21%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1895.62it/s]|\n",
      "Column Shapes Score: 93.28%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 105.47it/s]|\n",
      "Column Pair Trends Score: 86.58%\n",
      "\n",
      "Overall Score (Average): 89.93%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1936.88it/s]|\n",
      "Column Shapes Score: 97.59%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 95.75it/s]|\n",
      "Column Pair Trends Score: 96.03%\n",
      "\n",
      "Overall Score (Average): 96.81%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "895it [00:01, 452.23it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1812.09it/s]|\n",
      "Column Shapes Score: 98.12%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 83.39it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 49.06%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1013.88it/s]|\n",
      "Column Shapes Score: 92.16%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 99.81it/s]|\n",
      "Column Pair Trends Score: 85.16%\n",
      "\n",
      "Overall Score (Average): 88.66%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1115.54it/s]|\n",
      "Column Shapes Score: 97.33%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 86.84it/s]|\n",
      "Column Pair Trends Score: 95.59%\n",
      "\n",
      "Overall Score (Average): 96.46%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "896it [00:02, 442.21it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1590.03it/s]|\n",
      "Column Shapes Score: 98.09%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 81.85it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 49.05%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1892.74it/s]|\n",
      "Column Shapes Score: 92.78%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 104.06it/s]|\n",
      "Column Pair Trends Score: 86.38%\n",
      "\n",
      "Overall Score (Average): 89.58%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 624.97it/s]|\n",
      "Column Shapes Score: 97.06%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 95.02it/s]|\n",
      "Column Pair Trends Score: 95.22%\n",
      "\n",
      "Overall Score (Average): 96.14%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "897it [00:02, 444.07it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1603.94it/s]|\n",
      "Column Shapes Score: 98.47%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 83.84it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 49.23%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1047.50it/s]|\n",
      "Column Shapes Score: 91.77%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |███████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 105.56it/s]|\n",
      "Column Pair Trends Score: 85.14%\n",
      "\n",
      "Overall Score (Average): 88.45%\n",
      "\n",
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1935.53it/s]|\n",
      "Column Shapes Score: 96.95%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 93.76it/s]|\n",
      "Column Pair Trends Score: 95.0%\n",
      "\n",
      "Overall Score (Average): 95.98%\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "897it [00:01, 461.08it/s]                                                                                                                                                          \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating report ...\n",
      "\n",
      "(1/2) Evaluating Column Shapes: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 1871.73it/s]|\n",
      "Column Shapes Score: 96.59%\n",
      "\n",
      "(2/2) Evaluating Column Pair Trends: |████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 83.95it/s]|\n",
      "Column Pair Trends Score: 0.0%\n",
      "\n",
      "Overall Score (Average): 48.3%\n",
      "\n"
     ]
    }
   ],
   "source": [
    "num_lst = [100, 200, 400, 800]\n",
    "epoch_lst = [400, 600, 400, 200]\n",
    "for i in range(4):\n",
    "    num = num_lst[i]\n",
    "    sampled_data = data_lst[i]\n",
    "    sampled_data = sampled_data[cols].copy()\n",
    "    df_test = test_data.copy()\n",
    "    df_comb = pd.concat([sampled_data, df_test])\n",
    "    dummy_cols = list(pd.get_dummies(df_comb[x_cols]).columns)\n",
    "    df_test_dummy = pd.get_dummies(df_test[cols])\n",
    "    df_test_dummy = col_check(df_test_dummy, dummy_cols)\n",
    "    X_test = df_test_dummy.to_numpy()\n",
    "    y_test = df_test[y_col].to_numpy()\n",
    "    metadata = SingleTableMetadata()\n",
    "    metadata.detect_from_dataframe(sampled_data[cols])\n",
    "    metadata_dict = metadata.to_dict()\n",
    "    ctgan_model = CTGANSynthesizer(metadata, cuda=True)\n",
    "    ctgan_model.fit(sampled_data[cols])\n",
    "    tvae_model = TVAESynthesizer(metadata, cuda=True)\n",
    "    tvae_model.fit(sampled_data[cols])\n",
    "\n",
    "    great_model = GReaT(llm='distilgpt2', batch_size = 40, epochs=epoch_lst[i])\n",
    "    great_model.fit(sampled_data[cols])\n",
    "\n",
    "    df_train_dummy = pd.get_dummies(sampled_data[x_cols])\n",
    "    df_train_dummy = col_check(df_train_dummy, dummy_cols)\n",
    "    X_train = df_train_dummy.to_numpy()\n",
    "    y_train = sampled_data[y_col].to_numpy()\n",
    "    baseline = compare_MLE(X_test, y_test, X_train, y_train, task='classification')\n",
    "    res[str(num)] = {}\n",
    "    res[str(num)]['baseline'] = baseline\n",
    "    res_shape[str(num)] = {'ctgan':[], 'tvae':[], 'great':[]}\n",
    "\n",
    "    for j in range(len(seed_lst)):\n",
    "        seed = seed_lst[j]\n",
    "        res[str(num)][str(seed)] = {}\n",
    "        # CTGAN\n",
    "        #ctgan_model.set_random_state(seed)\n",
    "        ctgan_syn = ctgan_model.sample(len(sampled_data))\n",
    "        ctgan_syn.reset_index(inplace=True, drop=True)\n",
    "        ctgan_syn.to_csv(f'gen/Asia/{num}/CTGAN/df_{seed}.csv', index=False)\n",
    "        ctgan_report = QualityReport()\n",
    "        ctgan_report.generate(df_test[cols], ctgan_syn[cols], metadata_dict)\n",
    "        res_shape[str(num)]['ctgan'].append(ctgan_report._overall_score)\n",
    "        \n",
    "        ctgan_df_dummy = pd.get_dummies(ctgan_syn[x_cols])\n",
    "        ctgan_df_dummy = col_check(ctgan_df_dummy, dummy_cols)\n",
    "        X_train = ctgan_df_dummy.to_numpy()\n",
    "        y_train = ctgan_syn[y_col].to_numpy()\n",
    "\n",
    "        ctgan_performance = compare_MLE(X_test, y_test, X_train, y_train, task='classification')\n",
    "        res[str(num)][str(seed)]['ctgan'] = ctgan_performance\n",
    "        \n",
    "        # TVAE\n",
    "        #tvae_model.set_random_state(seed)\n",
    "        tvae_syn = tvae_model.sample(len(sampled_data))\n",
    "        tvae_syn.reset_index(inplace=True, drop=True)\n",
    "        tvae_syn.to_csv(f'gen/Asia/{num}/TVAE/df_{seed}.csv', index=False)\n",
    "        tvae_report = QualityReport()\n",
    "        tvae_report.generate(df_test[cols], tvae_syn[cols], metadata_dict)\n",
    "        res_shape[str(num)]['tvae'].append(tvae_report._overall_score)\n",
    "        tvae_df_dummy = pd.get_dummies(tvae_syn[x_cols])\n",
    "        tvae_df_dummy = col_check(tvae_df_dummy, dummy_cols)\n",
    "\n",
    "        X_train = tvae_df_dummy.to_numpy()\n",
    "        y_train = tvae_syn[y_col].to_numpy()\n",
    "\n",
    "        tvae_performance = compare_MLE(X_test, y_test, X_train, y_train, task='classification')\n",
    "        res[str(num)][str(seed)]['tave'] = tvae_performance\n",
    "\n",
    "        # Be-great\n",
    "        great_syn_df = great_model.sample(len(sampled_data), max_length=4000)\n",
    "        great_syn_df.reset_index(inplace=True, drop=True)\n",
    "        great_syn_df.to_csv(f'gen/Asia/{num}/BeGReaT/df_{seed}.csv', index=False)\n",
    "        great_report = QualityReport()\n",
    "        great_report.generate(df_test[cols], great_syn_df[cols], metadata_dict)\n",
    "        res_shape[str(num)]['great'].append(great_report._overall_score)\n",
    "        great_df_dummy = pd.get_dummies(great_syn_df[x_cols])\n",
    "        great_df_dummy = col_check(great_df_dummy, dummy_cols)\n",
    "\n",
    "        X_train = great_df_dummy.to_numpy()\n",
    "        y_train = great_syn_df[y_col].to_numpy()\n",
    "        great_performance = compare_MLE(X_test, y_test, X_train, y_train, task='classification')\n",
    "        res[str(num)][str(seed)]['great'] = great_performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e33f8637-09d3-4e38-a87e-2f55f020f0c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mle_summary_f1(res):\n",
    "    eval_sum_counts = {}\n",
    "    highest_avg_values = {}\n",
    "\n",
    "    for num, results in res.items():\n",
    "        eval_sum_counts[num] = {}\n",
    "        highest_avg_values[num] = {}\n",
    "        for seed, results0 in results.items():\n",
    "            if seed != 'baseline':\n",
    "                for model, results00 in results0.items():\n",
    "                    if model not in eval_sum_counts[num]:\n",
    "                        eval_sum_counts[num][model] = {}\n",
    "                    for eval_model, results000 in results00.items():\n",
    "                        accuracy = results000.get('accuracy', None)\n",
    "                        f1_score = results000.get('f1 score', None)\n",
    "                        if accuracy is not None and f1_score is not None:\n",
    "                            avg_value = (f1_score + f1_score) / 2\n",
    "                            if model not in highest_avg_values[num]:\n",
    "                                highest_avg_values[num][model] = []\n",
    "                            highest_avg_values[num][model].append((avg_value, seed, eval_model))\n",
    "            else:\n",
    "                eval_sum_counts[num]['baseline'] = results0\n",
    "\n",
    "    res_average = {}\n",
    "    for num, model_data in highest_avg_values.items():\n",
    "        res_average[num] = {}\n",
    "        for model, avg_data in model_data.items():\n",
    "            # Select the highest average value for each seed\n",
    "            highest_avg_by_seed = {}\n",
    "            for avg_value, seed, eval_model in avg_data:\n",
    "                if seed not in highest_avg_by_seed or avg_value > highest_avg_by_seed[seed][0]:\n",
    "                    highest_avg_by_seed[seed] = (avg_value, eval_model)\n",
    "\n",
    "            # Calculate mean and std of the highest average values across seeds\n",
    "            highest_avg_values_list = [value[0] for value in highest_avg_by_seed.values()]\n",
    "            res_average[num][model] = {\n",
    "                'highest_avg_mean': np.mean(highest_avg_values_list),\n",
    "                'highest_avg_std': np.std(highest_avg_values_list),\n",
    "                'eval_model_selected': {seed: eval_model for seed, (avg_value, eval_model) in highest_avg_by_seed.items()}\n",
    "            }\n",
    "\n",
    "    return res_average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "41fee85f-e0a9-4a50-b9a5-e5fdd7203c1a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'ctgan': {'highest_avg_mean': 0.628015770718181,\n",
       "   'highest_avg_std': 0.1932072670972264,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Random Forest',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Random Forest',\n",
       "    '4': 'SVC'}},\n",
       "  'tave': {'highest_avg_mean': 0.8267023836017104,\n",
       "   'highest_avg_std': 0.00505273209046899,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'SVC',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Logistic Regression'}},\n",
       "  'great': {'highest_avg_mean': 0.8307665121726002,\n",
       "   'highest_avg_std': 7.521146659581241e-05,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'SVC',\n",
       "    '2': 'Random Forest',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Random Forest'}}},\n",
       " '200': {'ctgan': {'highest_avg_mean': 0.7119419975698825,\n",
       "   'highest_avg_std': 0.10484813494955227,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Random Forest',\n",
       "    '2': 'SVC',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Logistic Regression'}},\n",
       "  'tave': {'highest_avg_mean': 0.8246396143636046,\n",
       "   'highest_avg_std': 0.00505273209046899,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Logistic Regression'}},\n",
       "  'great': {'highest_avg_mean': 0.8247121238614505,\n",
       "   'highest_avg_std': 0.004995282838904802,\n",
       "   'eval_model_selected': {'0': 'Random Forest',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'SVC',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Logistic Regression'}}},\n",
       " '400': {'ctgan': {'highest_avg_mean': 0.5924469797304788,\n",
       "   'highest_avg_std': 0.17415422151324847,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'SVC',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Random Forest',\n",
       "    '4': 'SVC'}},\n",
       "  'tave': {'highest_avg_mean': 0.8288320578539221,\n",
       "   'highest_avg_std': 0.003915417573975378,\n",
       "   'eval_model_selected': {'0': 'Random Forest',\n",
       "    '1': 'Random Forest',\n",
       "    '2': 'Random Forest',\n",
       "    '3': 'Random Forest',\n",
       "    '4': 'Random Forest'}},\n",
       "  'great': {'highest_avg_mean': 0.8266102687437273,\n",
       "   'highest_avg_std': 0.004977520623873177,\n",
       "   'eval_model_selected': {'0': 'Random Forest',\n",
       "    '1': 'Random Forest',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Logistic Regression'}}},\n",
       " '800': {'ctgan': {'highest_avg_mean': 0.48326169890010445,\n",
       "   'highest_avg_std': 0.0598793819132495,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'SVC',\n",
       "    '2': 'SVC',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Random Forest'}},\n",
       "  'tave': {'highest_avg_mean': 0.8246396143636044,\n",
       "   'highest_avg_std': 0.00505273209046899,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Random Forest'}},\n",
       "  'great': {'highest_avg_mean': 0.8246489077824426,\n",
       "   'highest_avg_std': 0.0049832394581810855,\n",
       "   'eval_model_selected': {'0': 'Random Forest',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Random Forest'}}}}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary_f1(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "194a7dd7-51f4-4f1e-a708-d8828ca48718",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Logistic Regression': {'f1 score': 0.8205140758873928, 'accuracy': 0.82},\n",
       " 'Random Forest': {'f1 score': 0.8308279220779221, 'accuracy': 0.83},\n",
       " 'SVC': {'f1 score': 0.8259177624078436, 'accuracy': 0.825}}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"100\"][\"baseline\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a54e2496-ac27-4b20-87e5-c40c3a962f83",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Logistic Regression': {'f1 score': 0.8205140758873928, 'accuracy': 0.82},\n",
       " 'Random Forest': {'f1 score': 0.830674397314617, 'accuracy': 0.83},\n",
       " 'SVC': {'f1 score': 0.8308279220779221, 'accuracy': 0.83}}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"200\"][\"baseline\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "5dbdea7e-3761-4b20-ad68-ce88182191e0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Logistic Regression': {'f1 score': 0.830674397314617, 'accuracy': 0.83},\n",
       " 'Random Forest': {'f1 score': 0.8308279220779221, 'accuracy': 0.83},\n",
       " 'SVC': {'f1 score': 0.8308279220779221, 'accuracy': 0.83}}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"400\"][\"baseline\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "9feb0339-181c-4873-962d-87322c63ce6f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Logistic Regression': {'f1 score': 0.8308279220779221, 'accuracy': 0.83},\n",
       " 'Random Forest': {'f1 score': 0.8308279220779221, 'accuracy': 0.83},\n",
       " 'SVC': {'f1 score': 0.8308279220779221, 'accuracy': 0.83}}"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"800\"][\"baseline\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5e8e6d0b-e827-456d-81ec-239fc61edf4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "file_name = \"mle_res/Asia_baseline.json\"\n",
    "with open(file_name, \"w\") as file:\n",
    "    json.dump(res, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9ac260f7-b31f-4934-bdc2-e373b5a75f83",
   "metadata": {},
   "outputs": [],
   "source": [
    "from xgboost import XGBClassifier\n",
    "xgb_clf = XGBClassifier(\n",
    "                    objective='binary:logistic',  # for binary classification, change for multiclass\n",
    "                    n_estimators=100,             # number of boosting rounds\n",
    "                    learning_rate=0.1,            # step size shrinkage\n",
    "                    max_depth=3,                  # maximum depth of a tree\n",
    "                    eval_metric='logloss'         # evaluation metric\n",
    "                )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e3a033c-a112-43c5-a7f4-09a97ed76615",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8de901e-0c41-479e-b6d9-6205167f5f62",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(4):\n",
    "    train_data = data_lst[i]\n",
    "    df_comb = pd.concat([train_data[x_cols], test_data[x_cols]])\n",
    "    dummy_cols = list(pd.get_dummies(df_comb[x_cols]).columns)\n",
    "    train_dummies = pd.get_dummies(train_data[x_cols])\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef1ffec2-f4db-463f-afb2-ec1e99bc4d24",
   "metadata": {},
   "source": [
    "# SMOTE model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c8fe0837-fa31-404b-b97c-3e9e6b8d962a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from imblearn.over_sampling import SMOTE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d39ad2e2-7c85-41fc-93db-8bfa800e0a1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_lst = [100, 200, 400, 800]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e2a546b3-cca6-4323-b71c-7aa495784cd4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'Logistic Regression': {'f1 score': 0.8261663985290738, 'accuracy': 0.825}, 'Random Forest': {'f1 score': 0.826323495709302, 'accuracy': 0.825}, 'SVC': {'f1 score': 0.8313487946292341, 'accuracy': 0.83}, 'XGB': {'f1 score': 0.8313487946292341, 'accuracy': 0.83}}\n",
      "{'Logistic Regression': {'f1 score': 0.8261663985290738, 'accuracy': 0.825}, 'Random Forest': {'f1 score': 0.8261663985290738, 'accuracy': 0.825}, 'SVC': {'f1 score': 0.826323495709302, 'accuracy': 0.825}, 'XGB': {'f1 score': 0.8261663985290738, 'accuracy': 0.825}}\n",
      "{'Logistic Regression': {'f1 score': 0.8363619047619048, 'accuracy': 0.835}, 'Random Forest': {'f1 score': 0.8363619047619048, 'accuracy': 0.835}, 'SVC': {'f1 score': 0.8414252454701893, 'accuracy': 0.84}, 'XGB': {'f1 score': 0.8363619047619048, 'accuracy': 0.835}}\n",
      "{'Logistic Regression': {'f1 score': 0.8363619047619048, 'accuracy': 0.835}, 'Random Forest': {'f1 score': 0.8413636363636364, 'accuracy': 0.84}, 'SVC': {'f1 score': 0.8364421980694687, 'accuracy': 0.835}, 'XGB': {'f1 score': 0.8363619047619048, 'accuracy': 0.835}}\n"
     ]
    }
   ],
   "source": [
    "res = {}\n",
    "for i in range(4):\n",
    "    num = num_lst[i]\n",
    "    sampled_data = data_lst[i].copy()\n",
    "    sampled_data = sampled_data[cols].copy()\n",
    "    sampled_data[\"label\"] = [0] * len(sampled_data)\n",
    "    sampled_data.loc[:len(sampled_data)//2, \"label\"] = 1\n",
    "    df_test = test_data.copy()\n",
    "    df_comb = pd.concat([sampled_data, df_test])\n",
    "    dummy_cols = list(pd.get_dummies(df_comb[x_cols]).columns)\n",
    "    df_test_dummy = pd.get_dummies(df_test[x_cols])\n",
    "    df_test_dummy = col_check(df_test_dummy, dummy_cols)\n",
    "    X_test = df_test_dummy.to_numpy()\n",
    "    y_test = df_test[y_col].to_numpy()\n",
    "    df_train_dummy = pd.get_dummies(sampled_data[x_cols])\n",
    "    df_train_dummy = col_check(df_train_dummy, dummy_cols)\n",
    "    df_train_dummy[y_col] = sampled_data[y_col]\n",
    "    X_train = df_train_dummy.to_numpy()\n",
    "    print(compare_MLE(X_test, y_test, X_train[:, :-1], X_train[:, -1], task=\"classification\"))\n",
    "    y_train = sampled_data[\"label\"].to_numpy()\n",
    "    res[str(num)] = {}\n",
    "    for j in range(len(seed_lst)):\n",
    "        seed = seed_lst[j]\n",
    "        res[str(num)][str(seed)] = {}\n",
    "        smote = SMOTE(random_state=seed)\n",
    "        X_syn, y_syn = smote.fit_resample(X_train, y_train)\n",
    "        performance = compare_MLE(X_test, y_test, X_syn[:, :-1], X_syn[:, -1], task=\"classification\")\n",
    "        res[str(num)][str(seed)][\"SMOTE\"] = performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3aa27899-1ade-400f-9a0d-334408654362",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mle_summary_f1(res):\n",
    "    eval_sum_counts = {}\n",
    "    highest_avg_values = {}\n",
    "\n",
    "    for num, results in res.items():\n",
    "        eval_sum_counts[num] = {}\n",
    "        highest_avg_values[num] = {}\n",
    "        for seed, results0 in results.items():\n",
    "            if seed != 'baseline':\n",
    "                for model, results00 in results0.items():\n",
    "                    if model not in eval_sum_counts[num]:\n",
    "                        eval_sum_counts[num][model] = {}\n",
    "                    for eval_model, results000 in results00.items():\n",
    "                        accuracy = results000.get('accuracy', None)\n",
    "                        f1_score = results000.get('f1 score', None)\n",
    "                        if accuracy is not None and f1_score is not None:\n",
    "                            avg_value = (f1_score + f1_score) / 2\n",
    "                            if model not in highest_avg_values[num]:\n",
    "                                highest_avg_values[num][model] = []\n",
    "                            highest_avg_values[num][model].append((avg_value, seed, eval_model))\n",
    "            else:\n",
    "                eval_sum_counts[num]['baseline'] = results0\n",
    "\n",
    "    res_average = {}\n",
    "    for num, model_data in highest_avg_values.items():\n",
    "        res_average[num] = {}\n",
    "        for model, avg_data in model_data.items():\n",
    "            # Select the highest average value for each seed\n",
    "            highest_avg_by_seed = {}\n",
    "            for avg_value, seed, eval_model in avg_data:\n",
    "                if seed not in highest_avg_by_seed or avg_value > highest_avg_by_seed[seed][0]:\n",
    "                    highest_avg_by_seed[seed] = (avg_value, eval_model)\n",
    "\n",
    "            # Calculate mean and std of the highest average values across seeds\n",
    "            highest_avg_values_list = [value[0] for value in highest_avg_by_seed.values()]\n",
    "            res_average[num][model] = {\n",
    "                'highest_avg_mean': np.mean(highest_avg_values_list),\n",
    "                'highest_avg_std': np.std(highest_avg_values_list),\n",
    "                'eval_model_selected': {seed: eval_model for seed, (avg_value, eval_model) in highest_avg_by_seed.items()}\n",
    "            }\n",
    "\n",
    "    return res_average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "64fa08da-1fa4-44c0-984d-edf2c8d6c24c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'SMOTE': {'highest_avg_mean': 0.8323590250134387,\n",
       "   'highest_avg_std': 0.004933267956213334,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'SVC',\n",
       "    '2': 'SVC',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Random Forest'}}},\n",
       " '200': {'SMOTE': {'highest_avg_mean': 0.8263234957093019,\n",
       "   'highest_avg_std': 1.1102230246251565e-16,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'SVC',\n",
       "    '2': 'SVC',\n",
       "    '3': 'SVC',\n",
       "    '4': 'SVC'}}},\n",
       " '400': {'SMOTE': {'highest_avg_mean': 0.8414252454701894,\n",
       "   'highest_avg_std': 1.1102230246251565e-16,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'SVC',\n",
       "    '2': 'SVC',\n",
       "    '3': 'SVC',\n",
       "    '4': 'SVC'}}},\n",
       " '800': {'SMOTE': {'highest_avg_mean': 0.8413636363636364,\n",
       "   'highest_avg_std': 0.0,\n",
       "   'eval_model_selected': {'0': 'Random Forest',\n",
       "    '1': 'Random Forest',\n",
       "    '2': 'Random Forest',\n",
       "    '3': 'Random Forest',\n",
       "    '4': 'Random Forest'}}}}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary_f1(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3237cdc3-175a-4abb-88c8-1af2788f547e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10",
   "language": "python",
   "name": "py3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
