{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e1918997-302a-4b11-8608-96825e9df970",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, r2_score\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.tree import DecisionTreeRegressor\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.svm import SVC\n",
    "from xgboost import XGBClassifier\n",
    "from xgboost import XGBRegressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4056b0f8-07fd-48af-9b8a-b8fbd3d714f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def col_check(df, cols):\n",
    "    if set(cols) - set(df.columns):\n",
    "        diff = list(set(cols) - set(df.columns))\n",
    "        for col in diff:\n",
    "            df[col] = [0] * len(df)\n",
    "    if set(df.columns) - set(cols):\n",
    "        dropped_cols = list(set(df.columns) - set(cols))\n",
    "        df.drop(columns=dropped_cols, inplace=True)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9a900903-5cd9-45c9-acfc-119c29daf268",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_MLE(X_test, y_test, X_train, y_train, task, seed = 1234):\n",
    "    '''\n",
    "    Train machine learning models on the synthetic dataset and then test it on the real data. \n",
    "    '''\n",
    "    performance = {}\n",
    "    classification_model = {\n",
    "        \"Logistic Regression\": LogisticRegression(random_state=seed, max_iter=3000),\n",
    "        \"Random Forest\": RandomForestClassifier(max_depth=3, n_jobs=10, random_state=seed),\n",
    "        \"SVC\": SVC(random_state=seed),\n",
    "        \"XGB\": XGBClassifier(\n",
    "                    objective='binary:logistic',  # for binary classification, change for multiclass\n",
    "                    n_estimators=100,             # number of boosting rounds\n",
    "                    learning_rate=0.1,            # step size shrinkage\n",
    "                    max_depth=3,                  # maximum depth of a tree\n",
    "                    eval_metric='logloss'         # evaluation metric\n",
    "        )\n",
    "    }\n",
    "\n",
    "    regression_models = {\n",
    "        \"Linear Regression\": LinearRegression(),\n",
    "        \"Decision Tree\": DecisionTreeRegressor(max_depth=3, random_state=seed),\n",
    "        \"XGB\": XGBRegressor(\n",
    "                objective='reg:squarederror',  # regression with squared loss\n",
    "                n_estimators=100,              # number of boosting rounds\n",
    "                learning_rate=0.1,             # step size shrinkage\n",
    "                max_depth=3,                   # maximum depth of a tree\n",
    "        )\n",
    "    }\n",
    "\n",
    "    if task == 'classification':\n",
    "        for name, model in classification_model.items():\n",
    "            model.fit(X_train, y_train)\n",
    "            predictions = model.predict(X_test)\n",
    "            f1 = f1_score(y_test, predictions, average=\"weighted\")\n",
    "            accu = accuracy_score(y_test, predictions)\n",
    "            performance[name] = {}\n",
    "            performance[name]['f1 score'] = f1\n",
    "            performance[name]['accuracy'] = accu\n",
    "    if task == 'regression':\n",
    "        for name, model in regression_models.items():\n",
    "            model.fit(X_train, y_train)\n",
    "            y_pred = model.predict(X_test)\n",
    "            r2_linear = r2_score(y_test, y_pred)\n",
    "            performance[name] = {}\n",
    "            performance[name]['R2'] = r2_linear\n",
    "    \n",
    "    return performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "aee81c9b-65ac-4d58-b9fc-3b16f331edf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def MLE_eval(data_nm, test_path, cols, x_col, y_col, task):\n",
    "    path = f\"gen/{data_nm}/\"\n",
    "    df_test = pd.read_csv(test_path)\n",
    "    num_lst = [100, 200, 400, 800]\n",
    "    res = {}\n",
    "    for i in range(4):\n",
    "        num = num_lst[i]\n",
    "        res[str(num)] = {}\n",
    "        for j in range(5):\n",
    "            print(i, j)\n",
    "            res[str(num)][str(j)] = {}\n",
    "            data_path = path + f\"{num}/df_{j}.csv\"\n",
    "            df_temp = pd.read_csv(data_path)\n",
    "            df_comb = pd.concat([df_temp, df_test])\n",
    "            dummy_cols = list(pd.get_dummies(df_comb[x_col]).columns)\n",
    "            train_dummy = pd.get_dummies(df_temp[x_col])\n",
    "            train_dummy = col_check(train_dummy, dummy_cols)\n",
    "            test_dummy = pd.get_dummies(df_test[x_col])\n",
    "            test_dummy = col_check(test_dummy, dummy_cols)\n",
    "            \n",
    "            X_train = train_dummy.to_numpy()\n",
    "            X_test = test_dummy.to_numpy()\n",
    "            \n",
    "            if data_nm == \"adult\":\n",
    "                y_test = df_test[y_col].apply(lambda x: 1 if x == \">50K\" else 0).to_numpy()    \n",
    "                y_train = df_temp[y_col].apply(lambda x: 1 if x == \">50K\" else 0).to_numpy()\n",
    "            else:\n",
    "                y_test = df_test[y_col].to_numpy()\n",
    "                y_train = df_temp[y_col].to_numpy()\n",
    "\n",
    "            performance = compare_MLE(X_test, y_test, X_train, y_train, task)\n",
    "            \n",
    "            res[str(num)][str(j)][\"mallm\"] = performance\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "17e39f9b-ccd7-45b0-a4a1-d39ced84602d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mle_summary(res):\n",
    "    eval_sum_counts = {}\n",
    "    highest_avg_values = {}\n",
    "\n",
    "    for num, results in res.items():\n",
    "        eval_sum_counts[num] = {}\n",
    "        highest_avg_values[num] = {}\n",
    "        for seed, results0 in results.items():\n",
    "            if seed != 'baseline':\n",
    "                for model, results00 in results0.items():\n",
    "                    if model not in eval_sum_counts[num]:\n",
    "                        eval_sum_counts[num][model] = {}\n",
    "                    for eval_model, results000 in results00.items():\n",
    "                        accuracy = results000.get('accuracy', None)\n",
    "                        f1_score = results000.get('f1 score', None)\n",
    "                        if accuracy is not None and f1_score is not None:\n",
    "                            avg_value = (f1_score + f1_score) / 2\n",
    "                            if model not in highest_avg_values[num]:\n",
    "                                highest_avg_values[num][model] = []\n",
    "                            highest_avg_values[num][model].append((avg_value, seed, eval_model))\n",
    "            else:\n",
    "                eval_sum_counts[num]['baseline'] = results0\n",
    "\n",
    "    res_average = {}\n",
    "    for num, model_data in highest_avg_values.items():\n",
    "        res_average[num] = {}\n",
    "        for model, avg_data in model_data.items():\n",
    "            # Select the highest average value for each seed\n",
    "            highest_avg_by_seed = {}\n",
    "            for avg_value, seed, eval_model in avg_data:\n",
    "                if seed not in highest_avg_by_seed or avg_value > highest_avg_by_seed[seed][0]:\n",
    "                    highest_avg_by_seed[seed] = (avg_value, eval_model)\n",
    "\n",
    "            # Calculate mean and std of the highest average values across seeds\n",
    "            highest_avg_values_list = [value[0] for value in highest_avg_by_seed.values()]\n",
    "            res_average[num][model] = {\n",
    "                'highest_avg_mean': np.mean(highest_avg_values_list),\n",
    "                'highest_avg_std': np.std(highest_avg_values_list),\n",
    "                'eval_model_selected': {seed: eval_model for seed, (avg_value, eval_model) in highest_avg_by_seed.items()}\n",
    "            }\n",
    "\n",
    "    return res_average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a5753d97-211f-473e-ac3e-a97c1f762725",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mle_summary2(res):\n",
    "\n",
    "    eval_sum_counts = {}\n",
    "    highest_r2_values = {}\n",
    "\n",
    "    for num, results in res.items():\n",
    "        eval_sum_counts[num] = {}\n",
    "        highest_r2_values[num] = {}\n",
    "        for seed, results0 in results.items():\n",
    "            if seed != 'baseline':\n",
    "                for model, results00 in results0.items():\n",
    "                    if model not in eval_sum_counts[num]:\n",
    "                        eval_sum_counts[num][model] = {}\n",
    "                    for eval_model, results000 in results00.items():\n",
    "                        r2_value = results000.get('R2', None)\n",
    "                        if r2_value is not None:\n",
    "                            if model not in highest_r2_values[num]:\n",
    "                                highest_r2_values[num][model] = []\n",
    "                            highest_r2_values[num][model].append((r2_value, seed, eval_model))\n",
    "            else:\n",
    "                eval_sum_counts[num]['baseline'] = results0\n",
    "\n",
    "    res_average = {}\n",
    "    for num, model_data in highest_r2_values.items():\n",
    "        res_average[num] = {}\n",
    "        for model, r2_data in model_data.items():\n",
    "            # Select the highest R2 value for each seed\n",
    "            highest_r2_by_seed = {}\n",
    "            for r2_value, seed, eval_model in r2_data:\n",
    "                if seed not in highest_r2_by_seed or r2_value > highest_r2_by_seed[seed][0]:\n",
    "                    highest_r2_by_seed[seed] = (r2_value, eval_model)\n",
    "\n",
    "            # Calculate mean and std of the highest R2 values across seeds\n",
    "            highest_r2_values_list = [value[0] for value in highest_r2_by_seed.values()]\n",
    "            res_average[num][model] = {\n",
    "                'highest_R2_mean': np.mean(highest_r2_values_list),\n",
    "                'highest_R2_std': np.std(highest_r2_values_list),\n",
    "                'eval_model_selected': {seed: eval_model for seed, (r2_value, eval_model) in highest_r2_by_seed.items()}\n",
    "            }\n",
    "\n",
    "    return res_average"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "786b0fbb-184e-4ee0-9c4b-815d5991483e",
   "metadata": {},
   "source": [
    "# Adult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "257677fd-c5e6-43ac-8d88-29e1f67d7a31",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_nm = \"adult\"\n",
    "test_path = \"../sample/Adult/data_test.csv\"\n",
    "x_col = ['age', 'workclass', 'education', 'education-num',\n",
    "          'marital-status', 'occupation', 'relationship', 'race', 'sex',\n",
    "          'capital-gain', 'capital-loss', 'hours-per-week', 'native-country']\n",
    "y_col = 'Income'\n",
    "cate_var = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country', 'Income']\n",
    "bool_var = []\n",
    "cols = x_col + [y_col]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f67f6319-3e4a-4af7-bc8a-c6c443d7f681",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n",
      "/opt/conda/envs/py310/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n"
     ]
    }
   ],
   "source": [
    "res_adult = MLE_eval(data_nm,test_path, cols, x_col, y_col, \"classification\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "84729ac4-4b19-4460-a1e3-02d49ba4de18",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'mallm': {'highest_avg_mean': 0.7871025808667989,\n",
       "   'highest_avg_std': 0.022877290196157623,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Random Forest',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Logistic Regression'}}},\n",
       " '200': {'mallm': {'highest_avg_mean': 0.7690483455831161,\n",
       "   'highest_avg_std': 0.031246547264812928,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'Random Forest',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Logistic Regression'}}},\n",
       " '400': {'mallm': {'highest_avg_mean': 0.7906288789265383,\n",
       "   'highest_avg_std': 0.017478684413401627,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Random Forest',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'XGB'}}},\n",
       " '800': {'mallm': {'highest_avg_mean': 0.7954349423258762,\n",
       "   'highest_avg_std': 0.022127149976999202,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Random Forest'}}}}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary(res_adult)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f242c280-d936-4426-9d27-da805b4711ba",
   "metadata": {},
   "source": [
    "# ATACH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "d5f299e2-3568-4877-8670-9090c177300f",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_nm = \"atach\"\n",
    "test_path = \"../sample/ATACH/data_test.csv\"\n",
    "x_cols = ['treatment', 'age', 'ICH volume',\n",
    "       'ICH Location', 'IVH volume', 'GCS score', 'NIHSS score',\n",
    "       'Systolic blood pressure',\n",
    "       'Diastolic Blood Pressure',\n",
    "       'Hypertension', 'Hyperlipidemia',\n",
    "       'Type I Diabetes', 'Type II Diabetes',\n",
    "       'Congestive heart failure', 'Atrial Fibrillation',\n",
    "       'PTCA',\n",
    "       'Peripheral Vascular Disease',\n",
    "       'Myocardial fraction', 'Anti-diabetic',\n",
    "       'Antihypertensives', 'White blood count', 'Hemoglobin',\n",
    "       'Hematocrit', 'Platelet count', 'APTT', 'INR', 'Glucose', 'Sodium',\n",
    "       'Potassium', 'Chloride', 'CD', 'Blood urea nitrogen', 'Creatinine',\n",
    "       'race', 'sex', 'ethnicity']\n",
    "y_col = 'mRS score after 30 days'\n",
    "\n",
    "cols = x_cols + [y_col]\n",
    "cate_var = ['race', 'ICH Location', 'sex', 'ethnicity']\n",
    "bool_var = ['treatment','Hypertension', 'Hyperlipidemia',\n",
    "       'Type I Diabetes', 'Type II Diabetes',\n",
    "       'Congestive heart failure', 'Atrial Fibrillation',\n",
    "       'PTCA',\n",
    "       'Peripheral Vascular Disease',\n",
    "       'Myocardial fraction', 'Anti-diabetic',\n",
    "       'Antihypertensives']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "a060f84d-fd16-4f30-a86e-65b9e9a489b7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0\n",
      "0 1\n",
      "0 2\n",
      "0 3\n",
      "0 4\n",
      "1 0\n",
      "1 1\n",
      "1 2\n",
      "1 3\n",
      "1 4\n",
      "2 0\n",
      "2 1\n",
      "2 2\n",
      "2 3\n",
      "2 4\n",
      "3 0\n",
      "3 1\n",
      "3 2\n",
      "3 3\n",
      "3 4\n"
     ]
    }
   ],
   "source": [
    "res_atach = MLE_eval(data_nm,test_path, cols, x_cols, y_col, \"regression\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "fb92e0e9-5ceb-4be2-bc1c-448d7235c448",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'mallm': {'highest_R2_mean': 0.2726437535528787,\n",
       "   'highest_R2_std': 0.07065633622194069,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '200': {'mallm': {'highest_R2_mean': 0.2755933100862018,\n",
       "   'highest_R2_std': 0.07055720662765429,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'Decision Tree',\n",
       "    '4': 'XGB'}}},\n",
       " '400': {'mallm': {'highest_R2_mean': 0.2748287420113195,\n",
       "   'highest_R2_std': 0.04389970036946165,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'XGB',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'XGB'}}},\n",
       " '800': {'mallm': {'highest_R2_mean': 0.3587216031632454,\n",
       "   'highest_R2_std': 0.02251701122044336,\n",
       "   'eval_model_selected': {'0': 'Linear Regression',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'XGB',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'Decision Tree'}}}}"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary2(res_atach)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4871884-6428-4870-8d5d-a556aaea9740",
   "metadata": {},
   "source": [
    "# Asia"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "6d5d7b89-bbf6-4755-855d-d2baeb0a603b",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_cols = [\"visit to Asia\", \"tuberculosis\", \"smoke\", \"lung cancer\", \"bronchitis\", \"either tuberculosis or lung cancer\", \"chest X-ray\"]\n",
    "y_col = \"Dyspnea\"\n",
    "cols = x_cols + [y_col]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "c906c0e6-e589-41a5-9204-16576b2a1969",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0\n",
      "0 1\n",
      "0 2\n",
      "0 3\n",
      "0 4\n",
      "1 0\n",
      "1 1\n",
      "1 2\n",
      "1 3\n",
      "1 4\n",
      "2 0\n",
      "2 1\n",
      "2 2\n",
      "2 3\n",
      "2 4\n",
      "3 0\n",
      "3 1\n",
      "3 2\n",
      "3 3\n",
      "3 4\n"
     ]
    }
   ],
   "source": [
    "data_nm = \"asia\"\n",
    "test_path = \"../sample/Asia/data_test.csv\"\n",
    "\n",
    "res_asia = MLE_eval(data_nm, test_path, cols, x_cols, y_col, \"classification\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "8481fd1f-2ac5-4ff1-862a-07d9228dbb9d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'mallm': {'highest_avg_mean': 0.8282369192116856,\n",
       "   'highest_avg_std': 0.0040629483723550715,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Random Forest',\n",
       "    '4': 'Random Forest'}}},\n",
       " '200': {'mallm': {'highest_avg_mean': 0.8302860665733747,\n",
       "   'highest_avg_std': 0.005913645150991315,\n",
       "   'eval_model_selected': {'0': 'SVC',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Random Forest',\n",
       "    '4': 'SVC'}}},\n",
       " '400': {'mallm': {'highest_avg_mean': 0.8261978179651195,\n",
       "   'highest_avg_std': 6.283887209126427e-05,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'XGB',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Logistic Regression'}}},\n",
       " '800': {'mallm': {'highest_avg_mean': 0.8373745729035618,\n",
       "   'highest_avg_std': 0.0020253362833138056,\n",
       "   'eval_model_selected': {'0': 'Logistic Regression',\n",
       "    '1': 'Logistic Regression',\n",
       "    '2': 'Logistic Regression',\n",
       "    '3': 'Logistic Regression',\n",
       "    '4': 'Logistic Regression'}}}}"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary(res_asia)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a37b4c7-dab4-4c25-b49f-5a7005f91658",
   "metadata": {},
   "source": [
    "# ERICH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "2949b837-2fdd-441c-9f99-9c1edcd927b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_cols = ['sex', 'age', 'race', 'ICH volume', 'IVH volume',\n",
    "       'ICH location', 'GCS score', 'Systolic blood pressure',\n",
    "       'Diastolic blood pressure', 'Hypertension', 'Hyperlipidemia',\n",
    "       'Type 2 diabetes', 'Type 1 diabetes', 'Atrial Fibrillation',\n",
    "       'Congestive Heart Failure', 'Vascular disease', 'Cocaine use', 'Smoke',\n",
    "       'PTCA', 'Myocardial infraction', 'anti-diabetic', 'Anti-hypertensives',\n",
    "       'White blood cell count', 'Hemoglobin', 'Platelet count', 'APTT', 'INR',\n",
    "       'Glucose']\n",
    "y_col = 'mRS score after 90 days'\n",
    "\n",
    "cols = x_cols + [y_col]\n",
    "cate_var = ['race', 'sex', 'ICH location']\n",
    "bool_var = ['Hypertension', 'Hyperlipidemia',\n",
    "       'Type 2 diabetes', 'Type 1 diabetes', 'Atrial Fibrillation',\n",
    "       'Congestive Heart Failure', 'Vascular disease', 'Cocaine use', 'Smoke',\n",
    "       'PTCA', 'Myocardial infraction', 'anti-diabetic', 'Anti-hypertensives']\n",
    "num_var = list((set(cols) - set(cate_var)) - set(bool_var))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "d77f2d15-5b3d-4ee4-bf45-21768fd2d389",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_nm = \"erich\"\n",
    "test_path = \"../sample/ERICH/data_test.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "0a97a40c-6f5a-47a2-b1f8-523b6342e486",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0\n",
      "0 1\n",
      "0 2\n",
      "0 3\n",
      "0 4\n",
      "1 0\n",
      "1 1\n",
      "1 2\n",
      "1 3\n",
      "1 4\n",
      "2 0\n",
      "2 1\n",
      "2 2\n",
      "2 3\n",
      "2 4\n",
      "3 0\n",
      "3 1\n",
      "3 2\n",
      "3 3\n",
      "3 4\n"
     ]
    }
   ],
   "source": [
    "res_erich = MLE_eval(data_nm,test_path, cols, x_cols, y_col, \"regression\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "8cfcf03c-90cd-44e7-b2b4-7b8376310ebd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'mallm': {'highest_R2_mean': -0.025303441006437243,\n",
       "   'highest_R2_std': 0.06714007592553137,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'XGB',\n",
       "    '4': 'Decision Tree'}}},\n",
       " '200': {'mallm': {'highest_R2_mean': 0.024862252565487398,\n",
       "   'highest_R2_std': 0.02182546049414037,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'XGB',\n",
       "    '4': 'XGB'}}},\n",
       " '400': {'mallm': {'highest_R2_mean': 0.022909552180283364,\n",
       "   'highest_R2_std': 0.033641591410841575,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'XGB',\n",
       "    '3': 'Decision Tree',\n",
       "    '4': 'XGB'}}},\n",
       " '800': {'mallm': {'highest_R2_mean': 0.015070391277223339,\n",
       "   'highest_R2_std': 0.01984183404135882,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'XGB',\n",
       "    '2': 'XGB',\n",
       "    '3': 'XGB',\n",
       "    '4': 'Linear Regression'}}}}"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary2(res_erich)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0010d39b-5ba8-4483-bd78-3eee49d249dc",
   "metadata": {},
   "source": [
    "# Insurance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "56b39175-5f7d-499f-aafd-a90d7078ae3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_nm = \"Insurance\"\n",
    "test_path = \"../sample/Insurance/data_test.csv\"\n",
    "x_cols = ['age', 'sex', 'bmi', 'children', 'smoker', 'region']\n",
    "y_col = 'charges'\n",
    "bool_var = []\n",
    "cols = x_cols + [y_col]\n",
    "cate_var = ['sex', 'smoker', 'region']\n",
    "num_var = list(set(cols) - set(cate_var))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "32e28aed-8968-44b6-9f01-d19a52fb3281",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0\n",
      "0 1\n",
      "0 2\n",
      "0 3\n",
      "0 4\n",
      "1 0\n",
      "1 1\n",
      "1 2\n",
      "1 3\n",
      "1 4\n",
      "2 0\n",
      "2 1\n",
      "2 2\n",
      "2 3\n",
      "2 4\n",
      "3 0\n",
      "3 1\n",
      "3 2\n",
      "3 3\n",
      "3 4\n"
     ]
    }
   ],
   "source": [
    "res_erich = MLE_eval(data_nm,test_path, cols, x_cols, y_col, \"regression\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "62f09232-7922-4ce0-92dc-ce1b7222f3ea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'100': {'mallm': {'highest_R2_mean': 0.7151909242151676,\n",
       "   'highest_R2_std': 0.04469875818188462,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'XGB',\n",
       "    '4': 'Decision Tree'}}},\n",
       " '200': {'mallm': {'highest_R2_mean': 0.6931016561006802,\n",
       "   'highest_R2_std': 0.0263639496081884,\n",
       "   'eval_model_selected': {'0': 'XGB',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'XGB',\n",
       "    '4': 'Linear Regression'}}},\n",
       " '400': {'mallm': {'highest_R2_mean': 0.7136494390532002,\n",
       "   'highest_R2_std': 0.03494503522622808,\n",
       "   'eval_model_selected': {'0': 'Decision Tree',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Linear Regression',\n",
       "    '3': 'Linear Regression',\n",
       "    '4': 'Linear Regression'}}},\n",
       " '800': {'mallm': {'highest_R2_mean': 0.7217330689872063,\n",
       "   'highest_R2_std': 0.014785901619559321,\n",
       "   'eval_model_selected': {'0': 'Linear Regression',\n",
       "    '1': 'Decision Tree',\n",
       "    '2': 'Decision Tree',\n",
       "    '3': 'Decision Tree',\n",
       "    '4': 'XGB'}}}}"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mle_summary2(res_erich)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c79a999b-9157-47d5-8be6-f5d3418af720",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10",
   "language": "python",
   "name": "py3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
