{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# XGBoost Readmission Prediction\n",
    "\n",
    "This notebook implements readmission prediction using XGBoost on the MIMIC-III dataset.\n",
    "- XGBoost classifier\n",
    "- Uses only the **most recent visit** (not full sequence)\n",
    "- Multi-hot encoding (not embeddings)\n",
    "- Evaluation metrics (AUPRC, AUROC, F1, Kappa) \n",
    "    - Not sure whether Pyhealth suitable for ML predictions, use sklearn as replacement.\n",
    "- Data splits and preprocessing pipeline\n",
    "- Task definition (readmission < 15 days)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numpy Version: 1.26.2\n",
      "Sklearn Version: 1.4.0\n",
      "XGBoost Version: 3.1.2\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import pickle\n",
    "import numpy as np\n",
    "from preprocess.parse_csv import Mimic3Parser, parse_patient_info, EHRParser\n",
    "from preprocess.encoded import encode_code\n",
    "from preprocess.build_dataset import (\n",
    "    split_patients_disparity,\n",
    "    build_code_y_binary\n",
    ")\n",
    "\n",
    "import os\n",
    "import pickle\n",
    "import time\n",
    "import random\n",
    "import xgboost as xgb\n",
    "import sklearn\n",
    "from sklearn.metrics import average_precision_score, roc_auc_score, f1_score, cohen_kappa_score\n",
    "\n",
    "\n",
    "print(\"Numpy Version:\",  np.__version__)\n",
    "print(\"Sklearn Version:\", sklearn.__version__)\n",
    "print(\"XGBoost Version:\", xgb.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Situation 1. If parser.parse(sample_num=None, sorting=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XGBoost Readmission Prediction - Data Preprocessing\n",
      "==================================================\n",
      "Dataset: mimic3\n",
      "Task: readmission\n",
      "Feature keys: ['d', 'p', 'm']\n",
      "Train patients: 4000, Test patients: 2000\n",
      "Random seed: 6666\n",
      "\n",
      "Step 1: Parsing CSV files...\n",
      "parsing the csv file of admission ...\n",
      "\t58976 in 58976 rows\n",
      "Sorting admissions:  True\n",
      "parsing csv file of diagnosis ...\n",
      "\t651047 in 651047 rows\n",
      "parsing csv file of procedures ...\n",
      "\t240095 in 240095 rows\n",
      "parsing csv file of medications ...\n",
      "\tmapping NDC to ATC codes...\n",
      "\t 1930338 Index(['SUBJECT_ID', 'HADM_ID', 'NDC'], dtype='object')\n",
      "\t 878345 Index(['SUBJECT_ID', 'HADM_ID', 'NDC'], dtype='object')\n",
      "\t878345 in 878345 rows\n",
      "calibrating patients by admission ...\n",
      "calibrating admission by patients ...\n",
      "Aligning admissions across different concepts ...\n",
      "\tnum of total admission:  19894\n",
      "\tnum of valid admission:  17926\n",
      "\tvalid diagnosis visit num:  17325\n",
      "\tvalid procedure visit num:  17325\n",
      "\tvalid medication visit num:  17325\n",
      "Total patients: 6497\n",
      "Total admissions: 17325\n",
      "\n",
      "Step 2: Encoding medical codes...\n",
      "# Diagnosis: 4702; # Procedure: 1431; # Medication: 151\n",
      "\n",
      "Step 3: Loading patient demographic information...\n",
      "Gender: {'M': 3617, 'F': 2880}\n",
      "Age Groups: {'<18': 49, '18-30': 144, '30-60': 1949, '>60': 4355}\n",
      "Gender vs Age Groups: {'M': {'<18': 30, '18-30': 75, '30-60': 1139, '>60': 2373}, 'F': {'<18': 19, '18-30': 69, '30-60': 810, '>60': 1982}}\n",
      "Loaded demographic info for 6497 patients\n",
      "\n",
      "Step 4: Splitting patients into train/valid/test sets...\n",
      "\t100%00%\n",
      "Train: 4000, Valid: 497, Test: 2000\n",
      "\n",
      "Step 5: Preparing admission events...\n",
      "Admission events prepared\n",
      "\n",
      "Step 6: Building multi-hot encoded features...\n",
      "Current concept type of codes_x: d\n",
      "\tShape of x: (7188, 42, 4702), (2871, 42, 4702)\n",
      "Current concept type of codes_x: p\n",
      "\tShape of x: (7188, 42, 1431), (2871, 42, 1431)\n",
      "Current concept type of codes_x: m\n",
      "\tShape of x: (7188, 42, 151), (2871, 42, 151)\n",
      "\n",
      "Step 7: Building readmission labels...\n",
      "Shape of y: (7188,), (769,), (2871,)\n",
      "The number of element with value of 1 is: 1303\n",
      "\n",
      "Step 8: Extracting most recent visit features for XGBoost...\n",
      "X_train shape: (7188, 6284)\n",
      "X_valid shape: (769, 6284)\n",
      "X_test shape: (2871, 6284)\n",
      "\n",
      "Step 9: Saving processed data to standard path...\n",
      "Data saved to data\\mimic3\\standard\\xgboost_data.pkl\n",
      "\n",
      "==================================================\n",
      "Data preprocessing completed!\n",
      "==================================================\n"
     ]
    }
   ],
   "source": [
    "# Cell 1: Data Preprocessing Pipeline\n",
    "# ====== Configuration ======\n",
    "dataset = 'mimic3'\n",
    "data_path = 'data'\n",
    "dataset_path = os.path.join(data_path, dataset)\n",
    "raw_path = os.path.join(dataset_path, 'raw')\n",
    "parsed_path = os.path.join(dataset_path, 'parsed')\n",
    "standard_path = os.path.join(dataset_path, 'standard')\n",
    "\n",
    "# Feature configuration\n",
    "feature_keys = ['d', 'p', 'm']  # diagnosis, procedure, medication\n",
    "task = 'readmission'\n",
    "train_num = 4000\n",
    "test_num = 2000\n",
    "seed = 6666\n",
    "\n",
    "print(\"=\" * 50)\n",
    "print(\"XGBoost Readmission Prediction - Data Preprocessing\")\n",
    "print(\"=\" * 50)\n",
    "print(f\"Dataset: {dataset}\")\n",
    "print(f\"Task: {task}\")\n",
    "print(f\"Feature keys: {feature_keys}\")\n",
    "print(f\"Train patients: {train_num}, Test patients: {test_num}\")\n",
    "print(f\"Random seed: {seed}\\n\")\n",
    "\n",
    "print(\"Step 1: Parsing CSV files...\")\n",
    "parser = Mimic3Parser(raw_path, procedure=True, medication=True)\n",
    "patient_admission, admissions = parser.parse(sample_num=None, sorting=True)\n",
    "admission_codes, admission_pcs, admission_med = admissions\n",
    "\n",
    "print(f\"Total patients: {len(patient_admission)}\")\n",
    "print(f\"Total admissions: {len(admission_codes)}\\n\")\n",
    "\n",
    "print(\"Step 2: Encoding medical codes...\")\n",
    "admission_codes_encoded, code_map = encode_code(patient_admission, admission_codes)\n",
    "admission_pcs_encoded, pcs_map = encode_code(patient_admission, admission_pcs)\n",
    "admission_med_encoded, med_map = encode_code(patient_admission, admission_med)\n",
    "\n",
    "code_nums = {\n",
    "    'd': len(code_map),\n",
    "    'p': len(pcs_map),\n",
    "    'm': len(med_map)\n",
    "}\n",
    "print(f\"# Diagnosis: {code_nums['d']}; # Procedure: {code_nums['p']}; # Medication: {code_nums['m']}\\n\")\n",
    "\n",
    "print(\"Step 3: Loading patient demographic information...\")\n",
    "patient_info = parse_patient_info(raw_path, parsed_path)\n",
    "print(f\"Loaded demographic info for {len(patient_info)} patients\\n\")\n",
    "\n",
    "print(\"Step 4: Splitting patients into train/valid/test sets...\")\n",
    "train_pids, valid_pids, test_pids_g1, test_pids_g2 = split_patients_disparity(\n",
    "    patient_admission=patient_admission,\n",
    "    admission_codes=admission_codes,\n",
    "    code_map=code_map,\n",
    "    train_num=train_num,\n",
    "    test_num=test_num,\n",
    "    patient_info=patient_info,\n",
    "    feature_key=\"GENDER\",\n",
    "    g1_value=\"M\",\n",
    ")\n",
    "test_pids = np.concatenate((test_pids_g1, test_pids_g2))\n",
    "print(f\"Train: {len(train_pids)}, Valid: {len(valid_pids)}, Test: {len(test_pids)}\\n\")\n",
    "\n",
    "print(\"Step 5: Preparing admission events...\")\n",
    "admission_events_encoded = {\n",
    "    'd': admission_codes_encoded,\n",
    "    'p': admission_pcs_encoded,\n",
    "    'm': admission_med_encoded\n",
    "}\n",
    "print(\"Admission events prepared\\n\")\n",
    "\n",
    "print(\"Step 6: Building multi-hot encoded features...\")\n",
    "\n",
    "def build_code_x_multi_hot_fixed(pids, patient_admission, admission_events_encoded, event_types, code_nums):\n",
    "    \"\"\"\n",
    "    Build multi-hot encoded features for each visit\n",
    "    Fixed version to handle integer type conversion properly\n",
    "    \"\"\"\n",
    "    n = sum([len(patient_admission[pid]) - 1 for pid in pids])\n",
    "    x = {key: None for key in ['d', 'p', 'm']}\n",
    "    lens = {key: None for key in ['d', 'p', 'm']}\n",
    "    \n",
    "    for e_type in event_types:\n",
    "        max_admission_num = max([len(admissions) for admissions in patient_admission.values()])\n",
    "        x_t = np.zeros((n, max_admission_num, code_nums[e_type]), dtype=bool)\n",
    "        lens_t = np.zeros((n,), dtype=int)\n",
    "        p_idx_t = np.zeros((len(pids),), dtype=int)\n",
    "        \n",
    "        for i, pid in enumerate(pids):\n",
    "            admissions = patient_admission[pid]\n",
    "            p_idx_t[i] = p_idx_t[i-1] + len(admissions) - 1\n",
    "            for k, admission in enumerate(admissions[:-1]):\n",
    "                codes = np.array(admission_events_encoded[e_type][admission[EHRParser.adm_id_col]], dtype=int) - 1\n",
    "                for idx in range(p_idx_t[i-1]+k, p_idx_t[i]):\n",
    "                    x_t[idx, k, codes] = 1\n",
    "                lens_t[p_idx_t[i-1]+k] = k + 1\n",
    "        x[e_type], lens[e_type] = x_t, lens_t\n",
    "    return x, lens\n",
    "\n",
    "train_codes_x, train_visit_lens = build_code_x_multi_hot_fixed(\n",
    "    train_pids, patient_admission, admission_events_encoded, feature_keys, code_nums\n",
    ")\n",
    "valid_codes_x, valid_visit_lens = build_code_x_multi_hot_fixed(\n",
    "    valid_pids, patient_admission, admission_events_encoded, feature_keys, code_nums\n",
    ")\n",
    "test_codes_x, test_visit_lens = build_code_x_multi_hot_fixed(\n",
    "    test_pids, patient_admission, admission_events_encoded, feature_keys, code_nums\n",
    ")\n",
    "\n",
    "for t in train_codes_x:\n",
    "    print(f\"Current concept type of codes_x: {t}\")\n",
    "    print(f\"\\tShape of x: {train_codes_x[t].shape}, {test_codes_x[t].shape}\")\n",
    "\n",
    "print(\"\\nStep 7: Building readmission labels...\")\n",
    "train_y = build_code_y_binary(train_pids, patient_admission, task='readmission')\n",
    "valid_y = build_code_y_binary(valid_pids, patient_admission, task='readmission')\n",
    "test_y = build_code_y_binary(test_pids, patient_admission, task='readmission')\n",
    "\n",
    "print(f\"Shape of y: {train_y.shape}, {valid_y.shape}, {test_y.shape}\")\n",
    "print(f\"The number of element with value of 1 is: {np.sum(train_y == 1).item()}\\n\")\n",
    "\n",
    "print(\"Step 8: Extracting most recent visit features for XGBoost...\")\n",
    "\n",
    "def extract_last_visit_multihot(codes_x, visit_lens, feature_keys):\n",
    "    \"\"\"Extract multi-hot encoding from most recent visit only\"\"\"\n",
    "    N = codes_x[feature_keys[0]].shape[0]\n",
    "    feature_vectors = []\n",
    "    \n",
    "    for key in feature_keys:  # ['d', 'p', 'm']\n",
    "        key_features = []\n",
    "        for i in range(N):\n",
    "            last_visit_idx = visit_lens[key][i] - 1\n",
    "            last_visit_codes = codes_x[key][i, last_visit_idx, :]\n",
    "            key_features.append(last_visit_codes)\n",
    "        feature_vectors.append(np.array(key_features))  # (N, vocab_size_key)\n",
    "    \n",
    "    # Concatenate all feature types\n",
    "    X = np.concatenate(feature_vectors, axis=1)\n",
    "    return X.astype(float)\n",
    "\n",
    "X_train = extract_last_visit_multihot(train_codes_x, train_visit_lens, feature_keys)\n",
    "X_valid = extract_last_visit_multihot(valid_codes_x, valid_visit_lens, feature_keys)\n",
    "X_test = extract_last_visit_multihot(test_codes_x, test_visit_lens, feature_keys)\n",
    "\n",
    "print(f\"X_train shape: {X_train.shape}\")\n",
    "print(f\"X_valid shape: {X_valid.shape}\")\n",
    "print(f\"X_test shape: {X_test.shape}\\n\")\n",
    "\n",
    "print(\"Step 9: Saving processed data to standard path...\")\n",
    "if not os.path.exists(standard_path):\n",
    "    os.makedirs(standard_path)\n",
    "xgboost_data = {\n",
    "    'X_train': X_train,\n",
    "    'X_valid': X_valid,\n",
    "    'X_test': X_test,\n",
    "    'y_train': train_y,\n",
    "    'y_valid': valid_y,\n",
    "    'y_test': test_y,\n",
    "}\n",
    "pickle.dump(xgboost_data, open(os.path.join(standard_path, 'xgboost_data.pkl'), 'wb'))\n",
    "print(f\"Data saved to {os.path.join(standard_path, 'xgboost_data.pkl')}\\n\")\n",
    "\n",
    "print(\"=\" * 50)\n",
    "print(\"Data preprocessing completed!\")\n",
    "print(\"=\" * 50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XGBoost Training and Evaluation\n",
      "==================================================\n",
      "Task: readmission\n",
      "Dataset: mimic3\n",
      "Random seed: 6666\n",
      "\n",
      "Loading data...\n",
      "X_train shape: (7188, 6284)\n",
      "X_valid shape: (769, 6284)\n",
      "X_test shape: (2871, 6284)\n",
      "Shape of y: (7188,), (769,), (2871,)\n",
      "The number of element with value of 1 is: 1303\n",
      "\n",
      "Output size: 1\n",
      "Model: XGBClassifier\n",
      "Parameters: objective=binary:logistic, eval_metric=logloss, random_state=6666\n",
      "\n",
      "Training XGBoost model...\n",
      "==================================================\n",
      "[0]\tvalidation_0-logloss:0.42345\n",
      "[1]\tvalidation_0-logloss:0.42484\n",
      "[2]\tvalidation_0-logloss:0.42036\n",
      "[3]\tvalidation_0-logloss:0.42097\n",
      "[4]\tvalidation_0-logloss:0.42081\n",
      "[5]\tvalidation_0-logloss:0.42203\n",
      "[6]\tvalidation_0-logloss:0.42066\n",
      "[7]\tvalidation_0-logloss:0.42103\n",
      "[8]\tvalidation_0-logloss:0.42136\n",
      "[9]\tvalidation_0-logloss:0.42122\n",
      "[10]\tvalidation_0-logloss:0.42170\n",
      "[11]\tvalidation_0-logloss:0.42312\n",
      "[12]\tvalidation_0-logloss:0.42387\n",
      "[13]\tvalidation_0-logloss:0.42259\n",
      "[14]\tvalidation_0-logloss:0.42616\n",
      "[15]\tvalidation_0-logloss:0.42610\n",
      "[16]\tvalidation_0-logloss:0.42546\n",
      "[17]\tvalidation_0-logloss:0.42725\n",
      "[18]\tvalidation_0-logloss:0.42668\n",
      "[19]\tvalidation_0-logloss:0.42597\n",
      "[20]\tvalidation_0-logloss:0.42485\n",
      "[21]\tvalidation_0-logloss:0.42462\n",
      "[22]\tvalidation_0-logloss:0.42453\n",
      "[23]\tvalidation_0-logloss:0.42420\n",
      "[24]\tvalidation_0-logloss:0.42464\n",
      "[25]\tvalidation_0-logloss:0.42387\n",
      "[26]\tvalidation_0-logloss:0.42365\n",
      "[27]\tvalidation_0-logloss:0.42537\n",
      "[28]\tvalidation_0-logloss:0.42262\n",
      "[29]\tvalidation_0-logloss:0.42204\n",
      "[30]\tvalidation_0-logloss:0.42267\n",
      "[31]\tvalidation_0-logloss:0.42159\n",
      "[32]\tvalidation_0-logloss:0.42160\n",
      "[33]\tvalidation_0-logloss:0.42206\n",
      "[34]\tvalidation_0-logloss:0.42094\n",
      "[35]\tvalidation_0-logloss:0.42184\n",
      "[36]\tvalidation_0-logloss:0.42176\n",
      "[37]\tvalidation_0-logloss:0.42205\n",
      "[38]\tvalidation_0-logloss:0.42175\n",
      "[39]\tvalidation_0-logloss:0.42223\n",
      "[40]\tvalidation_0-logloss:0.42176\n",
      "[41]\tvalidation_0-logloss:0.42076\n",
      "[42]\tvalidation_0-logloss:0.42074\n",
      "[43]\tvalidation_0-logloss:0.42042\n",
      "[44]\tvalidation_0-logloss:0.42074\n",
      "[45]\tvalidation_0-logloss:0.42042\n",
      "[46]\tvalidation_0-logloss:0.42023\n",
      "[47]\tvalidation_0-logloss:0.42104\n",
      "[48]\tvalidation_0-logloss:0.42155\n",
      "[49]\tvalidation_0-logloss:0.42303\n",
      "[50]\tvalidation_0-logloss:0.42397\n",
      "[51]\tvalidation_0-logloss:0.42520\n",
      "[52]\tvalidation_0-logloss:0.42485\n",
      "[53]\tvalidation_0-logloss:0.42454\n",
      "[54]\tvalidation_0-logloss:0.42445\n",
      "[55]\tvalidation_0-logloss:0.42389\n",
      "[56]\tvalidation_0-logloss:0.42498\n",
      "[57]\tvalidation_0-logloss:0.42398\n",
      "[58]\tvalidation_0-logloss:0.42418\n",
      "[59]\tvalidation_0-logloss:0.42445\n",
      "[60]\tvalidation_0-logloss:0.42495\n",
      "[61]\tvalidation_0-logloss:0.42511\n",
      "[62]\tvalidation_0-logloss:0.42454\n",
      "[63]\tvalidation_0-logloss:0.42493\n",
      "[64]\tvalidation_0-logloss:0.42481\n",
      "[65]\tvalidation_0-logloss:0.42437\n",
      "[66]\tvalidation_0-logloss:0.42418\n",
      "[67]\tvalidation_0-logloss:0.42415\n",
      "[68]\tvalidation_0-logloss:0.42478\n",
      "[69]\tvalidation_0-logloss:0.42439\n",
      "[70]\tvalidation_0-logloss:0.42326\n",
      "[71]\tvalidation_0-logloss:0.42303\n",
      "[72]\tvalidation_0-logloss:0.42335\n",
      "[73]\tvalidation_0-logloss:0.42375\n",
      "[74]\tvalidation_0-logloss:0.42363\n",
      "[75]\tvalidation_0-logloss:0.42343\n",
      "[76]\tvalidation_0-logloss:0.42281\n",
      "[77]\tvalidation_0-logloss:0.42340\n",
      "[78]\tvalidation_0-logloss:0.42396\n",
      "[79]\tvalidation_0-logloss:0.42430\n",
      "[80]\tvalidation_0-logloss:0.42403\n",
      "[81]\tvalidation_0-logloss:0.42424\n",
      "[82]\tvalidation_0-logloss:0.42454\n",
      "[83]\tvalidation_0-logloss:0.42499\n",
      "[84]\tvalidation_0-logloss:0.42526\n",
      "[85]\tvalidation_0-logloss:0.42534\n",
      "[86]\tvalidation_0-logloss:0.42502\n",
      "[87]\tvalidation_0-logloss:0.42506\n",
      "[88]\tvalidation_0-logloss:0.42498\n",
      "[89]\tvalidation_0-logloss:0.42558\n",
      "[90]\tvalidation_0-logloss:0.42564\n",
      "[91]\tvalidation_0-logloss:0.42583\n",
      "[92]\tvalidation_0-logloss:0.42572\n",
      "[93]\tvalidation_0-logloss:0.42570\n",
      "[94]\tvalidation_0-logloss:0.42541\n",
      "[95]\tvalidation_0-logloss:0.42555\n",
      "[96]\tvalidation_0-logloss:0.42598\n",
      "[97]\tvalidation_0-logloss:0.42614\n",
      "[98]\tvalidation_0-logloss:0.42637\n",
      "[99]\tvalidation_0-logloss:0.42556\n",
      "\n",
      "Training completed! Time cost: 2.0s\n",
      "\n",
      "==================================================\n",
      "Evaluation on test set:\n",
      "==================================================\n",
      "    Evaluation: AUPRC: 0.2356 --- AUROC: 0.5769 --- F1: 0.0933 --- Kappa: 0.0485\n",
      "\n",
      "==================================================\n",
      "XGBoost pipeline completed!\n",
      "==================================================\n"
     ]
    }
   ],
   "source": [
    "# Cell 2: XGBoost Training and Evaluation\n",
    "\n",
    "# ====== Configuration ======\n",
    "batch_size = 128  # Not used for XGBoost but kept for consistency\n",
    "task = 'readmission'\n",
    "feature_keys = ['d', 'p', 'm']\n",
    "data_path = 'data'\n",
    "dataset = 'mimic3'\n",
    "dataset_path = os.path.join(data_path, dataset)\n",
    "standard_path = os.path.join(dataset_path, 'standard')\n",
    "\n",
    "seed = 6666\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "\n",
    "print(\"=\" * 50)\n",
    "print(\"XGBoost Training and Evaluation\")\n",
    "print(\"=\" * 50)\n",
    "print(f\"Task: {task}\")\n",
    "print(f\"Dataset: {dataset}\")\n",
    "print(f\"Random seed: {seed}\\n\")\n",
    "\n",
    "# ====== Load data ======\n",
    "print(\"Loading data...\")\n",
    "xgboost_dataset = pickle.load(open(os.path.join(standard_path, 'xgboost_data.pkl'), 'rb'))\n",
    "X_train = xgboost_dataset['X_train']\n",
    "X_valid = xgboost_dataset['X_valid']\n",
    "X_test = xgboost_dataset['X_test']\n",
    "y_train = xgboost_dataset['y_train']\n",
    "y_valid = xgboost_dataset['y_valid']\n",
    "y_test = xgboost_dataset['y_test']\n",
    "\n",
    "print(f\"X_train shape: {X_train.shape}\")\n",
    "print(f\"X_valid shape: {X_valid.shape}\")\n",
    "print(f\"X_test shape: {X_test.shape}\")\n",
    "print(f\"Shape of y: {y_train.shape}, {y_valid.shape}, {y_test.shape}\")\n",
    "print(f\"The number of element with value of 1 is: {np.sum(y_train == 1).item()}\\n\")\n",
    "\n",
    "# ====== Define evaluation function ======\n",
    "def evaluate_xgboost(model, X, y, output_size=1):\n",
    "    \"\"\"Evaluate XGBoost model with same metrics as Transformer baseline\"\"\"\n",
    "    y_pred_proba = model.predict_proba(X)[:, 1]\n",
    "    y_pred_binary = (y_pred_proba > 0.5).astype(int)\n",
    "\n",
    "    auprc = average_precision_score(y, y_pred_proba)\n",
    "    auroc = roc_auc_score(y, y_pred_proba)\n",
    "    f1 = f1_score(y, y_pred_binary)\n",
    "    kappa = cohen_kappa_score(y, y_pred_binary)\n",
    "    \n",
    "    print(f'    Evaluation: AUPRC: {auprc:.4f} --- AUROC: {auroc:.4f} --- F1: {f1:.4f} --- Kappa: {kappa:.4f}')\n",
    "    return None, None\n",
    "\n",
    "# ====== Model setup ======\n",
    "output_size = y_train.shape[1] if len(y_train.shape) == 2 else 1\n",
    "print(f\"Output size: {output_size}\")\n",
    "\n",
    "# Initialize XGBoost model (using default parameters)\n",
    "model = xgb.XGBClassifier(\n",
    "    objective='binary:logistic',\n",
    "    eval_metric='logloss',\n",
    "    random_state=seed,\n",
    ")\n",
    "\n",
    "print(f\"Model: {model.__class__.__name__}\")\n",
    "print(f\"Parameters: objective=binary:logistic, eval_metric=logloss, random_state={seed}\\n\")\n",
    "\n",
    "# ====== Training ======\n",
    "print(\"Training XGBoost model...\")\n",
    "print(\"=\" * 50)\n",
    "\n",
    "st = time.time()\n",
    "model.fit(\n",
    "    X_train,\n",
    "    y_train,\n",
    "    eval_set=[(X_valid, y_valid)],\n",
    "    verbose=True\n",
    ")\n",
    "et = time.time()\n",
    "\n",
    "# Calculate time cost\n",
    "time_cost_seconds = et - st\n",
    "if time_cost_seconds <= 60:\n",
    "    time_str = '%.1fs' % time_cost_seconds\n",
    "elif time_cost_seconds <= 3600:\n",
    "    time_str = '%dm%.1fs' % (time_cost_seconds // 60, time_cost_seconds % 60)\n",
    "else:\n",
    "    time_str = '%dh%dm%.1fs' % (time_cost_seconds // 3600, (time_cost_seconds % 3600) // 60, time_cost_seconds % 60)\n",
    "\n",
    "print(f\"\\nTraining completed! Time cost: {time_str}\\n\")\n",
    "\n",
    "# ====== Evaluation ======\n",
    "print(\"=\" * 50)\n",
    "print(\"Evaluation on test set:\")\n",
    "print(\"=\" * 50)\n",
    "test_loss, _ = evaluate_xgboost(model, X_test, y_test, output_size)\n",
    "\n",
    "print(\"\\n\" + \"=\" * 50)\n",
    "print(\"XGBoost pipeline completed!\")\n",
    "print(\"=\" * 50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Situation 2. If parser.parse(sample_num=None, sorting=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XGBoost Readmission Prediction - Data Preprocessing\n",
      "==================================================\n",
      "Dataset: mimic3\n",
      "Task: readmission\n",
      "Feature keys: ['d', 'p', 'm']\n",
      "Train patients: 4000, Test patients: 2000\n",
      "Random seed: 6666\n",
      "\n",
      "Step 1: Parsing CSV files...\n",
      "parsing the csv file of admission ...\n",
      "\t58976 in 58976 rows\n",
      "Sorting admissions:  False\n",
      "parsing csv file of diagnosis ...\n",
      "\t651047 in 651047 rows\n",
      "parsing csv file of procedures ...\n",
      "\t240095 in 240095 rows\n",
      "parsing csv file of medications ...\n",
      "\tmapping NDC to ATC codes...\n",
      "\t 1930338 Index(['SUBJECT_ID', 'HADM_ID', 'NDC'], dtype='object')\n",
      "\t 878345 Index(['SUBJECT_ID', 'HADM_ID', 'NDC'], dtype='object')\n",
      "\t878345 in 878345 rows\n",
      "calibrating patients by admission ...\n",
      "calibrating admission by patients ...\n",
      "Aligning admissions across different concepts ...\n",
      "\tnum of total admission:  19894\n",
      "\tnum of valid admission:  17926\n",
      "\tvalid diagnosis visit num:  17325\n",
      "\tvalid procedure visit num:  17325\n",
      "\tvalid medication visit num:  17325\n",
      "Total patients: 6497\n",
      "Total admissions: 17325\n",
      "\n",
      "Step 2: Encoding medical codes...\n",
      "# Diagnosis: 4702; # Procedure: 1431; # Medication: 151\n",
      "\n",
      "Step 3: Loading patient demographic information...\n",
      "Gender: {'M': 3617, 'F': 2880}\n",
      "Age Groups: {'<18': 49, '18-30': 144, '30-60': 1949, '>60': 4355}\n",
      "Gender vs Age Groups: {'M': {'<18': 30, '18-30': 75, '30-60': 1139, '>60': 2373}, 'F': {'<18': 19, '18-30': 69, '30-60': 810, '>60': 1982}}\n",
      "Loaded demographic info for 6497 patients\n",
      "\n",
      "Step 4: Splitting patients into train/valid/test sets...\n",
      "\t100%00%\n",
      "Train: 4000, Valid: 497, Test: 2000\n",
      "\n",
      "Step 5: Preparing admission events...\n",
      "Admission events prepared\n",
      "\n",
      "Step 6: Building multi-hot encoded features...\n",
      "Current concept type of codes_x: d\n",
      "\tShape of x: (7188, 42, 4702), (2871, 42, 4702)\n",
      "Current concept type of codes_x: p\n",
      "\tShape of x: (7188, 42, 1431), (2871, 42, 1431)\n",
      "Current concept type of codes_x: m\n",
      "\tShape of x: (7188, 42, 151), (2871, 42, 151)\n",
      "\n",
      "Step 7: Building readmission labels...\n",
      "Shape of y: (7188,), (769,), (2871,)\n",
      "The number of element with value of 1 is: 4049\n",
      "\n",
      "Step 8: Extracting most recent visit features for XGBoost...\n",
      "X_train shape: (7188, 6284)\n",
      "X_valid shape: (769, 6284)\n",
      "X_test shape: (2871, 6284)\n",
      "\n",
      "Step 9: Saving processed data to standard path...\n",
      "Data saved to data\\mimic3\\standard\\xgboost_data.pkl\n",
      "\n",
      "==================================================\n",
      "Data preprocessing completed!\n",
      "==================================================\n"
     ]
    }
   ],
   "source": [
    "# Cell 1: Data Preprocessing Pipeline\n",
    "# ====== Configuration ======\n",
    "dataset = 'mimic3'\n",
    "data_path = 'data'\n",
    "dataset_path = os.path.join(data_path, dataset)\n",
    "raw_path = os.path.join(dataset_path, 'raw')\n",
    "parsed_path = os.path.join(dataset_path, 'parsed')\n",
    "standard_path = os.path.join(dataset_path, 'standard')\n",
    "\n",
    "# Feature configuration\n",
    "feature_keys = ['d', 'p', 'm']  # diagnosis, procedure, medication\n",
    "task = 'readmission'\n",
    "train_num = 4000\n",
    "test_num = 2000\n",
    "seed = 6666\n",
    "\n",
    "print(\"=\" * 50)\n",
    "print(\"XGBoost Readmission Prediction - Data Preprocessing\")\n",
    "print(\"=\" * 50)\n",
    "print(f\"Dataset: {dataset}\")\n",
    "print(f\"Task: {task}\")\n",
    "print(f\"Feature keys: {feature_keys}\")\n",
    "print(f\"Train patients: {train_num}, Test patients: {test_num}\")\n",
    "print(f\"Random seed: {seed}\\n\")\n",
    "\n",
    "print(\"Step 1: Parsing CSV files...\")\n",
    "parser = Mimic3Parser(raw_path, procedure=True, medication=True)\n",
    "patient_admission, admissions = parser.parse(sample_num=None, sorting=False)\n",
    "admission_codes, admission_pcs, admission_med = admissions\n",
    "\n",
    "print(f\"Total patients: {len(patient_admission)}\")\n",
    "print(f\"Total admissions: {len(admission_codes)}\\n\")\n",
    "\n",
    "print(\"Step 2: Encoding medical codes...\")\n",
    "admission_codes_encoded, code_map = encode_code(patient_admission, admission_codes)\n",
    "admission_pcs_encoded, pcs_map = encode_code(patient_admission, admission_pcs)\n",
    "admission_med_encoded, med_map = encode_code(patient_admission, admission_med)\n",
    "\n",
    "code_nums = {\n",
    "    'd': len(code_map),\n",
    "    'p': len(pcs_map),\n",
    "    'm': len(med_map)\n",
    "}\n",
    "print(f\"# Diagnosis: {code_nums['d']}; # Procedure: {code_nums['p']}; # Medication: {code_nums['m']}\\n\")\n",
    "\n",
    "print(\"Step 3: Loading patient demographic information...\")\n",
    "patient_info = parse_patient_info(raw_path, parsed_path)\n",
    "print(f\"Loaded demographic info for {len(patient_info)} patients\\n\")\n",
    "\n",
    "print(\"Step 4: Splitting patients into train/valid/test sets...\")\n",
    "train_pids, valid_pids, test_pids_g1, test_pids_g2 = split_patients_disparity(\n",
    "    patient_admission=patient_admission,\n",
    "    admission_codes=admission_codes,\n",
    "    code_map=code_map,\n",
    "    train_num=train_num,\n",
    "    test_num=test_num,\n",
    "    patient_info=patient_info,\n",
    "    feature_key=\"GENDER\",\n",
    "    g1_value=\"M\",\n",
    ")\n",
    "test_pids = np.concatenate((test_pids_g1, test_pids_g2))\n",
    "print(f\"Train: {len(train_pids)}, Valid: {len(valid_pids)}, Test: {len(test_pids)}\\n\")\n",
    "\n",
    "print(\"Step 5: Preparing admission events...\")\n",
    "admission_events_encoded = {\n",
    "    'd': admission_codes_encoded,\n",
    "    'p': admission_pcs_encoded,\n",
    "    'm': admission_med_encoded\n",
    "}\n",
    "print(\"Admission events prepared\\n\")\n",
    "\n",
    "print(\"Step 6: Building multi-hot encoded features...\")\n",
    "\n",
    "def build_code_x_multi_hot_fixed(pids, patient_admission, admission_events_encoded, event_types, code_nums):\n",
    "    \"\"\"\n",
    "    Build multi-hot encoded features for each visit\n",
    "    Fixed version to handle integer type conversion properly\n",
    "    \"\"\"\n",
    "    n = sum([len(patient_admission[pid]) - 1 for pid in pids])\n",
    "    x = {key: None for key in ['d', 'p', 'm']}\n",
    "    lens = {key: None for key in ['d', 'p', 'm']}\n",
    "    \n",
    "    for e_type in event_types:\n",
    "        max_admission_num = max([len(admissions) for admissions in patient_admission.values()])\n",
    "        x_t = np.zeros((n, max_admission_num, code_nums[e_type]), dtype=bool)\n",
    "        lens_t = np.zeros((n,), dtype=int)\n",
    "        p_idx_t = np.zeros((len(pids),), dtype=int)\n",
    "        \n",
    "        for i, pid in enumerate(pids):\n",
    "            admissions = patient_admission[pid]\n",
    "            p_idx_t[i] = p_idx_t[i-1] + len(admissions) - 1\n",
    "            for k, admission in enumerate(admissions[:-1]):\n",
    "                codes = np.array(admission_events_encoded[e_type][admission[EHRParser.adm_id_col]], dtype=int) - 1\n",
    "                for idx in range(p_idx_t[i-1]+k, p_idx_t[i]):\n",
    "                    x_t[idx, k, codes] = 1\n",
    "                lens_t[p_idx_t[i-1]+k] = k + 1\n",
    "        x[e_type], lens[e_type] = x_t, lens_t\n",
    "    return x, lens\n",
    "\n",
    "train_codes_x, train_visit_lens = build_code_x_multi_hot_fixed(\n",
    "    train_pids, patient_admission, admission_events_encoded, feature_keys, code_nums\n",
    ")\n",
    "valid_codes_x, valid_visit_lens = build_code_x_multi_hot_fixed(\n",
    "    valid_pids, patient_admission, admission_events_encoded, feature_keys, code_nums\n",
    ")\n",
    "test_codes_x, test_visit_lens = build_code_x_multi_hot_fixed(\n",
    "    test_pids, patient_admission, admission_events_encoded, feature_keys, code_nums\n",
    ")\n",
    "\n",
    "for t in train_codes_x:\n",
    "    print(f\"Current concept type of codes_x: {t}\")\n",
    "    print(f\"\\tShape of x: {train_codes_x[t].shape}, {test_codes_x[t].shape}\")\n",
    "\n",
    "print(\"\\nStep 7: Building readmission labels...\")\n",
    "train_y = build_code_y_binary(train_pids, patient_admission, task='readmission')\n",
    "valid_y = build_code_y_binary(valid_pids, patient_admission, task='readmission')\n",
    "test_y = build_code_y_binary(test_pids, patient_admission, task='readmission')\n",
    "\n",
    "print(f\"Shape of y: {train_y.shape}, {valid_y.shape}, {test_y.shape}\")\n",
    "print(f\"The number of element with value of 1 is: {np.sum(train_y == 1).item()}\\n\")\n",
    "\n",
    "print(\"Step 8: Extracting most recent visit features for XGBoost...\")\n",
    "\n",
    "def extract_last_visit_multihot(codes_x, visit_lens, feature_keys):\n",
    "    \"\"\"Extract multi-hot encoding from most recent visit only\"\"\"\n",
    "    N = codes_x[feature_keys[0]].shape[0]\n",
    "    feature_vectors = []\n",
    "    \n",
    "    for key in feature_keys:  # ['d', 'p', 'm']\n",
    "        key_features = []\n",
    "        for i in range(N):\n",
    "            last_visit_idx = visit_lens[key][i] - 1\n",
    "            last_visit_codes = codes_x[key][i, last_visit_idx, :]\n",
    "            key_features.append(last_visit_codes)\n",
    "        feature_vectors.append(np.array(key_features))  # (N, vocab_size_key)\n",
    "    \n",
    "    X = np.concatenate(feature_vectors, axis=1)  # (N, vocab_d + vocab_p + vocab_m)\n",
    "    return X.astype(float)\n",
    "\n",
    "X_train = extract_last_visit_multihot(train_codes_x, train_visit_lens, feature_keys)\n",
    "X_valid = extract_last_visit_multihot(valid_codes_x, valid_visit_lens, feature_keys)\n",
    "X_test = extract_last_visit_multihot(test_codes_x, test_visit_lens, feature_keys)\n",
    "\n",
    "print(f\"X_train shape: {X_train.shape}\")\n",
    "print(f\"X_valid shape: {X_valid.shape}\")\n",
    "print(f\"X_test shape: {X_test.shape}\\n\")\n",
    "\n",
    "print(\"Step 9: Saving processed data to standard path...\")\n",
    "if not os.path.exists(standard_path):\n",
    "    os.makedirs(standard_path)\n",
    "xgboost_data = {\n",
    "    'X_train': X_train,\n",
    "    'X_valid': X_valid,\n",
    "    'X_test': X_test,\n",
    "    'y_train': train_y,\n",
    "    'y_valid': valid_y,\n",
    "    'y_test': test_y,\n",
    "}\n",
    "pickle.dump(xgboost_data, open(os.path.join(standard_path, 'xgboost_data.pkl'), 'wb'))\n",
    "print(f\"Data saved to {os.path.join(standard_path, 'xgboost_data.pkl')}\\n\")\n",
    "\n",
    "print(\"=\" * 50)\n",
    "print(\"Data preprocessing completed!\")\n",
    "print(\"=\" * 50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XGBoost Training and Evaluation\n",
      "==================================================\n",
      "Task: readmission\n",
      "Dataset: mimic3\n",
      "Random seed: 6666\n",
      "\n",
      "Loading data...\n",
      "X_train shape: (7188, 6284)\n",
      "X_valid shape: (769, 6284)\n",
      "X_test shape: (2871, 6284)\n",
      "Shape of y: (7188,), (769,), (2871,)\n",
      "The number of element with value of 1 is: 4049\n",
      "\n",
      "Output size: 1\n",
      "Model: XGBClassifier\n",
      "Parameters: objective=binary:logistic, eval_metric=logloss, random_state=6666\n",
      "\n",
      "Training XGBoost model...\n",
      "==================================================\n",
      "[0]\tvalidation_0-logloss:0.67975\n",
      "[1]\tvalidation_0-logloss:0.67163\n",
      "[2]\tvalidation_0-logloss:0.66867\n",
      "[3]\tvalidation_0-logloss:0.66245\n",
      "[4]\tvalidation_0-logloss:0.65998\n",
      "[5]\tvalidation_0-logloss:0.65751\n",
      "[6]\tvalidation_0-logloss:0.65797\n",
      "[7]\tvalidation_0-logloss:0.65677\n",
      "[8]\tvalidation_0-logloss:0.65483\n",
      "[9]\tvalidation_0-logloss:0.65405\n",
      "[10]\tvalidation_0-logloss:0.65430\n",
      "[11]\tvalidation_0-logloss:0.65531\n",
      "[12]\tvalidation_0-logloss:0.65347\n",
      "[13]\tvalidation_0-logloss:0.65424\n",
      "[14]\tvalidation_0-logloss:0.65547\n",
      "[15]\tvalidation_0-logloss:0.65838\n",
      "[16]\tvalidation_0-logloss:0.65962\n",
      "[17]\tvalidation_0-logloss:0.65948\n",
      "[18]\tvalidation_0-logloss:0.65816\n",
      "[19]\tvalidation_0-logloss:0.65888\n",
      "[20]\tvalidation_0-logloss:0.65836\n",
      "[21]\tvalidation_0-logloss:0.65861\n",
      "[22]\tvalidation_0-logloss:0.65865\n",
      "[23]\tvalidation_0-logloss:0.65888\n",
      "[24]\tvalidation_0-logloss:0.65846\n",
      "[25]\tvalidation_0-logloss:0.65810\n",
      "[26]\tvalidation_0-logloss:0.65898\n",
      "[27]\tvalidation_0-logloss:0.65900\n",
      "[28]\tvalidation_0-logloss:0.66225\n",
      "[29]\tvalidation_0-logloss:0.66218\n",
      "[30]\tvalidation_0-logloss:0.66286\n",
      "[31]\tvalidation_0-logloss:0.66211\n",
      "[32]\tvalidation_0-logloss:0.66279\n",
      "[33]\tvalidation_0-logloss:0.66389\n",
      "[34]\tvalidation_0-logloss:0.66507\n",
      "[35]\tvalidation_0-logloss:0.66581\n",
      "[36]\tvalidation_0-logloss:0.66611\n",
      "[37]\tvalidation_0-logloss:0.66625\n",
      "[38]\tvalidation_0-logloss:0.66648\n",
      "[39]\tvalidation_0-logloss:0.66683\n",
      "[40]\tvalidation_0-logloss:0.66749\n",
      "[41]\tvalidation_0-logloss:0.66658\n",
      "[42]\tvalidation_0-logloss:0.66688\n",
      "[43]\tvalidation_0-logloss:0.66666\n",
      "[44]\tvalidation_0-logloss:0.66698\n",
      "[45]\tvalidation_0-logloss:0.66633\n",
      "[46]\tvalidation_0-logloss:0.66615\n",
      "[47]\tvalidation_0-logloss:0.66532\n",
      "[48]\tvalidation_0-logloss:0.66506\n",
      "[49]\tvalidation_0-logloss:0.66587\n",
      "[50]\tvalidation_0-logloss:0.66628\n",
      "[51]\tvalidation_0-logloss:0.66690\n",
      "[52]\tvalidation_0-logloss:0.66707\n",
      "[53]\tvalidation_0-logloss:0.66763\n",
      "[54]\tvalidation_0-logloss:0.66856\n",
      "[55]\tvalidation_0-logloss:0.66930\n",
      "[56]\tvalidation_0-logloss:0.67002\n",
      "[57]\tvalidation_0-logloss:0.66981\n",
      "[58]\tvalidation_0-logloss:0.66977\n",
      "[59]\tvalidation_0-logloss:0.66986\n",
      "[60]\tvalidation_0-logloss:0.66972\n",
      "[61]\tvalidation_0-logloss:0.66934\n",
      "[62]\tvalidation_0-logloss:0.66890\n",
      "[63]\tvalidation_0-logloss:0.66868\n",
      "[64]\tvalidation_0-logloss:0.66864\n",
      "[65]\tvalidation_0-logloss:0.66773\n",
      "[66]\tvalidation_0-logloss:0.66870\n",
      "[67]\tvalidation_0-logloss:0.66985\n",
      "[68]\tvalidation_0-logloss:0.66921\n",
      "[69]\tvalidation_0-logloss:0.67021\n",
      "[70]\tvalidation_0-logloss:0.67050\n",
      "[71]\tvalidation_0-logloss:0.67001\n",
      "[72]\tvalidation_0-logloss:0.67050\n",
      "[73]\tvalidation_0-logloss:0.66964\n",
      "[74]\tvalidation_0-logloss:0.67103\n",
      "[75]\tvalidation_0-logloss:0.67044\n",
      "[76]\tvalidation_0-logloss:0.67103\n",
      "[77]\tvalidation_0-logloss:0.67110\n",
      "[78]\tvalidation_0-logloss:0.67185\n",
      "[79]\tvalidation_0-logloss:0.67274\n",
      "[80]\tvalidation_0-logloss:0.67401\n",
      "[81]\tvalidation_0-logloss:0.67420\n",
      "[82]\tvalidation_0-logloss:0.67466\n",
      "[83]\tvalidation_0-logloss:0.67472\n",
      "[84]\tvalidation_0-logloss:0.67512\n",
      "[85]\tvalidation_0-logloss:0.67488\n",
      "[86]\tvalidation_0-logloss:0.67555\n",
      "[87]\tvalidation_0-logloss:0.67530\n",
      "[88]\tvalidation_0-logloss:0.67523\n",
      "[89]\tvalidation_0-logloss:0.67616\n",
      "[90]\tvalidation_0-logloss:0.67527\n",
      "[91]\tvalidation_0-logloss:0.67647\n",
      "[92]\tvalidation_0-logloss:0.67690\n",
      "[93]\tvalidation_0-logloss:0.67718\n",
      "[94]\tvalidation_0-logloss:0.67687\n",
      "[95]\tvalidation_0-logloss:0.67799\n",
      "[96]\tvalidation_0-logloss:0.67857\n",
      "[97]\tvalidation_0-logloss:0.67769\n",
      "[98]\tvalidation_0-logloss:0.67672\n",
      "[99]\tvalidation_0-logloss:0.67732\n",
      "\n",
      "Training completed! Time cost: 1.8s\n",
      "\n",
      "==================================================\n",
      "Evaluation on test set:\n",
      "==================================================\n",
      "    Evaluation: AUPRC: 0.7139 --- AUROC: 0.6598 --- F1: 0.6818 --- Kappa: 0.2132\n",
      "\n",
      "==================================================\n",
      "XGBoost pipeline completed!\n",
      "==================================================\n"
     ]
    }
   ],
   "source": [
    "# Cell 2: XGBoost Training and Evaluation\n",
    "\n",
    "# ====== Configuration ======\n",
    "batch_size = 128  # Not used for XGBoost but kept for consistency\n",
    "task = 'readmission'\n",
    "feature_keys = ['d', 'p', 'm']\n",
    "data_path = 'data'\n",
    "dataset = 'mimic3'\n",
    "dataset_path = os.path.join(data_path, dataset)\n",
    "standard_path = os.path.join(dataset_path, 'standard')\n",
    "\n",
    "seed = 6666\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "\n",
    "print(\"=\" * 50)\n",
    "print(\"XGBoost Training and Evaluation\")\n",
    "print(\"=\" * 50)\n",
    "print(f\"Task: {task}\")\n",
    "print(f\"Dataset: {dataset}\")\n",
    "print(f\"Random seed: {seed}\\n\")\n",
    "\n",
    "# ====== Load data ======\n",
    "print(\"Loading data...\")\n",
    "xgboost_dataset = pickle.load(open(os.path.join(standard_path, 'xgboost_data.pkl'), 'rb'))\n",
    "X_train = xgboost_dataset['X_train']\n",
    "X_valid = xgboost_dataset['X_valid']\n",
    "X_test = xgboost_dataset['X_test']\n",
    "y_train = xgboost_dataset['y_train']\n",
    "y_valid = xgboost_dataset['y_valid']\n",
    "y_test = xgboost_dataset['y_test']\n",
    "\n",
    "print(f\"X_train shape: {X_train.shape}\")\n",
    "print(f\"X_valid shape: {X_valid.shape}\")\n",
    "print(f\"X_test shape: {X_test.shape}\")\n",
    "print(f\"Shape of y: {y_train.shape}, {y_valid.shape}, {y_test.shape}\")\n",
    "print(f\"The number of element with value of 1 is: {np.sum(y_train == 1).item()}\\n\")\n",
    "\n",
    "# ====== Define evaluation function ======\n",
    "def evaluate_xgboost(model, X, y, output_size=1):\n",
    "    \"\"\"Evaluate XGBoost model with same metrics as Transformer baseline\"\"\"\n",
    "    y_pred_proba = model.predict_proba(X)[:, 1]\n",
    "    y_pred_binary = (y_pred_proba > 0.5).astype(int)\n",
    "\n",
    "    auprc = average_precision_score(y, y_pred_proba)\n",
    "    auroc = roc_auc_score(y, y_pred_proba)\n",
    "    f1 = f1_score(y, y_pred_binary)\n",
    "    kappa = cohen_kappa_score(y, y_pred_binary)\n",
    "    \n",
    "    print(f'    Evaluation: AUPRC: {auprc:.4f} --- AUROC: {auroc:.4f} --- F1: {f1:.4f} --- Kappa: {kappa:.4f}')\n",
    "    return None, None\n",
    "\n",
    "# ====== Model setup ======\n",
    "output_size = y_train.shape[1] if len(y_train.shape) == 2 else 1\n",
    "print(f\"Output size: {output_size}\")\n",
    "\n",
    "# Initialize XGBoost model (using default parameters)\n",
    "model = xgb.XGBClassifier(\n",
    "    objective='binary:logistic',\n",
    "    eval_metric='logloss',\n",
    "    random_state=seed,\n",
    ")\n",
    "\n",
    "print(f\"Model: {model.__class__.__name__}\")\n",
    "print(f\"Parameters: objective=binary:logistic, eval_metric=logloss, random_state={seed}\\n\")\n",
    "\n",
    "# ====== Training ======\n",
    "print(\"Training XGBoost model...\")\n",
    "print(\"=\" * 50)\n",
    "\n",
    "st = time.time()\n",
    "model.fit(\n",
    "    X_train,\n",
    "    y_train,\n",
    "    eval_set=[(X_valid, y_valid)],\n",
    "    verbose=True\n",
    ")\n",
    "et = time.time()\n",
    "\n",
    "# Calculate time cost\n",
    "time_cost_seconds = et - st\n",
    "if time_cost_seconds <= 60:\n",
    "    time_str = '%.1fs' % time_cost_seconds\n",
    "elif time_cost_seconds <= 3600:\n",
    "    time_str = '%dm%.1fs' % (time_cost_seconds // 60, time_cost_seconds % 60)\n",
    "else:\n",
    "    time_str = '%dh%dm%.1fs' % (time_cost_seconds // 3600, (time_cost_seconds % 3600) // 60, time_cost_seconds % 60)\n",
    "\n",
    "print(f\"\\nTraining completed! Time cost: {time_str}\\n\")\n",
    "\n",
    "# ====== Evaluation ======\n",
    "print(\"=\" * 50)\n",
    "print(\"Evaluation on test set:\")\n",
    "print(\"=\" * 50)\n",
    "test_loss, _ = evaluate_xgboost(model, X_test, y_test, output_size)\n",
    "\n",
    "print(\"\\n\" + \"=\" * 50)\n",
    "print(\"XGBoost pipeline completed!\")\n",
    "print(\"=\" * 50)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
