import os
import timeit
import pickle
import argparse
from datetime import datetime
from pprint import pprint
from functools import partial

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor, RandomForestClassifier
from sklearn.neural_network import MLPClassifier, MLPRegressor

from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, r2_score

# explainers
import shap
import lime
from lime.lime_tabular import LimeTabularExplainer

import dataloaders
from dataloaders import DATASETS, REGRESSION, give_prototype_indices
from train_save_models import load_models

from ExpCertifyBB import Ecertify

import warnings
warnings.filterwarnings('ignore')

# source: https://github.com/amparore/leaf/blob/master/leaf.py#L69
# Build the linear classifier of a LIME explainer
def get_LIME_classifier(lime_expl, label_x0, x0):
    import sklearn
    features_weights = [x[1] for x in lime_expl.local_exp[label_x0]]
    features_indices = [x[0] for x in lime_expl.local_exp[label_x0]]    # feature' indices
    intercept = lime_expl.intercept[label_x0]
    coef = np.zeros(len(x0))
    coef[features_indices] = features_weights
    if hasattr(lime_expl, 'perfect_local_concordance') and lime_expl.perfect_local_concordance:
        # print('have perfect_local_concordance classifier!')
        g = lime.lime_base.TranslatedRidge(alpha=1.0)
        g.x0 = np.zeros(len(x0))
        g.x0 = lime_expl.x0
        # g.x0[features_indices] = lime_expl.x0[features_indices]
        g.f_x0 = lime_expl.predict_proba[label_x0]
        g.coef_ = g.ridge.coef_ = coef
        g.intercept_ = g.ridge.intercept_ = intercept
        # print('g.x0', g.x0)
        # print('g.f_x0', g.f_x0)
        # print('g.coef_', g.coef_)
        # print('g.intercept_', g.intercept_)
    else:
        g = sklearn.linear_model.Ridge(alpha=1.0, fit_intercept=True)#, normalize=False)
        g.coef_ = coef
        g.intercept_ = intercept
    return g

# Build the linear classifier of a SHAP explainer
def get_SHAP_classifier(label_x0, phi, phi0, x0, EX):
    import sklearn
    coef = np.divide(phi[label_x0], (x0 - EX), where=(x0 - EX)!=0)
    g = sklearn.linear_model.Ridge(alpha=1.0, fit_intercept=True)#, normalize=False)
    g.coef_ = coef
    g.intercept_ = phi0[label_x0]
    return g


is_regression = False


def _bb(x, model, label_x0=0, is_regression=False):
    """
        x: single 1d numpy array of shape (d, )
    """
    x = [x]

    if is_regression:
        return model.predict(x)[0]
    else:
        return model.predict_proba(x)[:, label_x0][0]


def _e(x, expl_func):
    """
                x: single 1d numpy array of shape (d, )
        expl_func: a callable/sklearn model with predict method
    """

    x = [x]
    return expl_func.predict(x)[0]


