import numpy as np
from sklearn.model_selection import train_test_split

import lime
import lime.lime_tabular


np.random.seed(0)


class LocalApprox(object):

    def __init__(self, X_train, predict_fn):
        self.explainer = lime.lime_tabular.LimeTabularExplainer(X_train, class_names=['0', '1'], discretize_continuous=True, sample_around_instance=True, random_state=0)
        self.predict_fn = predict_fn

    def extract_weights(self, x_0, shift=0.1):
        exp = self.explainer.explain_instance(x_0, self.predict_fn, num_features=x_0.shape[0], num_samples=5000)
        coefs = exp.local_exp[1] 
        # intercept = exp.intercept[0]
        b = exp.intercept[1] - 0.5
        exp = exp.local_exp[1]
        exp = sorted(exp, key=lambda x: x[0])
        w = np.zeros(x_0.shape[0])
        for e in exp:
            w[e[0]] = e[1]
        # coefs = sorted(coefs, key=lambda x: x[0])
    
        w = np.array([e[1] for e in coefs])
        b = -shift - np.dot(w, x_0)

        return w, np.array(b).reshape(1,)


