{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9fb39070-d80a-4f0b-bb13-1b0572b1aedc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "parent_dir = os.path.abspath('../..')\n",
    "sys.path.append(parent_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b981d19f-24b6-48aa-b128-b135b85ac1ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, r2_score\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.tree import DecisionTreeRegressor\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.svm import SVC\n",
    "from xgboost import XGBClassifier\n",
    "from xgboost import XGBRegressor\n",
    "from eval_utils import compare_MLE, data_profiling, cate_info, col_check,  transform_label, _calculate_dcr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "39732ab0-73f0-445c-91e3-fa07d68d5061",
   "metadata": {},
   "outputs": [],
   "source": [
    "def col_check(df, cols):\n",
    "    if set(cols) - set(df.columns):\n",
    "        diff = list(set(cols) - set(df.columns))\n",
    "        for col in diff:\n",
    "            df[col] = [0] * len(df)\n",
    "    if set(df.columns) - set(cols):\n",
    "        dropped_cols = list(set(df.columns) - set(cols))\n",
    "        df.drop(columns=dropped_cols, inplace=True)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8cdf35fd-e2a5-4e28-8212-c72b562b2844",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_MLE(X_test, y_test, X_train, y_train, task, seed = 1234):\n",
    "    '''\n",
    "    Train machine learning models on the synthetic dataset and then test it on the real data. \n",
    "    '''\n",
    "    performance = {}\n",
    "    classification_model = {\n",
    "        \"Logistic Regression\": LogisticRegression(random_state=seed, max_iter=3000),\n",
    "        \"Random Forest\": RandomForestClassifier(max_depth=3, n_jobs=10, random_state=seed),\n",
    "        \"SVC\": SVC(random_state=seed),\n",
    "        \"XGB\": XGBClassifier(\n",
    "                    objective='binary:logistic',  # for binary classification, change for multiclass\n",
    "                    n_estimators=100,             # number of boosting rounds\n",
    "                    learning_rate=0.1,            # step size shrinkage\n",
    "                    max_depth=3,                  # maximum depth of a tree\n",
    "                    eval_metric='logloss'         # evaluation metric\n",
    "        )\n",
    "    }\n",
    "\n",
    "    regression_models = {\n",
    "        \"Linear Regression\": LinearRegression(),\n",
    "        \"Decision Tree\": DecisionTreeRegressor(max_depth=3, random_state=seed),\n",
    "        \"XGB\": XGBRegressor(\n",
    "                objective='reg:squarederror',  # regression with squared loss\n",
    "                n_estimators=100,              # number of boosting rounds\n",
    "                learning_rate=0.1,             # step size shrinkage\n",
    "                max_depth=3,                   # maximum depth of a tree\n",
    "        )\n",
    "    }\n",
    "\n",
    "    if task == 'classification':\n",
    "        for name, model in classification_model.items():\n",
    "            model.fit(X_train, y_train)\n",
    "            predictions = model.predict(X_test)\n",
    "            f1 = f1_score(y_test, predictions, average=\"weighted\")\n",
    "            accu = accuracy_score(y_test, predictions)\n",
    "            performance[name] = {}\n",
    "            performance[name]['f1 score'] = f1\n",
    "            performance[name]['accuracy'] = accu\n",
    "    if task == 'regression':\n",
    "        for name, model in regression_models.items():\n",
    "            model.fit(X_train, y_train)\n",
    "            y_pred = model.predict(X_test)\n",
    "            r2_linear = r2_score(y_test, y_pred)\n",
    "            performance[name] = {}\n",
    "            performance[name]['R2'] = r2_linear\n",
    "    \n",
    "    return performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3a171fae-5064-4ae2-b74a-8175a3d2579a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def MLE_eval(data_nm, test_path, cols, x_col, y_col, task, model_lst):\n",
    "    path = f\"gen/{data_nm}/\"\n",
    "    df_test = pd.read_csv(test_path)\n",
    "    num_lst = [100, 200, 400, 800]\n",
    "    models = model_lst#[\"CTGAN\", \"TVAE\", \"BeGReaT\", \"tabddpm\"]\n",
    "    res = {}\n",
    "    for i in range(4):\n",
    "        num = num_lst[i]\n",
    "        res[str(num)] = {}\n",
    "        for j in range(5):\n",
    "            res[str(num)][str(j)] = {}\n",
    "            for model_nm in models:\n",
    "                data_path = path + f\"{num}/{model_nm}/df_{j}.csv\"\n",
    "                df_temp = pd.read_csv(data_path)\n",
    "                df_comb = pd.concat([df_temp, df_test])\n",
    "                dummy_cols = list(pd.get_dummies(df_comb[x_col]).columns)\n",
    "                train_dummy = pd.get_dummies(df_temp[x_col])\n",
    "                train_dummy = col_check(train_dummy, dummy_cols)\n",
    "                test_dummy = pd.get_dummies(df_test[x_col])\n",
    "                test_dummy = col_check(test_dummy, dummy_cols)\n",
    "            \n",
    "                X_train = train_dummy.to_numpy()\n",
    "                X_test = test_dummy.to_numpy()\n",
    "            \n",
    "                if data_nm == \"Adult\":\n",
    "                    y_test = df_test[y_col].apply(lambda x: 1 if x == \">50K\" else 0).to_numpy()    \n",
    "                    y_train = df_temp[y_col].apply(lambda x: 1 if x == \">50K\" else 0).to_numpy()\n",
    "                else:\n",
    "                    y_test = df_test[y_col].to_numpy()\n",
    "                    y_train = df_temp[y_col].to_numpy()\n",
    "\n",
    "                performance = compare_MLE(X_test, y_test, X_train, y_train, task)\n",
    "            \n",
    "                res[str(num)][str(j)][model_nm] = performance\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "15d8e357-fb04-41be-bef3-c36f47ac1702",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mle_summary(res):\n",
    "    eval_sum_counts = {}\n",
    "    highest_avg_values = {}\n",
    "\n",
    "    for num, results in res.items():\n",
    "        eval_sum_counts[num] = {}\n",
    "        highest_avg_values[num] = {}\n",
    "        for gen_model, results0 in results.items():\n",
    "            if gen_model != 'baseline':\n",
    "                for model, results00 in results0.items():\n",
    "                    if model not in eval_sum_counts[num]:\n",
    "                        eval_sum_counts[num][model] = {}\n",
    "                    for eval_model, results000 in results00.items():\n",
    "                        accuracy = results000.get('accuracy', None)\n",
    "                        f1_score = results000.get('f1 score', None)\n",
    "                        if accuracy is not None and f1_score is not None:\n",
    "                            avg_value = (f1_score + f1_score) / 2\n",
    "                            if model not in highest_avg_values[num]:\n",
    "                                highest_avg_values[num][model] = []\n",
    "                            highest_avg_values[num][model].append((avg_value, gen_model, eval_model))\n",
    "            else:\n",
    "                eval_sum_counts[num]['baseline'] = results0\n",
    "\n",
    "    res_average = {}\n",
    "    for num, model_data in highest_avg_values.items():\n",
    "        res_average[num] = {}\n",
    "        for model, avg_data in model_data.items():\n",
    "            # Select the highest average value for each seed\n",
    "            highest_avg_by_seed = {}\n",
    "            for avg_value, seed, eval_model in avg_data:\n",
    "                if seed not in highest_avg_by_seed or avg_value > highest_avg_by_seed[seed][0]:\n",
    "                    highest_avg_by_seed[seed] = (avg_value, eval_model)\n",
    "\n",
    "            # Calculate mean and std of the highest average values across seeds\n",
    "            highest_avg_values_list = [value[0] for value in highest_avg_by_seed.values()]\n",
    "            res_average[num][model] = {\n",
    "                'highest_avg_mean': np.mean(highest_avg_values_list),\n",
    "                'highest_avg_std': np.std(highest_avg_values_list),\n",
    "                'eval_model_selected': {seed: eval_model for seed, (avg_value, eval_model) in highest_avg_by_seed.items()}\n",
    "            }\n",
    "\n",
    "    return res_average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "187945f6-78f7-46db-b73d-0d6f0c3b62b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mle_summary2(res):\n",
    "\n",
    "    eval_sum_counts = {}\n",
    "    highest_r2_values = {}\n",
    "\n",
    "    for num, results in res.items():\n",
    "        eval_sum_counts[num] = {}\n",
    "        highest_r2_values[num] = {}\n",
    "        for seed, results0 in results.items():\n",
    "            if seed != 'baseline':\n",
    "                for model, results00 in results0.items():\n",
    "                    if model not in eval_sum_counts[num]:\n",
    "                        eval_sum_counts[num][model] = {}\n",
    "                    for eval_model, results000 in results00.items():\n",
    "                        r2_value = results000.get('R2', None)\n",
    "                        if r2_value is not None:\n",
    "                            if model not in highest_r2_values[num]:\n",
    "                                highest_r2_values[num][model] = []\n",
    "                            highest_r2_values[num][model].append((r2_value, seed, eval_model))\n",
    "            else:\n",
    "                eval_sum_counts[num]['baseline'] = results0\n",
    "\n",
    "    res_average = {}\n",
    "    for num, model_data in highest_r2_values.items():\n",
    "        res_average[num] = {}\n",
    "        for model, r2_data in model_data.items():\n",
    "            # Select the highest R2 value for each seed\n",
    "            highest_r2_by_seed = {}\n",
    "            for r2_value, seed, eval_model in r2_data:\n",
    "                if seed not in highest_r2_by_seed or r2_value > highest_r2_by_seed[seed][0]:\n",
    "                    highest_r2_by_seed[seed] = (r2_value, eval_model)\n",
    "\n",
    "            # Calculate mean and std of the highest R2 values across seeds\n",
    "            highest_r2_values_list = [value[0] for value in highest_r2_by_seed.values()]\n",
    "            res_average[num][model] = {\n",
    "                'highest_R2_mean': np.mean(highest_r2_values_list),\n",
    "                'highest_R2_std': np.std(highest_r2_values_list),\n",
    "                'eval_model_selected': {seed: eval_model for seed, (r2_value, eval_model) in highest_r2_by_seed.items()}\n",
    "            }\n",
    "\n",
    "    return res_average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f88da071-8d1a-4874-a174-c123c9f4199f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mle_baseline(data_nm, cols, x_col, y_col, task):\n",
    "    num_lst = [100, 200, 400, 800]\n",
    "    path = f\"../sample/{data_nm}/\"\n",
    "    data100 = pd.read_csv(path + \"data100.csv\")\n",
    "    data200 = pd.read_csv(path + \"data200.csv\")\n",
    "    data400 = pd.read_csv(path + \"data400.csv\")\n",
    "    data800 = pd.read_csv(path + \"data800.csv\")\n",
    "    data_test = pd.read_csv(path + \"data_test.csv\")\n",
    "    data_lst = [data100, data200, data400, data800]\n",
    "    res = {}\n",
    "    for i in range(4):\n",
    "        num = num_lst[i]\n",
    "        data_temp = data_lst[i]\n",
    "        data_comb = pd.concat([data_temp, data_test])\n",
    "        dummy_cols = list(pd.get_dummies(data_comb[x_col]).columns)\n",
    "        train_dummy = pd.get_dummies(data_temp[x_col])\n",
    "        train_dummy = col_check(train_dummy, dummy_cols)\n",
    "        test_dummy = pd.get_dummies(data_test[x_col])\n",
    "        test_dummy = col_check(test_dummy, dummy_cols)\n",
    "        X_train = train_dummy.to_numpy()\n",
    "        X_test = test_dummy.to_numpy()\n",
    "\n",
    "        if data_nm == \"Adult\":\n",
    "            y_test = data_test[y_col].apply(lambda x: 1 if x == \">50K\" else 0).to_numpy()    \n",
    "            y_train = data_temp[y_col].apply(lambda x: 1 if x == \">50K\" else 0).to_numpy()\n",
    "        else:\n",
    "            y_test = data_test[y_col].to_numpy()\n",
    "            y_train = data_temp[y_col].to_numpy()\n",
    "\n",
    "        performance = compare_MLE(X_test, y_test, X_train, y_train, task)\n",
    "            \n",
    "        res[str(num)] = performance\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db294c45-fcfd-498a-b3da-e7ef796dfffa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_DCR(data_name):\n",
    "    res = {}\n",
    "    path_real = f\"../sample/{data_name}/data100.csv\"\n",
    "    path = f\"gen/{data_name}/\"\n",
    "    meta_data = SingleTableMetadata()\n",
    "    meta_data.detect_from_dataframe(test_data)\n",
    "    meta_data = meta_data.to_dict()\n",
    "    model_nm_lst = [\"CTGAN\", \"TVAE\", \"BeGReaT\", \"tabddpm\"]\n",
    "    for model_nm in model_nm_lst: # Loop over num of shots\n",
    "        df_lst = []\n",
    "        for j in range(5):\n",
    "            df_temp = pd.read_csv(path + f\"{model_nm}/df_{j}.csv\")\n",
    "            df_lst.append(df_temp)\n",
    "        df_comb = pd.concat(df_lst)\n",
    "        df_comb.reset_index(inplace=True, drop=True)\n",
    "        dcr_ = _calculate_dcr(df_comb, test_data, meta_data)\n",
    "        res[i] = dcr_[\"DCR\"].describe()\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "c3eb8d24-9cb6-4dbb-88e6-3892db78b592",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_lst1 = [\"CTGAN\", \"TVAE\", \"BeGReaT\", \"tabddpm\", \"SMOTE\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcde1bd5-337c-475c-bc60-bd21e7d46715",
   "metadata": {},
   "source": [
    "# Adult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "dc4dec11-fea0-4225-a008-d2d4ca84ad93",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_nm = \"Adult\"\n",
    "test_path = \"../sample/Adult/data_test.csv\"\n",
    "x_col = ['age', 'workclass', 'education', 'education-num',\n",
    "          'marital-status', 'occupation', 'relationship', 'race', 'sex',\n",
    "          'capital-gain', 'capital-loss', 'hours-per-week', 'native-country']\n",
    "y_col = 'Income'\n",
    "cate_var = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country', 'Income']\n",
    "bool_var = []\n",
    "cols = x_col + [y_col]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "9eae3945-3ad9-45ee-b7ab-cbfb1b390142",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n"
     ]
    }
   ],
   "source": [
    "res_adult1 = MLE_eval(data_nm,test_path, cols, x_col, y_col, \"classification\", model_lst1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "98efa5f2-9454-43c2-a333-02967aeb33e7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'CTGAN': {'highest_avg_mean': 0.6636301352363263,\n",
       "   'highest_avg_std': 0.05940435968931609,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'XGB',\n",
       "    '4': 'Random Forest'}},\n",
       "  'TVAE': {'highest_avg_mean': 0.6677050513542302,\n",
       "   'highest_avg_std': 0.05440051213022549,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'XGB'}},\n",
       "  'BeGReaT': {'highest_avg_mean': 0.7116196618043663,\n",
       "   'highest_avg_std': 0.0262331105716064,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'XGB',\n",
       "    '4': 'Logistic Regression'}},\n",
       "  'tabddpm': {'highest_avg_mean': 0.7574216839268793,\n",
       "   'highest_avg_std': 0.021615070930736467,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}},\n",
       "  'SMOTE': {'highest_avg_mean': 0.7802230856871496,\n",
       "   'highest_avg_std': 0.01167121043900469,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'XGB',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'XGB'}}},\n",
       " '200': {'CTGAN': {'highest_avg_mean': 0.627527034625771,\n",
       "   'highest_avg_std': 0.03333583919576956,\n",
       "   'eval_model_selected': {'0': 'Random Forest',\n",
       "    '1': 'Random Forest',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'SVC'}},\n",
       "  'TVAE': {'highest_avg_mean': 0.6679672713678595,\n",
       "   'highest_avg_std': 0.04622199128184658,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'XGB',\n",
       "    '3': 'Random Forest',\n",
       "    '4': 'Logistic Regression'}},\n",
       "  'BeGReaT': {'highest_avg_mean': 0.7241501526999029,\n",
       "   'highest_avg_std': 0.05926795456563485,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'XGB',\n",
       "    '4': 'Logistic Regression'}},\n",
       "  'tabddpm': {'highest_avg_mean': 0.6014836552362094,\n",
       "   'highest_avg_std': 0.14777697310715177,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'XGB',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'XGB'}},\n",
       "  'SMOTE': {'highest_avg_mean': 0.7774130185812606,\n",
       "   'highest_avg_std': 0.03590879297088282,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '400': {'CTGAN': {'highest_avg_mean': 0.6306515037725602,\n",
       "   'highest_avg_std': 0.020715240425535124,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}},\n",
       "  'TVAE': {'highest_avg_mean': 0.706297012542217,\n",
       "   'highest_avg_std': 0.06550919835724806,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Random Forest'}},\n",
       "  'BeGReaT': {'highest_avg_mean': 0.7922234675395542,\n",
       "   'highest_avg_std': 0.044296791798715804,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'XGB',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'SVC'}},\n",
       "  'tabddpm': {'highest_avg_mean': 0.8171766840984158,\n",
       "   'highest_avg_std': 0.031165572626036434,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'XGB'}},\n",
       "  'SMOTE': {'highest_avg_mean': 0.8532543971287317,\n",
       "   'highest_avg_std': 0.0276483243551322,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '800': {'CTGAN': {'highest_avg_mean': 0.6378158886221824,\n",
       "   'highest_avg_std': 0.046793515232267435,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'XGB',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Random Forest'}},\n",
       "  'TVAE': {'highest_avg_mean': 0.7730885713373222,\n",
       "   'highest_avg_std': 0.019830500475346142,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Logistic Regression'}},\n",
       "  'BeGReaT': {'highest_avg_mean': 0.7477855670283354,\n",
       "   'highest_avg_std': 0.06547682145004044,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'XGB',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'XGB'}},\n",
       "  'tabddpm': {'highest_avg_mean': 0.7008189431125961,\n",
       "   'highest_avg_std': 0.03420424922878394,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'Logistic Regression'}},\n",
       "  'SMOTE': {'highest_avg_mean': 0.7149212686853039,\n",
       "   'highest_avg_std': 0.03086047492725985,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'SVC',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'XGB',\n",
       "    '4': 'SVC'}}}}"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary(res_adult1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3f1cb66-3400-4fc9-bc00-9f836860ea0c",
   "metadata": {},
   "source": [
    "# Insurance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "aaeb1f3d-e421-4b1b-af08-c73c075bbe96",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_nm = \"Insurance\"\n",
    "test_path = \"../sample/Insurance/data_test.csv\"\n",
    "x_cols = ['age', 'sex', 'bmi', 'children', 'smoker', 'region']\n",
    "y_col = 'charges'\n",
    "bool_var = []\n",
    "cols = x_cols + [y_col]\n",
    "cate_var = ['sex', 'smoker', 'region']\n",
    "num_var = list(set(cols) - set(cate_var))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "8ce8587f-4d73-46e8-8d55-e9df5391a524",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_ins = MLE_eval(data_nm, test_path ,  cols, x_cols, y_col , \"regression\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "1ddae9a2-8182-46e5-b393-52f1e191f92e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'CTGAN': {'highest_R2_mean': -0.09040735565092328,\n",
       "   'highest_R2_std': 0.10930127629571845,\n",
       "   'eval_model_selected': {'0': 'Linear Regression',\n",
       "    '1': 'Linear Regression',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'Decision Tree',\n",
       "    '4': 'Linear Regression'}},\n",
       "  'TVAE': {'highest_R2_mean': 0.39171860150248294,\n",
       "   'highest_R2_std': 0.14641499346673362,\n",
       "   'eval_model_selected': {'0': 'Linear Regression',\n",
       "    '1': 'Linear Regression',\n",
       "    '2': 'XGB',\n",
       "    '3': 'Decision Tree',\n",
       "    '4': 'Linear Regression'}},\n",
       "  'BeGReaT': {'highest_R2_mean': 0.5418979468246582,\n",
       "   'highest_R2_std': 0.09707917457767164,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Linear Regression',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'Linear Regression'}},\n",
       "  'tabddpm': {'highest_R2_mean': 0.7624303742190909,\n",
       "   'highest_R2_std': 0.03129635083178132,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '200': {'CTGAN': {'highest_R2_mean': -0.12193555230703738,\n",
       "   'highest_R2_std': 0.08413994844750666,\n",
       "   'eval_model_selected': {'0': 'Linear Regression',\n",
       "    '1': 'XGB',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'Decision Tree'}},\n",
       "  'TVAE': {'highest_R2_mean': 0.6152291497777626,\n",
       "   'highest_R2_std': 0.05118837012917515,\n",
       "   'eval_model_selected': {'0': 'Linear Regression',\n",
       "    '1': 'Linear Regression',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'XGB'}},\n",
       "  'BeGReaT': {'highest_R2_mean': 0.7194260161421691,\n",
       "   'highest_R2_std': 0.02754438574230846,\n",
       "   'eval_model_selected': {'0': 'Linear Regression',\n",
       "    '1': 'Linear Regression',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'Decision Tree',\n",
       "    '4': 'Decision Tree'}},\n",
       "  'tabddpm': {'highest_R2_mean': 0.5557095359204078,\n",
       "   'highest_R2_std': 0.1368112962065708,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '400': {'CTGAN': {'highest_R2_mean': -0.18311718025972729,\n",
       "   'highest_R2_std': 0.10143256235415325,\n",
       "   'eval_model_selected': {'0': 'Linear Regression',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'Linear Regression'}},\n",
       "  'TVAE': {'highest_R2_mean': 0.6202113647015879,\n",
       "   'highest_R2_std': 0.05028388670900579,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'Linear Regression',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'XGB'}},\n",
       "  'BeGReaT': {'highest_R2_mean': 0.7207394496033093,\n",
       "   'highest_R2_std': 0.029821808891411165,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'XGB'}},\n",
       "  'tabddpm': {'highest_R2_mean': 0.792891512443146,\n",
       "   'highest_R2_std': 0.033526650523678314,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '800': {'CTGAN': {'highest_R2_mean': -0.4135677511072884,\n",
       "   'highest_R2_std': 0.06487463749053196,\n",
       "   'eval_model_selected': {'0': 'Linear Regression',\n",
       "    '1': 'Linear Regression',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'Linear Regression'}},\n",
       "  'TVAE': {'highest_R2_mean': 0.6791944692461079,\n",
       "   'highest_R2_std': 0.01025513881982885,\n",
       "   'eval_model_selected': {'0': 'Linear Regression',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'Linear Regression'}},\n",
       "  'BeGReaT': {'highest_R2_mean': 0.5264043214403559,\n",
       "   'highest_R2_std': 0.20754380408714018,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'XGB',\n",
       "    '4': 'Decision Tree'}},\n",
       "  'tabddpm': {'highest_R2_mean': 0.8333332139648313,\n",
       "   'highest_R2_std': 0.0068136764791930865,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}}}"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary2(res_ins)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "c93494a9-0f2c-4747-aa9e-e06d2dd59902",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_ins2 = MLE_eval(data_nm, test_path ,  cols, x_cols, y_col , \"regression\", model_lst2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "ff10445d-3f12-4e93-a20e-8ff6f1e83a78",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'SMOTE': {'highest_R2_mean': 0.7974248604927965,\n",
       "   'highest_R2_std': 0.006631692274423577,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '200': {'SMOTE': {'highest_R2_mean': 0.7930401456550837,\n",
       "   'highest_R2_std': 0.01595218049504417,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '400': {'SMOTE': {'highest_R2_mean': 0.8279899989817192,\n",
       "   'highest_R2_std': 0.003430451255819587,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '800': {'SMOTE': {'highest_R2_mean': 0.8320564040674512,\n",
       "   'highest_R2_std': 0.004563220114489803,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}}}"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary2(res_ins2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "1873767b-ae29-4764-b45a-68ba973377d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_res = {}\n",
    "num_lst = [100, 200, 400, 800]\n",
    "data_test = pd.read_csv(\"../sample/Insurance/data_test.csv\")\n",
    "for i in range(4):\n",
    "    num = num_lst[i]\n",
    "    path = f\"../sample/Insurance/data{num}.csv\"\n",
    "    df_train = pd.read_csv(path)\n",
    "\n",
    "    df_comb = pd.concat([df_train, data_test])\n",
    "    dummy_cols = pd.get_dummies(df_comb[cols])\n",
    "    \n",
    "    train_dummy = pd.get_dummies(df_train[x_cols])\n",
    "    train_dummy = col_check(train_dummy, dummy_cols)\n",
    "    X_train = train_dummy.to_numpy()\n",
    "    y_train = df_train[y_col].to_numpy()\n",
    "\n",
    "    test_dummy = pd.get_dummies(data_test[x_cols])\n",
    "    test_dummy = col_check(test_dummy, dummy_cols)\n",
    "    X_test = test_dummy.to_numpy()\n",
    "    y_test = data_test[y_col].to_numpy()\n",
    "\n",
    "    baseline = compare_MLE(X_test, y_test, X_train, y_train, task='regression')\n",
    "    baseline_res[str(num)] = baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "22ce9e80-7d21-4484-8ab7-a8d47d6b1cf0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'Linear Regression': {'R2': 0.7064976270675443},\n",
       "  'Decision Tree': {'R2': 0.7003012919982186},\n",
       "  'XGB': {'R2': 0.8154059881463351}},\n",
       " '200': {'Linear Regression': {'R2': 0.7175687268136394},\n",
       "  'Decision Tree': {'R2': 0.8126264414312482},\n",
       "  'XGB': {'R2': 0.8292661457672201}},\n",
       " '400': {'Linear Regression': {'R2': 0.7152057244067371},\n",
       "  'Decision Tree': {'R2': 0.8102072861693204},\n",
       "  'XGB': {'R2': 0.8462928576316135}},\n",
       " '800': {'Linear Regression': {'R2': 0.7213844234241604},\n",
       "  'Decision Tree': {'R2': 0.8151842843538009},\n",
       "  'XGB': {'R2': 0.853841321948853}}}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "baseline_res"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05194b2c-a6eb-4b8f-89fe-3d3b99d3afec",
   "metadata": {},
   "source": [
    "# Asia"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "2a6cf5a9-2d50-4874-82ec-63c926b8f430",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_nm = \"Asia\"\n",
    "test_path = \"../sample/Asia/data_test.csv\"\n",
    "x_cols = [\"visit to Asia\", \"tuberculosis\", \"smoke\", \"lung cancer\", \"bronchitis\", \"either tuberculosis or lung cancer\", \"chest X-ray\"]\n",
    "y_col = \"Dyspnea\"\n",
    "cols = x_cols + [y_col]\n",
    "num_var = []\n",
    "cate_var = []\n",
    "bool_var = cols"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "7293b2da-bdc8-4ae3-a24e-69c9105839a6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'Logistic Regression': {'f1 score': 0.8261663985290738,\n",
       "   'accuracy': 0.825},\n",
       "  'Random Forest': {'f1 score': 0.826323495709302, 'accuracy': 0.825},\n",
       "  'SVC': {'f1 score': 0.8313487946292341, 'accuracy': 0.83},\n",
       "  'XGB': {'f1 score': 0.8313487946292341, 'accuracy': 0.83}},\n",
       " '200': {'Logistic Regression': {'f1 score': 0.8261663985290738,\n",
       "   'accuracy': 0.825},\n",
       "  'Random Forest': {'f1 score': 0.8261663985290738, 'accuracy': 0.825},\n",
       "  'SVC': {'f1 score': 0.826323495709302, 'accuracy': 0.825},\n",
       "  'XGB': {'f1 score': 0.8261663985290738, 'accuracy': 0.825}},\n",
       " '400': {'Logistic Regression': {'f1 score': 0.8363619047619048,\n",
       "   'accuracy': 0.835},\n",
       "  'Random Forest': {'f1 score': 0.8363619047619048, 'accuracy': 0.835},\n",
       "  'SVC': {'f1 score': 0.8414252454701893, 'accuracy': 0.84},\n",
       "  'XGB': {'f1 score': 0.8363619047619048, 'accuracy': 0.835}},\n",
       " '800': {'Logistic Regression': {'f1 score': 0.8363619047619048,\n",
       "   'accuracy': 0.835},\n",
       "  'Random Forest': {'f1 score': 0.8413636363636364, 'accuracy': 0.84},\n",
       "  'SVC': {'f1 score': 0.8364421980694687, 'accuracy': 0.835},\n",
       "  'XGB': {'f1 score': 0.8363619047619048, 'accuracy': 0.835}}}"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_baseline(data_nm, cols, x_cols, y_col, \"classification\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "fb6039ee-3fd2-49b8-b3cf-8a62fe95fef3",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_asia = MLE_eval(data_nm,test_path, cols, x_cols, y_col, \"classification\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "2ad40cce-62d9-48ab-9f32-481d68c764ed",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'CTGAN': {'highest_avg_mean': 0.6119309716183847,\n",
       "   'highest_avg_std': 0.18716277531719197,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Random Forest',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'XGB',\n",
       "    '4': 'SVC'}},\n",
       "  'TVAE': {'highest_avg_mean': 0.8353217066937431,\n",
       "   'highest_avg_std': 0.007475277813792156,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'SVC',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Logistic Regression'}},\n",
       "  'BeGReaT': {'highest_avg_mean': 0.8393999091868756,\n",
       "   'highest_avg_std': 0.0024805202258318857,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'SVC',\n",
       "    '2': 'Random Forest',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Random Forest'}}},\n",
       " '200': {'CTGAN': {'highest_avg_mean': 0.7189452092291161,\n",
       "   'highest_avg_std': 0.08670464636445172,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Random Forest',\n",
       "    '2': 'XGB',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Logistic Regression'}},\n",
       "  'TVAE': {'highest_avg_mean': 0.8322699373055199,\n",
       "   'highest_avg_std': 0.007475277813792156,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Logistic Regression'}},\n",
       "  'BeGReaT': {'highest_avg_mean': 0.8313553376451963,\n",
       "   'highest_avg_std': 0.006401497389792805,\n",
       "   'eval_model_selected': {'0': 'Random Forest',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'SVC',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Logistic Regression'}}},\n",
       " '400': {'CTGAN': {'highest_avg_mean': 0.5917414364459938,\n",
       "   'highest_avg_std': 0.16774813711353181,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'SVC',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Random Forest',\n",
       "    '4': 'SVC'}},\n",
       "  'TVAE': {'highest_avg_mean': 0.8374552821319321,\n",
       "   'highest_avg_std': 0.005764527314681893,\n",
       "   'eval_model_selected': {'0': 'Random Forest',\n",
       "    '1': 'Random Forest',\n",
       "    '2': 'Random Forest',\n",
       "    '3': 'SVC',\n",
       "    '4': 'SVC'}},\n",
       "  'BeGReaT': {'highest_avg_mean': 0.834284394909465,\n",
       "   'highest_avg_std': 0.0068753341269205456,\n",
       "   'eval_model_selected': {'0': 'Random Forest',\n",
       "    '1': 'Random Forest',\n",
       "    '2': 'Random Forest',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Logistic Regression'}}},\n",
       " '800': {'CTGAN': {'highest_avg_mean': 0.4932351549546354,\n",
       "   'highest_avg_std': 0.03525705440719728,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'SVC',\n",
       "    '2': 'SVC',\n",
       "    '3': 'XGB',\n",
       "    '4': 'Random Forest'}},\n",
       "  'TVAE': {'highest_avg_mean': 0.8322699373055201,\n",
       "   'highest_avg_std': 0.007475277813792156,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Random Forest'}},\n",
       "  'BeGReaT': {'highest_avg_mean': 0.8333795667328298,\n",
       "   'highest_avg_std': 0.006792338862645273,\n",
       "   'eval_model_selected': {'0': 'Random Forest',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'SVC',\n",
       "    '3': 'SVC',\n",
       "    '4': 'Random Forest'}}}}"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary(res_asia)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "387fdb06-46d1-4d35-9e64-12408157b1e4",
   "metadata": {},
   "source": [
    "# ATACH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "32771387-3e5d-4fad-a7db-a16424ae6089",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_nm = \"ATACH\"\n",
    "test_path = \"../sample/ATACH/data_test.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "742f092f-121f-4d22-b497-af3bf3e837d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_cols = ['treatment', 'age', 'ICH volume',\n",
    "       'ICH Location', 'IVH volume', 'GCS score', 'NIHSS score',\n",
    "       'Systolic blood pressure',\n",
    "       'Diastolic Blood Pressure',\n",
    "       'Hypertension', 'Hyperlipidemia',\n",
    "       'Type I Diabetes', 'Type II Diabetes',\n",
    "       'Congestive heart failure', 'Atrial Fibrillation',\n",
    "       'PTCA',\n",
    "       'Peripheral Vascular Disease',\n",
    "       'Myocardial fraction', 'Anti-diabetic',\n",
    "       'Antihypertensives', 'White blood count', 'Hemoglobin',\n",
    "       'Hematocrit', 'Platelet count', 'APTT', 'INR', 'Glucose', 'Sodium',\n",
    "       'Potassium', 'Chloride', 'CD', 'Blood urea nitrogen', 'Creatinine',\n",
    "       'race', 'sex', 'ethnicity']\n",
    "y_col = 'mRS score after 30 days'\n",
    "\n",
    "cols = x_cols + [y_col]\n",
    "cate_var = ['race', 'ICH Location', 'sex', 'ethnicity']\n",
    "bool_var = ['treatment','Hypertension', 'Hyperlipidemia',\n",
    "       'Type I Diabetes', 'Type II Diabetes',\n",
    "       'Congestive heart failure', 'Atrial Fibrillation',\n",
    "       'PTCA',\n",
    "       'Peripheral Vascular Disease',\n",
    "       'Myocardial fraction', 'Anti-diabetic',\n",
    "       'Antihypertensives']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "5c5f69a0-0aea-459a-b853-a25409e096d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_atach = MLE_eval(data_nm,test_path, cols, x_cols, y_col, \"regression\", model_lst2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "61c5da05-3242-4da6-9df9-87342fd849b7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'SMOTE': {'highest_R2_mean': 0.2707921630715144,\n",
       "   'highest_R2_std': 0.027694076993970862,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '200': {'SMOTE': {'highest_R2_mean': 0.305717950558993,\n",
       "   'highest_R2_std': 0.038256652672649015,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '400': {'SMOTE': {'highest_R2_mean': 0.3152947215546486,\n",
       "   'highest_R2_std': 0.021012623832375156,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '800': {'SMOTE': {'highest_R2_mean': 0.37128263331046407,\n",
       "   'highest_R2_std': 0.03093936082767671,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}}}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary2(res_atach)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26d48220-9f7b-4839-8fde-64f4e68fa1fd",
   "metadata": {},
   "source": [
    "# ERICH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "072c2615-6423-4b71-b858-fae44fb5d187",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_nm = \"ERICH\"\n",
    "test_path = \"../sample/ERICH/data_test.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "e90dfc63-0f55-4a46-be44-96d80c7f14cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_cols = ['sex', 'age', 'race', 'ICH volume', 'IVH volume',\n",
    "       'ICH location', 'GCS score', 'Systolic blood pressure',\n",
    "       'Diastolic blood pressure', 'Hypertension', 'Hyperlipidemia',\n",
    "       'Type 2 diabetes', 'Type 1 diabetes', 'Atrial Fibrillation',\n",
    "       'Congestive Heart Failure', 'Vascular disease', 'Cocaine use', 'Smoke',\n",
    "       'PTCA', 'Myocardial infraction', 'anti-diabetic', 'Anti-hypertensives',\n",
    "       'White blood cell count', 'Hemoglobin', 'Platelet count', 'APTT', 'INR',\n",
    "       'Glucose']\n",
    "y_col = 'mRS score after 90 days'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "354a6ab7-4e8f-4013-b333-8990bed2b111",
   "metadata": {},
   "outputs": [],
   "source": [
    "cols = x_cols + [y_col]\n",
    "cate_var = ['race', 'sex', 'ICH location']\n",
    "bool_var = ['Hypertension', 'Hyperlipidemia',\n",
    "       'Type 2 diabetes', 'Type 1 diabetes', 'Atrial Fibrillation',\n",
    "       'Congestive Heart Failure', 'Vascular disease', 'Cocaine use', 'Smoke',\n",
    "       'PTCA', 'Myocardial infraction', 'anti-diabetic', 'Anti-hypertensives']\n",
    "num_var = list((set(cols) - set(cate_var)) - set(bool_var))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "3379cad8-955d-417e-9ae0-489731973bc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_erich2 = MLE_eval(data_nm,test_path, cols, x_cols, y_col, \"regression\", model_lst2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b719563e-ea97-4d43-adcd-a435383193c8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n",
      "/tmp/ipykernel_3737/328479900.py:5: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
      "  df[col] = [0] * len(df)\n"
     ]
    }
   ],
   "source": [
    "res_erich = MLE_eval(data_nm,test_path, cols, x_cols, y_col, \"regression\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "da967ec1-e397-41fa-a9e3-72a78318ea0a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'SMOTE': {'highest_R2_mean': -0.1494940462036806,\n",
       "   'highest_R2_std': 0.1268203716653973,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '200': {'SMOTE': {'highest_R2_mean': 0.04772545666001553,\n",
       "   'highest_R2_std': 0.0574216792233605,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '400': {'SMOTE': {'highest_R2_mean': 0.06908277981126347,\n",
       "   'highest_R2_std': 0.04891572715592982,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '800': {'SMOTE': {'highest_R2_mean': 0.10488877520197322,\n",
       "   'highest_R2_std': 0.05034555332614683,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}}}"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary2(res_erich2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "028ad341-6f9e-48c8-83be-36d706a0e81d",
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline_res = {}\n",
    "num_lst = [100, 200, 400, 800]\n",
    "data_test = pd.read_csv(\"../sample/ERICH/data_test.csv\")\n",
    "for i in range(4):\n",
    "    num = num_lst[i]\n",
    "    path = f\"../sample/ERICH/data{num}.csv\"\n",
    "    df_train = pd.read_csv(path)\n",
    "\n",
    "    df_comb = pd.concat([df_train, data_test])\n",
    "    dummy_cols = pd.get_dummies(df_comb[cols])\n",
    "    \n",
    "    train_dummy = pd.get_dummies(df_train[x_cols])\n",
    "    train_dummy = col_check(train_dummy, dummy_cols)\n",
    "    X_train = train_dummy.to_numpy()\n",
    "    y_train = df_train[y_col].to_numpy()\n",
    "\n",
    "    test_dummy = pd.get_dummies(data_test[x_cols])\n",
    "    test_dummy = col_check(test_dummy, dummy_cols)\n",
    "    X_test = test_dummy.to_numpy()\n",
    "    y_test = data_test[y_col].to_numpy()\n",
    "\n",
    "    baseline = compare_MLE(X_test, y_test, X_train, y_train, task='regression')\n",
    "    baseline_res[str(num)] = baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "74fdca93-7a68-4a33-a40f-aa9864392e64",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'Linear Regression': {'R2': -2.8558346910998287},\n",
       "  'Decision Tree': {'R2': -0.26026096993894776},\n",
       "  'XGB': {'R2': -0.048671900539349666}},\n",
       " '200': {'Linear Regression': {'R2': -0.7334412855686576},\n",
       "  'Decision Tree': {'R2': 0.04463888431831975},\n",
       "  'XGB': {'R2': 0.15847891351135623}},\n",
       " '400': {'Linear Regression': {'R2': -0.29213022574775427},\n",
       "  'Decision Tree': {'R2': -0.03718086475927773},\n",
       "  'XGB': {'R2': 0.1824214939694706}},\n",
       " '800': {'Linear Regression': {'R2': -0.1482094797229898},\n",
       "  'Decision Tree': {'R2': 0.07243588170166804},\n",
       "  'XGB': {'R2': 0.2141180395139447}}}"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "baseline_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b736defc-ba52-419e-9ac8-2696d0867209",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10",
   "language": "python",
   "name": "py3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
