{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "274e8bb5-530f-4902-9ac6-0b7b2a03876c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "parent_dir = os.path.abspath('../..')\n",
    "sys.path.append(parent_dir)\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"  # or \"true\" depending on your requirement\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from eval_utils import compare_MLE, data_profiling, cate_info, col_check,  transform_label\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "import os\n",
    "import json\n",
    "import pickle\n",
    "import torch\n",
    "from sdv.metadata import SingleTableMetadata\n",
    "from sdv.single_table import CTGANSynthesizer, TVAESynthesizer\n",
    "from sdmetrics.reports.single_table import QualityReport\n",
    "from be_great import GReaT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9f5a7947-082a-4b6f-9e31-5c287aedda8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"../sample/ERICH/\"\n",
    "\n",
    "test_data = pd.read_csv(path + \"data_test.csv\")\n",
    "data100 = pd.read_csv(path + \"data100.csv\")\n",
    "data200 = pd.read_csv(path + \"data200.csv\")\n",
    "data400 = pd.read_csv(path + \"data400.csv\")\n",
    "data800 = pd.read_csv(path + \"data800.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6f53f3b0-1b8c-4e12-bbfd-a3f61edeb703",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_lst = [data100, data200, data400, data800]\n",
    "num_lst = [100, 200, 400, 800]\n",
    "\n",
    "seed_lst = list(range(5))\n",
    "\n",
    "x_cols = ['sex', 'age', 'race', 'ICH volume', 'IVH volume',\n",
    "       'ICH location', 'GCS score', 'Systolic blood pressure',\n",
    "       'Diastolic blood pressure', 'Hypertension', 'Hyperlipidemia',\n",
    "       'Type 2 diabetes', 'Type 1 diabetes', 'Atrial Fibrillation',\n",
    "       'Congestive Heart Failure', 'Vascular disease', 'Cocaine use', 'Smoke',\n",
    "       'PTCA', 'Myocardial infraction', 'anti-diabetic', 'Anti-hypertensives',\n",
    "       'White blood cell count', 'Hemoglobin', 'Platelet count', 'APTT', 'INR',\n",
    "       'Glucose']\n",
    "y_col = 'mRS score after 90 days'\n",
    "cols = x_cols + [y_col]\n",
    "cate_var = ['race', 'sex', 'ICH location']\n",
    "bool_var = ['Hypertension', 'Hyperlipidemia',\n",
    "       'Type 2 diabetes', 'Type 1 diabetes', 'Atrial Fibrillation',\n",
    "       'Congestive Heart Failure', 'Vascular disease', 'Cocaine use', 'Smoke',\n",
    "       'PTCA', 'Myocardial infraction', 'anti-diabetic', 'Anti-hypertensives']\n",
    "num_var = list((set(cols) - set(cate_var)) - set(bool_var))\n",
    "\n",
    "res_shape = {}\n",
    "res = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f6bd256-8a1c-49a1-875a-3724a89b375f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sdv/single_table/base.py:92: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/torch/autograd/graph.py:768: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at ../aten/src/ATen/cuda/CublasHandlePool.cpp:135.)\n",
      "  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n",
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1200' max='1200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1200/1200 06:21, Epoch 400/400]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.787900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.632700</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                                                                                                                             | 0/100 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "102it [00:03, 27.11it/s]                                                                                                                                                                                                                  \n",
      "108it [00:03, 29.49it/s]                                                                                                                                                                                                                  \n",
      "136it [00:04, 27.74it/s]                                                                                                                                                                                                                  \n",
      "116it [00:03, 31.36it/s]                                                                                                                                                                                                                  \n",
      "107it [00:03, 29.31it/s]                                                                                                                                                                                                                  \n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sdv/single_table/base.py:92: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n",
      "You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='319' max='3000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 319/3000 01:57 < 16:31, 2.70 it/s, Epoch 63.60/600]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "num_lst = [100, 200, 400, 800]\n",
    "epoch_lst = [400, 600, 400, 200]\n",
    "for i in range(4):\n",
    "    num = num_lst[i]\n",
    "    sampled_data = data_lst[i]\n",
    "    sampled_data = sampled_data[cols].copy()\n",
    "    df_test = test_data.copy()\n",
    "    df_comb = pd.concat([sampled_data, df_test])\n",
    "    dummy_cols = list(pd.get_dummies(df_comb[x_cols]).columns)\n",
    "    df_test_dummy = pd.get_dummies(df_test[cols])\n",
    "    df_test_dummy = col_check(df_test_dummy, dummy_cols)\n",
    "    X_test = df_test_dummy.to_numpy()\n",
    "    y_test = df_test[y_col].to_numpy()\n",
    "    metadata = SingleTableMetadata()\n",
    "    metadata.detect_from_dataframe(sampled_data[cols])\n",
    "    metadata_dict = metadata.to_dict()\n",
    "    ctgan_model = CTGANSynthesizer(metadata, cuda=True)\n",
    "    ctgan_model.fit(sampled_data[cols])\n",
    "    tvae_model = TVAESynthesizer(metadata, cuda=True)\n",
    "    tvae_model.fit(sampled_data[cols])\n",
    "\n",
    "    great_model = GReaT(llm='distilgpt2', batch_size = 40, epochs=epoch_lst[i])\n",
    "    great_model.fit(sampled_data[cols])\n",
    "\n",
    "    df_train_dummy = pd.get_dummies(sampled_data[x_cols])\n",
    "    df_train_dummy = col_check(df_train_dummy, dummy_cols)\n",
    "    X_train = df_train_dummy.to_numpy()\n",
    "    y_train = sampled_data[y_col].to_numpy()\n",
    "    baseline = compare_MLE(X_test, y_test, X_train, y_train, task='regression')\n",
    "    res[str(num)] = {}\n",
    "    res[str(num)]['baseline'] = baseline\n",
    "    res_shape[str(num)] = {'ctgan':[], 'tvae':[], 'great':[]}\n",
    "\n",
    "    for j in range(len(seed_lst)):\n",
    "        seed = seed_lst[j]\n",
    "        res[str(num)][str(seed)] = {}\n",
    "        # CTGAN\n",
    "        #ctgan_model.set_random_state(seed)\n",
    "        ctgan_syn = ctgan_model.sample(len(sampled_data))\n",
    "        ctgan_syn.reset_index(inplace=True, drop=True)\n",
    "        ctgan_syn.to_csv(f'gen/ERICH/{num}/CTGAN/df_{seed}.csv', index=False)\n",
    "        ctgan_df_dummy = pd.get_dummies(ctgan_syn[x_cols])\n",
    "        ctgan_df_dummy = col_check(ctgan_df_dummy, dummy_cols)\n",
    "        X_train = ctgan_df_dummy.to_numpy()\n",
    "        y_train = ctgan_syn[y_col].to_numpy()\n",
    "\n",
    "        ctgan_performance = compare_MLE(X_test, y_test, X_train, y_train, task='regression')\n",
    "        res[str(num)][str(seed)]['ctgan'] = ctgan_performance\n",
    "        \n",
    "        # TVAE\n",
    "        #tvae_model.set_random_state(seed)\n",
    "        tvae_syn = tvae_model.sample(len(sampled_data))\n",
    "        tvae_syn.reset_index(inplace=True, drop=True)\n",
    "        tvae_syn.to_csv(f'gen/ERICH/{num}/TVAE/df_{seed}.csv', index=False)\n",
    "        \n",
    "        tvae_df_dummy = pd.get_dummies(tvae_syn[x_cols])\n",
    "        tvae_df_dummy = col_check(tvae_df_dummy, dummy_cols)\n",
    "        X_train = tvae_df_dummy.to_numpy()\n",
    "        y_train = tvae_syn[y_col].to_numpy()\n",
    "        tvae_performance = compare_MLE(X_test, y_test, X_train, y_train, task='regression')\n",
    "        res[str(num)][str(seed)]['tave'] = tvae_performance\n",
    "\n",
    "        # Be-great\n",
    "        great_syn_df = great_model.sample(len(sampled_data), max_length=4000)\n",
    "        great_syn_df.reset_index(inplace=True, drop=True)\n",
    "        great_syn_df.to_csv(f'gen/ERICH/{num}/BeGReaT/df_{seed}.csv', index=False)\n",
    "        \n",
    "        great_df_dummy = pd.get_dummies(great_syn_df[x_cols])\n",
    "        great_df_dummy = col_check(great_df_dummy, dummy_cols)\n",
    "        X_train = great_df_dummy.to_numpy()\n",
    "        y_train = great_syn_df[y_col].to_numpy()\n",
    "        great_performance = compare_MLE(X_test, y_test, X_train, y_train, task='regression')\n",
    "        res[str(num)][str(seed)]['great'] = great_performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd2141ab-3c0c-49ca-9169-30bbe9bc6e3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mle_summary(res):\n",
    "\n",
    "    eval_sum_counts = {}\n",
    "    highest_r2_values = {}\n",
    "\n",
    "    for num, results in res.items():\n",
    "        eval_sum_counts[num] = {}\n",
    "        highest_r2_values[num] = {}\n",
    "        for seed, results0 in results.items():\n",
    "            if seed != 'baseline':\n",
    "                for model, results00 in results0.items():\n",
    "                    if model not in eval_sum_counts[num]:\n",
    "                        eval_sum_counts[num][model] = {}\n",
    "                    for eval_model, results000 in results00.items():\n",
    "                        r2_value = results000.get('R2', None)\n",
    "                        if r2_value is not None:\n",
    "                            if model not in highest_r2_values[num]:\n",
    "                                highest_r2_values[num][model] = []\n",
    "                            highest_r2_values[num][model].append((r2_value, seed, eval_model))\n",
    "            else:\n",
    "                eval_sum_counts[num]['baseline'] = results0\n",
    "\n",
    "    res_average = {}\n",
    "    for num, model_data in highest_r2_values.items():\n",
    "        res_average[num] = {}\n",
    "        for model, r2_data in model_data.items():\n",
    "            # Select the highest R2 value for each seed\n",
    "            highest_r2_by_seed = {}\n",
    "            for r2_value, seed, eval_model in r2_data:\n",
    "                if seed not in highest_r2_by_seed or r2_value > highest_r2_by_seed[seed][0]:\n",
    "                    highest_r2_by_seed[seed] = (r2_value, eval_model)\n",
    "\n",
    "            # Calculate mean and std of the highest R2 values across seeds\n",
    "            highest_r2_values_list = [value[0] for value in highest_r2_by_seed.values()]\n",
    "            res_average[num][model] = {\n",
    "                'highest_R2_mean': np.mean(highest_r2_values_list),\n",
    "                'highest_R2_std': np.std(highest_r2_values_list),\n",
    "                'eval_model_selected': {seed: eval_model for seed, (r2_value, eval_model) in highest_r2_by_seed.items()}\n",
    "            }\n",
    "\n",
    "    return res_average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b64696a4-c722-4bd6-bdc6-576ceeb54306",
   "metadata": {},
   "outputs": [],
   "source": [
    "mle_summary(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "eb8d3df6-4205-4f86-9d9a-860537a55105",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Linear Regression': {'R2': -2.8558346910998154},\n",
       " 'Decision Tree': {'R2': -0.26026096993894776}}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"100\"][\"baseline\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "88167ad7-8f24-4cc8-80d3-6a9c43a2e4f7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Linear Regression': {'R2': -0.7334412855686625},\n",
       " 'Decision Tree': {'R2': -0.02811539816982167}}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"200\"][\"baseline\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fc820a0d-5171-4027-a3a8-7d29df09267f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Linear Regression': {'R2': -0.2921302257477518},\n",
       " 'Decision Tree': {'R2': -0.03718086475927773}}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"400\"][\"baseline\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "bf844357-eb1e-4703-8fdc-a4505819486b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Linear Regression': {'R2': -0.14820947972299292},\n",
       " 'Decision Tree': {'R2': 0.07243588170166804}}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"800\"][\"baseline\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7fb376e1-e851-46d0-9032-3663681a81c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "file_name = \"mle_res/atach_baseline.json\"\n",
    "with open(file_name, \"w\") as file:\n",
    "    json.dump(res, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "5832bca6-bcf0-4ca5-8ecb-84dac73056af",
   "metadata": {},
   "outputs": [],
   "source": [
    "import xgboost as xgb\n",
    "from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, r2_score\n",
    "xgb_reg = xgb.XGBRegressor(\n",
    "    objective='reg:squarederror',  # regression with squared loss\n",
    "    n_estimators=100,              # number of boosting rounds\n",
    "    learning_rate=0.1,             # step size shrinkage\n",
    "    max_depth=3,                   # maximum depth of a tree\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "cc54adb0-99ef-4463-9346-0394342157b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100, r2: -0.048671900539349666\n",
      "200, r2: 0.15847891351135623\n",
      "400, r2: 0.1824214939694706\n",
      "800, r2: 0.2141180395139447\n"
     ]
    }
   ],
   "source": [
    "num_lst = [100, 200, 400,800]\n",
    "data_lst = [data100, data200, data400, data800]\n",
    "seed_lst = list(range(5))\n",
    "for i in range(4):\n",
    "    num = num_lst[i]\n",
    "    sampled_data = data_lst[i]\n",
    "    sampled_data = sampled_data[cols].copy()\n",
    "    df_test = test_data.copy()\n",
    "    df_comb = pd.concat([sampled_data, df_test])\n",
    "    dummy_cols = list(pd.get_dummies(df_comb[x_cols]).columns)\n",
    "    df_test_dummy = pd.get_dummies(df_test[cols])\n",
    "    df_test_dummy = col_check(df_test_dummy, dummy_cols)\n",
    "    X_test = df_test_dummy.to_numpy()\n",
    "    y_test = df_test[y_col].to_numpy()\n",
    "    \n",
    "    df_train_dummy = pd.get_dummies(sampled_data[x_cols])\n",
    "    df_train_dummy = col_check(df_train_dummy, dummy_cols)\n",
    "    X_train = df_train_dummy.to_numpy()\n",
    "    y_train = sampled_data[y_col].to_numpy()\n",
    "\n",
    "    \n",
    "    xgb_reg.fit(X_train, y_train)\n",
    "    y_pred = xgb_reg.predict(X_test)\n",
    "    r2_linear = r2_score(y_test, y_pred)\n",
    "    print(f\"{num}, r2: {r2_linear}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "85ee037e-120e-4f40-9255-9f3fc3a6327e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ctgan100 = pd.read_csv(\"gen/ERICH/100/TVAE/df_1.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "09ca6ceb-d757-46e1-9786-371b256a2c6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train_dummy = pd.get_dummies(ctgan100[x_cols])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "36a12c56-4402-4c61-8f79-8a1c0772b79f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train_dummy = col_check(df_train_dummy, dummy_cols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "901d82f9-a2e5-40aa-9a02-745a62e01c41",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "800, r2: -0.05938151610575715\n"
     ]
    }
   ],
   "source": [
    "X_train = df_train_dummy.to_numpy()\n",
    "y_train = ctgan100[y_col].to_numpy()\n",
    "xgb_reg.fit(X_train, y_train)\n",
    "y_pred = xgb_reg.predict(X_test)\n",
    "r2_linear = r2_score(y_test, y_pred)\n",
    "print(f\"{num}, r2: {r2_linear}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67927cb4-ed11-40b3-aae3-eaeddffb7781",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10",
   "language": "python",
   "name": "py3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
