import shap
import sklearn
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
import os
import pickle
import numpy as np
import shap_dataset.shap_dataset

class ShapModel:
    def __init__(self, dataset_name, n_estimators, max_depth):
        if 'EXPERIMENT_CACHE' in os.environ:
            dir_path = f"{os.environ['EXPERIMENT_CACHE']}/trees"    
        else:
            dir_path = f"{os.path.dirname(os.path.realpath(__file__))}"
        
        pkl_path = f"{dir_path}/trained_models/{dataset_name}_t={n_estimators}_d={max_depth}.pkl"
        if os.path.isfile(pkl_path):
             with open(pkl_path, "rb") as f:
                 self.rf, self.X_train, self.X_test, self.y_train, self.y_test, score, self.n = pickle.load(f)
        else:

            X, y, dataset_type = ShapModel.load_data(dataset_name)
            self.X_train, self.X_test, self.y_train, self.y_test = sklearn.model_selection.train_test_split(X, y, test_size=0.1, random_state=0)
            self.n = self.X_train.shape[1]
            if dataset_type == "regression":
                self.rf = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth, n_jobs=10)
            elif dataset_type == "classification":
                self.rf = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, n_jobs=10)
            self.rf.fit(self.X_train, self.y_train)
            score = self.rf.score(self.X_test, self.y_test)
            with open(pkl_path, "wb") as f:
                pickle.dump([self.rf, self.X_train, self.X_test, self.y_train, self.y_test, score, self.n], f)
        print(dataset_name, score)



    @staticmethod
    def load_data(dataset_name):
        if dataset_name == "adult":
            X, y = shap.datasets.adult()
            dataset_type = "classification"
        elif dataset_name == "boston":
            X, y = shap.datasets.boston()
            dataset_type = "regression"
        elif dataset_name == "crimes":
            X, y = shap.datasets.communitiesandcrime()
            dataset_type = "regression"
        elif dataset_name == "corrgroups60":
            X, y = shap.datasets.corrgroups60()
            dataset_type = "regression"
        elif dataset_name == "diabetes":
            X, y = shap.datasets.diabetes()
            dataset_type = "regression"
        elif dataset_name == "independent60":
            X, y = shap.datasets.independentlinear60()
            dataset_type = "regression"
        elif dataset_name == "iris":
            X, y = shap.datasets.iris()
            dataset_type = "classification"
        elif dataset_name == "entacmaea":
            tmp = shap_dataset.shap_dataset.EntacmaeaDataset()
            X, y = tmp.x, tmp.y
            dataset_type= "regression"
        elif dataset_name == "gpu":
            tmp = shap_dataset.shap_dataset.SGEMMDataset()
            X, y = tmp.x, tmp.y
            dataset_type = "regression"
        elif dataset_name == "gb1":
            tmp = shap_dataset.shap_dataset.GB1Dataset()
            X, y = tmp.x, tmp.y
            dataset_type = "regression"
        elif dataset_name == "avgfp":
            tmp = shap_dataset.shap_dataset.avGFPDataset()
            X, y = tmp.x, tmp.y
            dataset_type = "regression"
        elif dataset_name == "sgemm":
            tmp = shap_dataset.shap_dataset.SGEMMDataset()
            X, y = tmp.x, tmp.y
            dataset_type = "regression"

        else:
            raise ValueError(f"dataset name \"{dataset_name}\" is not known")
        if dataset_type == "regression":
            y = (y - np.mean(y)) / np.std(y)
        elif dataset_type == "classification":
            def replace_groups(data):
                a, b, c, = np.unique(data, True, True)
                _, ret = np.unique(b[c], False, True)
                return ret
            y = replace_groups(y)
        return X, y , dataset_type


if __name__ == "__main__":
    ShapModel("avgfp", 1000, 30)
