{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e8b39b24-a571-45ea-981e-cd12349e2d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import xgboost as xgb\n",
    "from sklearn.metrics import balanced_accuracy_score, f1_score, roc_auc_score, precision_recall_curve, auc\n",
    "from sklearn.model_selection import StratifiedGroupKFold, cross_val_score, train_test_split\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d31a1d1-678a-4693-b8d1-13783d671412",
   "metadata": {},
   "source": [
    "Train / test splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "33e0a7ac-5a66-4864-8ac5-80579a4d9651",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"he_splits.pickle\", \"rb\") as f:\n",
    "    splits = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d80521b8-5c3c-4863-83d4-bdb64cfba593",
   "metadata": {},
   "source": [
    "Feature dataset and labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "309cea09-b6a3-4723-ad23-7983133b34dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"starcoder2_dataset.pickle\", \"rb\") as f:\n",
    "    X = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a2075d94-70df-4fca-be9b-eb03d4956cac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4100, 6912)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4b2be9dd-fcb5-46f8-bcbd-193adc58e475",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('starcoder2_humaneval_Y.pickle', 'rb') as f:\n",
    "    y = pickle.load(f)\n",
    "y = np.array(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3fa1f75b-7e6d-4f73-bedc-3af78c509cd6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4100,)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "851f2c64-c1d5-4fe2-9750-44b824dfd309",
   "metadata": {},
   "source": [
    "Train XGBoost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9667bff8-9a39-443f-bd22-320fc8d3b003",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_classifier(X_train, y_train, X_test, y_test):\n",
    "    print (X_train.shape, X_test.shape)\n",
    "    print (y_train.shape, y_test.shape)\n",
    "    bst = xgb.XGBClassifier(tree_method=\"hist\")\n",
    "    bst.fit(X_train, y_train)\n",
    "    test_pred = bst.predict(X_test)\n",
    "    test_proba = bst.predict_proba(X_test)\n",
    "\n",
    "    rocauc = roc_auc_score(y_test, test_proba[:, 1])\n",
    "    f1 = f1_score(y_test, test_pred)\n",
    "\n",
    "    return bst, rocauc, f1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "664f4dde-d377-4165-a730-eb91d665be60",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RUN 0\n",
      "(3275, 6912) (825, 6912)\n",
      "(3275,) (825,)\n",
      "RUN 1\n",
      "(3275, 6912) (825, 6912)\n",
      "(3275,) (825,)\n",
      "RUN 2\n",
      "(3275, 6912) (825, 6912)\n",
      "(3275,) (825,)\n",
      "RUN 3\n",
      "(3275, 6912) (825, 6912)\n",
      "(3275,) (825,)\n",
      "RUN 4\n",
      "(3300, 6912) (800, 6912)\n",
      "(3300,) (800,)\n"
     ]
    }
   ],
   "source": [
    "roc_auc_scores, f1_scores = [], []\n",
    "clfs = []\n",
    "for i, (train_index, test_index) in enumerate(splits):\n",
    "    print ('RUN {}'.format(i))\n",
    "    X_train, X_test = X[train_index], X[test_index]\n",
    "    y_train, y_test = y[train_index], y[test_index]\n",
    "    bst, rocauc, f1 = train_classifier(X_train, y_train, X_test, y_test)\n",
    "    clfs.append(bst)\n",
    "    roc_auc_scores.append(rocauc)\n",
    "    f1_scores.append(f1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d70262ee-70f8-4eb1-8988-2dfc44325248",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "ROC AUC:  0.829 0.027\n",
      "[0.805 0.88  0.815 0.813 0.833]\n",
      "\n",
      "F1 score 0.542 0.069\n",
      "[0.529 0.586 0.413 0.576 0.605]\n"
     ]
    }
   ],
   "source": [
    "print ('\\nROC AUC: ', np.mean(roc_auc_scores).round(3), np.std(roc_auc_scores).round(3))\n",
    "print (np.array(roc_auc_scores).round(3))\n",
    "print ('\\nF1 score', np.mean(f1_scores).round(3), np.std(f1_scores).round(3))\n",
    "print (np.array(f1_scores).round(3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba98f204-6d02-4e9e-b20c-e4f3fca4d092",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "106e6cbe-e98d-4754-a0cc-c1969b40c70f",
   "metadata": {},
   "source": [
    "Feature pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4fb63deb-4590-4c50-8ffc-620b310e17a1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RATIO: 1.0\n",
      "N FEATURES: 6912\n",
      "(3275, 6912) (825, 6912)\n",
      "(3275,) (825,)\n",
      "(3275, 6912) (825, 6912)\n",
      "(3275,) (825,)\n",
      "(3275, 6912) (825, 6912)\n",
      "(3275,) (825,)\n",
      "(3275, 6912) (825, 6912)\n",
      "(3275,) (825,)\n",
      "(3300, 6912) (800, 6912)\n",
      "(3300,) (800,)\n",
      "RATIO: 0.2\n",
      "N FEATURES: 1382\n",
      "(3275, 1382) (825, 1382)\n",
      "(3275,) (825,)\n",
      "(3275, 1382) (825, 1382)\n",
      "(3275,) (825,)\n",
      "(3275, 1382) (825, 1382)\n",
      "(3275,) (825,)\n",
      "(3275, 1382) (825, 1382)\n",
      "(3275,) (825,)\n",
      "(3300, 1382) (800, 1382)\n",
      "(3300,) (800,)\n",
      "RATIO: 0.1\n",
      "N FEATURES: 691\n",
      "(3275, 691) (825, 691)\n",
      "(3275,) (825,)\n",
      "(3275, 691) (825, 691)\n",
      "(3275,) (825,)\n",
      "(3275, 691) (825, 691)\n",
      "(3275,) (825,)\n",
      "(3275, 691) (825, 691)\n",
      "(3275,) (825,)\n",
      "(3300, 691) (800, 691)\n",
      "(3300,) (800,)\n",
      "RATIO: 0.05\n",
      "N FEATURES: 346\n",
      "(3275, 346) (825, 346)\n",
      "(3275,) (825,)\n",
      "(3275, 346) (825, 346)\n",
      "(3275,) (825,)\n",
      "(3275, 346) (825, 346)\n",
      "(3275,) (825,)\n",
      "(3275, 346) (825, 346)\n",
      "(3275,) (825,)\n",
      "(3300, 346) (800, 346)\n",
      "(3300,) (800,)\n",
      "RATIO: 0.01\n",
      "N FEATURES: 69\n",
      "(3275, 69) (825, 69)\n",
      "(3275,) (825,)\n",
      "(3275, 69) (825, 69)\n",
      "(3275,) (825,)\n",
      "(3275, 69) (825, 69)\n",
      "(3275,) (825,)\n",
      "(3275, 69) (825, 69)\n",
      "(3275,) (825,)\n",
      "(3300, 69) (800, 69)\n",
      "(3300,) (800,)\n",
      "RATIO: 0.005\n",
      "N FEATURES: 35\n",
      "(3275, 35) (825, 35)\n",
      "(3275,) (825,)\n",
      "(3275, 35) (825, 35)\n",
      "(3275,) (825,)\n",
      "(3275, 35) (825, 35)\n",
      "(3275,) (825,)\n",
      "(3275, 35) (825, 35)\n",
      "(3275,) (825,)\n",
      "(3300, 35) (800, 35)\n",
      "(3300,) (800,)\n",
      "RATIO: 0.001\n",
      "N FEATURES: 7\n",
      "(3275, 7) (825, 7)\n",
      "(3275,) (825,)\n",
      "(3275, 7) (825, 7)\n",
      "(3275,) (825,)\n",
      "(3275, 7) (825, 7)\n",
      "(3275,) (825,)\n",
      "(3275, 7) (825, 7)\n",
      "(3275,) (825,)\n",
      "(3300, 7) (800, 7)\n",
      "(3300,) (800,)\n",
      "RATIO: 0.0001\n",
      "N FEATURES: 1\n",
      "(3275, 1) (825, 1)\n",
      "(3275,) (825,)\n",
      "(3275, 1) (825, 1)\n",
      "(3275,) (825,)\n",
      "(3275, 1) (825, 1)\n",
      "(3275,) (825,)\n",
      "(3275, 1) (825, 1)\n",
      "(3275,) (825,)\n",
      "(3300, 1) (800, 1)\n",
      "(3300,) (800,)\n"
     ]
    }
   ],
   "source": [
    "roc_auc_dict = {}\n",
    "f1_dict = {}\n",
    "# ratio of features\n",
    "n_feat_ratio = np.array([1.0, 0.2, 0.1, 0.05, 0.01, 0.005, 0.001, 0.0001])\n",
    "for r in n_feat_ratio:\n",
    "    n_valid_feat = max(1, int(np.round(X.shape[-1] * r)))\n",
    "    print ('RATIO: {}'.format(r))\n",
    "    print ('N FEATURES: {}'.format(n_valid_feat))\n",
    "    roc_auc_scores = []\n",
    "    f1_scores = []\n",
    "    for i, (train_index, test_index) in enumerate(splits):\n",
    "        f_imp = clfs[i].feature_importances_\n",
    "        if r == 1.0:\n",
    "            idx = np.arange(X.shape[-1])\n",
    "        else:\n",
    "            idx = np.argsort(f_imp)[::-1][:n_valid_feat]\n",
    "\n",
    "        X_train, X_test = X[train_index][:, idx], X[test_index][:, idx]\n",
    "        y_train, y_test = y[train_index], y[test_index]\n",
    "\n",
    "        _, rocauc, f1 = train_classifier(X_train, y_train, X_test, y_test)\n",
    "        roc_auc_scores.append(rocauc)\n",
    "        f1_scores.append(f1)\n",
    "\n",
    "    roc_auc_dict[n_valid_feat] = [np.mean(roc_auc_scores), np.std(roc_auc_scores)]\n",
    "    f1_dict[n_valid_feat] = [np.mean(f1_scores), np.std(f1_scores)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5a75f61a-cd74-4dab-b7d7-6a679945b65d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{6912: [0.8292541176535716, 0.02690027179445281],\n",
       " 1382: [0.8276969887545664, 0.026586875700867485],\n",
       " 691: [0.8214412912631979, 0.025205299374000414],\n",
       " 346: [0.8214373630515558, 0.03468322021272182],\n",
       " 69: [0.7933733723160551, 0.03410984022561642],\n",
       " 35: [0.7559856988384809, 0.018145698895861716],\n",
       " 7: [0.7005054641875688, 0.05537690094546059],\n",
       " 1: [0.59802156545816, 0.03944814082614698]}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "roc_auc_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5f5fa740-4386-47c3-907c-6e3356265f5f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{6912: [0.5418316758169006, 0.06894192735186622],\n",
       " 1382: [0.5321938919451055, 0.0618384835715161],\n",
       " 691: [0.5647276696666125, 0.036563783696122766],\n",
       " 346: [0.5245616329511918, 0.06334539794211273],\n",
       " 69: [0.5035559872516128, 0.07697708578384646],\n",
       " 35: [0.4783804433596764, 0.031691606379855485],\n",
       " 7: [0.4112018370233871, 0.05147814422039116],\n",
       " 1: [0.3559568711108261, 0.0994750362814694]}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f1_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef626a5e-b463-4837-b2f2-5e65cf0e21a8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:base] *",
   "language": "python",
   "name": "conda-base-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
