{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "996ae376",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import imblearn\n",
    "import scipy\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import sklearn\n",
    "import pickle\n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "\n",
    "from copy import deepcopy\n",
    "from collections import Counter\n",
    "from imblearn.over_sampling import SMOTE\n",
    "from random import gauss\n",
    "from scipy.spatial import distance_matrix\n",
    "\n",
    "from scipy.optimize import minimize\n",
    "from numpy.random import rand\n",
    "from scipy.spatial import distance_matrix\n",
    "\n",
    "from dataset import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72f7ab17",
   "metadata": {},
   "source": [
    "## NeMe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "aab9d214",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_cat_idxs():\n",
    "    \"\"\"\n",
    "    Get indexes for all categorical features that are one hot encoded\n",
    "    \"\"\"\n",
    "\n",
    "    cat_idxs = list()\n",
    "    start_idx = len(continuous_feature_names)\n",
    "    for cat in enc.categories_:\n",
    "        cat_idxs.append([start_idx, start_idx + cat.shape[0]])\n",
    "        start_idx = start_idx + cat.shape[0]\n",
    "    return cat_idxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5a6ca84c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def neme_bounds(x):\n",
    "    \n",
    "    bounds = list()\n",
    "\n",
    "    for i in range(len(continuous_feature_names)):\n",
    "\n",
    "        cat_name = continuous_feature_names[i]\n",
    "        value = x[i]\n",
    "\n",
    "        # If the continuous feature can take any value\n",
    "        if action_meta[ cat_name ]['can_increase'] and action_meta[ cat_name ]['can_decrease']:\n",
    "            f_range = (0,1)\n",
    "            bounds.append(f_range)\n",
    "\n",
    "        # If the continous feature can only go up\n",
    "        elif action_meta[ cat_name ]['can_increase'] and not action_meta[ cat_name ]['can_decrease']:\n",
    "            f_range = (value,1)\n",
    "            bounds.append(f_range)\n",
    "\n",
    "        # if the continuous features can only go down\n",
    "        elif not action_meta[ cat_name ]['can_increase'] and action_meta[ cat_name ]['can_decrease']:\n",
    "            f_range = (0, value)\n",
    "            bounds.append(f_range)\n",
    "\n",
    "        # If it's not actionable\n",
    "        else:\n",
    "            f_range = (value, value)\n",
    "            bounds.append(f_range)\n",
    "            \n",
    "    for i in range(len(cat_idxs)):\n",
    "                \n",
    "        if action_meta[categorical_feature_names[i]]['actionable'] == False:\n",
    "            for j in range(cat_idxs[i][1] - cat_idxs[i][0]):\n",
    "                bounds.append((x[cat_idxs[i][0]+j], x[cat_idxs[i][0]+j]))\n",
    "                \n",
    "        else:\n",
    "            for j in range(cat_idxs[i][1] - cat_idxs[i][0]):\n",
    "                bounds.append((0,1))\n",
    "                \n",
    "    return tuple(bounds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d9cc81ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "def round_neme_cats(x):\n",
    "    for i in range(len(categorical_feature_names)):\n",
    "        cat_values = x[cat_idxs[i][0]: cat_idxs[i][1]]\n",
    "        max_value_idx = np.argmax(cat_values)\n",
    "        cat_values *= 0.\n",
    "        cat_values[max_value_idx] = 1.\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bc99bc91",
   "metadata": {},
   "outputs": [],
   "source": [
    "def clip_cats_to_actionable(result, original_x):\n",
    "\n",
    "    for instance in result:\n",
    "        \n",
    "        for i in range(len(categorical_feature_names)):\n",
    "            \n",
    "            cat_name = categorical_feature_names[i]\n",
    "            cat_values = instance[cat_idxs[i][0]: cat_idxs[i][1]]\n",
    "            value_idx = np.argmax(cat_values)\n",
    "            org_value_idx = np.argmax(original_x[cat_idxs[i][0]: cat_idxs[i][1]])\n",
    "            \n",
    "            # If actionable\n",
    "            if action_meta[categorical_feature_names[i]]['actionable'] == True:\n",
    "                \n",
    "                # If the continuous feature can take any value\n",
    "                if action_meta[ cat_name ]['can_increase'] and action_meta[ cat_name ]['can_decrease']:\n",
    "                    continue\n",
    "\n",
    "                # If the continous feature can only go up\n",
    "                elif action_meta[ cat_name ]['can_increase'] and not action_meta[ cat_name ]['can_decrease']:\n",
    "                    if value_idx < org_value_idx:\n",
    "                        instance[cat_idxs[i][0]: cat_idxs[i][1]] = original_x[cat_idxs[i][0]: cat_idxs[i][1]]\n",
    "\n",
    "                # if the continuous features can only go down\n",
    "                elif not action_meta[ cat_name ]['can_increase'] and action_meta[ cat_name ]['can_decrease']:\n",
    "                    if value_idx > org_value_idx:\n",
    "                        instance[cat_idxs[i][0]: cat_idxs[i][1]] = original_x[cat_idxs[i][0]: cat_idxs[i][1]]\n",
    "            \n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1abb3175",
   "metadata": {},
   "outputs": [],
   "source": [
    "def objective(x):\n",
    "    \"\"\"\n",
    "    probability of semi-factual class\n",
    "    l2 distance matrix of m samples\n",
    "    minimize negative of both to maximize objective\n",
    "    \"\"\"\n",
    "            \n",
    "    max_prob_of_sf = -clf.predict_proba(x.reshape(m,-1)).T[1].sum()\n",
    "    max_div = -distance_matrix(x.reshape(m,-1), x.reshape(m,-1)).sum()\n",
    "    dists, _ = knn.kneighbors(X=x.reshape(m, -1), n_neighbors=1, return_distance=True)\n",
    "    dists = dists.sum()\n",
    "    \n",
    "    return max_prob_of_sf + max_div + dists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "146e0be5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/eoinkenny/opt/anaconda3/envs/semifactual/lib/python3.9/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator OneHotEncoder from version 1.1.1 when using version 1.1.3. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
      "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n",
      "  warnings.warn(\n",
      "/Users/eoinkenny/opt/anaconda3/envs/semifactual/lib/python3.9/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator MultinomialNB from version 1.1.1 when using version 1.1.3. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
      "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "def get_actionable_feature_idxs(continuous_features, categorical_features):\n",
    "    feature_names = continuous_feature_names + categorical_feature_names\n",
    "    actionable_idxs = list() \n",
    "    for i, f in enumerate(feature_names):\n",
    "        if action_meta[f]['actionable']:\n",
    "            actionable_idxs.append( [i, action_meta[f]['can_increase'], action_meta[f]['can_decrease']] )\n",
    "    return actionable_idxs\n",
    "\n",
    "\n",
    "action_meta = actionability_constraints()\n",
    "\n",
    "df_train = pd.read_csv('data/df_train.csv')\n",
    "df_test = pd.read_csv('data/df_test.csv')\n",
    "\n",
    "X_train = np.load('data/X_train.npy', )\n",
    "X_test = np.load('data/X_test.npy', )\n",
    "y_train = np.load('data/y_train.npy', )\n",
    "y_test = np.load('data/y_test.npy', )\n",
    "\n",
    "# ## Normalization\n",
    "scaler = MinMaxScaler().fit(X_train)\n",
    "X_train = scaler.transform(X_train)\n",
    "X_test = scaler.transform(X_test)\n",
    "\n",
    "with open('data/enc.pkl', 'rb') as file:\n",
    "    enc = pickle.load(file)\n",
    "\n",
    "# ## Generate Training Column Label\n",
    "#### Logistic Regression\n",
    "with open('data/clf.pkl', 'rb') as file:\n",
    "    clf = pickle.load(file)\n",
    "\n",
    "test_preds = clf.predict(X_test)\n",
    "train_preds = clf.predict(X_train)\n",
    "\n",
    "test_probs = clf.predict_proba(X_test)\n",
    "train_probs = clf.predict_proba(X_train)\n",
    "\n",
    "df_test['preds'] = test_preds\n",
    "df_test['probs'] = test_probs.T[1]\n",
    "\n",
    "df_train['preds'] = train_preds\n",
    "df_train['probs'] = train_probs.T[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "34feb361",
   "metadata": {},
   "outputs": [],
   "source": [
    "action_meta = actionability_constraints()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e2297a0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_idxs = generate_cat_idxs()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "976991ef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>KNeighborsClassifier()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">KNeighborsClassifier</label><div class=\"sk-toggleable__content\"><pre>KNeighborsClassifier()</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "KNeighborsClassifier()"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn = KNeighborsClassifier()\n",
    "knn.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fa8774d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fd014c80",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3, 71)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/jy/977g91sd4p56btqf1lyy890w0000gn/T/ipykernel_7219/771131067.py:18: DeprecationWarning: Use of `minimize` with `x0.ndim != 1` is deprecated. Currently, singleton dimensions will be removed from `x0`, but an error will be raised in SciPy 1.11.0.\n",
      "  result = minimize(objective, x, method='nelder-mead', bounds=bnds)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3, 71)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/jy/977g91sd4p56btqf1lyy890w0000gn/T/ipykernel_7219/771131067.py:18: DeprecationWarning: Use of `minimize` with `x0.ndim != 1` is deprecated. Currently, singleton dimensions will be removed from `x0`, but an error will be raised in SciPy 1.11.0.\n",
      "  result = minimize(objective, x, method='nelder-mead', bounds=bnds)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3, 71)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/jy/977g91sd4p56btqf1lyy890w0000gn/T/ipykernel_7219/771131067.py:18: DeprecationWarning: Use of `minimize` with `x0.ndim != 1` is deprecated. Currently, singleton dimensions will be removed from `x0`, but an error will be raised in SciPy 1.11.0.\n",
      "  result = minimize(objective, x, method='nelder-mead', bounds=bnds)\n"
     ]
    }
   ],
   "source": [
    "ga_df = pd.read_csv('data/GA_Xps_diverse.csv')\n",
    "test_idxs = np.sort(np.array(ga_df.test_idx.value_counts().index.tolist()))\n",
    "m = 3\n",
    "final_data = list()\n",
    "found_sfs = list()\n",
    "\n",
    "for test_idx in test_idxs[:3]:\n",
    "\n",
    "    original_x = deepcopy(X_test[test_idx])\n",
    "    x = deepcopy(X_test[test_idx])\n",
    "\n",
    "    bnds = neme_bounds(x)\n",
    "    bnds = np.array(list(bnds) * m).reshape(-1, 2).tolist()\n",
    "    bnds = tuple([tuple(ele) for ele in bnds])\n",
    "\n",
    "    x = np.tile(x, m).reshape(m, -1)\n",
    "    print(x.shape)\n",
    "    result = minimize(objective, x, method='nelder-mead', bounds=bnds)\n",
    "\n",
    "    #### Clip categories to actionable\n",
    "    result = clip_cats_to_actionable(result['x'].reshape(m, -1), original_x)\n",
    "            \n",
    "    for i, pred in enumerate(clf.predict(result).tolist()):\n",
    "        if pred == 0:\n",
    "            found_sfs.append(1)\n",
    "            final_data.append(original_x.tolist())\n",
    "        else:\n",
    "            found_sfs.append(1)\n",
    "            final_data.append(result[i].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69c94c24",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d27a021f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8382b1ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# final_df = pd.DataFrame(final_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5f176175",
   "metadata": {},
   "outputs": [],
   "source": [
    "# final_df['test_idx'] = ga_df.test_idx\n",
    "# final_df['sf_found'] = found_sfs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa8e72ef",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5814f42f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "0c878b70",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_df.to_csv('data/neme_diverse.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "5282dea0",
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'DataFrame' object has no attribute 'test_idx'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Input \u001b[0;32mIn [21]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m final_df[\u001b[43mfinal_df\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtest_idx\u001b[49m\u001b[38;5;241m.\u001b[39misin(test_idxs)]\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/semifactual/lib/python3.9/site-packages/pandas/core/generic.py:5902\u001b[0m, in \u001b[0;36mNDFrame.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m   5895\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m   5896\u001b[0m     name \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_internal_names_set\n\u001b[1;32m   5897\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m name \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_metadata\n\u001b[1;32m   5898\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m name \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_accessors\n\u001b[1;32m   5899\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_info_axis\u001b[38;5;241m.\u001b[39m_can_hold_identifiers_and_holds_name(name)\n\u001b[1;32m   5900\u001b[0m ):\n\u001b[1;32m   5901\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m[name]\n\u001b[0;32m-> 5902\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mobject\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getattribute__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'DataFrame' object has no attribute 'test_idx'"
     ]
    }
   ],
   "source": [
    "final_df[final_df.test_idx.isin(test_idxs)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d973ffdc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "402747f0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d15683c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99090a01",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "semifactual",
   "language": "python",
   "name": "semifactual"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