def run_for_one_dataset_and_classifier(dataset, model, x_train, y_train,
                                       n_samples_to_certify=3,  # no. of prototypes

                                       # lime arguments
                                       num_features=5,
                                       top_labels=1,
                                       explanation_samples=1000,

                                       # certification arguments
                                       Z=10,
                                       theta=0.75,
                                       sigma=0.1
                                       ):
    model_name = str(model).replace('\n', '').replace(' ', '')
    print(f"dataset: {dataset} --- model: {model_name} --- {n_samples_to_certify} examples")

    EX = x_train.mean(0)
    npEX = np.array(EX)
    StdX = np.array(x_train.std(0))

    result = {}
    W, S = give_prototype_indices(x_train, m=n_samples_to_certify)
    result["prototype-indices"] = list(S)
    result["prototype-weights"] = list(W)

    QUERY_BUDGET = 10 ** (np.arange(4) + 1)
    STRATEGIES = [1, 2, 3, 4]
    CHOICE = "min"
    NUMRUNS = 100

    for sample_idx in S:
        print(sample_idx)
        # if sample_idx in [6034, 1922, 2390]:
        #     print(f"skipping idx={sample_idx}")
        #     continue
        result[sample_idx] = {}
        x0 = x_train.iloc[sample_idx]
        true_label = y_train.iloc[sample_idx]

        # Get the output of the black-box classifier on x0
        output = model.predict_proba([x0])[0]
        label_x0 = np.argmax(output)
        prob_x0 = output[label_x0]
        lime_x0 = np.divide((x0 - npEX), StdX, where=np.logical_not(np.isclose(StdX, 0)))
        shap_x0 = x0 - EX

        print(f"model class probs output: {output}, ground truth: {true_label}")
        print('prob_x0: ', prob_x0, '   label_x0: ', label_x0)

        ###################### CERTIFYING LIME EXPLANATION
        print(f"--------- lime ---------")
        result[sample_idx]["LIME"] = {}
        from lime import lime_image

        LIMEEXPL = LimeTabularExplainer(x_train.astype('float'),
                                        feature_names=x_train.columns.tolist(),
                                        # class_names=['0', '1'],
                                        discretize_continuous=False,
                                        sample_around_instance=True,
                                        random_state=1234)

        lime_expl = LIMEEXPL.explain_instance(np.array(x0), model.predict_proba,
                                              num_features=num_features,
                                              top_labels=top_labels,
                                              num_samples=explanation_samples)
        func = get_LIME_classifier(lime_expl, label_x0, x0)

        bb = partial(_bb, model=model, label_x0=label_x0, is_regression=False)
        e = partial(_e, expl_func=func)

        def f(x):  # Fidelity function
            # fidelity = 1-abs(bb(x) - e(x))/max(abs(bb(x)), abs(e(x))) #Normalized MAE
            fidelity = 1 - abs(bb(x) - e(x))  # 1 - MAE
            return fidelity

        print(f"bb(x)={bb(x0.values):.4f}, e(x)={e(x0.values):.4f}")
        print(f"fidelity at x0: f(x)={f(x0.values):.4f}")

        # - certification code starts here
        certresult = {}

        for s in STRATEGIES:
            NUMRUNS = 10 if s == 4 else 50
            certresult[s] = {}
            for Q in QUERY_BUDGET:
                print(f"\tcertifying with s={s}, Q={Q} with NUMRUNS={NUMRUNS}")

                certresult[s][Q] = {"w": [], "time-per-run": None, "num-runs": NUMRUNS}

                x = x0.values
                d = len(x)

                theta = 0.75  # fidelity threshold
                Z = 10  # number of hypercubes to certify
                eps = 0.01 / d  # min gap between lb and ub
                certicubeperrun = np.zeros(NUMRUNS)

                t_0 = timeit.default_timer()
                for irun in range(NUMRUNS):
                    ub = x_train.values.max()  # initial hypercube half-width
                    lb = 0  # since x is the center of the hypercube
                    Currbst = 0  # current certified hypercube half width
                    Certicube = Ecertify(x, theta, Z, Q, lb, ub, sigma, s, f, choice=CHOICE)
                    certicubeperrun[irun] = Certicube
                t_1 = timeit.default_timer()
                time_per_run = round((t_1 - t_0) / NUMRUNS, 3)

                certresult[s][Q]["w"] = certicubeperrun
                certresult[s][Q]["time-per-run"] = time_per_run
                certresult[s][Q]["num-runs"] = NUMRUNS
                certresult[s][Q]["w-mean"] = np.mean(certicubeperrun)
                certresult[s][Q]["w-error"] = np.std(certicubeperrun) / np.sqrt(NUMRUNS)
                certresult[s][Q]["choice"] = CHOICE

        result[sample_idx]["LIME"] = certresult

        ###################### CERTIFYING SHAP EXPLANATION
        print(f"--------- shap ---------")
        result[sample_idx]["SHAP"] = {}
        EX = x_train.mean(0)
        npEX = np.array(EX)
        StdX = np.array(x_train.std(0))

        # shap
        SHAPEXPL = shap.KernelExplainer(model.predict_proba, EX, nsamples=explanation_samples)

        shap_phi = SHAPEXPL.shap_values(x0, l1_reg="num_features(10)")
        shap_phi0 = SHAPEXPL.expected_value
        func = get_SHAP_classifier(label_x0, shap_phi, shap_phi0, x0, EX)

        bb = partial(_bb, model=model, label_x0=label_x0, is_regression=False)
        e = partial(_e, expl_func=func)

        def f(x):  # Fidelity function
            # fidelity = 1-abs(bb(x) - e(x))/max(abs(bb(x)), abs(e(x))) #Normalized MAE
            fidelity = 1 - abs(bb(x) - e(x))  # 1 - MAE
            return fidelity

        print(f"bb(x)={bb(x0.values):.4f}, e(x-w)={e(x0.values):.4f}")
        print(f"fidelity at x0: f(x)={f(x0.values):.4f}")

        # - certification code starts here
        certresult = {}

        for s in STRATEGIES:
            NUMRUNS = 10 if s == 4 else 50
            certresult[s] = {}
            for Q in QUERY_BUDGET:
                print(f"\tcertifying with s={s}, Q={Q} with NUMRUNS={NUMRUNS}")

                certresult[s][Q] = {"w": [], "time-per-run": None, "num-runs": NUMRUNS}

                x = x0.values
                d = len(x)

                theta = 0.75  # fidelity threshold
                Z = 10  # number of hypercubes to certify
                eps = 0.1 / d  # min gap between lb and ub
                certicubeperrun = np.zeros(NUMRUNS)

                t_0 = timeit.default_timer()
                for irun in range(NUMRUNS):
                    ub = 1  # initial hypercube half-width
                    lb = 0  # since x is the center of the hypercube
                    Currbst = 0  # current certified hypercube half width
                    Certicube = Ecertify(x, theta, Z, Q, lb, ub, sigma, s, f, choice=CHOICE)
                    certicubeperrun[irun] = Certicube
                t_1 = timeit.default_timer()
                time_per_run = round((t_1 - t_0) / NUMRUNS, 3)

                certresult[s][Q]["w"] = certicubeperrun
                certresult[s][Q]["time-per-run"] = time_per_run
                certresult[s][Q]["num-runs"] = NUMRUNS
                certresult[s][Q]["w-mean"] = np.mean(certicubeperrun)
                certresult[s][Q]["w-error"] = np.std(certicubeperrun) / np.sqrt(NUMRUNS)
                certresult[s][Q]["choice"] = CHOICE

        result[sample_idx]["SHAP"] = certresult
        print("-" * 100)

    return result


