import shap
import os
import pandas as pd
import numpy as np
import sklearn
from sklearn.preprocessing import OneHotEncoder
from sklearn.ensemble import RandomForestRegressor
import random
import pathlib


class ShapDataset:
    def __init__(self, x, y, categorical_features, continuous_features, bin_count, one_hot=False):
        self.y = (y-np.mean(y))/np.std(y)
        self.x = x
        self.column_names = x.columns.values
        self.categorical_features = categorical_features
        self.continuous_features = continuous_features
        self.bin_count = bin_count
        self.feature_map = list(range(len(self.column_names)))

        # convert continuous features into categorical by quantiles
        if len(continuous_features) > 0:
            self._bin_continuous_features()
        self.x = self.x.to_numpy(dtype=np.float64)

        # convert categorical features to one-hot representations
        if one_hot:
            self._one_hot_transform()


    def _bin_continuous_features(self):
        for i, c_name in enumerate(self.continuous_features):
            _, bins = pd.qcut(self.x[c_name], self.bin_count[i], retbins=True, labels=False, duplicates='drop')
            self.x[c_name] = np.digitize(self.x[c_name], bins[1:-1])
    

    def _one_hot_transform(self):
        ### One hot
        self.enc = OneHotEncoder(sparse_output=False, drop="if_binary")
        self.enc.fit(self.x)

        # Feature map
        self.feature_map = []
        self.feature_split = []
        for i, feature_set in enumerate(self.enc.categories_):
            bit_count = len(feature_set)
            try:
                if self.enc.drop_idx_[i] != None:
                    bit_count -= 1
            except: # In case there is no dropped
                pass

            self.feature_map.extend([i]*bit_count)
            self.feature_split.append(len(self.feature_map))
        self.feature_split = [0] + self.feature_split
        self.feature_names = self.enc.get_feature_names_out(self.column_names)

        self.x = self.enc.fit_transform(self.x)
    

    def select_important_features(self, n_features, random_seed=0):
        '''
        Narrows down features to most important ones
        '''
        random.seed(random_seed)
        train_ind = random.sample(range(self.x.shape[0]), self.x.shape[0]//5)
        val_ind = [i for i in range(self.x.shape[0]) if i not in train_ind]
        self.rf = RandomForestRegressor(n_jobs = 20, n_estimators=100, random_state=random_seed)
        self.rf.fit(self.x[train_ind], self.y[train_ind])
        sorted_idx = list(self.rf.feature_importances_.argsort())
        sorted_idx.reverse()
        selected_features = sorted_idx[:n_features]
        print(f"Random forest score (trained on {len(train_ind)} points): {self.rf.score(self.x[val_ind], self.y[val_ind])}")
        print("Selected feature indices:", selected_features)
        print("Feature importances:", self.rf.feature_importances_[selected_features])
        self.x = self.x[:, selected_features]
        self.feature_map = list(np.array(self.feature_map)[selected_features])


class CrimesDataset(ShapDataset):
    def __init__(self):
        x, y = shap.datasets.communitiesandcrime()
        categorical_features = ["MedNumB"]
        continuous_features = [c for c in x.columns if c not in categorical_features]
        bin_count = [4] * len(continuous_features)
        super().__init__(x, y, categorical_features, continuous_features, bin_count, one_hot=True)


class HarvardCleanEnergyDataset(ShapDataset):
    def __init__(self, no_features=20):
        this_directory = pathlib.Path(__file__).parent.resolve()

        df = pd.read_csv(f"{this_directory}/raw_datasets/harvard_clean_energy/moldata{no_features}.csv")
        def f(string):
            return np.array([int(char) for char in string.replace("[", "").replace("]", "").split(",")])
        a = np.stack(df["MorganFingerprint"].map(f).values)
        df = df.join(pd.DataFrame(a))
        df_x = df.drop(["id", "SMILES_str", "MorganFingerprint", "e_gap_alpha"], axis=1)
        df_y = df["e_gap_alpha"]
        categorical_features = list(df_x.columns)
        continuous_features = []
        # no continuous features to bin
        bin_count = []
        super().__init__(df_x, df_y.to_numpy().squeeze() , categorical_features, continuous_features, bin_count, one_hot=False)


class EntacmaeaDataset(ShapDataset):
    def __init__(self, use_cache=True):
        # Load data
        this_directory = pathlib.Path(__file__).parent.resolve()
        data_file = f"{this_directory}/raw_datasets/Entacmaea/quadricolor_fluorscent.csv"
        if os.path.exists(data_file):
            data_df = pd.read_csv(data_file)
        else:
            raise Exception(f"Could not find Entacmaea data in '{data_file}'")
        
        mutations = data_df["genotype"]
        # Convert input binary strings to tensors
        mutations = [[int(bit) for bit in m[1:-1]]
                            for m in mutations]
        columns = [f"feature{i}" for i in range(len(mutations[0]))]
        x = pd.DataFrame(np.array(mutations).astype(float), columns=columns)
        y = np.array(list(data_df["brightness"])).astype(float)

        categorical_features = list(x.columns)
        continuous_features = []
        # no continuous features to bin
        bin_count = []

        super().__init__(x, y , categorical_features, continuous_features, bin_count, one_hot=False)





class avGFPDataset(ShapDataset):


    def __init__(self, use_cache=True):
        # Load data
        this_directory = pathlib.Path(__file__).parent.resolve()
        data_file = f"{this_directory}/raw_datasets/avGFP/avGFP.csv"
        if os.path.exists(data_file):
            self.data_df = pd.read_csv(data_file, delimiter="\t", keep_default_na=False)
        else:
            raise Exception(f"Could not find avGFP data in '{data_file}'")
        
        x, y = self.compute_mutation_dataset()
        columns = [f"feature{i}" for i in range(x.shape[1])]
        x = pd.DataFrame(np.array(x).astype(float), columns=columns)
        y = np.array(list(y)).astype(float)

        categorical_features = list(x.columns)
        continuous_features = []
        # no continuous features to bin
        bin_count = []

        super().__init__(x, y , categorical_features, continuous_features, bin_count, one_hot=False)


    def compute_mutation_dataset(self):
        mutations = self.data_df["aaMutations"][1:] # Ignore the first row which is the reference

        # Convert mutations into tuples of site indices
        splited_mutations = [s.split(":") for s in mutations]
        mutation_sites = [()] + [tuple(sorted([int(aam[2:-1]) for aam in aams])) for aams in splited_mutations]
        self.data_df["mutation_sites"] = mutation_sites

        # Aggregate similar records (same mutation sites but different mutations) into their average brightness
        aggregated_df = self.data_df.groupby('mutation_sites').agg({'medianBrightness': np.mean}).reset_index()
        aggregated_df.columns = aggregated_df.columns.get_level_values(0)
        
        # Generate X, y
        mutation_sites = [list(aam_tup) for aam_tup in aggregated_df["mutation_sites"]][1:]
        site_count = max([max(l) for l in mutation_sites if len(l) > 0]) + 1

        X = np.zeros((len(mutation_sites), site_count))
        X[np.arange(X.shape[0]).repeat([*map(len, mutation_sites)]), np.concatenate(mutation_sites)] = 1
        X = np.vstack([np.zeros([1, site_count]), X]) # Add the reference to X

        return X, aggregated_df["medianBrightness"]


class SGEMMDataset(ShapDataset):
    def __init__(self):
        # Load data
        this_directory = pathlib.Path(__file__).parent.resolve()
        data_file = f"{this_directory}/raw_datasets/sgemm/sgemm.csv"
        if os.path.exists(data_file):
            self.data_df = pd.read_csv(data_file, keep_default_na=False)
        else:
            raise Exception(f"Could not find sgemm data in '{data_file}'")
        
        # Aggregate runtimes
        runtime_columns = ["Run1 (ms)","Run2 (ms)","Run3 (ms)","Run4 (ms)"]
        self.data_df["time"] = np.mean(self.data_df[runtime_columns], axis=1)
        self.data_df = self.data_df.drop(columns=runtime_columns)

        x = self.data_df.drop(columns=["time"])
        y = self.data_df["time"]

        categorical_features = self.data_df.drop(columns=["time"]).columns
        continuous_features = []
        bin_count = [4] * len(continuous_features)
        super().__init__(x, y.to_numpy(), categorical_features, continuous_features, bin_count, one_hot=True)


class GB1Dataset(ShapDataset):
    def __init__(self):
        # Load data
        this_directory = pathlib.Path(__file__).parent.resolve()
        data_file = f"{this_directory}/raw_datasets/gb1/gb1-1.csv"
        if os.path.exists(data_file):
            self.data_df = pd.read_csv(data_file, keep_default_na=False)
        else:
            raise Exception(f"Could not find GB1 data in '{data_file}'")
        
        self.variants = self.data_df["Variants"]
        x_frame = self.one_hot_encode_variants()
        columns = [f"feature{i}" for i in range(x_frame.shape[1])]
        x = pd.DataFrame(x_frame.astype(float), columns=columns)
        y = self.data_df["Fitness"]

        categorical_features = list(x.columns)
        continuous_features = []
        # no continuous features to bin
        bin_count = []

        super().__init__(x, y.to_numpy(), categorical_features, continuous_features, bin_count, one_hot=True)

    def one_hot_encode_variants(self):
        # Check if all variants have the same length
        lengths = [len(v) for v in self.variants]
        assert min(lengths) == max(lengths)

        # Split each variants into its list of characters to feed into one hot encoder
        splitted_variants = [[c for c in v] for v in self.variants]
        return OneHotEncoder(sparse_output=False).fit_transform(splitted_variants)

class AdultIncome(ShapDataset):
    def __init__(self):
        x, y = shap.datasets.adult()
        categorical_features = ["Workclass", "Marital Status", "Relationship", "Race", "Sex", "Country", "Occupation"]
        continuous_features = x.drop(categorical_features, axis=1).columns.to_list()
        bin_count = [4] * len(continuous_features)
        super().__init__(x, y, categorical_features, continuous_features, bin_count, one_hot=True)
        pass
