#!/us/bin/env python
#-*- coding:utf-8 -*-
##
## xgbooster.py
##
##

#
#==============================================================================
from __future__ import print_function
from .validate import SMTValidator
from .encode import SMTEncoder, MXEncoder
from .explain import SMTExplainer, MXExplainer, MXIExplainer
from .compile import MXCompiler
from .process import DataProcessor
import numpy as np
import os
import resource
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import sklearn
# print('The scikit-learn version is {}.'.format(sklearn.__version__))

from  sklearn.preprocessing import OneHotEncoder
import sys
from six.moves import range
from .tree import TreeEnsemble
import xgboost as xgb
from xgboost import XGBClassifier, Booster
import pickle


#
#==============================================================================
class XGBooster(object):
    """
        The main class to train/encode/explain XGBoost models.
    """

    def __init__(self, options, from_data=None, from_model=None,
            from_encoding=None):
        """
            Constructor.
        """

        assert from_data or from_model or from_encoding, \
                'At least one input file should be specified'

        self.init_stime = resource.getrusage(resource.RUSAGE_SELF).ru_utime
        self.init_ctime = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime

        # saving command-line options
        self.options = options
        self.seed = self.options.seed
        np.random.seed(self.seed)

        if from_data:
            self.use_categorical = self.options.use_categorical
            # saving data
            self.data = from_data
            dataset = np.asarray(self.data.samps, dtype=np.float32)


            # split data into X and y
            self.feature_names = self.data.names[:-1]
            self.nb_features = len(self.feature_names)

            self.X = dataset[:, 0:self.nb_features]
            self.Y = dataset[:, self.nb_features]
            self.num_class = len(set(self.Y))
            self.target_name = list(range(self.num_class))
            self.wghts = {tuple(list(X) + [self.Y[i]]): self.data.wghts[i] for i, X in enumerate(self.X)}
            self.class_name = self.data.names[-1]

            param_dist = {'n_estimators':self.options.n_estimators,
                      'max_depth':self.options.maxdepth}

            if(self.num_class == 2):
                param_dist['objective'] = 'binary:logistic'
                param_dist['eval_metric'] = 'error'
            else:
                param_dist['eval_metric'] = 'merror'

            self.model = XGBClassifier(**param_dist)

            # split data into train and test sets
            self.test_size = self.options.testsplit
            if (self.test_size > 0):
                self.X_train, self.X_test, self.Y_train, self.Y_test = \
                        train_test_split(self.X, self.Y, test_size=self.test_size,
                                random_state=self.seed)
            else:
                self.X_train = self.X
                self.X_test = [] # need a fix
                self.Y_train = self.Y
                self.Y_test = []# need a fix

            # check if we have info about categorical features
            if (self.use_categorical):
                self.categorical_features = from_data.categorical_features
                self.categorical_names = from_data.categorical_names
                self.target_name = from_data.class_names

                ####################################
                # this is a set of checks to make sure that we use the same as anchor encoding
                cat_names = sorted(self.categorical_names.keys())
                assert(cat_names == self.categorical_features)
                self.encoder = {}
                for i in self.categorical_features:
                    self.encoder.update({i: OneHotEncoder(categories='auto', sparse=False)})#,
                    #self.encoder[i].fit(self.X[:,[i]])
                    self.encoder[i].fit([[f] for f in range(len(self.categorical_names[i]))])

            else:
                self.categorical_features = []
                self.categorical_names = []
                self.encoder = []

            fname = from_data

        elif from_model:
            fname = from_model
            self.load_datainfo(from_model)
            if (self.use_categorical is False) and (self.options.use_categorical is True):
                print("Error: Note that the model is trained without categorical features info. Please do not use -c option for predictions")
                exit()
            # load model

        elif from_encoding:
            self.use_categorical = self.options.use_categorical
            fname = from_encoding

            # encoding, feature names, and number of classes
            # are read from an input file
            if fname.endswith('.cnf'):
                enc = MXEncoder(None, [], None, self, from_encoding)
                self.enc, self.intvs, self.imaps, self.ivars, self.feature_names, \
                        self.num_class = enc.access()
                self.mxe = enc
            else:
                enc = SMTEncoder(None, None, None, self, from_encoding)
                self.enc, self.intvs, self.imaps, self.ivars, self.feature_names, \
                        self.num_class = enc.access()

        # create extra file names
        try:
            os.stat(options.output)
        except:
            os.mkdir(options.output)

        self.mapping_features()
        #################
        self.test_encoding_transformes()

        bench_name = os.path.splitext(os.path.basename(options.files[0]))[0]
        bench_dir_name = options.output + "/" + bench_name
        try:
            os.stat(bench_dir_name)
        except:
            os.mkdir(bench_dir_name)

        self.basename = (os.path.join(bench_dir_name, bench_name +
                        "_nbestim_" + str(options.n_estimators) +
                        "_maxdepth_" + str(options.maxdepth) +
                        "_testsplit_" + str(options.testsplit)))

        data_suffix = '.splitdata.pkl'
        self.modfile = self.basename + '.mod.pkl'

        self.mod_plainfile =  self.basename + '.mod.txt'

        self.resfile =  self.basename + '.res.txt'
        self.encfile =  self.basename + '.enc.txt'
        self.expfile =  self.basename + '.exp.txt'

    def form_datefile_name(self, modfile):
        data_suffix =  '.splitdata.pkl'
        return  modfile + data_suffix

    def pickle_save_file(self, filename, data):
        try:
            f =  open(filename, "wb")
            pickle.dump(data, f)
            f.close()
        except:
            print("Cannot save to file", filename)
            exit()

    def pickle_load_file(self, filename):
        try:
            f =  open(filename, "rb")
            data = pickle.load(f)
            f.close()
            return data
        except:
            print("Cannot load from file", filename)
            exit()

    def save_datainfo(self, filename):

        print("saving  model to ", filename)
        self.pickle_save_file(filename, self.model)

        filename_data = self.form_datefile_name(filename)
        print("saving  data to ", filename_data)
        samples = {}
        samples["X"] = [[]] #self.X
        samples["Y"] = [[]] #self.Y
        samples["X_train"] = self.X_train[0:1,:] # #self.X_train
        samples["Y_train"] = self.Y_train[0:1]  # #self.Y_train
        samples["X_test"] = [[]] #self.X_test
        samples["Y_test"] = [[]] #self.Y_test
        samples["feature_names"] = self.feature_names
        samples["target_name"] = self.target_name
        samples["num_class"] = self.num_class
        samples["categorical_features"] = self.categorical_features
        samples["categorical_names"] = self.categorical_names
        samples["encoder"] = self.encoder
        samples["use_categorical"] = self.use_categorical
        samples["wghts"] = {}
        samples["class_name"] = self.class_name

        self.pickle_save_file(filename_data, samples)

    def load_datainfo(self, filename):
        print("loading model from ", filename)
        self.model = XGBClassifier()
        self.model = self.pickle_load_file(filename)

        datafile = self.form_datefile_name(filename)
        print("loading data from ", datafile)
        loaded_data = self.pickle_load_file(datafile)
        self.X = loaded_data["X"]
        self.Y = loaded_data["Y"]
        self.X_train = loaded_data["X_train"]
        self.X_test = loaded_data["X_test"]
        self.Y_train = loaded_data["Y_train"]
        self.Y_test = loaded_data["Y_test"]
        self.feature_names = loaded_data["feature_names"]
        self.target_name = loaded_data["target_name"]
        self.num_class = loaded_data["num_class"]
        self.nb_features = len(self.feature_names)
        self.categorical_features = loaded_data["categorical_features"]
        self.categorical_names = loaded_data["categorical_names"]
        self.encoder = loaded_data["encoder"]
        self.use_categorical = loaded_data["use_categorical"]
        self.wghts = loaded_data["wghts"]
        self.class_name = loaded_data["class_name"]

    def train(self, outfile=None):
        """
            Train a tree ensemble using XGBoost.
        """

        return self.build_xgbtree(outfile)

    def encode(self, test_on=None):
        """
            Encode a tree ensemble trained previously.
        """

        if self.options.encode in ('mx', 'mxe', 'maxsat', 'mxint', 'mxa'):
            encoder = MXEncoder(self.model, self.feature_names, self.num_class, self)
            self.mxe = encoder
        else:  # smt or smtbool
            encoder = SMTEncoder(self.model, self.feature_names, self.num_class, self)
        self.enc, self.intvs, self.imaps, self.ivars = encoder.encode()

        if test_on:
            encoder.test_sample(np.array(test_on))

        encoder.save_to(self.encfile)

    def explainer(self):
        if 'x' not in dir(self):
            if self.options.encode in ('mx', 'mxe', 'maxsat', 'mxint', 'mxa'):
                if not self.options.ilits:
                    self.x = MXExplainer(self.enc, self.intvs, self.imaps,
                            self.ivars, self.feature_names, self.num_class,
                            self.options, self)
                else:
                    self.x = MXIExplainer(self.enc, self.intvs, self.imaps,
                            self.ivars, self.feature_names, self.num_class,
                            self.options, self)
            else:
                self.x = SMTExplainer(self.enc, self.intvs, self.imaps,
                        self.ivars, self.feature_names, self.num_class,
                        self.options, self)
        return self.x

    def explain(self, sample, inst_id, use_lime=False, use_anchor=False, use_shap=False,
            expl_ext=None, prefer_ext=False, nof_feats=5):
        """
            Explain a prediction made for a given sample with a previously
            trained tree ensemble.
        """

        if 'x' not in dir(self):
            if self.options.encode in ('mx', 'mxe', 'maxsat', 'mxint', 'mxa'):
                if not self.options.ilits:
                    self.x = MXExplainer(self.enc, self.intvs, self.imaps,
                            self.ivars, self.feature_names, self.num_class,
                            self.options, self)
                else:
                    self.x = MXIExplainer(self.enc, self.intvs, self.imaps,
                            self.ivars, self.feature_names, self.num_class,
                            self.options, self)
            else:
                self.x = SMTExplainer(self.enc, self.intvs, self.imaps,
                        self.ivars, self.feature_names, self.num_class,
                        self.options, self)

        expl = self.x.explain(np.array(sample), self.options.smallest, inst_id,
                expl_ext, prefer_ext)

        # returning the explanation
        return expl

    def compile(self):
        """
            Compile a given XGBoost model to a set of rules.
        """

        self.c = MXCompiler(self.enc, self.intvs, self.imaps,
                self.ivars, self.feature_names, self.num_class,
                self.options, self)

        return self.c.compile()

    def process(self):
        """
            Process a given dataset.
        """

        self.p = DataProcessor(self.enc, self.intvs, self.imaps, self.ivars,
                self.feature_names, self.num_class, self.options, self)

        self.p.process()

    def validate(self, sample, expl):
        """
            Make an attempt to show that a given explanation is optimistic.
        """

        # there must exist an encoding
        if 'enc' not in dir(self):
            encoder = SMTEncoder(self.model, self.feature_names, self.num_class,
                    self)
            self.enc, _, _, _ = encoder.encode()

        if 'v' not in dir(self):
            self.v = SMTValidator(self.enc, self.feature_names, self.num_class,
                    self)

        # try to compute a counterexample
        return self.v.validate(np.array(sample), expl)

    def transform(self, x):
        if(len(x) == 0):
            return x
        if (len(x.shape) == 1):
            x = np.expand_dims(x, axis=0)
        if (self.use_categorical):
            assert(self.encoder != [])
            tx = []
            for i in range(self.nb_features):
                if (i in self.categorical_features):
                    self.encoder[i].drop = None
                    tx_aux = self.encoder[i].transform(x[:,[i]])
                    tx_aux = np.vstack(tx_aux)
                    tx.append(tx_aux)
                else:
                    tx.append(x[:,[i]])
            tx = np.hstack(tx)
            return tx
        else:
            return x

    def transform_inverse(self, x):
        if(len(x) == 0):
            return x
        if (len(x.shape) == 1):
            x = np.expand_dims(x, axis=0)
        if (self.use_categorical):
            assert(self.encoder != [])
            inverse_x = []
            for i, xi in enumerate(x):
                inverse_xi = np.zeros(self.nb_features)
                for f in range(self.nb_features):
                    if f in self.categorical_features:
                        nb_values = len(self.categorical_names[f])
                        v = xi[:nb_values]
                        v = np.expand_dims(v, axis=0)

                        iv = self.encoder[f].inverse_transform(v)
                        inverse_xi[f] =iv
                        xi = xi[nb_values:]
                    else:
                        inverse_xi[f] = xi[0]
                        xi = xi[1:]
                inverse_x.append(inverse_xi)
            return inverse_x
        else:
            return x

    def transform_inverse_by_index(self, idx):
        if (idx in self.extended_feature_names):
            return self.extended_feature_names[idx]
        else:
            print("Warning there is no feature {} in the internal mapping".format(idx))
            return None

    def transform_by_value(self, feat_value_pair):
        if (feat_value_pair in self.extended_feature_names.values()):
            keys = (list(self.extended_feature_names.keys())[list( self.extended_feature_names.values()).index(feat_value_pair)])
            return keys
        else:
            print("Warning there is no value {} in the internal mapping".format(feat_value_pair))
            return None

    def mapping_features(self):
        self.extended_feature_names = {}
        self.extended_feature_names_as_array_strings = []
        counter = 0
        if (self.use_categorical):
            for i in range(self.nb_features):
                if (i in self.categorical_features):
                    for j, _ in enumerate(self.encoder[i].categories_[0]):
                        self.extended_feature_names.update({counter:  (self.feature_names[i], j)})
                        self.extended_feature_names_as_array_strings.append("f{}_{}".format(i,j)) # str(self.feature_names[i]), j))
                        counter = counter + 1
                else:
                    self.extended_feature_names.update({counter: (self.feature_names[i], None)})
                    self.extended_feature_names_as_array_strings.append("f{}".format(i)) #(self.feature_names[i])
                    counter = counter + 1
        else:
            for i in range(self.nb_features):
                self.extended_feature_names.update({counter: (self.feature_names[i], None)})
                self.extended_feature_names_as_array_strings.append("f{}".format(i))#(self.feature_names[i])
                counter = counter + 1

    def readable_sample(self, x):
        readable_x = []
        for i, v in enumerate(x):
            if (i in self.categorical_features):
                readable_x.append(self.categorical_names[i][int(v)])
            else:
                readable_x.append(v)
        return np.asarray(readable_x)

    def test_encoding_transformes(self):
        # test encoding

        X = self.X_train[[0],:]

        #print("Sample of length", len(X[0])," : ", X)
        enc_X = self.transform(X)
        #print("Encoded sample of length", len(enc_X[0])," : ", enc_X)
        inv_X = self.transform_inverse(enc_X)
        #print("Back to sample", inv_X)
        #print("Readable sample", self.readable_sample(inv_X[0]))
        assert((inv_X == X).all())

        if (self.options.verb > 3):
            for i in range(len(self.extended_feature_names)):
                print(i, self.transform_inverse_by_index(i))
            for key, value in self.extended_feature_names.items():
                print(value, self.transform_by_value(value))

    def transfomed_sample_info(self, i):
        print(enc.categories_)

    def build_xgbtree(self, outfile=None):
        """
            Build an ensemble of trees.
        """

        time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                    resource.getrusage(resource.RUSAGE_SELF).ru_utime

        if (outfile is None):
            outfile = self.modfile
        else:
            self.datafile = self.form_datefile_name(outfile)

        # fit model no training data

        if (len(self.X_test) > 0):
            eval_set=[(self.transform(self.X_train), self.Y_train), (self.transform(self.X_test), self.Y_test)]
        else:
            eval_set=[(self.transform(self.X_train), self.Y_train)]

        print("start xgb")
        self.model.fit(self.transform(self.X_train), self.Y_train,
                  eval_set=eval_set,
                  verbose=self.options.verb) # eval_set=[(X_test, Y_test)],
        print("end xgb")

        evals_result = self.model.evals_result()
        ########## saving model
        self.save_datainfo(outfile)
        print("saving plain model to ", self.mod_plainfile)

        self.model._Booster.dump_model(self.mod_plainfile)

        ensemble = TreeEnsemble(self.model, self.extended_feature_names_as_array_strings, nb_classes = self.num_class)

        y_pred_prob = self.model.predict_proba(self.transform(self.X_train[:10]))
        y_pred_prob_compute = ensemble.predict(self.transform(self.X_train[:10]), self.num_class)
        assert(np.absolute(y_pred_prob_compute- y_pred_prob).sum() < 0.01*len(y_pred_prob))

        ### accuracy
        try:
            train_accuracy = round(1 - evals_result['validation_0']['merror'][-1],2)
        except:
            try:
                train_accuracy = round(1 - evals_result['validation_0']['error'][-1],2)
            except:
                assert(False)

        try:
            test_accuracy = round(1 - evals_result['validation_1']['merror'][-1],2)
        except:
            try:
                test_accuracy = round(1 - evals_result['validation_1']['error'][-1],2)
            except:
                print("no results test data")
                test_accuracy = 0

        #ensemble.print_tree()

        #### saving
        print("saving results to ", self.resfile)
        with open(self.resfile, 'w') as f:
            f.write("{} & {} & {} &{}  &{} & {} \\\\ \n \hline \n".format(
                                           os.path.basename(self.options.files[0]).replace("_","-"),
                                           train_accuracy,
                                           test_accuracy,
                                           self.options.n_estimators,
                                           self.options.maxdepth,
                                           self.options.testsplit))
        f.close()

        #print("Train accuracy: %.2f%%" % (train_accuracy * 100.0))
        #if self.test_size > 0:
        #    print("Test accuracy: %.2f%%" % (test_accuracy * 100.0))

        time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
               resource.getrusage(resource.RUSAGE_SELF).ru_utime - time

        #print('  rtime: {0}'.format(time))

        return train_accuracy, test_accuracy, self.model
