{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "from transformers import pipeline\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
    "import torch\n",
    "import shap\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.sparse\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.impute import SimpleImputer\n",
    "\n",
    "import copy\n",
    "import gc\n",
    "import itertools\n",
    "import logging\n",
    "import time\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy.sparse\n",
    "import sklearn\n",
    "from packaging import version\n",
    "from scipy.special import binom\n",
    "from sklearn.linear_model import Lasso, LassoLarsIC, lars_path\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from tqdm.auto import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def kmeans(X, k, round_values=True):\n",
    "    \"\"\" Summarize a dataset with k mean samples weighted by the number of data points they\n",
    "    each represent.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    X : numpy.array or pandas.DataFrame or any scipy.sparse matrix\n",
    "        Matrix of data samples to summarize (# samples x # features)\n",
    "\n",
    "    k : int\n",
    "        Number of means to use for approximation.\n",
    "\n",
    "    round_values : bool\n",
    "        For all i, round the ith dimension of each mean sample to match the nearest value\n",
    "        from X[:,i]. This ensures discrete features always get a valid value.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    DenseData object.\n",
    "    \"\"\"\n",
    "\n",
    "    group_names = [str(i) for i in range(X.shape[1])]\n",
    "    if str(type(X)).endswith(\"'pandas.core.frame.DataFrame'>\"):\n",
    "        group_names = X.columns\n",
    "        X = X.values\n",
    "\n",
    "    # in case there are any missing values in data impute them\n",
    "    imp = SimpleImputer(missing_values=np.nan, strategy='mean')\n",
    "    X = imp.fit_transform(X)\n",
    "\n",
    "    # Specify `n_init` for consistent behaviour between sklearn versions\n",
    "    kmeans = KMeans(n_clusters=k, random_state=0, n_init=10).fit(X)\n",
    "\n",
    "    if round_values:\n",
    "        for i in range(k):\n",
    "            for j in range(X.shape[1]):\n",
    "                xj = X[:,j].toarray().flatten() if scipy.sparse.issparse(X) else X[:, j] # sparse support courtesy of @PrimozGodec\n",
    "                ind = np.argmin(np.abs(xj - kmeans.cluster_centers_[i,j]))\n",
    "                kmeans.cluster_centers_[i,j] = X[ind,j]\n",
    "    return DenseData(kmeans.cluster_centers_, group_names, None, 1.0*np.bincount(kmeans.labels_))\n",
    "\n",
    "\n",
    "class Instance:\n",
    "    def __init__(self, x, group_display_values):\n",
    "        self.x = x\n",
    "        self.group_display_values = group_display_values\n",
    "\n",
    "\n",
    "def convert_to_instance(val):\n",
    "    if isinstance(val, Instance):\n",
    "        return val\n",
    "    else:\n",
    "        return Instance(val, None)\n",
    "\n",
    "\n",
    "class InstanceWithIndex(Instance):\n",
    "    def __init__(self, x, column_name, index_value, index_name, group_display_values):\n",
    "        Instance.__init__(self, x, group_display_values)\n",
    "        self.index_value = index_value\n",
    "        self.index_name = index_name\n",
    "        self.column_name = column_name\n",
    "\n",
    "    def convert_to_df(self):\n",
    "        index = pd.DataFrame(self.index_value, columns=[self.index_name])\n",
    "        data = pd.DataFrame(self.x, columns=self.column_name)\n",
    "        df = pd.concat([index, data], axis=1)\n",
    "        df = df.set_index(self.index_name)\n",
    "        return df\n",
    "\n",
    "\n",
    "def convert_to_instance_with_index(val, column_name, index_value, index_name):\n",
    "    return InstanceWithIndex(val, column_name, index_value, index_name, None)\n",
    "\n",
    "\n",
    "def match_instance_to_data(instance, data):\n",
    "    assert isinstance(instance, Instance), \"instance must be of type Instance!\"\n",
    "\n",
    "    if isinstance(data, DenseData):\n",
    "        if instance.group_display_values is None:\n",
    "            instance.group_display_values = [instance.x[0, group[0]] if len(group) == 1 else \"\" for group in data.groups]\n",
    "        assert len(instance.group_display_values) == len(data.groups)\n",
    "        instance.groups = data.groups\n",
    "\n",
    "\n",
    "class Model:\n",
    "    def __init__(self, f, out_names):\n",
    "        self.f = f\n",
    "        self.out_names = out_names\n",
    "\n",
    "\n",
    "def convert_to_model(val, keep_index=False):\n",
    "    \"\"\" Convert a model to a Model object.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    val : function or Model object\n",
    "        The model function or a Model object.\n",
    "\n",
    "    keep_index : bool\n",
    "        If True then the index values will be passed to the model function as the first argument.\n",
    "        When this is False the feature names will be removed from the model object to avoid unnecessary warnings.\n",
    "    \"\"\"\n",
    "    if isinstance(val, Model):\n",
    "        out = val\n",
    "    else:\n",
    "        out = Model(val, None)\n",
    "\n",
    "    # Fix for the sklearn warning\n",
    "    # 'X does not have valid feature names, but <model> was fitted with feature names'\n",
    "    if not keep_index: # when using keep index, a dataframe with expected features names is expected to be passed\n",
    "        f_self = getattr(out.f, \"__self__\", None)\n",
    "        if f_self and hasattr(f_self, \"feature_names_in_\"):\n",
    "            # Make a copy so that the feature names are not removed from the original model\n",
    "            out = copy.deepcopy(out)\n",
    "            out.f.__self__.feature_names_in_ = None\n",
    "\n",
    "    return out\n",
    "\n",
    "\n",
    "def match_model_to_data(model, data):\n",
    "    assert isinstance(model, Model), \"model must be of type Model!\"\n",
    "\n",
    "    try:\n",
    "        if isinstance(data, DenseDataWithIndex):\n",
    "            out_val = model.f(data.convert_to_df())\n",
    "        else:\n",
    "            out_val = model.f(data.data)\n",
    "    except Exception:\n",
    "        print(\"Provided model function fails when applied to the provided data set.\")\n",
    "        raise\n",
    "\n",
    "    if model.out_names is None:\n",
    "        if len(out_val.shape) == 1:\n",
    "            model.out_names = [\"output value\"]\n",
    "        else:\n",
    "            model.out_names = [\"output value \"+str(i) for i in range(out_val.shape[0])]\n",
    "\n",
    "    return out_val\n",
    "\n",
    "\n",
    "\n",
    "class Data:\n",
    "    def __init__(self):\n",
    "        pass\n",
    "\n",
    "\n",
    "class SparseData(Data):\n",
    "    def __init__(self, data, *args):\n",
    "        num_samples = data.shape[0]\n",
    "        self.weights = np.ones(num_samples)\n",
    "        self.weights /= np.sum(self.weights)\n",
    "        self.transposed = False\n",
    "        self.groups = None\n",
    "        self.group_names = None\n",
    "        self.groups_size = data.shape[1]\n",
    "        self.data = data\n",
    "\n",
    "\n",
    "class DenseData(Data):\n",
    "    def __init__(self, data, group_names, *args):\n",
    "        self.groups = args[0] if len(args) > 0 and args[0] is not None else [np.array([i]) for i in range(len(group_names))]\n",
    "\n",
    "        l = sum(len(g) for g in self.groups)\n",
    "        num_samples = data.shape[0]\n",
    "        t = False\n",
    "        if l != data.shape[1]:\n",
    "            t = True\n",
    "            num_samples = data.shape[1]\n",
    "\n",
    "        valid = (not t and l == data.shape[1]) or (t and l == data.shape[0])\n",
    "        assert valid, \"# of names must match data matrix!\"\n",
    "\n",
    "        self.weights = args[1] if len(args) > 1 else np.ones(num_samples)\n",
    "        self.weights /= np.sum(self.weights)\n",
    "        wl = len(self.weights)\n",
    "        valid = (not t and wl == data.shape[0]) or (t and wl == data.shape[1])\n",
    "        assert valid, \"# weights must match data matrix!\"\n",
    "\n",
    "        self.transposed = t\n",
    "        self.group_names = group_names\n",
    "        self.data = data\n",
    "        self.groups_size = len(self.groups)\n",
    "\n",
    "\n",
    "class DenseDataWithIndex(DenseData):\n",
    "    def __init__(self, data, group_names, index, index_name, *args):\n",
    "        DenseData.__init__(self, data, group_names, *args)\n",
    "        self.index_value = index\n",
    "        self.index_name = index_name\n",
    "\n",
    "    def convert_to_df(self):\n",
    "        data = pd.DataFrame(self.data, columns=self.group_names)\n",
    "        index = pd.DataFrame(self.index_value, columns=[self.index_name])\n",
    "        df = pd.concat([index, data], axis=1)\n",
    "        df = df.set_index(self.index_name)\n",
    "        return df\n",
    "\n",
    "\n",
    "def convert_to_data(val, keep_index=False):\n",
    "    if isinstance(val, Data):\n",
    "        return val\n",
    "    elif type(val) == np.ndarray:\n",
    "        return DenseData(val, [str(i) for i in range(val.shape[1])])\n",
    "    elif str(type(val)).endswith(\"'pandas.core.series.Series'>\"):\n",
    "        return DenseData(val.values.reshape((1,len(val))), list(val.index))\n",
    "    elif str(type(val)).endswith(\"'pandas.core.frame.DataFrame'>\"):\n",
    "        if keep_index:\n",
    "            return DenseDataWithIndex(val.values, list(val.columns), val.index.values, val.index.name)\n",
    "        else:\n",
    "            return DenseData(val.values, list(val.columns))\n",
    "    elif scipy.sparse.issparse(val):\n",
    "        if not scipy.sparse.isspmatrix_csr(val):\n",
    "            val = val.tocsr()\n",
    "        return SparseData(val)\n",
    "    else:\n",
    "        assert False, \"Unknown type passed as data object: \"+str(type(val))\n",
    "\n",
    "class Link:\n",
    "    def __init__(self):\n",
    "        pass\n",
    "\n",
    "\n",
    "class IdentityLink(Link):\n",
    "    def __str__(self):\n",
    "        return \"identity\"\n",
    "\n",
    "    @staticmethod\n",
    "    def f(x):\n",
    "        return x\n",
    "\n",
    "    @staticmethod\n",
    "    def finv(x):\n",
    "        return x\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "class LogitLink(Link):\n",
    "    def __str__(self):\n",
    "        return \"logit\"\n",
    "\n",
    "    @staticmethod\n",
    "    def f(x):\n",
    "        return np.log(x/(1-x))\n",
    "\n",
    "    @staticmethod\n",
    "    def finv(x):\n",
    "        return 1/(1+np.exp(-x))\n",
    "\n",
    "\n",
    "def convert_to_link(val):\n",
    "    if isinstance(val, Link):\n",
    "        return val\n",
    "    elif val == \"identity\":\n",
    "        return IdentityLink()\n",
    "    elif val == \"logit\":\n",
    "        return LogitLink()\n",
    "    else:\n",
    "        assert False, \"Passed link object must be a subclass of iml.Link\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "\n",
    "class Kernel:\n",
    "    \"\"\"Uses the Kernel SHAP method to explain the output of any function.\n",
    "\n",
    "    Kernel SHAP is a method that uses a special weighted linear regression\n",
    "    to compute the importance of each feature. The computed importance values\n",
    "    are Shapley values from game theory and also coefficients from a local linear\n",
    "    regression.\n",
    "\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    model : function or iml.Model\n",
    "        User supplied function that takes a matrix of samples (# samples x # features) and\n",
    "        computes the output of the model for those samples. The output can be a vector\n",
    "        (# samples) or a matrix (# samples x # model outputs).\n",
    "\n",
    "    data : numpy.array or pandas.DataFrame or shap.common.DenseData or any scipy.sparse matrix\n",
    "        The background dataset to use for integrating out features. To determine the impact\n",
    "        of a feature, that feature is set to \"missing\" and the change in the model output\n",
    "        is observed. Since most models aren't designed to handle arbitrary missing data at test\n",
    "        time, we simulate \"missing\" by replacing the feature with the values it takes in the\n",
    "        background dataset. So if the background dataset is a simple sample of all zeros, then\n",
    "        we would approximate a feature being missing by setting it to zero. For small problems\n",
    "        this background dataset can be the whole training set, but for larger problems consider\n",
    "        using a single reference value or using the kmeans function to summarize the dataset.\n",
    "        Note: for sparse case we accept any sparse matrix but convert to lil format for\n",
    "        performance.\n",
    "\n",
    "    feature_names : list\n",
    "        The names of the features in the background dataset. If the background dataset is\n",
    "        supplied as a pandas.DataFrame, then feature_names can be set to None (the default value)\n",
    "        and the feature names will be taken as the column names of the dataframe.\n",
    "\n",
    "    link : \"identity\" or \"logit\"\n",
    "        A generalized linear model link to connect the feature importance values to the model\n",
    "        output. Since the feature importance values, phi, sum up to the model output, it often makes\n",
    "        sense to connect them to the output with a link function where link(output) = sum(phi).\n",
    "        If the model output is a probability then the LogitLink link function makes the feature\n",
    "        importance values have log-odds units.\n",
    "\n",
    "    Examples\n",
    "    --------\n",
    "    See :ref:`Kernel Explainer Examples <kernel_explainer_examples>`\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, model, data, feature_names=None, link=IdentityLink(), **kwargs):\n",
    "\n",
    "        self.data_feature_names = list(data.columns)\n",
    "\n",
    "        # convert incoming inputs to standardized iml objects\n",
    "        self.link = LogitLink()\n",
    "        self.keep_index = kwargs.get(\"keep_index\", False)\n",
    "        self.keep_index_ordered = kwargs.get(\"keep_index_ordered\", False)\n",
    "        self.model = convert_to_model(model, keep_index=self.keep_index)\n",
    "        self.data = convert_to_data(data, keep_index=self.keep_index)\n",
    "        model_null = match_model_to_data(self.model, self.data)\n",
    "\n",
    "        # enforce our current input type limitations\n",
    "        assert isinstance(self.data, DenseData) or isinstance(self.data, SparseData), \\\n",
    "               \"Shap explainer only supports the DenseData and SparseData input currently.\"\n",
    "        assert not self.data.transposed, \"Shap explainer does not support transposed DenseData or SparseData currently.\"\n",
    "\n",
    "        # init our parameters\n",
    "        self.N = self.data.data.shape[0]\n",
    "        self.P = self.data.data.shape[1]\n",
    "        self.linkfv = np.vectorize(self.link.f)\n",
    "        self.nsamplesAdded = 0\n",
    "        self.nsamplesRun = 0\n",
    "\n",
    "        # find E_x[f(x)]\n",
    "        if isinstance(model_null, (pd.DataFrame, pd.Series)):\n",
    "            model_null = np.squeeze(model_null.values)\n",
    "        self.fnull = np.sum((model_null.T * self.data.weights).T, 0)\n",
    "        self.expected_value = self.linkfv(self.fnull)\n",
    "\n",
    "        # see if we have a vector output\n",
    "        self.vector_out = True\n",
    "        if len(self.fnull.shape) == 0:\n",
    "            self.vector_out = False\n",
    "            self.fnull = np.array([self.fnull])\n",
    "            self.D = 1\n",
    "            self.expected_value = float(self.expected_value)\n",
    "        else:\n",
    "            self.D = self.fnull.shape[0]\n",
    "\n",
    "    def shap_values(self, X, **kwargs):\n",
    "        \"\"\" Estimate the SHAP values for a set of samples.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        X : numpy.array or pandas.DataFrame or any scipy.sparse matrix\n",
    "            A matrix of samples (# samples x # features) on which to explain the model's output.\n",
    "\n",
    "        nsamples : \"auto\" or int\n",
    "            Number of times to re-evaluate the model when explaining each prediction. More samples\n",
    "            lead to lower variance estimates of the SHAP values. The \"auto\" setting uses\n",
    "            `nsamples = 2 * X.shape[1] + 2048`.\n",
    "\n",
    "        l1_reg : \"num_features(int)\", \"auto\" (default for now, but deprecated), \"aic\", \"bic\", or float\n",
    "            The l1 regularization to use for feature selection (the estimation procedure is based on\n",
    "            a debiased lasso). The auto option currently uses \"aic\" when less that 20% of the possible sample\n",
    "            space is enumerated, otherwise it uses no regularization. THE BEHAVIOR OF \"auto\" WILL CHANGE\n",
    "            in a future version to be based on num_features instead of AIC.\n",
    "            The \"aic\" and \"bic\" options use the AIC and BIC rules for regularization.\n",
    "            Using \"num_features(int)\" selects a fix number of top features. Passing a float directly sets the\n",
    "            \"alpha\" parameter of the sklearn.linear_model.Lasso model used for feature selection.\n",
    "\n",
    "        gc_collect : bool\n",
    "           Run garbage collection after each explanation round. Sometime needed for memory intensive explanations (default False).\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        array or list\n",
    "            For models with a single output this returns a matrix of SHAP values\n",
    "            (# samples x # features). Each row sums to the difference between the model output for that\n",
    "            sample and the expected value of the model output (which is stored as expected_value\n",
    "            attribute of the explainer). For models with vector outputs this returns a list\n",
    "            of such matrices, one for each output.\n",
    "        \"\"\"\n",
    "\n",
    "        # convert dataframes\n",
    "        if str(type(X)).endswith(\"pandas.core.series.Series'>\"):\n",
    "            X = X.values\n",
    "        elif str(type(X)).endswith(\"'pandas.core.frame.DataFrame'>\"):\n",
    "            if self.keep_index:\n",
    "                index_value = X.index.values\n",
    "                index_name = X.index.name\n",
    "                column_name = list(X.columns)\n",
    "            X = X.values\n",
    "\n",
    "        x_type = str(type(X))\n",
    "        arr_type = \"'numpy.ndarray'>\"\n",
    "        # if sparse, convert to lil for performance\n",
    "        if scipy.sparse.issparse(X) and not scipy.sparse.isspmatrix_lil(X):\n",
    "            X = X.tolil()\n",
    "        assert x_type.endswith(arr_type) or scipy.sparse.isspmatrix_lil(X), \"Unknown instance type: \" + x_type\n",
    "        assert len(X.shape) == 1 or len(X.shape) == 2, \"Instance must have 1 or 2 dimensions!\"\n",
    "\n",
    "        # single instance\n",
    "        if len(X.shape) == 1:\n",
    "            data = X.reshape((1, X.shape[0]))\n",
    "            if self.keep_index:\n",
    "                data = convert_to_instance_with_index(data, column_name, index_name, index_value)\n",
    "            explanation = self.explain(data, **kwargs)\n",
    "\n",
    "            # vector-output\n",
    "            s = explanation.shape\n",
    "            if len(s) == 2:\n",
    "                outs = [np.zeros(s[0]) for j in range(s[1])]\n",
    "                for j in range(s[1]):\n",
    "                    outs[j] = explanation[:, j]\n",
    "                return outs\n",
    "\n",
    "            # single-output\n",
    "            else:\n",
    "                out = np.zeros(s[0])\n",
    "                out[:] = explanation\n",
    "                return out\n",
    "\n",
    "        # explain the whole dataset\n",
    "        elif len(X.shape) == 2:\n",
    "            explanations = []\n",
    "            for i in tqdm(range(X.shape[0]), disable=kwargs.get(\"silent\", False)):\n",
    "                data = X[i:i + 1, :]\n",
    "                if self.keep_index:\n",
    "                    data = convert_to_instance_with_index(data, column_name, index_value[i:i + 1], index_name)\n",
    "                explanations.append(self.explain(data, **kwargs))\n",
    "                if kwargs.get(\"gc_collect\", False):\n",
    "                    gc.collect()\n",
    "\n",
    "            # vector-output\n",
    "            s = explanations[0].shape\n",
    "            if len(s) == 2:\n",
    "                outs = [np.zeros((X.shape[0], s[0])) for j in range(s[1])]\n",
    "                for i in range(X.shape[0]):\n",
    "                    for j in range(s[1]):\n",
    "                        outs[j][i] = explanations[i][:, j]\n",
    "                return outs\n",
    "\n",
    "            # single-output\n",
    "            else:\n",
    "                out = np.zeros((X.shape[0], s[0]))\n",
    "                for i in range(X.shape[0]):\n",
    "                    out[i] = explanations[i]\n",
    "                return out\n",
    "\n",
    "    def explain(self, incoming_instance, **kwargs):\n",
    "        # convert incoming input to a standardized iml object\n",
    "        instance = convert_to_instance(incoming_instance)\n",
    "        match_instance_to_data(instance, self.data)\n",
    "\n",
    "        # find the feature groups we will test. If a feature does not change from its\n",
    "        # current value then we know it doesn't impact the model\n",
    "        self.varyingInds = self.varying_groups(instance.x)\n",
    "        if self.data.groups is None:\n",
    "            self.varyingFeatureGroups = np.array([i for i in self.varyingInds])\n",
    "            self.M = self.varyingFeatureGroups.shape[0]\n",
    "        else:\n",
    "            self.varyingFeatureGroups = [self.data.groups[i] for i in self.varyingInds]\n",
    "            self.M = len(self.varyingFeatureGroups)\n",
    "            groups = self.data.groups\n",
    "            # convert to numpy array as it is much faster if not jagged array (all groups of same length)\n",
    "            if self.varyingFeatureGroups and all(len(groups[i]) == len(groups[0]) for i in self.varyingInds):\n",
    "                self.varyingFeatureGroups = np.array(self.varyingFeatureGroups)\n",
    "                # further performance optimization in case each group has a single value\n",
    "                if self.varyingFeatureGroups.shape[1] == 1:\n",
    "                    self.varyingFeatureGroups = self.varyingFeatureGroups.flatten()\n",
    "\n",
    "        # find f(x)\n",
    "        if self.keep_index:\n",
    "            model_out = self.model.f(instance.convert_to_df())\n",
    "        else:\n",
    "            model_out = self.model.f(instance.x)\n",
    "        if isinstance(model_out, (pd.DataFrame, pd.Series)):\n",
    "            model_out = model_out.values\n",
    "        self.fx = model_out[0]\n",
    "\n",
    "        if not self.vector_out:\n",
    "            self.fx = np.array([self.fx])\n",
    "\n",
    "        # if no features vary then no feature has an effect\n",
    "        if self.M == 0:\n",
    "            phi = np.zeros((self.data.groups_size, self.D))\n",
    "            phi_var = np.zeros((self.data.groups_size, self.D))\n",
    "\n",
    "        # if only one feature varies then it has all the effect\n",
    "        elif self.M == 1:\n",
    "            phi = np.zeros((self.data.groups_size, self.D))\n",
    "            phi_var = np.zeros((self.data.groups_size, self.D))\n",
    "            diff = self.link.f(self.fx) - self.link.f(self.fnull)\n",
    "            for d in range(self.D):\n",
    "                phi[self.varyingInds[0],d] = diff[d]\n",
    "\n",
    "        # if more than one feature varies then we have to do real work\n",
    "        else:\n",
    "            self.l1_reg = kwargs.get(\"l1_reg\", \"auto\")\n",
    "\n",
    "            # pick a reasonable number of samples if the user didn't specify how many they wanted\n",
    "            self.nsamples = kwargs.get(\"nsamples\", \"auto\")\n",
    "            if self.nsamples == \"auto\":\n",
    "                self.nsamples = 2 * self.M + 2**11\n",
    "\n",
    "            # if we have enough samples to enumerate all subsets then ignore the unneeded samples\n",
    "            self.max_samples = 2 ** 30\n",
    "            if self.M <= 30:\n",
    "                self.max_samples = 2 ** self.M - 2\n",
    "                if self.nsamples > self.max_samples:\n",
    "                    self.nsamples = self.max_samples\n",
    "\n",
    "            # reserve space for some of our computations\n",
    "            self.allocate()\n",
    "\n",
    "            # weight the different subset sizes\n",
    "            num_subset_sizes = int(np.ceil((self.M - 1) / 2.0))\n",
    "            num_paired_subset_sizes = int(np.floor((self.M - 1) / 2.0))\n",
    "            weight_vector = np.array([(self.M - 1.0) / (i * (self.M - i)) for i in range(1, num_subset_sizes + 1)])\n",
    "            weight_vector[:num_paired_subset_sizes] *= 2\n",
    "            weight_vector /= np.sum(weight_vector)\n",
    "\n",
    "            # fill out all the subset sizes we can completely enumerate\n",
    "            # given nsamples*remaining_weight_vector[subset_size]\n",
    "            num_full_subsets = 0\n",
    "            num_samples_left = self.nsamples\n",
    "            group_inds = np.arange(self.M, dtype='int64')\n",
    "            mask = np.zeros(self.M)\n",
    "            remaining_weight_vector = copy.copy(weight_vector)\n",
    "            for subset_size in range(1, num_subset_sizes + 1):\n",
    "\n",
    "                # determine how many subsets (and their complements) are of the current size\n",
    "                nsubsets = binom(self.M, subset_size)\n",
    "                if subset_size <= num_paired_subset_sizes:\n",
    "                    nsubsets *= 2\n",
    "\n",
    "                # see if we have enough samples to enumerate all subsets of this size\n",
    "                if num_samples_left * remaining_weight_vector[subset_size - 1] / nsubsets >= 1.0 - 1e-8:\n",
    "                    num_full_subsets += 1\n",
    "                    num_samples_left -= nsubsets\n",
    "\n",
    "                    # rescale what's left of the remaining weight vector to sum to 1\n",
    "                    if remaining_weight_vector[subset_size - 1] < 1.0:\n",
    "                        remaining_weight_vector /= (1 - remaining_weight_vector[subset_size - 1])\n",
    "\n",
    "                    # add all the samples of the current subset size\n",
    "                    w = weight_vector[subset_size - 1] / binom(self.M, subset_size)\n",
    "                    if subset_size <= num_paired_subset_sizes:\n",
    "                        w /= 2.0\n",
    "                    for inds in itertools.combinations(group_inds, subset_size):\n",
    "                        mask[:] = 0.0\n",
    "                        mask[np.array(inds, dtype='int64')] = 1.0\n",
    "                        self.addsample(instance.x, mask, w)\n",
    "                        if subset_size <= num_paired_subset_sizes:\n",
    "                            mask[:] = np.abs(mask - 1)\n",
    "                            self.addsample(instance.x, mask, w)\n",
    "                else:\n",
    "                    break\n",
    "\n",
    "            # add random samples from what is left of the subset space\n",
    "            nfixed_samples = self.nsamplesAdded\n",
    "            samples_left = self.nsamples - self.nsamplesAdded\n",
    "            if num_full_subsets != num_subset_sizes:\n",
    "                remaining_weight_vector = copy.copy(weight_vector)\n",
    "                remaining_weight_vector[:num_paired_subset_sizes] /= 2 # because we draw two samples each below\n",
    "                remaining_weight_vector = remaining_weight_vector[num_full_subsets:]\n",
    "                remaining_weight_vector /= np.sum(remaining_weight_vector)\n",
    "                ind_set = np.random.choice(len(remaining_weight_vector), 4 * samples_left, p=remaining_weight_vector)\n",
    "                ind_set_pos = 0\n",
    "                used_masks = {}\n",
    "                while samples_left > 0 and ind_set_pos < len(ind_set):\n",
    "                    mask.fill(0.0)\n",
    "                    ind = ind_set[ind_set_pos] # we call np.random.choice once to save time and then just read it here\n",
    "                    ind_set_pos += 1\n",
    "                    subset_size = ind + num_full_subsets + 1\n",
    "                    mask[np.random.permutation(self.M)[:subset_size]] = 1.0\n",
    "\n",
    "                    # only add the sample if we have not seen it before, otherwise just\n",
    "                    # increment a previous sample's weight\n",
    "                    mask_tuple = tuple(mask)\n",
    "                    new_sample = False\n",
    "                    if mask_tuple not in used_masks:\n",
    "                        new_sample = True\n",
    "                        used_masks[mask_tuple] = self.nsamplesAdded\n",
    "                        samples_left -= 1\n",
    "                        self.addsample(instance.x, mask, 1.0)\n",
    "                    else:\n",
    "                        self.kernelWeights[used_masks[mask_tuple]] += 1.0\n",
    "\n",
    "                    # add the compliment sample\n",
    "                    if samples_left > 0 and subset_size <= num_paired_subset_sizes:\n",
    "                        mask[:] = np.abs(mask - 1)\n",
    "\n",
    "                        # only add the sample if we have not seen it before, otherwise just\n",
    "                        # increment a previous sample's weight\n",
    "                        if new_sample:\n",
    "                            samples_left -= 1\n",
    "                            self.addsample(instance.x, mask, 1.0)\n",
    "                        else:\n",
    "                            # we know the compliment sample is the next one after the original sample, so + 1\n",
    "                            self.kernelWeights[used_masks[mask_tuple] + 1] += 1.0\n",
    "\n",
    "                # normalize the kernel weights for the random samples to equal the weight left after\n",
    "                # the fixed enumerated samples have been already counted\n",
    "                weight_left = np.sum(weight_vector[num_full_subsets:])\n",
    "\n",
    "                self.kernelWeights[nfixed_samples:] *= weight_left / self.kernelWeights[nfixed_samples:].sum()\n",
    "\n",
    "            # execute the model on the synthetic samples we have created\n",
    "            self.run()\n",
    "\n",
    "            # solve then expand the feature importance (Shapley value) vector to contain the non-varying features\n",
    "            phi = np.zeros((self.data.groups_size, self.D))\n",
    "            phi_var = np.zeros((self.data.groups_size, self.D))\n",
    "            for d in range(self.D):\n",
    "                vphi, vphi_var = self.solve(self.nsamples / self.max_samples, d)\n",
    "                phi[self.varyingInds, d] = vphi\n",
    "                phi_var[self.varyingInds, d] = vphi_var\n",
    "\n",
    "        if not self.vector_out:\n",
    "            phi = np.squeeze(phi, axis=1)\n",
    "            phi_var = np.squeeze(phi_var, axis=1)\n",
    "\n",
    "        return phi\n",
    "\n",
    "    @staticmethod\n",
    "    def not_equal(i, j):\n",
    "        number_types = (int, float, np.number)\n",
    "        if isinstance(i, number_types) and isinstance(j, number_types):\n",
    "            return 0 if np.isclose(i, j, equal_nan=True) else 1\n",
    "        else:\n",
    "            return 0 if i == j else 1\n",
    "\n",
    "    def varying_groups(self, x):\n",
    "        if not scipy.sparse.issparse(x):\n",
    "            varying = np.zeros(self.data.groups_size)\n",
    "            for i in range(0, self.data.groups_size):\n",
    "                inds = self.data.groups[i]\n",
    "                x_group = x[0, inds]\n",
    "                if scipy.sparse.issparse(x_group):\n",
    "                    if all(j not in x.nonzero()[1] for j in inds):\n",
    "                        varying[i] = False\n",
    "                        continue\n",
    "                    x_group = x_group.todense()\n",
    "                num_mismatches = np.sum(np.frompyfunc(self.not_equal, 2, 1)(x_group, self.data.data[:, inds]))\n",
    "                varying[i] = num_mismatches > 0\n",
    "            varying_indices = np.nonzero(varying)[0]\n",
    "            return varying_indices\n",
    "        else:\n",
    "            varying_indices = []\n",
    "            # go over all nonzero columns in background and evaluation data\n",
    "            # if both background and evaluation are zero, the column does not vary\n",
    "            varying_indices = np.unique(np.union1d(self.data.data.nonzero()[1], x.nonzero()[1]))\n",
    "            remove_unvarying_indices = []\n",
    "            for i in range(0, len(varying_indices)):\n",
    "                varying_index = varying_indices[i]\n",
    "                # now verify the nonzero values do vary\n",
    "                data_rows = self.data.data[:, [varying_index]]\n",
    "                nonzero_rows = data_rows.nonzero()[0]\n",
    "\n",
    "                if nonzero_rows.size > 0:\n",
    "                    background_data_rows = data_rows[nonzero_rows]\n",
    "                    if scipy.sparse.issparse(background_data_rows):\n",
    "                        background_data_rows = background_data_rows.toarray()\n",
    "                    num_mismatches = np.sum(np.abs(background_data_rows - x[0, varying_index]) > 1e-7)\n",
    "                    # Note: If feature column non-zero but some background zero, can't remove index\n",
    "                    if num_mismatches == 0 and not \\\n",
    "                        (np.abs(x[0, [varying_index]][0, 0]) > 1e-7 and len(nonzero_rows) < data_rows.shape[0]):\n",
    "                        remove_unvarying_indices.append(i)\n",
    "            mask = np.ones(len(varying_indices), dtype=bool)\n",
    "            mask[remove_unvarying_indices] = False\n",
    "            varying_indices = varying_indices[mask]\n",
    "            return varying_indices\n",
    "\n",
    "    def allocate(self):\n",
    "        if scipy.sparse.issparse(self.data.data):\n",
    "            # We tile the sparse matrix in csr format but convert it to lil\n",
    "            # for performance when adding samples\n",
    "            shape = self.data.data.shape\n",
    "            nnz = self.data.data.nnz\n",
    "            data_rows, data_cols = shape\n",
    "            rows = data_rows * self.nsamples\n",
    "            shape = rows, data_cols\n",
    "            if nnz == 0:\n",
    "                self.synth_data = scipy.sparse.csr_matrix(shape, dtype=self.data.data.dtype).tolil()\n",
    "            else:\n",
    "                data = self.data.data.data\n",
    "                indices = self.data.data.indices\n",
    "                indptr = self.data.data.indptr\n",
    "                last_indptr_idx = indptr[len(indptr) - 1]\n",
    "                indptr_wo_last = indptr[:-1]\n",
    "                new_indptrs = []\n",
    "                for i in range(0, self.nsamples - 1):\n",
    "                    new_indptrs.append(indptr_wo_last + (i * last_indptr_idx))\n",
    "                new_indptrs.append(indptr + ((self.nsamples - 1) * last_indptr_idx))\n",
    "                new_indptr = np.concatenate(new_indptrs)\n",
    "                new_data = np.tile(data, self.nsamples)\n",
    "                new_indices = np.tile(indices, self.nsamples)\n",
    "                self.synth_data = scipy.sparse.csr_matrix((new_data, new_indices, new_indptr), shape=shape).tolil()\n",
    "        else:\n",
    "            self.synth_data = np.tile(self.data.data, (self.nsamples, 1))\n",
    "\n",
    "        self.maskMatrix = np.zeros((self.nsamples, self.M))\n",
    "        self.kernelWeights = np.zeros(self.nsamples)\n",
    "        self.y = np.zeros((self.nsamples * self.N, self.D))\n",
    "        self.ey = np.zeros((self.nsamples, self.D))\n",
    "        self.lastMask = np.zeros(self.nsamples)\n",
    "        self.nsamplesAdded = 0\n",
    "        self.nsamplesRun = 0\n",
    "        if self.keep_index:\n",
    "            self.synth_data_index = np.tile(self.data.index_value, self.nsamples)\n",
    "\n",
    "    def addsample(self, x, m, w):\n",
    "        offset = self.nsamplesAdded * self.N\n",
    "        if isinstance(self.varyingFeatureGroups, (list,)):\n",
    "            for j in range(self.M):\n",
    "                for k in self.varyingFeatureGroups[j]:\n",
    "                    if m[j] == 1.0:\n",
    "                        self.synth_data[offset:offset+self.N, k] = x[0, k]\n",
    "        else:\n",
    "            # for non-jagged numpy array we can significantly boost performance\n",
    "            mask = m == 1.0\n",
    "            groups = self.varyingFeatureGroups[mask]\n",
    "            if len(groups.shape) == 2:\n",
    "                for group in groups:\n",
    "                    self.synth_data[offset:offset+self.N, group] = x[0, group]\n",
    "            else:\n",
    "                # further performance optimization in case each group has a single feature\n",
    "                evaluation_data = x[0, groups]\n",
    "                # In edge case where background is all dense but evaluation data\n",
    "                # is all sparse, make evaluation data dense\n",
    "                if scipy.sparse.issparse(x) and not scipy.sparse.issparse(self.synth_data):\n",
    "                    evaluation_data = evaluation_data.toarray()\n",
    "                self.synth_data[offset:offset+self.N, groups] = evaluation_data\n",
    "        self.maskMatrix[self.nsamplesAdded, :] = m\n",
    "        self.kernelWeights[self.nsamplesAdded] = w\n",
    "        self.nsamplesAdded += 1\n",
    "\n",
    "    def run(self):\n",
    "        num_to_run = self.nsamplesAdded * self.N - self.nsamplesRun * self.N\n",
    "        data = self.synth_data[self.nsamplesRun*self.N:self.nsamplesAdded*self.N,:]\n",
    "        if self.keep_index:\n",
    "            index = self.synth_data_index[self.nsamplesRun*self.N:self.nsamplesAdded*self.N]\n",
    "            index = pd.DataFrame(index, columns=[self.data.index_name])\n",
    "            data = pd.DataFrame(data, columns=self.data.group_names)\n",
    "            data = pd.concat([index, data], axis=1).set_index(self.data.index_name)\n",
    "            if self.keep_index_ordered:\n",
    "                data = data.sort_index()\n",
    "        modelOut = self.model.f(data)\n",
    "        if isinstance(modelOut, (pd.DataFrame, pd.Series)):\n",
    "            modelOut = modelOut.values\n",
    "        self.y[self.nsamplesRun * self.N:self.nsamplesAdded * self.N, :] = np.reshape(modelOut, (num_to_run, self.D))\n",
    "\n",
    "        # find the expected value of each output\n",
    "        for i in range(self.nsamplesRun, self.nsamplesAdded):\n",
    "            eyVal = np.zeros(self.D)\n",
    "            for j in range(0, self.N):\n",
    "                eyVal += self.y[i * self.N + j, :] * self.data.weights[j]\n",
    "\n",
    "            self.ey[i, :] = eyVal\n",
    "            self.nsamplesRun += 1\n",
    "\n",
    "    def solve(self, fraction_evaluated, dim):\n",
    "        eyAdj = self.linkfv(self.ey[:, dim]) - self.link.f(self.fnull[dim])\n",
    "        s = np.sum(self.maskMatrix, 1)\n",
    "\n",
    "        # do feature selection if we have not well enumerated the space\n",
    "        nonzero_inds = np.arange(self.M)\n",
    "        \n",
    "        # if self.l1_reg == \"auto\":\n",
    "        #     warnings.warn(\n",
    "        #         \"l1_reg=\\\"auto\\\" is deprecated and in the next version (v0.29) the behavior will change from a \" \\\n",
    "        #         \"conditional use of AIC to simply \\\"num_features(10)\\\"!\"\n",
    "        #     )\n",
    "        if (self.l1_reg not in [\"auto\", False, 0]) or (fraction_evaluated < 0.2 and self.l1_reg == \"auto\"):\n",
    "            w_aug = np.hstack((self.kernelWeights * (self.M - s), self.kernelWeights * s))\n",
    "            w_sqrt_aug = np.sqrt(w_aug)\n",
    "            eyAdj_aug = np.hstack((eyAdj, eyAdj - (self.link.f(self.fx[dim]) - self.link.f(self.fnull[dim]))))\n",
    "            eyAdj_aug *= w_sqrt_aug\n",
    "            mask_aug = np.transpose(w_sqrt_aug * np.transpose(np.vstack((self.maskMatrix, self.maskMatrix - 1))))\n",
    "            #var_norms = np.array([np.linalg.norm(mask_aug[:, i]) for i in range(mask_aug.shape[1])])\n",
    "\n",
    "            # select a fixed number of top features\n",
    "            if isinstance(self.l1_reg, str) and self.l1_reg.startswith(\"num_features(\"):\n",
    "                r = int(self.l1_reg[len(\"num_features(\"):-1])\n",
    "                nonzero_inds = lars_path(mask_aug, eyAdj_aug, max_iter=r)[1]\n",
    "\n",
    "            # use an adaptive regularization method\n",
    "            elif self.l1_reg == \"auto\" or self.l1_reg == \"bic\" or self.l1_reg == \"aic\":\n",
    "                c = \"aic\" if self.l1_reg == \"auto\" else self.l1_reg\n",
    "\n",
    "                # \"Normalize\" parameter of LassoLarsIC was deprecated in sklearn version 1.2\n",
    "                if version.parse(sklearn.__version__) < version.parse(\"1.2.0\"):\n",
    "                    kwg = dict(normalize=False)\n",
    "                else:\n",
    "                    kwg = {}\n",
    "                model = make_pipeline(StandardScaler(with_mean=False), LassoLarsIC(criterion=c, **kwg))\n",
    "                nonzero_inds = np.nonzero(model.fit(mask_aug, eyAdj_aug)[1].coef_)[0]\n",
    "\n",
    "            # use a fixed regularization coefficient\n",
    "            else:\n",
    "                nonzero_inds = np.nonzero(Lasso(alpha=self.l1_reg).fit(mask_aug, eyAdj_aug).coef_)[0]\n",
    "\n",
    "        if len(nonzero_inds) == 0:\n",
    "            return np.zeros(self.M), np.ones(self.M)\n",
    "\n",
    "        # eliminate one variable with the constraint that all features sum to the output\n",
    "        eyAdj2 = eyAdj - self.maskMatrix[:, nonzero_inds[-1]] * (\n",
    "                    self.link.f(self.fx[dim]) - self.link.f(self.fnull[dim]))\n",
    "        etmp = np.transpose(np.transpose(self.maskMatrix[:, nonzero_inds[:-1]]) - self.maskMatrix[:, nonzero_inds[-1]])\n",
    "\n",
    "        # solve a weighted least squares equation to estimate phi\n",
    "        tmp = np.transpose(np.transpose(etmp) * np.transpose(self.kernelWeights))\n",
    "        etmp_dot = np.dot(np.transpose(tmp), etmp)\n",
    "        try:\n",
    "            tmp2 = np.linalg.inv(etmp_dot)\n",
    "        except np.linalg.LinAlgError:\n",
    "            tmp2 = np.linalg.pinv(etmp_dot)\n",
    "            warnings.warn(\n",
    "                \"Linear regression equation is singular, Moore-Penrose pseudoinverse is used instead of the regular inverse.\\n\"\n",
    "                \"To use regular inverse do one of the following:\\n\"\n",
    "                \"1) turn up the number of samples,\\n\"\n",
    "                \"2) turn up the L1 regularization with num_features(N) where N is less than the number of samples,\\n\"\n",
    "                \"3) group features together to reduce the number of inputs that need to be explained.\"\n",
    "            )\n",
    "        w = np.dot(tmp2, np.dot(np.transpose(tmp), eyAdj2))\n",
    "\n",
    "        phi = np.zeros(self.M)\n",
    "        phi[nonzero_inds[:-1]] = w\n",
    "        phi[nonzero_inds[-1]] = (self.link.f(self.fx[dim]) - self.link.f(self.fnull[dim])) - sum(w)\n",
    "\n",
    "        # clean up any rounding errors\n",
    "        for i in range(self.M):\n",
    "            if np.abs(phi[i]) < 1e-10:\n",
    "                phi[i] = 0\n",
    "\n",
    "        return phi, np.ones(len(phi))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class KernelReX:\n",
    "\n",
    "    def __init__(self, model, data, feature_names=None, link=IdentityLink(), **kwargs):\n",
    "\n",
    "        self.data_feature_names = list(data.columns)\n",
    "\n",
    "        # convert incoming inputs to standardized iml objects\n",
    "        self.link = LogitLink()\n",
    "        self.keep_index = kwargs.get(\"keep_index\", False)\n",
    "        self.keep_index_ordered = kwargs.get(\"keep_index_ordered\", False)\n",
    "        self.model = convert_to_model(model, keep_index=self.keep_index)\n",
    "        self.data = convert_to_data(data, keep_index=self.keep_index)\n",
    "        model_null = match_model_to_data(self.model, self.data)\n",
    "\n",
    "        # enforce our current input type limitations\n",
    "        assert isinstance(self.data, DenseData) or isinstance(self.data, SparseData), \\\n",
    "               \"Shap explainer only supports the DenseData and SparseData input currently.\"\n",
    "        assert not self.data.transposed, \"Shap explainer does not support transposed DenseData or SparseData currently.\"\n",
    "\n",
    "        # init our parameters\n",
    "        self.N = self.data.data.shape[0]\n",
    "        self.P = self.data.data.shape[1]\n",
    "        self.linkfv = np.vectorize(self.link.f)\n",
    "        self.nsamplesAdded = 0\n",
    "        self.nsamplesRun = 0\n",
    "\n",
    "        # find E_x[f(x)]\n",
    "        if isinstance(model_null, (pd.DataFrame, pd.Series)):\n",
    "            model_null = np.squeeze(model_null.values)\n",
    "        self.fnull = np.sum((model_null.T * self.data.weights).T, 0)\n",
    "        self.expected_value = self.linkfv(self.fnull)\n",
    "\n",
    "        # see if we have a vector output\n",
    "        self.vector_out = True\n",
    "        if len(self.fnull.shape) == 0:\n",
    "            self.vector_out = False\n",
    "            self.fnull = np.array([self.fnull])\n",
    "            self.D = 1\n",
    "            self.expected_value = float(self.expected_value)\n",
    "        else:\n",
    "            self.D = self.fnull.shape[0]\n",
    "\n",
    "    def shap_values(self, X, **kwargs):\n",
    "        \"\"\" Estimate the SHAP values for a set of samples.\n",
    "\n",
    "        Parameters\n",
    "        ----------\n",
    "        X : numpy.array or pandas.DataFrame or any scipy.sparse matrix\n",
    "            A matrix of samples (# samples x # features) on which to explain the model's output.\n",
    "\n",
    "        nsamples : \"auto\" or int\n",
    "            Number of times to re-evaluate the model when explaining each prediction. More samples\n",
    "            lead to lower variance estimates of the SHAP values. The \"auto\" setting uses\n",
    "            `nsamples = 2 * X.shape[1] + 2048`.\n",
    "\n",
    "        l1_reg : \"num_features(int)\", \"auto\" (default for now, but deprecated), \"aic\", \"bic\", or float\n",
    "            The l1 regularization to use for feature selection (the estimation procedure is based on\n",
    "            a debiased lasso). The auto option currently uses \"aic\" when less that 20% of the possible sample\n",
    "            space is enumerated, otherwise it uses no regularization. THE BEHAVIOR OF \"auto\" WILL CHANGE\n",
    "            in a future version to be based on num_features instead of AIC.\n",
    "            The \"aic\" and \"bic\" options use the AIC and BIC rules for regularization.\n",
    "            Using \"num_features(int)\" selects a fix number of top features. Passing a float directly sets the\n",
    "            \"alpha\" parameter of the sklearn.linear_model.Lasso model used for feature selection.\n",
    "\n",
    "        gc_collect : bool\n",
    "           Run garbage collection after each explanation round. Sometime needed for memory intensive explanations (default False).\n",
    "\n",
    "        Returns\n",
    "        -------\n",
    "        array or list\n",
    "            For models with a single output this returns a matrix of SHAP values\n",
    "            (# samples x # features). Each row sums to the difference between the model output for that\n",
    "            sample and the expected value of the model output (which is stored as expected_value\n",
    "            attribute of the explainer). For models with vector outputs this returns a list\n",
    "            of such matrices, one for each output.\n",
    "        \"\"\"\n",
    "\n",
    "        # convert dataframes\n",
    "        if str(type(X)).endswith(\"pandas.core.series.Series'>\"):\n",
    "            X = X.values\n",
    "        elif str(type(X)).endswith(\"'pandas.core.frame.DataFrame'>\"):\n",
    "            if self.keep_index:\n",
    "                index_value = X.index.values\n",
    "                index_name = X.index.name\n",
    "                column_name = list(X.columns)\n",
    "            X = X.values\n",
    "\n",
    "        x_type = str(type(X))\n",
    "        arr_type = \"'numpy.ndarray'>\"\n",
    "        # if sparse, convert to lil for performance\n",
    "        if scipy.sparse.issparse(X) and not scipy.sparse.isspmatrix_lil(X):\n",
    "            X = X.tolil()\n",
    "        assert x_type.endswith(arr_type) or scipy.sparse.isspmatrix_lil(X), \"Unknown instance type: \" + x_type\n",
    "        assert len(X.shape) == 1 or len(X.shape) == 2, \"Instance must have 1 or 2 dimensions!\"\n",
    "\n",
    "        # single instance\n",
    "        if len(X.shape) == 1:\n",
    "            data = X.reshape((1, X.shape[0]))\n",
    "            if self.keep_index:\n",
    "                data = convert_to_instance_with_index(data, column_name, index_name, index_value)\n",
    "            explanation = self.explain(data, **kwargs)\n",
    "\n",
    "            # vector-output\n",
    "            s = explanation.shape\n",
    "            if len(s) == 2:\n",
    "                outs = [np.zeros(s[0]) for j in range(s[1])]\n",
    "                for j in range(s[1]):\n",
    "                    outs[j] = explanation[:, j]\n",
    "                return outs\n",
    "\n",
    "            # single-output\n",
    "            else:\n",
    "                out = np.zeros(s[0])\n",
    "                out[:] = explanation\n",
    "                return out\n",
    "\n",
    "        # explain the whole dataset\n",
    "        elif len(X.shape) == 2:\n",
    "            explanations = []\n",
    "            for i in tqdm(range(X.shape[0]), disable=kwargs.get(\"silent\", False)):\n",
    "                data = X[i:i + 1, :]\n",
    "                if self.keep_index:\n",
    "                    data = convert_to_instance_with_index(data, column_name, index_value[i:i + 1], index_name)\n",
    "                explanations.append(self.explain(data, **kwargs))\n",
    "                if kwargs.get(\"gc_collect\", False):\n",
    "                    gc.collect()\n",
    "\n",
    "            # vector-output\n",
    "            s = explanations[0].shape\n",
    "            if len(s) == 2:\n",
    "                outs = [np.zeros((X.shape[0], s[0])) for j in range(s[1])]\n",
    "                for i in range(X.shape[0]):\n",
    "                    for j in range(s[1]):\n",
    "                        outs[j][i] = explanations[i][:, j]\n",
    "                return outs\n",
    "\n",
    "            # single-output\n",
    "            else:\n",
    "                out = np.zeros((X.shape[0], s[0]))\n",
    "                for i in range(X.shape[0]):\n",
    "                    out[i] = explanations[i]\n",
    "                return out\n",
    "\n",
    "    def explain(self, incoming_instance, **kwargs):\n",
    "        # convert incoming input to a standardized iml object\n",
    "        instance = convert_to_instance(incoming_instance)\n",
    "        match_instance_to_data(instance, self.data)\n",
    "\n",
    "        # find the feature groups we will test. If a feature does not change from its\n",
    "        # current value then we know it doesn't impact the model\n",
    "        self.varyingInds = self.varying_groups(instance.x)\n",
    "        if self.data.groups is None:\n",
    "            self.varyingFeatureGroups = np.array([i for i in self.varyingInds])\n",
    "            self.M = self.varyingFeatureGroups.shape[0]\n",
    "        else:\n",
    "            self.varyingFeatureGroups = [self.data.groups[i] for i in self.varyingInds]\n",
    "            self.M = len(self.varyingFeatureGroups)\n",
    "            groups = self.data.groups\n",
    "            # convert to numpy array as it is much faster if not jagged array (all groups of same length)\n",
    "            if self.varyingFeatureGroups and all(len(groups[i]) == len(groups[0]) for i in self.varyingInds):\n",
    "                self.varyingFeatureGroups = np.array(self.varyingFeatureGroups)\n",
    "                # further performance optimization in case each group has a single value\n",
    "                if self.varyingFeatureGroups.shape[1] == 1:\n",
    "                    self.varyingFeatureGroups = self.varyingFeatureGroups.flatten()\n",
    "\n",
    "        # find f(x)\n",
    "        if self.keep_index:\n",
    "            model_out = self.model.f(instance.convert_to_df())\n",
    "        else:\n",
    "            model_out = self.model.f(instance.x)\n",
    "        if isinstance(model_out, (pd.DataFrame, pd.Series)):\n",
    "            model_out = model_out.values\n",
    "        self.fx = model_out[0]\n",
    "\n",
    "        if not self.vector_out:\n",
    "            self.fx = np.array([self.fx])\n",
    "\n",
    "        # if no features vary then no feature has an effect\n",
    "        if self.M == 0:\n",
    "            phi = np.zeros((self.data.groups_size, self.D))\n",
    "            phi_var = np.zeros((self.data.groups_size, self.D))\n",
    "\n",
    "        # if only one feature varies then it has all the effect\n",
    "        elif self.M == 1:\n",
    "            phi = np.zeros((self.data.groups_size, self.D))\n",
    "            phi_var = np.zeros((self.data.groups_size, self.D))\n",
    "            diff = self.link.f(self.fx) - self.link.f(self.fnull)\n",
    "            for d in range(self.D):\n",
    "                phi[self.varyingInds[0],d] = diff[d]\n",
    "\n",
    "        # if more than one feature varies then we have to do real work\n",
    "        else:\n",
    "            self.l1_reg = kwargs.get(\"l1_reg\", \"auto\")\n",
    "\n",
    "            # pick a reasonable number of samples if the user didn't specify how many they wanted\n",
    "            self.nsamples = kwargs.get(\"nsamples\", \"auto\")\n",
    "            if self.nsamples == \"auto\":\n",
    "                self.nsamples = 2 * self.M + 2**11\n",
    "\n",
    "            # if we have enough samples to enumerate all subsets then ignore the unneeded samples\n",
    "            self.max_samples = 2 ** 30\n",
    "            if self.M <= 30:\n",
    "                self.max_samples = 2 ** self.M - 2\n",
    "                if self.nsamples > self.max_samples:\n",
    "                    self.nsamples = self.max_samples\n",
    "\n",
    "            # reserve space for some of our computations\n",
    "            self.allocate()\n",
    "\n",
    "            # weight the different subset sizes\n",
    "            num_subset_sizes = int(np.ceil((self.M - 1) / 2.0))\n",
    "            num_paired_subset_sizes = int(np.floor((self.M - 1) / 2.0))\n",
    "            weight_vector = np.array([(self.M - 1.0) / (i * (self.M - i)) for i in range(1, num_subset_sizes + 1)])\n",
    "            weight_vector[:num_paired_subset_sizes] *= 2\n",
    "            weight_vector /= np.sum(weight_vector)\n",
    "\n",
    "            # fill out all the subset sizes we can completely enumerate\n",
    "            # given nsamples*remaining_weight_vector[subset_size]\n",
    "            num_full_subsets = 0\n",
    "            num_samples_left = self.nsamples\n",
    "            group_inds = np.arange(self.M, dtype='int64')\n",
    "            mask = np.zeros(self.M)\n",
    "            remaining_weight_vector = copy.copy(weight_vector)\n",
    "            for subset_size in range(1, num_subset_sizes + 1):\n",
    "\n",
    "                # determine how many subsets (and their complements) are of the current size\n",
    "                nsubsets = binom(self.M, subset_size)\n",
    "                if subset_size <= num_paired_subset_sizes:\n",
    "                    nsubsets *= 2\n",
    "\n",
    "                # see if we have enough samples to enumerate all subsets of this size\n",
    "                if num_samples_left * remaining_weight_vector[subset_size - 1] / nsubsets >= 1.0 - 1e-8:\n",
    "                    num_full_subsets += 1\n",
    "                    num_samples_left -= nsubsets\n",
    "\n",
    "                    # rescale what's left of the remaining weight vector to sum to 1\n",
    "                    if remaining_weight_vector[subset_size - 1] < 1.0:\n",
    "                        remaining_weight_vector /= (1 - remaining_weight_vector[subset_size - 1])\n",
    "\n",
    "                    # add all the samples of the current subset size\n",
    "                    w = weight_vector[subset_size - 1] / binom(self.M, subset_size)\n",
    "                    if subset_size <= num_paired_subset_sizes:\n",
    "                        w /= 2.0\n",
    "                    for inds in itertools.combinations(group_inds, subset_size):\n",
    "                        mask[:] = 0.0\n",
    "                        mask[np.array(inds, dtype='int64')] = 1.0\n",
    "                        self.addsample(instance.x, mask, w)\n",
    "                        if subset_size <= num_paired_subset_sizes:\n",
    "                            mask[:] = np.abs(mask - 1)\n",
    "                            self.addsample(instance.x, mask, w)\n",
    "                else:\n",
    "                    break\n",
    "\n",
    "            # add random samples from what is left of the subset space\n",
    "            nfixed_samples = self.nsamplesAdded\n",
    "            samples_left = self.nsamples - self.nsamplesAdded\n",
    "            if num_full_subsets != num_subset_sizes:\n",
    "                remaining_weight_vector = copy.copy(weight_vector)\n",
    "                remaining_weight_vector[:num_paired_subset_sizes] /= 2 # because we draw two samples each below\n",
    "                remaining_weight_vector = remaining_weight_vector[num_full_subsets:]\n",
    "                remaining_weight_vector /= np.sum(remaining_weight_vector)\n",
    "                ind_set = np.random.choice(len(remaining_weight_vector), 4 * samples_left, p=remaining_weight_vector)\n",
    "                ind_set_pos = 0\n",
    "                used_masks = {}\n",
    "                while samples_left > 0 and ind_set_pos < len(ind_set):\n",
    "                    mask.fill(0.0)\n",
    "                    ind = ind_set[ind_set_pos] # we call np.random.choice once to save time and then just read it here\n",
    "                    ind_set_pos += 1\n",
    "                    subset_size = ind + num_full_subsets + 1\n",
    "                    mask[np.random.permutation(self.M)[:subset_size]] = 1.0\n",
    "\n",
    "                    # only add the sample if we have not seen it before, otherwise just\n",
    "                    # increment a previous sample's weight\n",
    "                    mask_tuple = tuple(mask)\n",
    "                    new_sample = False\n",
    "                    if mask_tuple not in used_masks:\n",
    "                        new_sample = True\n",
    "                        used_masks[mask_tuple] = self.nsamplesAdded\n",
    "                        samples_left -= 1\n",
    "                        self.addsample(instance.x, mask, 1.0)\n",
    "                    else:\n",
    "                        self.kernelWeights[used_masks[mask_tuple]] += 1.0\n",
    "\n",
    "                    # add the compliment sample\n",
    "                    if samples_left > 0 and subset_size <= num_paired_subset_sizes:\n",
    "                        mask[:] = np.abs(mask - 1)\n",
    "\n",
    "                        # only add the sample if we have not seen it before, otherwise just\n",
    "                        # increment a previous sample's weight\n",
    "                        if new_sample:\n",
    "                            samples_left -= 1\n",
    "                            self.addsample(instance.x, mask, 1.0)\n",
    "                        else:\n",
    "                            # we know the compliment sample is the next one after the original sample, so + 1\n",
    "                            self.kernelWeights[used_masks[mask_tuple] + 1] += 1.0\n",
    "\n",
    "                # normalize the kernel weights for the random samples to equal the weight left after\n",
    "                # the fixed enumerated samples have been already counted\n",
    "                weight_left = np.sum(weight_vector[num_full_subsets:])\n",
    "\n",
    "                self.kernelWeights[nfixed_samples:] *= weight_left / self.kernelWeights[nfixed_samples:].sum()\n",
    "\n",
    "            # execute the model on the synthetic samples we have created\n",
    "            self.run()\n",
    "\n",
    "            # solve then expand the feature importance (Shapley value) vector to contain the non-varying features\n",
    "            phi = np.zeros((self.data.groups_size, self.D))\n",
    "            phi_var = np.zeros((self.data.groups_size, self.D))\n",
    "            for d in range(self.D):\n",
    "                vphi, vphi_var = self.solve(self.nsamples / self.max_samples, d)\n",
    "                phi[self.varyingInds, d] = vphi\n",
    "                phi_var[self.varyingInds, d] = vphi_var\n",
    "\n",
    "        if not self.vector_out:\n",
    "            phi = np.squeeze(phi, axis=1)\n",
    "            phi_var = np.squeeze(phi_var, axis=1)\n",
    "\n",
    "        return phi\n",
    "\n",
    "    @staticmethod\n",
    "    def not_equal(i, j):\n",
    "        number_types = (int, float, np.number)\n",
    "        if isinstance(i, number_types) and isinstance(j, number_types):\n",
    "            return 0 if np.isclose(i, j, equal_nan=True) else 1\n",
    "        else:\n",
    "            return 0 if i == j else 1\n",
    "\n",
    "    def varying_groups(self, x):\n",
    "        if not scipy.sparse.issparse(x):\n",
    "            varying = np.zeros(self.data.groups_size)\n",
    "            for i in range(0, self.data.groups_size):\n",
    "                inds = self.data.groups[i]\n",
    "                x_group = x[0, inds]\n",
    "                if scipy.sparse.issparse(x_group):\n",
    "                    if all(j not in x.nonzero()[1] for j in inds):\n",
    "                        varying[i] = False\n",
    "                        continue\n",
    "                    x_group = x_group.todense()\n",
    "                num_mismatches = np.sum(np.frompyfunc(self.not_equal, 2, 1)(x_group, self.data.data[:, inds]))\n",
    "                varying[i] = num_mismatches > 0\n",
    "            varying_indices = np.nonzero(varying)[0]\n",
    "            return varying_indices\n",
    "        else:\n",
    "            varying_indices = []\n",
    "            # go over all nonzero columns in background and evaluation data\n",
    "            # if both background and evaluation are zero, the column does not vary\n",
    "            varying_indices = np.unique(np.union1d(self.data.data.nonzero()[1], x.nonzero()[1]))\n",
    "            remove_unvarying_indices = []\n",
    "            for i in range(0, len(varying_indices)):\n",
    "                varying_index = varying_indices[i]\n",
    "                # now verify the nonzero values do vary\n",
    "                data_rows = self.data.data[:, [varying_index]]\n",
    "                nonzero_rows = data_rows.nonzero()[0]\n",
    "\n",
    "                if nonzero_rows.size > 0:\n",
    "                    background_data_rows = data_rows[nonzero_rows]\n",
    "                    if scipy.sparse.issparse(background_data_rows):\n",
    "                        background_data_rows = background_data_rows.toarray()\n",
    "                    num_mismatches = np.sum(np.abs(background_data_rows - x[0, varying_index]) > 1e-7)\n",
    "                    # Note: If feature column non-zero but some background zero, can't remove index\n",
    "                    if num_mismatches == 0 and not \\\n",
    "                        (np.abs(x[0, [varying_index]][0, 0]) > 1e-7 and len(nonzero_rows) < data_rows.shape[0]):\n",
    "                        remove_unvarying_indices.append(i)\n",
    "            mask = np.ones(len(varying_indices), dtype=bool)\n",
    "            mask[remove_unvarying_indices] = False\n",
    "            varying_indices = varying_indices[mask]\n",
    "            return varying_indices\n",
    "\n",
    "    def allocate(self):\n",
    "        if scipy.sparse.issparse(self.data.data):\n",
    "            # We tile the sparse matrix in csr format but convert it to lil\n",
    "            # for performance when adding samples\n",
    "            shape = self.data.data.shape\n",
    "            nnz = self.data.data.nnz\n",
    "            data_rows, data_cols = shape\n",
    "            rows = data_rows * self.nsamples\n",
    "            shape = rows, data_cols\n",
    "            if nnz == 0:\n",
    "                self.synth_data = scipy.sparse.csr_matrix(shape, dtype=self.data.data.dtype).tolil()\n",
    "            else:\n",
    "                data = self.data.data.data\n",
    "                indices = self.data.data.indices\n",
    "                indptr = self.data.data.indptr\n",
    "                last_indptr_idx = indptr[len(indptr) - 1]\n",
    "                indptr_wo_last = indptr[:-1]\n",
    "                new_indptrs = []\n",
    "                for i in range(0, self.nsamples - 1):\n",
    "                    new_indptrs.append(indptr_wo_last + (i * last_indptr_idx))\n",
    "                new_indptrs.append(indptr + ((self.nsamples - 1) * last_indptr_idx))\n",
    "                new_indptr = np.concatenate(new_indptrs)\n",
    "                new_data = np.tile(data, self.nsamples)\n",
    "                new_indices = np.tile(indices, self.nsamples)\n",
    "                self.synth_data = scipy.sparse.csr_matrix((new_data, new_indices, new_indptr), shape=shape).tolil()\n",
    "        else:\n",
    "            self.synth_data = np.tile(self.data.data, (self.nsamples, 1))\n",
    "\n",
    "        self.maskMatrix = np.zeros((self.nsamples, self.M))\n",
    "        self.kernelWeights = np.zeros(self.nsamples)\n",
    "        self.y = np.zeros((self.nsamples * self.N, self.D))\n",
    "        self.ey = np.zeros((self.nsamples, self.D))\n",
    "        self.lastMask = np.zeros(self.nsamples)\n",
    "        self.nsamplesAdded = 0\n",
    "        self.nsamplesRun = 0\n",
    "        if self.keep_index:\n",
    "            self.synth_data_index = np.tile(self.data.index_value, self.nsamples)\n",
    "\n",
    "    def addsample(self, x, m, w):\n",
    "        offset = self.nsamplesAdded * self.N\n",
    "        if isinstance(self.varyingFeatureGroups, (list,)):\n",
    "            for j in range(self.M):\n",
    "                for k in self.varyingFeatureGroups[j]:\n",
    "                    if m[j] == 1.0:\n",
    "                        self.synth_data[offset:offset+self.N, k] = x[0, k]\n",
    "        else:\n",
    "            # for non-jagged numpy array we can significantly boost performance\n",
    "            mask = m == 1.0\n",
    "            groups = self.varyingFeatureGroups[mask]\n",
    "            if len(groups.shape) == 2:\n",
    "                for group in groups:\n",
    "                    self.synth_data[offset:offset+self.N, group] = x[0, group]\n",
    "            else:\n",
    "                # further performance optimization in case each group has a single feature\n",
    "                evaluation_data = x[0, groups]\n",
    "                # In edge case where background is all dense but evaluation data\n",
    "                # is all sparse, make evaluation data dense\n",
    "                if scipy.sparse.issparse(x) and not scipy.sparse.issparse(self.synth_data):\n",
    "                    evaluation_data = evaluation_data.toarray()\n",
    "                self.synth_data[offset:offset+self.N, groups] = evaluation_data\n",
    "        self.maskMatrix[self.nsamplesAdded, :] = m\n",
    "        self.kernelWeights[self.nsamplesAdded] = w\n",
    "        self.nsamplesAdded += 1\n",
    "\n",
    "    def run(self):\n",
    "        num_to_run = self.nsamplesAdded * self.N - self.nsamplesRun * self.N\n",
    "        data = self.synth_data[self.nsamplesRun*self.N:self.nsamplesAdded*self.N,:]\n",
    "        if self.keep_index:\n",
    "            index = self.synth_data_index[self.nsamplesRun*self.N:self.nsamplesAdded*self.N]\n",
    "            index = pd.DataFrame(index, columns=[self.data.index_name])\n",
    "            data = pd.DataFrame(data, columns=self.data.group_names)\n",
    "            data = pd.concat([index, data], axis=1).set_index(self.data.index_name)\n",
    "            if self.keep_index_ordered:\n",
    "                data = data.sort_index()\n",
    "        modelOut = self.model.f(data)\n",
    "        if isinstance(modelOut, (pd.DataFrame, pd.Series)):\n",
    "            modelOut = modelOut.values\n",
    "        self.y[self.nsamplesRun * self.N:self.nsamplesAdded * self.N, :] = np.reshape(modelOut, (num_to_run, self.D))\n",
    "\n",
    "        # find the expected value of each output\n",
    "        for i in range(self.nsamplesRun, self.nsamplesAdded):\n",
    "            eyVal = np.zeros(self.D)\n",
    "            for j in range(0, self.N):\n",
    "                eyVal += self.y[i * self.N + j, :] * self.data.weights[j]\n",
    "\n",
    "            self.ey[i, :] = eyVal\n",
    "            self.nsamplesRun += 1\n",
    "\n",
    "    def solve(self, fraction_evaluated, dim):\n",
    "        eyAdj = self.linkfv(self.ey[:, dim]) - self.link.f(self.fnull[dim])\n",
    "        s = np.sum(self.maskMatrix, 1)\n",
    "\n",
    "        # do feature selection if we have not well enumerated the space\n",
    "        nonzero_inds = np.arange(self.M)\n",
    "        \n",
    "        # if self.l1_reg == \"auto\":\n",
    "        #     warnings.warn(\n",
    "        #         \"l1_reg=\\\"auto\\\" is deprecated and in the next version (v0.29) the behavior will change from a \" \\\n",
    "        #         \"conditional use of AIC to simply \\\"num_features(10)\\\"!\"\n",
    "        #     )\n",
    "        if (self.l1_reg not in [\"auto\", False, 0]) or (fraction_evaluated < 0.2 and self.l1_reg == \"auto\"):\n",
    "            w_aug = np.hstack((self.kernelWeights * (self.M - s), self.kernelWeights * s))\n",
    "            w_sqrt_aug = np.sqrt(w_aug)\n",
    "            eyAdj_aug = np.hstack((eyAdj, eyAdj - (self.link.f(self.fx[dim]) - self.link.f(self.fnull[dim]))))\n",
    "            eyAdj_aug *= w_sqrt_aug\n",
    "            mask_aug = np.transpose(w_sqrt_aug * np.transpose(np.vstack((self.maskMatrix, self.maskMatrix - 1))))\n",
    "            #var_norms = np.array([np.linalg.norm(mask_aug[:, i]) for i in range(mask_aug.shape[1])])\n",
    "\n",
    "            # select a fixed number of top features\n",
    "            if isinstance(self.l1_reg, str) and self.l1_reg.startswith(\"num_features(\"):\n",
    "                r = int(self.l1_reg[len(\"num_features(\"):-1])\n",
    "                nonzero_inds = lars_path(mask_aug, eyAdj_aug, max_iter=r)[1]\n",
    "\n",
    "            # use an adaptive regularization method\n",
    "            elif self.l1_reg == \"auto\" or self.l1_reg == \"bic\" or self.l1_reg == \"aic\":\n",
    "                c = \"aic\" if self.l1_reg == \"auto\" else self.l1_reg\n",
    "\n",
    "                # \"Normalize\" parameter of LassoLarsIC was deprecated in sklearn version 1.2\n",
    "                if version.parse(sklearn.__version__) < version.parse(\"1.2.0\"):\n",
    "                    kwg = dict(normalize=False)\n",
    "                else:\n",
    "                    kwg = {}\n",
    "                model = make_pipeline(StandardScaler(with_mean=False), LassoLarsIC(criterion=c, **kwg))\n",
    "                nonzero_inds = np.nonzero(model.fit(mask_aug, eyAdj_aug)[1].coef_)[0]\n",
    "\n",
    "            # use a fixed regularization coefficient\n",
    "            else:\n",
    "                nonzero_inds = np.nonzero(Lasso(alpha=self.l1_reg).fit(mask_aug, eyAdj_aug).coef_)[0]\n",
    "\n",
    "        if len(nonzero_inds) == 0:\n",
    "            return np.zeros(self.M), np.ones(self.M)\n",
    "\n",
    "        # eliminate one variable with the constraint that all features sum to the output\n",
    "        eyAdj2 = eyAdj - self.maskMatrix[:, nonzero_inds[-1]] * (\n",
    "                    self.link.f(self.fx[dim]) - self.link.f(self.fnull[dim]))\n",
    "        etmp = np.transpose(np.transpose(self.maskMatrix[:, nonzero_inds[:-1]]) - self.maskMatrix[:, nonzero_inds[-1]])\n",
    "\n",
    "        # solve a weighted least squares equation to estimate phi\n",
    "        tmp = np.transpose(np.transpose(etmp) * np.transpose(self.kernelWeights))\n",
    "        etmp_dot = np.dot(np.transpose(tmp), etmp)\n",
    "        try:\n",
    "            tmp2 = np.linalg.inv(etmp_dot)\n",
    "        except np.linalg.LinAlgError:\n",
    "            tmp2 = np.linalg.pinv(etmp_dot)\n",
    "            warnings.warn(\n",
    "                \"Linear regression equation is singular, Moore-Penrose pseudoinverse is used instead of the regular inverse.\\n\"\n",
    "                \"To use regular inverse do one of the following:\\n\"\n",
    "                \"1) turn up the number of samples,\\n\"\n",
    "                \"2) turn up the L1 regularization with num_features(N) where N is less than the number of samples,\\n\"\n",
    "                \"3) group features together to reduce the number of inputs that need to be explained.\"\n",
    "            )\n",
    "        w = np.dot(tmp2, np.dot(np.transpose(tmp), eyAdj2))\n",
    "\n",
    "        phi = np.zeros(self.M)\n",
    "        phi[nonzero_inds[:-1]] = w\n",
    "        phi[nonzero_inds[-1]] = (self.link.f(self.fx[dim]) - self.link.f(self.fnull[dim])) - sum(w)\n",
    "\n",
    "        # clean up any rounding errors\n",
    "        for i in range(self.M):\n",
    "            if np.abs(phi[i]) < 1e-10:\n",
    "                phi[i] = 0\n",
    "\n",
    "        return phi, np.ones(len(phi))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from predict import predict, predict_lr,tokenizer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "test = pd.read_csv('/home/outerform/github-repo/ReX/data/sentiment-test', header=None,\n",
    "                       sep='\\t', names=['text', 'label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "# select the rows with text less than 10 words after tokenizer \n",
    "test = test[test['text'].apply(lambda x: len(tokenizer.encode(x,add_special_tokens=False)) <= 10)]\n",
    "test = test.reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Slicing a positional slice with .loc is not supported, and will raise TypeError in a future version.  Use .loc with labels or .iloc with positions instead.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>word_1</th>\n",
       "      <th>word_2</th>\n",
       "      <th>word_3</th>\n",
       "      <th>word_4</th>\n",
       "      <th>word_5</th>\n",
       "      <th>word_6</th>\n",
       "      <th>word_7</th>\n",
       "      <th>word_8</th>\n",
       "      <th>word_9</th>\n",
       "      <th>word_10</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>15491</td>\n",
       "      <td>28616</td>\n",
       "      <td>22444</td>\n",
       "      <td>4095</td>\n",
       "      <td>1997</td>\n",
       "      <td>6782</td>\n",
       "      <td>1998</td>\n",
       "      <td>11541</td>\n",
       "      <td>1012</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2009</td>\n",
       "      <td>1005</td>\n",
       "      <td>1055</td>\n",
       "      <td>2074</td>\n",
       "      <td>11757</td>\n",
       "      <td>10634</td>\n",
       "      <td>1012</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2204</td>\n",
       "      <td>3185</td>\n",
       "      <td>1012</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2087</td>\n",
       "      <td>2047</td>\n",
       "      <td>5691</td>\n",
       "      <td>2031</td>\n",
       "      <td>1037</td>\n",
       "      <td>4408</td>\n",
       "      <td>20682</td>\n",
       "      <td>1012</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2498</td>\n",
       "      <td>2062</td>\n",
       "      <td>2084</td>\n",
       "      <td>1037</td>\n",
       "      <td>19960</td>\n",
       "      <td>3695</td>\n",
       "      <td>16748</td>\n",
       "      <td>13012</td>\n",
       "      <td>21031</td>\n",
       "      <td>1012</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>237</th>\n",
       "      <td>1037</td>\n",
       "      <td>8242</td>\n",
       "      <td>6298</td>\n",
       "      <td>4038</td>\n",
       "      <td>1012</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>238</th>\n",
       "      <td>2092</td>\n",
       "      <td>1010</td>\n",
       "      <td>2009</td>\n",
       "      <td>2515</td>\n",
       "      <td>2175</td>\n",
       "      <td>2006</td>\n",
       "      <td>5091</td>\n",
       "      <td>1012</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>239</th>\n",
       "      <td>2054</td>\n",
       "      <td>4268</td>\n",
       "      <td>2097</td>\n",
       "      <td>7523</td>\n",
       "      <td>2003</td>\n",
       "      <td>1037</td>\n",
       "      <td>2047</td>\n",
       "      <td>8145</td>\n",
       "      <td>7028</td>\n",
       "      <td>1012</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>240</th>\n",
       "      <td>2045</td>\n",
       "      <td>1005</td>\n",
       "      <td>1055</td>\n",
       "      <td>2025</td>\n",
       "      <td>2438</td>\n",
       "      <td>2000</td>\n",
       "      <td>15770</td>\n",
       "      <td>1996</td>\n",
       "      <td>4038</td>\n",
       "      <td>1012</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>241</th>\n",
       "      <td>1037</td>\n",
       "      <td>2613</td>\n",
       "      <td>18856</td>\n",
       "      <td>16814</td>\n",
       "      <td>2121</td>\n",
       "      <td>1012</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>242 rows × 11 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    word_1 word_2 word_3 word_4 word_5 word_6 word_7 word_8 word_9 word_10  \\\n",
       "0    15491  28616  22444   4095   1997   6782   1998  11541   1012       0   \n",
       "1     2009   1005   1055   2074  11757  10634   1012      0      0       0   \n",
       "2     2204   3185   1012      0      0      0      0      0      0       0   \n",
       "3     2087   2047   5691   2031   1037   4408  20682   1012      0       0   \n",
       "4     2498   2062   2084   1037  19960   3695  16748  13012  21031    1012   \n",
       "..     ...    ...    ...    ...    ...    ...    ...    ...    ...     ...   \n",
       "237   1037   8242   6298   4038   1012      0      0      0      0       0   \n",
       "238   2092   1010   2009   2515   2175   2006   5091   1012      0       0   \n",
       "239   2054   4268   2097   7523   2003   1037   2047   8145   7028    1012   \n",
       "240   2045   1005   1055   2025   2438   2000  15770   1996   4038    1012   \n",
       "241   1037   2613  18856  16814   2121   1012      0      0      0       0   \n",
       "\n",
       "    label  \n",
       "0     NaN  \n",
       "1     NaN  \n",
       "2     NaN  \n",
       "3     NaN  \n",
       "4     NaN  \n",
       "..    ...  \n",
       "237   NaN  \n",
       "238   NaN  \n",
       "239   NaN  \n",
       "240   NaN  \n",
       "241   NaN  \n",
       "\n",
       "[242 rows x 11 columns]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cols = [f'word_{i}' for i in range(1,11)]+['label']\n",
    "new_test = pd.DataFrame(columns=cols)\n",
    "for i in range(len(test)):\n",
    "    new_test.loc[i,0:-1] = tokenizer.encode(test.iloc[i]['text'], add_special_tokens=False, max_length=10, truncation=True, padding='max_length')\n",
    "new_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def predict_shap(sentences):\n",
    "    # print(sentences)\n",
    "    sentences = tokenizer.batch_decode(sentences, skip_special_tokens=True, clean_up_tokenization_spaces=True)\n",
    "    # print(sentences)\n",
    "    res = predict(sentences)\n",
    "    res[:,1],res[:,0] = res[:,0].copy(),res[:,1].copy()\n",
    "    return res\n",
    "\n",
    "def match(predicates, sentence):\n",
    "    for [x,y,d] in predicates:\n",
    "        if sum(sentence[x+1:y])<d:\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "link = LogitLink()\n",
    "def predict_one(ins, sentence, num_sample = 100):\n",
    "    samples:np.ndarray = np.random.binomial(num_sample*len(ins)).reshape(num_sample, len(ins))\n",
    "    # samples = [random_switch(sample) for sample in samples]\n",
    "    samples = np.array(list(filter(lambda x: match(sentence,x),samples)))\n",
    "    res = predict_shap(samples)[:,0]\n",
    "    res = link.f(res)\n",
    "    res = link.finv(res.mean())\n",
    "    return [res,1-res]\n",
    "    \n",
    "def predict_shap_rex(ins, sentences):\n",
    "    res = []\n",
    "    for sentence in sentences:\n",
    "        res.append(predict_one(ins,sentence))\n",
    "    return np.array(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# back_ground = pd.DataFrame(columns=new_test.columns)\n",
    "# back_ground.loc[0] = tokenizer.unk_token_id\n",
    "# back_ground"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import token\n",
    "\n",
    "\n",
    "class ReXExplainer:\n",
    "    def __init__(self,sentence):\n",
    "        self.tokens = np.array(tokenizer.encode(sentence, add_special_tokens=False))\n",
    "        predicates = []\n",
    "        for i in range(len(self.tokens)):\n",
    "            for d in range(i):\n",
    "                predicates.append([i,i,d])\n",
    "            for j in range(i+1, len(self.tokens)):\n",
    "                for d in range(j-i):\n",
    "                    predicates.append([i,j,d])\n",
    "        back_ground = pd.DataFrame(columns=list(range(len(predicates))))\n",
    "        back_ground.loc[0] = 0\n",
    "        self.predicates = np.array(predicates)\n",
    "        self.text = sentence\n",
    "        self.explainer = KernelReX(self.predict_shap_rex, back_ground, link=\"logit\")\n",
    "        self.shap_values = self.explainer.shap_values(np.ones(len(self.predicates)),nsamples=1000)\n",
    "    \n",
    "    def match(self, rules, sentence):\n",
    "        # print(rules==1)\n",
    "        for [x,y,d] in self.predicates[rules==1]:\n",
    "            if x == y:\n",
    "                if sentence[x]!=1 or sum(sentence[:x]>0)<d:\n",
    "                    return False\n",
    "            elif sentence[x]!=1 or sentence[y]==1 or sum(sentence[x+1:y]>0)<d:\n",
    "                return False\n",
    "        return True\n",
    "    def covert(self,samples):\n",
    "        # print(samples)\n",
    "        # print(self.tokens)\n",
    "        # print(tokenizer.unk_token_id)\n",
    "        try:\n",
    "            samples = (samples==1)*self.tokens+(samples==2)*tokenizer.unk_token_id\n",
    "        except ValueError as e:\n",
    "            print(samples)\n",
    "            print(self.tokens)\n",
    "            print(tokenizer.unk_token_id)\n",
    "            raise e\n",
    "        # print(samples)\n",
    "        res = [np.pad(sample[sample.nonzero()],len(self.tokens)-len(sample.nonzero()[0]),mode='constant') for sample in samples]\n",
    "        # print(res)\n",
    "        return res\n",
    "    def predict_one(self, sentence, num_sample = 3):\n",
    "        samples:np.ndarray = np.random.choice(3, size=(num_sample,len(self.tokens)))\n",
    "        # samples = [random_switch(sample) for sample in samples]\n",
    "        samples = np.array(list(filter(lambda x: self.match(sentence,x),samples)))\n",
    "        if len(samples)==0:\n",
    "            return [0.5,0.5]\n",
    "        samples = self.covert(samples)\n",
    "        res = predict_shap(samples)[:,0]\n",
    "        res = link.f(res)\n",
    "        res = link.finv(res.mean())\n",
    "        return [res,1-res]\n",
    "        \n",
    "    def predict_shap_rex(self, sentences):\n",
    "        res = []\n",
    "        for sentence in sentences:\n",
    "            res.append(self.predict_one(sentence))\n",
    "        return np.array(res)\n",
    "    def shap_value(self):\n",
    "        return self.shap_values\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "explainer = ReXExplainer('I am not OK')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([-0.10329754, -0.17960393, -0.11221275,  0.3186519 , -0.04528667,\n",
       "        -0.03804585, -0.04530104, -0.04920282, -0.03443438,  0.        ,\n",
       "         0.03580937,  0.08360895,  0.27262453, -0.15869112,  0.        ,\n",
       "         0.        ]),\n",
       " array([ 0.10329754,  0.17960393,  0.11221275, -0.3186519 ,  0.04528667,\n",
       "         0.03804585,  0.04530104,  0.04920282,  0.03443438,  0.        ,\n",
       "        -0.03580937, -0.08360895, -0.27262453,  0.15869112,  0.        ,\n",
       "         0.        ])]"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "explainer.shap_value()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sentences</th>\n",
       "      <th>pred</th>\n",
       "      <th>label</th>\n",
       "      <th>exp_shap</th>\n",
       "      <th>exp_rex_shap</th>\n",
       "      <th>shap_time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>uneasy mishmash of styles and genres .</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[[2.7219935204, 2.0933041182, 0.86498494300000...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.250064</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>it 's just incredibly dull .</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[[-0.1276223819, -0.0865660875, -0.1404735482,...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.237741</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>good movie .</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>[[-6.9086247285999995, -1.4494726712000001, -0...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.195784</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>most new movies have a bright sheen .</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>[[-0.0084425255, -2.1161486946, 0.208185259400...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.239273</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>nothing more than a mediocre trifle .</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[[2.574848053, -0.1271480509, -0.2942770142, -...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>0.246558</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 sentences  pred  label  \\\n",
       "0  uneasy mishmash of styles and genres .      0      0   \n",
       "1            it 's just incredibly dull .      0      0   \n",
       "2                            good movie .      1      1   \n",
       "3   most new movies have a bright sheen .      1      1   \n",
       "4   nothing more than a mediocre trifle .      0      0   \n",
       "\n",
       "                                            exp_shap  exp_rex_shap  shap_time  \n",
       "0  [[2.7219935204, 2.0933041182, 0.86498494300000...           NaN   0.250064  \n",
       "1  [[-0.1276223819, -0.0865660875, -0.1404735482,...           NaN   0.237741  \n",
       "2  [[-6.9086247285999995, -1.4494726712000001, -0...           NaN   0.195784  \n",
       "3  [[-0.0084425255, -2.1161486946, 0.208185259400...           NaN   0.239273  \n",
       "4  [[2.574848053, -0.1271480509, -0.2942770142, -...           NaN   0.246558  "
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "result_df = pd.read_json('result.json')\n",
    "result_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df['predicates'] = 0\n",
    "result_df['times'] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bc44071a44194df5ba598b48c8427fa5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/242 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import time\n",
    "res = []\n",
    "for idx,sentence in tqdm(enumerate(test['text']),total=len(test)):\n",
    "    st = time.time()\n",
    "    explainer = ReXExplainer(sentence)\n",
    "    # print(explainer.shap_value())\n",
    "    # result_df.loc[idx,'exp_rex_shap'] = [explainer.shap_value()]\n",
    "    # result_df.loc[idx,'predicates'] = [explainer.predicates]\n",
    "    # result_df.loc[idx,'time'] = time.time() - st\n",
    "    res.append([explainer.shap_value(),explainer.predicates,time.time() - st])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_rex_shap,predicates,times = list(zip(*res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df['rex_time'] = times"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0      0\n",
       "1      0\n",
       "2      0\n",
       "3      0\n",
       "4      0\n",
       "      ..\n",
       "237    0\n",
       "238    0\n",
       "239    0\n",
       "240    0\n",
       "241    0\n",
       "Name: times, Length: 242, dtype: int64"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result_df.pop('times')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df.exp_rex_shap = exp_rex_shap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df.predicates = predicates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df.to_csv('rex_res.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import fabs\n",
    "selected = []\n",
    "for idx in range(len(times)):\n",
    "    # print(exp_rex_shap[idx][0])\n",
    "    order = sorted(list(range(len(predicates[idx]))),key=lambda x:fabs(exp_rex_shap[idx][0][x]),reverse=True)\n",
    "    # print(order)\n",
    "    order2 = []\n",
    "    for j in order:\n",
    "        if (predicates[idx][j][0],predicates[idx][j][1]) not in [(predicates[idx][k][0],predicates[idx][k][1]) for k in order2]:\n",
    "            order2.append(j)\n",
    "    order2 = order2[:predicates[idx][-1][0]+1]\n",
    "    selected.append(order2)\n",
    "\n",
    "result_df['selected'] = selected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sentences</th>\n",
       "      <th>pred</th>\n",
       "      <th>label</th>\n",
       "      <th>exp_shap</th>\n",
       "      <th>exp_rex_shap</th>\n",
       "      <th>predicates</th>\n",
       "      <th>rex_time</th>\n",
       "      <th>selected</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>uneasy mishmash of styles and genres .</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[[-2.9538865764, -4.5909770619, -1.1575369739,...</td>\n",
       "      <td>[[0.0030301080037045125, 0.04512306584326642, ...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>3.349509</td>\n",
       "      <td>[116, 43, 126, 90, 29, 3, 81, 58, 11]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>it 's just incredibly dull .</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[[0.7497899702, 0.5451659446, 0.3746587302, -0...</td>\n",
       "      <td>[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>3.980032</td>\n",
       "      <td>[67, 12, 24, 0, 1, 3, 6]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>good movie .</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>[[4.6833177305, 2.2067418796, 0.7322422585, -0...</td>\n",
       "      <td>[[-0.6053785292911353, -0.1460227179761313, -0...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [1, 1, 0], [...</td>\n",
       "      <td>0.864872</td>\n",
       "      <td>[5, 4, 0]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>most new movies have a bright sheen .</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>[[0.071488014, 1.0026802766, 1.0461042556, 0.3...</td>\n",
       "      <td>[[0.0, 0.0, 0.0, 0.0, -0.13180292906812108, -0...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>3.687434</td>\n",
       "      <td>[62, 32, 40, 9, 96, 36, 74, 18]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>nothing more than a mediocre trifle .</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[[-2.8252216751, 0.0676713233, 0.4272910984, -...</td>\n",
       "      <td>[[0.0, 0.0, 0.16593652560397842, 0.0, 0.0, 0.0...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>3.601137</td>\n",
       "      <td>[93, 42, 114, 142, 127, 45, 116, 190, 136, 200]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>237</th>\n",
       "      <td>a pleasant romantic comedy .</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>[[0.7536170943, 2.3774593832, 2.3024900137, 1....</td>\n",
       "      <td>[[-0.6200508956650741, -0.27859673602830437, -...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>2.007713</td>\n",
       "      <td>[0, 27, 2, 23, 5]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>238</th>\n",
       "      <td>well , it does go on forever .</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>[[1.5891878341, -0.273119646, 0.66363222610000...</td>\n",
       "      <td>[[0.0, 0.34067151078257396, -0.281100989178876...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>2.690393</td>\n",
       "      <td>[35, 48, 59, 103, 30, 17, 79, 64]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>239</th>\n",
       "      <td>what kids will discover is a new collectible .</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>[[0.2380497192, 1.1394186681, 1.5273038723, 0....</td>\n",
       "      <td>[[-0.48986563470030475, 1.0892845498378332, 0....</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>2.766771</td>\n",
       "      <td>[148, 97, 183, 152, 74, 24, 66, 11, 197, 1]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>240</th>\n",
       "      <td>there 's not enough to sustain the comedy .</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[[0.4373973349, 0.3067432346, 0.0313808303, -1...</td>\n",
       "      <td>[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>2.831442</td>\n",
       "      <td>[116, 128, 130, 112, 124, 13, 105, 41, 0, 1]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>241</th>\n",
       "      <td>a real clunker .</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[[1.9007193874000001, 7.6332260962, -2.7738666...</td>\n",
       "      <td>[[-0.43779886917471056, -0.3152869819914629, -...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>2.401160</td>\n",
       "      <td>[19, 35, 22, 39, 18, 0]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>242 rows × 8 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           sentences  pred  label  \\\n",
       "0            uneasy mishmash of styles and genres .      0      0   \n",
       "1                      it 's just incredibly dull .      0      0   \n",
       "2                                      good movie .      1      1   \n",
       "3             most new movies have a bright sheen .      1      1   \n",
       "4             nothing more than a mediocre trifle .      0      0   \n",
       "..                                               ...   ...    ...   \n",
       "237                    a pleasant romantic comedy .      1      1   \n",
       "238                  well , it does go on forever .      1      0   \n",
       "239  what kids will discover is a new collectible .      1      1   \n",
       "240     there 's not enough to sustain the comedy .      0      0   \n",
       "241                                a real clunker .      0      0   \n",
       "\n",
       "                                              exp_shap  \\\n",
       "0    [[-2.9538865764, -4.5909770619, -1.1575369739,...   \n",
       "1    [[0.7497899702, 0.5451659446, 0.3746587302, -0...   \n",
       "2    [[4.6833177305, 2.2067418796, 0.7322422585, -0...   \n",
       "3    [[0.071488014, 1.0026802766, 1.0461042556, 0.3...   \n",
       "4    [[-2.8252216751, 0.0676713233, 0.4272910984, -...   \n",
       "..                                                 ...   \n",
       "237  [[0.7536170943, 2.3774593832, 2.3024900137, 1....   \n",
       "238  [[1.5891878341, -0.273119646, 0.66363222610000...   \n",
       "239  [[0.2380497192, 1.1394186681, 1.5273038723, 0....   \n",
       "240  [[0.4373973349, 0.3067432346, 0.0313808303, -1...   \n",
       "241  [[1.9007193874000001, 7.6332260962, -2.7738666...   \n",
       "\n",
       "                                          exp_rex_shap  \\\n",
       "0    [[0.0030301080037045125, 0.04512306584326642, ...   \n",
       "1    [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...   \n",
       "2    [[-0.6053785292911353, -0.1460227179761313, -0...   \n",
       "3    [[0.0, 0.0, 0.0, 0.0, -0.13180292906812108, -0...   \n",
       "4    [[0.0, 0.0, 0.16593652560397842, 0.0, 0.0, 0.0...   \n",
       "..                                                 ...   \n",
       "237  [[-0.6200508956650741, -0.27859673602830437, -...   \n",
       "238  [[0.0, 0.34067151078257396, -0.281100989178876...   \n",
       "239  [[-0.48986563470030475, 1.0892845498378332, 0....   \n",
       "240  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...   \n",
       "241  [[-0.43779886917471056, -0.3152869819914629, -...   \n",
       "\n",
       "                                            predicates  rex_time  \\\n",
       "0    [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  3.349509   \n",
       "1    [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  3.980032   \n",
       "2    [[0, 1, 0], [0, 2, 0], [0, 2, 1], [1, 1, 0], [...  0.864872   \n",
       "3    [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  3.687434   \n",
       "4    [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  3.601137   \n",
       "..                                                 ...       ...   \n",
       "237  [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  2.007713   \n",
       "238  [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  2.690393   \n",
       "239  [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  2.766771   \n",
       "240  [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  2.831442   \n",
       "241  [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  2.401160   \n",
       "\n",
       "                                            selected  \n",
       "0              [116, 43, 126, 90, 29, 3, 81, 58, 11]  \n",
       "1                           [67, 12, 24, 0, 1, 3, 6]  \n",
       "2                                          [5, 4, 0]  \n",
       "3                    [62, 32, 40, 9, 96, 36, 74, 18]  \n",
       "4    [93, 42, 114, 142, 127, 45, 116, 190, 136, 200]  \n",
       "..                                               ...  \n",
       "237                                [0, 27, 2, 23, 5]  \n",
       "238                [35, 48, 59, 103, 30, 17, 79, 64]  \n",
       "239      [148, 97, 183, 152, 74, 24, 66, 11, 197, 1]  \n",
       "240     [116, 128, 130, 112, 124, 13, 105, 41, 0, 1]  \n",
       "241                          [19, 35, 22, 39, 18, 0]  \n",
       "\n",
       "[242 rows x 8 columns]"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df.to_json('rex_res.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import itertools\n",
    "\n",
    "from zmq import PollEvent\n",
    "\n",
    "def normalize(rules):\n",
    "    sum = 0\n",
    "    res = []\n",
    "    for _,x in rules:\n",
    "        sum += abs(x)\n",
    "    for x,y in rules:\n",
    "        res.append((x,y/sum))\n",
    "    return res\n",
    "\n",
    "\n",
    "def coverage_shap(line, threshold):\n",
    "    sentence = line['sentences']\n",
    "    words = tokenizer.encode(sentence,add_special_tokens=False)\n",
    "    shap_values = line['exp_shap'][0]\n",
    "    rex_shap_values = np.array(line['exp_rex_shap'][0])[line['selected']]\n",
    "    rex_shape_predicate = np.array(line['predicates'])[line['selected']]\n",
    "    tot = set()\n",
    "    inrule = set()\n",
    "    inLIME = set()\n",
    "    def fit_rex(now,threshold):\n",
    "        totwt = 0\n",
    "        # print(rules)\n",
    "        for (wordx, wordy, dist),wt in zip(rex_shape_predicate,rex_shap_values):\n",
    "            posx = -1\n",
    "            posy = -1\n",
    "            for i, x in enumerate(now):\n",
    "                if x == words[wordx]:\n",
    "                    posx = i\n",
    "                if x == words[wordy]:\n",
    "                    posy = i\n",
    "            if posx==-1 or posy==-1:\n",
    "                continue\n",
    "            if dist == -1:\n",
    "                totwt += wt\n",
    "            else:\n",
    "                if wordx == wordy:\n",
    "                    if posx>=dist:\n",
    "                        totwt +=wt\n",
    "                elif posy - posx >= dist:\n",
    "                    totwt += wt\n",
    "        if threshold < 0:\n",
    "            totwt = -totwt\n",
    "            threshold = - threshold\n",
    "        if totwt > threshold:\n",
    "            return True\n",
    "        return False\n",
    "\n",
    "    def fit_shap(now,threshold):\n",
    "        totwt = 0\n",
    "        if len(now) != len(words):\n",
    "            return False\n",
    "            \n",
    "        for i,x in enumerate(now):\n",
    "            if words[i] == x:\n",
    "                totwt += shap_values[i]\n",
    "        if threshold < 0:\n",
    "            totwt = -totwt\n",
    "            threshold = - threshold\n",
    "        if abs(totwt) > threshold:\n",
    "            return True\n",
    "        return False\n",
    "\n",
    "    def dfs(i, now):\n",
    "        if i == len(words):\n",
    "            for seq in itertools.permutations(now):\n",
    "                tot.add(seq)\n",
    "                if fit_rex(seq,threshold):\n",
    "                    inrule.add(seq)\n",
    "                if fit_shap(seq,threshold):\n",
    "                    inLIME.add(seq)\n",
    "                return\n",
    "        dfs(i+1, now+[words[i]])\n",
    "        dfs(i+1, now+[tokenizer.unk_token_id])\n",
    "        dfs(i+1, now)\n",
    "\n",
    "    ori = line['pred']\n",
    "    if ori == 0:\n",
    "        threshold = -threshold-0.00000001\n",
    "    dfs(0,[])\n",
    "    cover = [len(inLIME)/len(tot),len(inrule)/len(tot)]\n",
    "    if len(inrule)==0:\n",
    "        precision = [1,1]\n",
    "    else:\n",
    "        myres = 1-np.argmax(predict_shap(list(inrule)),-1)\n",
    "        anchores = 1-np.argmax(predict_shap(list(inLIME)),-1)\n",
    "        precision = [sum(anchores)/len(inLIME),sum(myres)/len(inrule)]\n",
    "        if ori == 0:\n",
    "            precision = [1-precision[0], 1-precision[1]]\n",
    "    return ori,cover,precision\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "17d0b24e93f9474e86492c939f3492a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/242 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_1721215/2228990062.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpbar\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m     \u001b[0mline\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult_df\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m     \u001b[0mori\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcover\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mprecision\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcoverage_shap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mline\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0.3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m     \u001b[0mpbar\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_description\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"cover: {cover[0]} vs {cover[1]} precision: {precision[0]} vs {precision[1]}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m     \u001b[0mcovers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcover\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_1721215/1478667781.py\u001b[0m in \u001b[0;36mcoverage_shap\u001b[0;34m(line, threshold)\u001b[0m\n\u001b[1;32m     87\u001b[0m         \u001b[0mprecision\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     88\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 89\u001b[0;31m         \u001b[0mmyres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpredict_shap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minrule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     90\u001b[0m         \u001b[0manchores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpredict_shap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minLIME\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     91\u001b[0m         \u001b[0mprecision\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0manchores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minLIME\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmyres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minrule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_1721215/3241712949.py\u001b[0m in \u001b[0;36mpredict_shap\u001b[0;34m(sentences)\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0msentences\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_decode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msentences\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mskip_special_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclean_up_tokenization_spaces\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0;31m# print(sentences)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m     \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpipe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msentences\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'label'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m\"P\"\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'score'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mr\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/ml/lib/python3.7/site-packages/transformers/pipelines/text_classification.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    138\u001b[0m             \u001b[0mIf\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mtop_k\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mused\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mone\u001b[0m \u001b[0msuch\u001b[0m \u001b[0mdictionary\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mreturned\u001b[0m \u001b[0mper\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    139\u001b[0m         \"\"\"\n\u001b[0;32m--> 140\u001b[0;31m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    141\u001b[0m         \u001b[0;31m# TODO try and retrieve it in a nicer way from _sanitize_parameters.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    142\u001b[0m         \u001b[0m_legacy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"top_k\"\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/ml/lib/python3.7/site-packages/transformers/pipelines/base.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs, num_workers, batch_size, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1061\u001b[0m                     \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_workers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpreprocess_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mforward_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpostprocess_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1062\u001b[0m                 )\n\u001b[0;32m-> 1063\u001b[0;31m                 \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0moutput\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mfinal_iterator\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1064\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1065\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/ml/lib/python3.7/site-packages/transformers/pipelines/base.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m   1061\u001b[0m                     \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_workers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpreprocess_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mforward_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpostprocess_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1062\u001b[0m                 )\n\u001b[0;32m-> 1063\u001b[0;31m                 \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0moutput\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mfinal_iterator\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1064\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1065\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/ml/lib/python3.7/site-packages/transformers/pipelines/pt_utils.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    113\u001b[0m         \u001b[0;31m# We're out of items within a batch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    114\u001b[0m         \u001b[0mitem\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 115\u001b[0;31m         \u001b[0mprocessed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    116\u001b[0m         \u001b[0;31m# We now have a batch of \"inferred things\".\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    117\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloader_batch_size\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/ml/lib/python3.7/site-packages/transformers/pipelines/text_classification.py\u001b[0m in \u001b[0;36mpostprocess\u001b[0;34m(self, model_outputs, function_to_apply, top_k, _legacy)\u001b[0m\n\u001b[1;32m    198\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    199\u001b[0m         dict_scores = [\n\u001b[0;32m--> 200\u001b[0;31m             \u001b[0;34m{\u001b[0m\u001b[0;34m\"label\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mid2label\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"score\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mscore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscore\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    201\u001b[0m         ]\n\u001b[1;32m    202\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0m_legacy\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "covers = []\n",
    "precisions = []\n",
    "\n",
    "pbar = tqdm(range(len(result_df)))\n",
    "for idx in pbar:\n",
    "    line = result_df.iloc[idx]\n",
    "    ori,cover,precision = coverage_shap(line,0.3)\n",
    "    pbar.set_description(f\"cover: {cover[0]} vs {cover[1]} precision: {precision[0]} vs {precision[1]}\")\n",
    "    covers.append(cover)\n",
    "    precisions.append(precision)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "result_df['cover_shap'],result_df['cover_shap_rex'] = list(zip(*covers))\n",
    "result_df['precision_shap'],result_df['precision_shap_rex'] = list(zip(*precisions))\n",
    "\n",
    "# result_df.to_json('result_shap.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sentences</th>\n",
       "      <th>pred</th>\n",
       "      <th>label</th>\n",
       "      <th>exp_shap</th>\n",
       "      <th>exp_rex_shap</th>\n",
       "      <th>predicates</th>\n",
       "      <th>rex_time</th>\n",
       "      <th>selected</th>\n",
       "      <th>cover_shap</th>\n",
       "      <th>cover_shap_rex</th>\n",
       "      <th>precision_shap</th>\n",
       "      <th>precision_shap_rex</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>uneasy mishmash of styles and genres .</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[[-2.9538865764, -4.5909770619, -1.1575369739,...</td>\n",
       "      <td>[[0.0030301080037045125, 0.04512306584326642, ...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>3.349509</td>\n",
       "      <td>[116, 43, 126, 90, 29, 3, 81, 58, 11]</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>it 's just incredibly dull .</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[[0.7497899702, 0.5451659446, 0.3746587302, -0...</td>\n",
       "      <td>[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>3.980032</td>\n",
       "      <td>[67, 12, 24, 0, 1, 3, 6]</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>good movie .</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>[[4.6833177305, 2.2067418796, 0.7322422585, -0...</td>\n",
       "      <td>[[-0.6053785292911353, -0.1460227179761313, -0...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [1, 1, 0], [...</td>\n",
       "      <td>0.864872</td>\n",
       "      <td>[5, 4, 0]</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>most new movies have a bright sheen .</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>[[0.071488014, 1.0026802766, 1.0461042556, 0.3...</td>\n",
       "      <td>[[0.0, 0.0, 0.0, 0.0, -0.13180292906812108, -0...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>3.687434</td>\n",
       "      <td>[62, 32, 40, 9, 96, 36, 74, 18]</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>nothing more than a mediocre trifle .</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[[-2.8252216751, 0.0676713233, 0.4272910984, -...</td>\n",
       "      <td>[[0.0, 0.0, 0.16593652560397842, 0.0, 0.0, 0.0...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>3.601137</td>\n",
       "      <td>[93, 42, 114, 142, 127, 45, 116, 190, 136, 200]</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>237</th>\n",
       "      <td>a pleasant romantic comedy .</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>[[0.7536170943, 2.3774593832, 2.3024900137, 1....</td>\n",
       "      <td>[[-0.6200508956650741, -0.27859673602830437, -...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>2.007713</td>\n",
       "      <td>[0, 27, 2, 23, 5]</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>238</th>\n",
       "      <td>well , it does go on forever .</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>[[1.5891878341, -0.273119646, 0.66363222610000...</td>\n",
       "      <td>[[0.0, 0.34067151078257396, -0.281100989178876...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>2.690393</td>\n",
       "      <td>[35, 48, 59, 103, 30, 17, 79, 64]</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>239</th>\n",
       "      <td>what kids will discover is a new collectible .</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>[[0.2380497192, 1.1394186681, 1.5273038723, 0....</td>\n",
       "      <td>[[-0.48986563470030475, 1.0892845498378332, 0....</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>2.766771</td>\n",
       "      <td>[148, 97, 183, 152, 74, 24, 66, 11, 197, 1]</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>240</th>\n",
       "      <td>there 's not enough to sustain the comedy .</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[[0.4373973349, 0.3067432346, 0.0313808303, -1...</td>\n",
       "      <td>[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>2.831442</td>\n",
       "      <td>[116, 128, 130, 112, 124, 13, 105, 41, 0, 1]</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>241</th>\n",
       "      <td>a real clunker .</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[[1.9007193874000001, 7.6332260962, -2.7738666...</td>\n",
       "      <td>[[-0.43779886917471056, -0.3152869819914629, -...</td>\n",
       "      <td>[[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...</td>\n",
       "      <td>2.401160</td>\n",
       "      <td>[19, 35, 22, 39, 18, 0]</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>242 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           sentences  pred  label  \\\n",
       "0            uneasy mishmash of styles and genres .      0      0   \n",
       "1                      it 's just incredibly dull .      0      0   \n",
       "2                                      good movie .      1      1   \n",
       "3             most new movies have a bright sheen .      1      1   \n",
       "4             nothing more than a mediocre trifle .      0      0   \n",
       "..                                               ...   ...    ...   \n",
       "237                    a pleasant romantic comedy .      1      1   \n",
       "238                  well , it does go on forever .      1      0   \n",
       "239  what kids will discover is a new collectible .      1      1   \n",
       "240     there 's not enough to sustain the comedy .      0      0   \n",
       "241                                a real clunker .      0      0   \n",
       "\n",
       "                                              exp_shap  \\\n",
       "0    [[-2.9538865764, -4.5909770619, -1.1575369739,...   \n",
       "1    [[0.7497899702, 0.5451659446, 0.3746587302, -0...   \n",
       "2    [[4.6833177305, 2.2067418796, 0.7322422585, -0...   \n",
       "3    [[0.071488014, 1.0026802766, 1.0461042556, 0.3...   \n",
       "4    [[-2.8252216751, 0.0676713233, 0.4272910984, -...   \n",
       "..                                                 ...   \n",
       "237  [[0.7536170943, 2.3774593832, 2.3024900137, 1....   \n",
       "238  [[1.5891878341, -0.273119646, 0.66363222610000...   \n",
       "239  [[0.2380497192, 1.1394186681, 1.5273038723, 0....   \n",
       "240  [[0.4373973349, 0.3067432346, 0.0313808303, -1...   \n",
       "241  [[1.9007193874000001, 7.6332260962, -2.7738666...   \n",
       "\n",
       "                                          exp_rex_shap  \\\n",
       "0    [[0.0030301080037045125, 0.04512306584326642, ...   \n",
       "1    [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...   \n",
       "2    [[-0.6053785292911353, -0.1460227179761313, -0...   \n",
       "3    [[0.0, 0.0, 0.0, 0.0, -0.13180292906812108, -0...   \n",
       "4    [[0.0, 0.0, 0.16593652560397842, 0.0, 0.0, 0.0...   \n",
       "..                                                 ...   \n",
       "237  [[-0.6200508956650741, -0.27859673602830437, -...   \n",
       "238  [[0.0, 0.34067151078257396, -0.281100989178876...   \n",
       "239  [[-0.48986563470030475, 1.0892845498378332, 0....   \n",
       "240  [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...   \n",
       "241  [[-0.43779886917471056, -0.3152869819914629, -...   \n",
       "\n",
       "                                            predicates  rex_time  \\\n",
       "0    [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  3.349509   \n",
       "1    [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  3.980032   \n",
       "2    [[0, 1, 0], [0, 2, 0], [0, 2, 1], [1, 1, 0], [...  0.864872   \n",
       "3    [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  3.687434   \n",
       "4    [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  3.601137   \n",
       "..                                                 ...       ...   \n",
       "237  [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  2.007713   \n",
       "238  [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  2.690393   \n",
       "239  [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  2.766771   \n",
       "240  [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  2.831442   \n",
       "241  [[0, 1, 0], [0, 2, 0], [0, 2, 1], [0, 3, 0], [...  2.401160   \n",
       "\n",
       "                                            selected  cover_shap  \\\n",
       "0              [116, 43, 126, 90, 29, 3, 81, 58, 11]           0   \n",
       "1                           [67, 12, 24, 0, 1, 3, 6]           0   \n",
       "2                                          [5, 4, 0]           0   \n",
       "3                    [62, 32, 40, 9, 96, 36, 74, 18]           0   \n",
       "4    [93, 42, 114, 142, 127, 45, 116, 190, 136, 200]           0   \n",
       "..                                               ...         ...   \n",
       "237                                [0, 27, 2, 23, 5]           0   \n",
       "238                [35, 48, 59, 103, 30, 17, 79, 64]           0   \n",
       "239      [148, 97, 183, 152, 74, 24, 66, 11, 197, 1]           0   \n",
       "240     [116, 128, 130, 112, 124, 13, 105, 41, 0, 1]           0   \n",
       "241                          [19, 35, 22, 39, 18, 0]           0   \n",
       "\n",
       "     cover_shap_rex  precision_shap  precision_shap_rex  \n",
       "0                 0               0                   0  \n",
       "1                 0               0                   0  \n",
       "2                 0               0                   0  \n",
       "3                 0               0                   0  \n",
       "4                 0               0                   0  \n",
       "..              ...             ...                 ...  \n",
       "237               0               0                   0  \n",
       "238               0               0                   0  \n",
       "239               0               0                   0  \n",
       "240               0               0                   0  \n",
       "241               0               0                   0  \n",
       "\n",
       "[242 rows x 12 columns]"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "getexp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
