{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "996ae376",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import imblearn\n",
    "import scipy\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import sklearn\n",
    "import pickle\n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "\n",
    "from copy import deepcopy\n",
    "from collections import Counter\n",
    "from imblearn.over_sampling import SMOTE\n",
    "from random import gauss\n",
    "from scipy.spatial import distance_matrix\n",
    "\n",
    "from scipy.optimize import minimize\n",
    "from numpy.random import rand\n",
    "from scipy.spatial import distance_matrix\n",
    "\n",
    "from dataset import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72f7ab17",
   "metadata": {},
   "source": [
    "## NeMe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "146e0be5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_actionable_feature_idxs(continuous_features, categorical_features):\n",
    "    feature_names = continuous_feature_names + categorical_feature_names\n",
    "    actionable_idxs = list() \n",
    "    for i, f in enumerate(feature_names):\n",
    "        if action_meta[f]['actionable']:\n",
    "            actionable_idxs.append( [i, action_meta[f]['can_increase'], action_meta[f]['can_decrease']] )\n",
    "    return actionable_idxs\n",
    "\n",
    "\n",
    "action_meta = actionability_constraints()\n",
    "\n",
    "df_train = pd.read_csv('data/df_train.csv')\n",
    "df_test = pd.read_csv('data/df_test.csv')\n",
    "\n",
    "X_train = np.load('data/X_train.npy', )\n",
    "X_test = np.load('data/X_test.npy', )\n",
    "y_train = np.load('data/y_train.npy', )\n",
    "y_test = np.load('data/y_test.npy', )\n",
    "\n",
    "# ## Normalization\n",
    "scaler = MinMaxScaler().fit(X_train)\n",
    "X_train = scaler.transform(X_train)\n",
    "X_test = scaler.transform(X_test)\n",
    "\n",
    "with open('data/enc.pkl', 'rb') as file:\n",
    "    enc = pickle.load(file)\n",
    "\n",
    "# ## Generate Training Column Label\n",
    "#### Logistic Regression\n",
    "with open('data/clf.pkl', 'rb') as file:\n",
    "    clf = pickle.load(file)\n",
    "\n",
    "test_preds = clf.predict(X_test)\n",
    "train_preds = clf.predict(X_train)\n",
    "\n",
    "test_probs = clf.predict_proba(X_test)\n",
    "train_probs = clf.predict_proba(X_train)\n",
    "\n",
    "df_test['preds'] = test_preds\n",
    "df_test['probs'] = test_probs.T[1]\n",
    "\n",
    "df_train['preds'] = train_preds\n",
    "df_train['probs'] = train_probs.T[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "aab9d214",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_cat_idxs():\n",
    "    \"\"\"\n",
    "    Get indexes for all categorical features that are one hot encoded\n",
    "    \"\"\"\n",
    "\n",
    "    cat_idxs = list()\n",
    "    start_idx = len(continuous_feature_names)\n",
    "    for cat in enc.categories_:\n",
    "        cat_idxs.append([start_idx, start_idx + cat.shape[0]])\n",
    "        start_idx = start_idx + cat.shape[0]\n",
    "    return cat_idxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "34feb361",
   "metadata": {},
   "outputs": [],
   "source": [
    "action_meta = actionability_constraints()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e2297a0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_idxs = generate_cat_idxs()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "976991ef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>KNeighborsClassifier()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">KNeighborsClassifier</label><div class=\"sk-toggleable__content\"><pre>KNeighborsClassifier()</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "KNeighborsClassifier()"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn = KNeighborsClassifier()\n",
    "knn.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5a6ca84c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def neme_bounds(x):\n",
    "    \n",
    "    bounds = list()\n",
    "\n",
    "    for i in range(len(continuous_feature_names)):\n",
    "\n",
    "        cat_name = continuous_feature_names[i]\n",
    "        value = x[i]\n",
    "\n",
    "        # If the continuous feature can take any value\n",
    "        if action_meta[ cat_name ]['can_increase'] and action_meta[ cat_name ]['can_decrease']:\n",
    "            f_range = (0,1)\n",
    "            bounds.append(f_range)\n",
    "\n",
    "        # If the continous feature can only go up\n",
    "        elif action_meta[ cat_name ]['can_increase'] and not action_meta[ cat_name ]['can_decrease']:\n",
    "            f_range = (value,1)\n",
    "            bounds.append(f_range)\n",
    "\n",
    "        # if the continuous features can only go down\n",
    "        elif not action_meta[ cat_name ]['can_increase'] and action_meta[ cat_name ]['can_decrease']:\n",
    "            f_range = (0, value)\n",
    "            bounds.append(f_range)\n",
    "\n",
    "        # If it's not actionable\n",
    "        else:\n",
    "            f_range = (value, value)\n",
    "            bounds.append(f_range)\n",
    "            \n",
    "    for i in range(len(cat_idxs)):\n",
    "                \n",
    "        if action_meta[categorical_feature_names[i]]['actionable'] == False:\n",
    "            for j in range(cat_idxs[i][1] - cat_idxs[i][0]):\n",
    "                bounds.append((x[cat_idxs[i][0]+j], x[cat_idxs[i][0]+j]))\n",
    "                \n",
    "        else:\n",
    "            for j in range(cat_idxs[i][1] - cat_idxs[i][0]):\n",
    "                bounds.append((0,1))\n",
    "                \n",
    "    return tuple(bounds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d9cc81ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "def round_neme_cats(x):\n",
    "    for i in range(len(categorical_feature_names)):\n",
    "        cat_values = x[cat_idxs[i][0]: cat_idxs[i][1]]\n",
    "        max_value_idx = np.argmax(cat_values)\n",
    "        cat_values *= 0.\n",
    "        cat_values[max_value_idx] = 1.\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "92cb7ea1",
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_values = np.array([0,1,0.95, 0.9]) > 0.9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "646ae0db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# df_test.iloc[test_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "50328f69",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_actionable_range(x):\n",
    "\n",
    "    dice_action = {}\n",
    "\n",
    "    for feature in action_meta.keys():\n",
    "\n",
    "        # Only add actionable features for DiCE's constraints\n",
    "        if action_meta[feature]['actionable']:\n",
    "\n",
    "            if feature in continuous_feature_names:\n",
    "                query_min_value = float(x[feature])\n",
    "                query_max_value = float(x[feature])\n",
    "                min_value = min([float(xxx) for xxx in pd.concat([x_train, x_test])[feature].values])\n",
    "                max_value = max([float(xxx) for xxx in pd.concat([x_train, x_test])[feature].values])\n",
    "            else:\n",
    "                query_min_value = int(x[feature][0])\n",
    "                query_max_value = int(x[feature][0])\n",
    "                min_value = min([int(xxx[0]) for xxx in pd.concat([x_train, x_test])[feature].values])\n",
    "                max_value = max([int(xxx[0]) for xxx in pd.concat([x_train, x_test])[feature].values])\n",
    "\n",
    "            # Is it up or down mutable?\n",
    "            if action_meta[feature]['can_increase']:\n",
    "                query_max_value = max_value\n",
    "\n",
    "            if action_meta[feature]['can_decrease']:\n",
    "                query_min_value = min_value\n",
    "\n",
    "            # If it is a continuous feature\n",
    "            if feature in numerical:\n",
    "                dice_action[feature] = [float(query_min_value), float(query_max_value)]\n",
    "            else:\n",
    "                dice_action[feature] = [str(x) + '-Cat' for x in list(range(query_min_value, query_max_value+1))]\n",
    "\n",
    "    return dice_action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "202a8f79",
   "metadata": {},
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "Input \u001b[0;32mIn [26]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mget_actionable_range\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n",
      "Input \u001b[0;32mIn [25]\u001b[0m, in \u001b[0;36mget_actionable_range\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m action_meta[feature][\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mactionable\u001b[39m\u001b[38;5;124m'\u001b[39m]:\n\u001b[1;32m     10\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m feature \u001b[38;5;129;01min\u001b[39;00m continuous_feature_names:\n\u001b[0;32m---> 11\u001b[0m         query_min_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mfloat\u001b[39m(\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[43mfeature\u001b[49m\u001b[43m]\u001b[49m)\n\u001b[1;32m     12\u001b[0m         query_max_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mfloat\u001b[39m(x[feature])\n\u001b[1;32m     13\u001b[0m         min_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmin\u001b[39m([\u001b[38;5;28mfloat\u001b[39m(xxx) \u001b[38;5;28;01mfor\u001b[39;00m xxx \u001b[38;5;129;01min\u001b[39;00m pd\u001b[38;5;241m.\u001b[39mconcat([x_train, x_test])[feature]\u001b[38;5;241m.\u001b[39mvalues])\n",
      "\u001b[0;31mIndexError\u001b[0m: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices"
     ]
    }
   ],
   "source": [
    "get_actionable_range(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c56ac6e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "34679323",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.choice(np.flatnonzero(cat_values == cat_values.max()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "08b31882",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_actionable_range(original_x, cat_name):\n",
    "\n",
    "    for i in range(len(cat_idxs)):\n",
    "\n",
    "        if action_meta[categorical_feature_names[i]]['actionable'] == False:\n",
    "            for j in range(cat_idxs[i][1] - cat_idxs[i][0]):\n",
    "                bounds.append((x[cat_idxs[i][0]+j], x[cat_idxs[i][0]+j]))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "fe4ad708",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.07142857, 0.19924692, 0.35185185, 1.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 1.        , 1.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 1.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       1.        , 0.        , 1.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 1.        , 0.        ,\n",
       "       1.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 1.        , 1.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 1.        , 0.        ,\n",
       "       1.        , 0.        , 0.        , 1.        , 0.        ,\n",
       "       0.        , 0.        , 1.        , 0.        , 0.        ,\n",
       "       0.        , 1.        , 1.        , 0.        , 1.        ,\n",
       "       0.        ])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "original_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "131fff60",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 0, 0, 0])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = np.array([0,1,2,3,4,5])\n",
    "\n",
    "x[3:] = 0\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "bc99bc91",
   "metadata": {},
   "outputs": [],
   "source": [
    "\tdef clip_cats_to_actionable(instance, original_x):\n",
    "\t\t\n",
    "\t\tfor i in range(len(categorical_feature_names)):\n",
    "\n",
    "\t\t\tcat_name = categorical_feature_names[i]\n",
    "\t\t\tcat_values = instance[cat_idxs[i][0]: cat_idxs[i][1]] > 0.95\n",
    "\t#         value_idx = np.argmax(cat_values)\n",
    "\t\t\tvalue_idx = np.random.choice(np.flatnonzero(cat_values == cat_values.max()))\n",
    "\t\t\t\n",
    "\t\t\t\n",
    "\t\t\torg_value_idx = np.argmax(original_x[cat_idxs[i][0]: cat_idxs[i][1]])\n",
    "\n",
    "\t\t\t# If actionable\n",
    "\t\t\tif action_meta[categorical_feature_names[i]]['actionable'] == True:\n",
    "\n",
    "\t\t\t\t# If the feature can take any value\n",
    "\t\t\t\tif action_meta[ cat_name ]['can_increase'] and action_meta[ cat_name ]['can_decrease']:\n",
    "\t\t\t\t\tinstance[cat_idxs[i][0]: cat_idxs[i][1]] = [0. for _ in range(len(cat_values))]\n",
    "\t\t\t\t\tinstance[cat_idxs[i][0]: cat_idxs[i][1]][value_idx] = 1.\n",
    "\n",
    "\t\t\t\t# If the feature can only go up\n",
    "\t\t\t\telif action_meta[ cat_name ]['can_increase'] and not action_meta[ cat_name ]['can_decrease']:\n",
    "\t\t\t\t\tif value_idx < org_value_idx:\n",
    "\t\t\t\t\t\tinstance[cat_idxs[i][0]: cat_idxs[i][1]] = original_x[cat_idxs[i][0]: cat_idxs[i][1]]\n",
    "\t\t\t\t\telse:\n",
    "\t\t\t\t\t\tinstance[cat_idxs[i][0]: cat_idxs[i][1]] = [0. for _ in range(len(cat_values))]\n",
    "\t\t\t\t\t\tinstance[cat_idxs[i][0]: cat_idxs[i][1]][value_idx] = 1.\n",
    "\n",
    "\t\t\t\t# if the feature can only go down\n",
    "\t\t\t\telif not action_meta[ cat_name ]['can_increase'] and action_meta[ cat_name ]['can_decrease']:\n",
    "\t\t\t\t\tif value_idx > org_value_idx:\n",
    "\t\t\t\t\t\tinstance[cat_idxs[i][0]: cat_idxs[i][1]] = original_x[cat_idxs[i][0]: cat_idxs[i][1]]\n",
    "\t\t\t\t\telse:\n",
    "\t\t\t\t\t\tinstance[cat_idxs[i][0]: cat_idxs[i][1]] = [0. for _ in range(len(cat_values))]\n",
    "\t\t\t\t\t\tinstance[cat_idxs[i][0]: cat_idxs[i][1]][value_idx] = 1.\n",
    "\t\t\t\t\t\t\n",
    "\t#             print(instance)\n",
    "\t\t\t\t\t\t\n",
    "\t\treturn instance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "e825ac3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# x = np.array([0,0,1,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "cf09be5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.random.choice(np.flatnonzero(x == x.max()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "1abb3175",
   "metadata": {},
   "outputs": [],
   "source": [
    "def objective(x):\n",
    "    \"\"\"\n",
    "    probability of semi-factual class\n",
    "    l2 distance matrix of m samples\n",
    "    minimize negative of both to maximize objective\n",
    "    \"\"\"\n",
    "    \n",
    "    sf_class = clf.predict(original_x.reshape(1,-1)).item()\n",
    "    is_sf_loss = clf.predict(x.reshape(1,-1)).item() == sf_class\n",
    "    \n",
    "    similarity_orig_loss = -1. * C_reg * np.linalg.norm(x - original_x, 2)\n",
    "    diversity_loss = (-C_diversity * sum([np.linalg.norm(x - np.array(sf), 2) for sf in CURRENT_SFS])) / m\n",
    "    loss = similarity_orig_loss + diversity_loss \n",
    "    return loss * is_sf_loss\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "id": "28748f90",
   "metadata": {},
   "outputs": [],
   "source": [
    "C_simple=.1 \n",
    "C_reg=1. \n",
    "C_diversity=5.\n",
    "C_feasibility=1.\n",
    "C_sf=1.\n",
    "sparsity_upper_bound=2.\n",
    "solver=\"Nelder-Mead\"\n",
    "max_iter=None\n",
    "non_zero_threshold_sparsity = 1e-5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "id": "9332b54d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(500, 71)"
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "id": "6d0ea4f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "ga_df = pd.read_csv('data/GA_Xps_diverse.csv')\n",
    "# test_idxs = np.sort(np.array(ga_df.test_idx.value_counts().index.tolist()))\n",
    "m = 10\n",
    "final_data = list()\n",
    "found_sfs = list()\n",
    "\n",
    "for test_idx in [3]:\n",
    "\n",
    "    # Compute diverse sfs\n",
    "    CURRENT_SFS = list()\n",
    "    original_x = deepcopy(X_test[test_idx])\n",
    "    bnds = neme_bounds(original_x) \n",
    "    \n",
    "    for i in range(m):\n",
    "        x = deepcopy(original_x)\n",
    "        result = minimize(objective, x, method='nelder-mead', bounds=bnds, options={'maxiter': 5000})\n",
    "        result = clip_cats_to_actionable(result['x'].reshape(1, -1)[0], original_x)\n",
    "        CURRENT_SFS.append(result.tolist())\n",
    "#         CURRENT_SFS.append(result['x'].reshape(1, -1)[0])\n",
    "\n",
    "    CURRENT_SFS = np.array(CURRENT_SFS)\n",
    "    \n",
    "    for i, pred in enumerate(clf.predict(CURRENT_SFS).tolist()):\n",
    "        if pred == 0:\n",
    "            found_sfs.append(0)\n",
    "            final_data.append(original_x.tolist())\n",
    "        else:\n",
    "            found_sfs.append(1)\n",
    "            final_data.append(CURRENT_SFS[i].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "id": "8382b1ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_df = pd.DataFrame(final_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "id": "5f176175",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_df['test_idx'] = [2]*10 #+ [2]*6 + [3]*6\n",
    "final_df['sf_found'] = found_sfs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 168,
   "id": "0c878b70",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_df.to_csv('data/neme_diverse.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "id": "df59ee71",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>...</th>\n",
       "      <th>63</th>\n",
       "      <th>64</th>\n",
       "      <th>65</th>\n",
       "      <th>66</th>\n",
       "      <th>67</th>\n",
       "      <th>68</th>\n",
       "      <th>69</th>\n",
       "      <th>70</th>\n",
       "      <th>test_idx</th>\n",
       "      <th>sf_found</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.071429</td>\n",
       "      <td>0.199247</td>\n",
       "      <td>0.351852</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.072159</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.351852</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.562994</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.351852</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.522915</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.351852</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.071429</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.351852</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.071429</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.351852</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.601480</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.351852</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0.077023</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.351852</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.199247</td>\n",
       "      <td>0.351852</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0.071429</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.351852</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>10 rows × 73 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          0         1         2    3    4    5    6    7    8    9  ...   63  \\\n",
       "0  0.071429  0.199247  0.351852  1.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0   \n",
       "1  0.072159  1.000000  0.351852  1.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0   \n",
       "2  0.562994  1.000000  0.351852  1.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0   \n",
       "3  0.522915  1.000000  0.351852  1.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0   \n",
       "4  0.071429  1.000000  0.351852  1.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0   \n",
       "5  0.071429  1.000000  0.351852  1.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0   \n",
       "6  0.601480  1.000000  0.351852  1.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0   \n",
       "7  0.077023  1.000000  0.351852  1.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0   \n",
       "8  1.000000  0.199247  0.351852  1.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0   \n",
       "9  0.071429  1.000000  0.351852  1.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0   \n",
       "\n",
       "    64   65   66   67   68   69   70  test_idx  sf_found  \n",
       "0  0.0  0.0  1.0  1.0  0.0  1.0  0.0         2         0  \n",
       "1  0.0  0.0  1.0  1.0  0.0  1.0  0.0         2         1  \n",
       "2  0.0  0.0  1.0  1.0  0.0  1.0  0.0         2         1  \n",
       "3  0.0  0.0  1.0  1.0  0.0  1.0  0.0         2         1  \n",
       "4  0.0  0.0  1.0  1.0  0.0  1.0  0.0         2         1  \n",
       "5  0.0  0.0  1.0  1.0  0.0  1.0  0.0         2         1  \n",
       "6  0.0  0.0  1.0  1.0  0.0  1.0  0.0         2         1  \n",
       "7  0.0  0.0  1.0  1.0  0.0  1.0  0.0         2         1  \n",
       "8  0.0  0.0  1.0  1.0  0.0  1.0  0.0         2         1  \n",
       "9  0.0  0.0  1.0  1.0  0.0  1.0  0.0         2         1  \n",
       "\n",
       "[10 rows x 73 columns]"
      ]
     },
     "execution_count": 169,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 170,
   "id": "5282dea0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.0,\n",
       " 0.8014832604836231,\n",
       " 1.2923188870835571,\n",
       " 1.2522391394224863,\n",
       " 0.8007530911232972,\n",
       " 0.8007530793286106,\n",
       " 1.3308040422987149,\n",
       " 0.8063475711925506,\n",
       " 0.9285714285714286,\n",
       " 0.8007530793286106]"
      ]
     },
     "execution_count": 170,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[sum(final_df.values[:, :-2][i] - original_x) for i in range(m)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 171,
   "id": "e6849c9c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((0.07142857142857142, 1),\n",
       " (0.19924692067138938, 1),\n",
       " (0.35185185185185186, 0.35185185185185186),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (1.0, 1.0),\n",
       " (0.0, 0.0),\n",
       " (0.0, 0.0),\n",
       " (0.0, 0.0),\n",
       " (0.0, 0.0),\n",
       " (0.0, 0.0),\n",
       " (0.0, 0.0),\n",
       " (0.0, 0.0),\n",
       " (0.0, 0.0),\n",
       " (0.0, 0.0),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0.0, 0.0),\n",
       " (0.0, 0.0),\n",
       " (1.0, 1.0),\n",
       " (0.0, 0.0),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (0, 1),\n",
       " (1.0, 1.0),\n",
       " (0.0, 0.0),\n",
       " (1.0, 1.0),\n",
       " (0.0, 0.0))"
      ]
     },
     "execution_count": 171,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bnds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd8b3fa0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43841d08",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b9eeee5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75917c8c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d118344",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c00c53f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "semifactual",
   "language": "python",
   "name": "semifactual"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