if __name__ == '__main__':
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="heartrisk", help='dataset name')
    parser.add_argument('--seed', type=int, default=1234, help='random seed')
    args = parser.parse_args()

    dataset = args.dataset
    seed = args.seed

    print(dataset)
    run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = f'./results/{dataset}_{run_id}'
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    function_string = f'load_{dataset}_dataset'
    loader = getattr(dataloaders, function_string)
    df, y = loader()

    if dataset == "arrhythmia":
        df = df.select_dtypes("float")  # remove byte columns
        df = df[df.columns[(df != 0).any()]]  # remove columns with all zeros
        # remove columns having 0 std deviation
        df = df.drop(columns=['chV6_RPwaveAmp', 'chDII_SPwave', 'chV6_RPwave', 'chDII_SPwaveAmp'])
        # also need to fill na values --- skip this dataset
    x_train, x_test, y_train, y_test = train_test_split(df, y, train_size=0.7, random_state=1234)
    if dataset == "arrhythmia":
        x_train.fillna(x_train.mean(), inplace=True)
        x_test.fillna(x_train.mean(), inplace=True)

    # standardization
    EXtr, StdXtr = x_train.mean(0), x_train.std(0)
    npEXtr = EXtr.values
    x_train = (x_train - EXtr) / StdXtr
    x_test = (x_test - EXtr) / StdXtr

    models = load_models(dataset=dataset, path='./saved_models_20230512_110418/')
    model_names = [str(m).replace('\n', '').replace(' ', '') for m in models]

    metric = r2_score if dataset in REGRESSION else accuracy_score
    metric_name = "Test r2_score:" if dataset in REGRESSION else "Test accuracy:"

    for i, model in enumerate(models):
        print(model_names[i])
        if not model_names[i].startswith("Gradient"):
            print(f"skipping because taking too much time!")
            continue

        print(f"\t {metric_name} {100 * metric(y_test, model.predict(x_test)):.2f}%")

        res = run_for_one_dataset_and_classifier(dataset, model, x_train, y_train, n_samples_to_certify=5)
        # res = {}

        fname = "result_" + model_names[i]
        res_object_name = os.path.join(results_dir, fname)
        with open(res_object_name, 'wb') as output:
            pickle.dump(res, output, pickle.HIGHEST_PROTOCOL)



