{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "274e8bb5-530f-4902-9ac6-0b7b2a03876c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "parent_dir = os.path.abspath('../..')\n",
    "sys.path.append(parent_dir)\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"  # or \"true\" depending on your requirement\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from eval_utils import compare_MLE, data_profiling, cate_info, col_check,  transform_label\n",
    "from sklearn.preprocessing import LabelBinarizer\n",
    "import os\n",
    "import json\n",
    "import pickle\n",
    "import torch\n",
    "from sdv.metadata import SingleTableMetadata\n",
    "from sdv.single_table import CTGANSynthesizer, TVAESynthesizer\n",
    "from sdmetrics.reports.single_table import QualityReport\n",
    "from be_great import GReaT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9f5a7947-082a-4b6f-9e31-5c287aedda8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"../sample/ATACH/\"\n",
    "\n",
    "test_data = pd.read_csv(path + \"data_test.csv\")\n",
    "data100 = pd.read_csv(path + \"data100.csv\")\n",
    "data200 = pd.read_csv(path + \"data200.csv\")\n",
    "data400 = pd.read_csv(path + \"data400.csv\")\n",
    "data800 = pd.read_csv(path + \"data800.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6f53f3b0-1b8c-4e12-bbfd-a3f61edeb703",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_lst = [data100, data200, data400, data800]\n",
    "num_lst = [100, 200, 400, 800]\n",
    "\n",
    "seed_lst = list(range(5))\n",
    "\n",
    "x_cols = ['treatment', 'age', 'ICH volume',\n",
    "       'ICH Location', 'IVH volume', 'GCS score', 'NIHSS score',\n",
    "       'Systolic blood pressure',\n",
    "       'Diastolic Blood Pressure',\n",
    "       'Hypertension', 'Hyperlipidemia',\n",
    "       'Type I Diabetes', 'Type II Diabetes',\n",
    "       'Congestive heart failure', 'Atrial Fibrillation',\n",
    "       'PTCA',\n",
    "       'Peripheral Vascular Disease',\n",
    "       'Myocardial fraction', 'Anti-diabetic',\n",
    "       'Antihypertensives', 'White blood count', 'Hemoglobin',\n",
    "       'Hematocrit', 'Platelet count', 'APTT', 'INR', 'Glucose', 'Sodium',\n",
    "       'Potassium', 'Chloride', 'CD', 'Blood urea nitrogen', 'Creatinine',\n",
    "       'race', 'sex', 'ethnicity']\n",
    "y_col = 'mRS score after 30 days'\n",
    "\n",
    "cols = x_cols + [y_col]\n",
    "cate_var = ['race', 'ICH Location', 'sex', 'ethnicity']\n",
    "bool_var = ['treatment','Hypertension', 'Hyperlipidemia',\n",
    "       'Type I Diabetes', 'Type II Diabetes',\n",
    "       'Congestive heart failure', 'Atrial Fibrillation',\n",
    "       'PTCA',\n",
    "       'Peripheral Vascular Disease',\n",
    "       'Myocardial fraction', 'Anti-diabetic',\n",
    "       'Antihypertensives']\n",
    "\n",
    "\n",
    "num_var = list((set(cols) - set(cate_var)) - set(bool_var))\n",
    "\n",
    "res_shape = {}\n",
    "res = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f6bd256-8a1c-49a1-875a-3724a89b375f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sdv/single_table/base.py:92: UserWarning: We strongly recommend saving the metadata using 'save_to_json' for replicability in future SDV versions.\n",
      "  warnings.warn(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/torch/autograd/graph.py:768: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at ../aten/src/ATen/cuda/CublasHandlePool.cpp:135.)\n",
      "  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n"
     ]
    }
   ],
   "source": [
    "num_lst = [100, 200, 400, 800]\n",
    "epoch_lst = [400, 600, 400, 200]\n",
    "for i in range(4):\n",
    "    num = num_lst[i]\n",
    "    sampled_data = data_lst[i]\n",
    "    sampled_data = sampled_data[cols].copy()\n",
    "    df_test = test_data.copy()\n",
    "    df_comb = pd.concat([sampled_data, df_test])\n",
    "    dummy_cols = list(pd.get_dummies(df_comb[x_cols]).columns)\n",
    "    df_test_dummy = pd.get_dummies(df_test[cols])\n",
    "    df_test_dummy = col_check(df_test_dummy, dummy_cols)\n",
    "    X_test = df_test_dummy.to_numpy()\n",
    "    y_test = df_test[y_col].to_numpy()\n",
    "    metadata = SingleTableMetadata()\n",
    "    metadata.detect_from_dataframe(sampled_data[cols])\n",
    "    metadata_dict = metadata.to_dict()\n",
    "    ctgan_model = CTGANSynthesizer(metadata, cuda=True)\n",
    "    ctgan_model.fit(sampled_data[cols])\n",
    "    tvae_model = TVAESynthesizer(metadata, cuda=True)\n",
    "    tvae_model.fit(sampled_data[cols])\n",
    "\n",
    "    great_model = GReaT(llm='distilgpt2', batch_size = 40, epochs=epoch_lst[i])\n",
    "    great_model.fit(sampled_data[cols])\n",
    "\n",
    "    df_train_dummy = pd.get_dummies(sampled_data[x_cols])\n",
    "    df_train_dummy = col_check(df_train_dummy, dummy_cols)\n",
    "    X_train = df_train_dummy.to_numpy()\n",
    "    y_train = sampled_data[y_col].to_numpy()\n",
    "    baseline = compare_MLE(X_test, y_test, X_train, y_train, task='regression')\n",
    "    res[str(num)] = {}\n",
    "    res[str(num)]['baseline'] = baseline\n",
    "    res_shape[str(num)] = {'ctgan':[], 'tvae':[], 'great':[]}\n",
    "\n",
    "    for j in range(len(seed_lst)):\n",
    "        seed = seed_lst[j]\n",
    "        res[str(num)][str(seed)] = {}\n",
    "        # CTGAN\n",
    "        #ctgan_model.set_random_state(seed)\n",
    "        ctgan_syn = ctgan_model.sample(len(sampled_data))\n",
    "        ctgan_syn.reset_index(inplace=True, drop=True)\n",
    "        ctgan_syn.to_csv(f'gen/ATACH/{num}/CTGAN/df_{seed}.csv', index=False)\n",
    "        ctgan_df_dummy = pd.get_dummies(ctgan_syn[x_cols])\n",
    "        ctgan_df_dummy = col_check(ctgan_df_dummy, dummy_cols)\n",
    "        X_train = ctgan_df_dummy.to_numpy()\n",
    "        y_train = ctgan_syn[y_col].to_numpy()\n",
    "\n",
    "        ctgan_performance = compare_MLE(X_test, y_test, X_train, y_train, task='regression')\n",
    "        res[str(num)][str(seed)]['ctgan'] = ctgan_performance\n",
    "        \n",
    "        # TVAE\n",
    "        #tvae_model.set_random_state(seed)\n",
    "        tvae_syn = tvae_model.sample(len(sampled_data))\n",
    "        tvae_syn.reset_index(inplace=True, drop=True)\n",
    "        tvae_syn.to_csv(f'gen/ATACH/{num}/TVAE/df_{seed}.csv', index=False)\n",
    "        \n",
    "        tvae_df_dummy = pd.get_dummies(tvae_syn[x_cols])\n",
    "        tvae_df_dummy = col_check(tvae_df_dummy, dummy_cols)\n",
    "        X_train = tvae_df_dummy.to_numpy()\n",
    "        y_train = tvae_syn[y_col].to_numpy()\n",
    "        tvae_performance = compare_MLE(X_test, y_test, X_train, y_train, task='regression')\n",
    "        res[str(num)][str(seed)]['tave'] = tvae_performance\n",
    "\n",
    "        # Be-great\n",
    "        great_syn_df = great_model.sample(len(sampled_data), max_length=4000)\n",
    "        great_syn_df.reset_index(inplace=True, drop=True)\n",
    "        great_syn_df.to_csv(f'gen/ATACH/{num}/BeGReaT/df_{seed}.csv', index=False)\n",
    "        \n",
    "        great_df_dummy = pd.get_dummies(great_syn_df[x_cols])\n",
    "        great_df_dummy = col_check(great_df_dummy, dummy_cols)\n",
    "        X_train = great_df_dummy.to_numpy()\n",
    "        y_train = great_syn_df[y_col].to_numpy()\n",
    "        great_performance = compare_MLE(X_test, y_test, X_train, y_train, task='regression')\n",
    "        res[str(num)][str(seed)]['great'] = great_performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fd2141ab-3c0c-49ca-9169-30bbe9bc6e3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mle_summary(res):\n",
    "\n",
    "    eval_sum_counts = {}\n",
    "    highest_r2_values = {}\n",
    "\n",
    "    for num, results in res.items():\n",
    "        eval_sum_counts[num] = {}\n",
    "        highest_r2_values[num] = {}\n",
    "        for seed, results0 in results.items():\n",
    "            if seed != 'baseline':\n",
    "                for model, results00 in results0.items():\n",
    "                    if model not in eval_sum_counts[num]:\n",
    "                        eval_sum_counts[num][model] = {}\n",
    "                    for eval_model, results000 in results00.items():\n",
    "                        r2_value = results000.get('R2', None)\n",
    "                        if r2_value is not None:\n",
    "                            if model not in highest_r2_values[num]:\n",
    "                                highest_r2_values[num][model] = []\n",
    "                            highest_r2_values[num][model].append((r2_value, seed, eval_model))\n",
    "            else:\n",
    "                eval_sum_counts[num]['baseline'] = results0\n",
    "\n",
    "    res_average = {}\n",
    "    for num, model_data in highest_r2_values.items():\n",
    "        res_average[num] = {}\n",
    "        for model, r2_data in model_data.items():\n",
    "            # Select the highest R2 value for each seed\n",
    "            highest_r2_by_seed = {}\n",
    "            for r2_value, seed, eval_model in r2_data:\n",
    "                if seed not in highest_r2_by_seed or r2_value > highest_r2_by_seed[seed][0]:\n",
    "                    highest_r2_by_seed[seed] = (r2_value, eval_model)\n",
    "\n",
    "            # Calculate mean and std of the highest R2 values across seeds\n",
    "            highest_r2_values_list = [value[0] for value in highest_r2_by_seed.values()]\n",
    "            res_average[num][model] = {\n",
    "                'highest_R2_mean': np.mean(highest_r2_values_list),\n",
    "                'highest_R2_std': np.std(highest_r2_values_list),\n",
    "                'eval_model_selected': {seed: eval_model for seed, (r2_value, eval_model) in highest_r2_by_seed.items()}\n",
    "            }\n",
    "\n",
    "    return res_average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b64696a4-c722-4bd6-bdc6-576ceeb54306",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'ctgan': {'highest_R2_mean': -0.5720288608135177,\n",
       "   'highest_R2_std': 0.39894601286391207,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'Decision Tree',\n",
       "    '4': 'Decision Tree'}},\n",
       "  'tave': {'highest_R2_mean': -0.5725813315211916,\n",
       "   'highest_R2_std': 0.5591437431199254,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'Decision Tree',\n",
       "    '4': 'Decision Tree'}},\n",
       "  'great': {'highest_R2_mean': -0.41160780260593394,\n",
       "   'highest_R2_std': 0.1832038815907858,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'Decision Tree',\n",
       "    '4': 'Decision Tree'}}},\n",
       " '200': {'ctgan': {'highest_R2_mean': -0.31729707584562444,\n",
       "   'highest_R2_std': 0.07954201500415138,\n",
       "   'eval_model_selected': {'0': 'Linear Regression',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'Linear Regression'}},\n",
       "  'tave': {'highest_R2_mean': -0.08759421061969426,\n",
       "   'highest_R2_std': 0.04547710667900465,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'Decision Tree'}},\n",
       "  'great': {'highest_R2_mean': 0.10862290357780394,\n",
       "   'highest_R2_std': 0.08389957303923688,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'Decision Tree'}}},\n",
       " '400': {'ctgan': {'highest_R2_mean': -0.08652553487036412,\n",
       "   'highest_R2_std': 0.059130208993469976,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'Linear Regression'}},\n",
       "  'tave': {'highest_R2_mean': -0.11217534773422681,\n",
       "   'highest_R2_std': 0.18862762685971377,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'Decision Tree',\n",
       "    '4': 'Decision Tree'}},\n",
       "  'great': {'highest_R2_mean': 0.18453450242074737,\n",
       "   'highest_R2_std': 0.06098159537536542,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Linear Regression',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'Decision Tree',\n",
       "    '4': 'Decision Tree'}}},\n",
       " '800': {'ctgan': {'highest_R2_mean': -0.05247001727277267,\n",
       "   'highest_R2_std': 0.04396287604783294,\n",
       "   'eval_model_selected': {'0': 'Linear Regression',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'Decision Tree',\n",
       "    '4': 'Decision Tree'}},\n",
       "  'tave': {'highest_R2_mean': 0.11786490618706255,\n",
       "   'highest_R2_std': 0.0770612300839469,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'Decision Tree',\n",
       "    '4': 'Linear Regression'}},\n",
       "  'great': {'highest_R2_mean': 0.0015496204732398766,\n",
       "   'highest_R2_std': 0.1420343637130474,\n",
       "   'eval_model_selected': {'0': 'Linear Regression',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'Linear Regression'}}}}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "eb8d3df6-4205-4f86-9d9a-860537a55105",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Linear Regression': {'R2': -0.2656002913756361},\n",
       " 'Decision Tree': {'R2': 0.25944843882665236}}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"100\"][\"baseline\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "88167ad7-8f24-4cc8-80d3-6a9c43a2e4f7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Linear Regression': {'R2': 0.035822014412650094},\n",
       " 'Decision Tree': {'R2': 0.26847709034248934}}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"200\"][\"baseline\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "fc820a0d-5171-4027-a3a8-7d29df09267f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Linear Regression': {'R2': 0.31222352438952494},\n",
       " 'Decision Tree': {'R2': 0.25435013951511465}}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"400\"][\"baseline\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "bf844357-eb1e-4703-8fdc-a4505819486b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'Linear Regression': {'R2': 0.1808548414813128},\n",
       " 'Decision Tree': {'R2': 0.3995694156890841}}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"800\"][\"baseline\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7fb376e1-e851-46d0-9032-3663681a81c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "file_name = \"mle_res/atach_baseline.json\"\n",
    "with open(file_name, \"w\") as file:\n",
    "    json.dump(res, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6d7010ba-c75f-4776-9108-b5dfae715099",
   "metadata": {},
   "outputs": [],
   "source": [
    "import xgboost as xgb\n",
    "from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, r2_score\n",
    "xgb_reg = xgb.XGBRegressor(\n",
    "    objective='reg:squarederror',  # regression with squared loss\n",
    "    n_estimators=100,              # number of boosting rounds\n",
    "    learning_rate=0.1,             # step size shrinkage\n",
    "    max_depth=3,                   # maximum depth of a tree\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "cc54adb0-99ef-4463-9346-0394342157b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100, r2: -0.6835652670960299\n"
     ]
    }
   ],
   "source": [
    "atach100 = pd.read_csv(\"gen/ATACH/100/CTGAN/df_1.csv\")\n",
    "df_test = test_data.copy()\n",
    "df_comb = pd.concat([atach100, df_test])\n",
    "dummy_cols = list(pd.get_dummies(df_comb[x_cols]).columns)\n",
    "df_train_dummy = pd.get_dummies(atach100[x_cols])\n",
    "df_train_dummy = col_check(df_train_dummy, dummy_cols)\n",
    "X_train = df_train_dummy.to_numpy()\n",
    "y_train = atach100[y_col].to_numpy()\n",
    "df_test_dummy = pd.get_dummies(df_test[cols])\n",
    "df_test_dummy = col_check(df_test_dummy, dummy_cols)\n",
    "X_test = df_test_dummy.to_numpy()\n",
    "y_test = df_test[y_col].to_numpy()\n",
    "xgb_reg.fit(X_train, y_train)\n",
    "y_pred = xgb_reg.predict(X_test)\n",
    "r2_linear = r2_score(y_test, y_pred)\n",
    "print(f\"100, r2: {r2_linear}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7e10c9a-3872-463a-b816-49d37cf201bf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10",
   "language": "python",
   "name": "py3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
