import logging
import numpy as np
import pandas as pd
from scipy.spatial import distance
from sklearn.mixture import BayesianGaussianMixture, GaussianMixture
# from tensorflow.keras.preprocessing import sequence
from scipy.stats import multivariate_normal
from sklearn.cluster import KMeans

CATEGORICAL = "categorical"
CONTINUOUS = "continuous"
ORDINAL = "ordinal"

class Transformer:

    @staticmethod
    def get_metadata(data, categorical_columns=tuple(), ordinal_columns=tuple()):
        meta = []

        df = pd.DataFrame(data)
        for index in df:
            column = df[index]

            if index in categorical_columns:
                mapper = column.value_counts().index.tolist()
                meta.append({
                    "name": index,
                    "type": CATEGORICAL,
                    "size": len(mapper),
                    "i2s": mapper
                })
            elif index in ordinal_columns:
                value_count = list(dict(column.value_counts()).items())
                value_count = sorted(value_count, key=lambda x: -x[1])
                mapper = list(map(lambda x: x[0], value_count))
                meta.append({
                    "name": index,
                    "type": ORDINAL,
                    "size": len(mapper),
                    "i2s": mapper
                })
            else: 
                meta.append({
                    "name": index,
                    "type": CONTINUOUS,
                    "min": column.min(),
                    "max": column.max(),
                })

        return meta

    def fit(self, data, categorical_columns=tuple(), ordinal_columns=tuple()):
        raise NotImplementedError

    def transform(self, data):
        raise NotImplementedError

    def inverse_transform(self, data):
        raise NotImplementedError

GMM ={
    'gmm': GaussianMixture,
    'bgm': BayesianGaussianMixture,
    'kmeans': KMeans
}

class BGMTransformer(Transformer):

    def __init__(self, n_clusters=10, type='bgm'):
        """n_cluster is the upper bound of modes."""
        self.meta = None
        self.n_clusters = n_clusters
        self.model = GMM[type]

    def fit(self, train, categorical_columns=tuple(), ordinal_columns=tuple()):
        data = train
        self.meta = self.get_metadata(data, categorical_columns, ordinal_columns)
        
        self.output_info = []
        self.output_dim = 0
        self.components = []

        data_t = []
        self.cat_idx = []
        self.con_idx = []
        self.ord_idx = []

        st = 0
        for id_, info in enumerate(self.meta):
            current = data[:, id_]
            
            if info['type'] == CATEGORICAL:
                col_t = np.zeros([len(data), info['size']])
                idx = list(map(info['i2s'].index, current))
                col_t[np.arange(len(data)), idx] = 1
                data_t.append(col_t)
                self.cat_idx = np.concatenate([ self.cat_idx, np.arange(st, st+info['size']) ]) 
                self.output_info += [(info['size'], 'softmax')]
                st += info['size']

            elif info['type'] == ORDINAL:
                current = current / info['size']
                current = current * 2 - 1
                data_t.append(current.reshape(-1, 1))
                self.output_info += [(1, 'tanh')]
                self.ord_idx.append( st )
                st += 1

            else: # continuous
                data_t.append(current.reshape(-1, 1))
                self.con_idx.append( st )
                self.output_info += [(1, 'tanh')]
                st += 1

        onehot_data = np.concatenate(data_t, axis=1)
        self.cat_idx = self.cat_idx.astype(np.int8)

        self.gmm = self.model(
            n_components=self.n_clusters,
            n_init=1, 
            )

        self.gmm.fit(onehot_data)
        self.means =self.gmm.means_
        self.stds = np.sqrt(np.array([np.diag(i) for i in self.gmm.covariances_])) * 4
        self.weights = self.gmm.weights_
        self.probs = self.gmm.predict_proba(onehot_data)

        self.rvs = [multivariate_normal(mean=self.means[i], cov=self.gmm.covariances_[i], allow_singular=True) for i in range(self.n_clusters)]
    
    def return_onehot(self, train):
        data_t = []
        st = 0
        for id_, info in enumerate(self.meta):
            current = train[:, id_]
            if info['type'] == CATEGORICAL:
                col_t = np.zeros([len(train), info['size']])
                idx = list(map(info['i2s'].index, current))
                col_t[np.arange(len(train)), idx] = 1
                data_t.append(col_t)
                st += info['size']

            elif info['type'] == ORDINAL:
                current = current / info['size']
                current = current * 2 - 1
                data_t.append(current.reshape(-1, 1))
                st += 1 

            else: # continuous
                data_t.append(current.reshape(-1, 1))
                st += 1

        onehot_train = np.concatenate(data_t, axis=1)
        return onehot_train

    def transform(self, train):
        onehot_train = self.return_onehot(train)

        self.full_log_probs = [np.log(rvs.pdf(onehot_train)) for rvs in self.rvs]

        opt_sel_ = np.argsort(self.probs, axis=1) # increasing order

        deprecates = []
        for i , num in enumerate(np.unique(opt_sel_[:, -1],  return_counts=True)[1]):
            if num < 100:
                deprecates.append(i)
                opt_sel_[opt_sel_==i] = -1
            else: pass
        logging.info(f"{deprecates} have been deprecated due to the dataset size limit")

        deprecates = sorted(deprecates)
        deprecates.reverse()

        for delete in deprecates:
            del self.full_log_probs[delete]

        idx = -1
        self.opt_sel = opt_sel_[:, idx]

        while np.sum(self.opt_sel == -1) != 0:
            idx -= 1
            self.opt_sel[self.opt_sel == -1] = opt_sel_[self.opt_sel == -1][:, idx]
        
        trains = []
        self.log_probs = []
        
        self.final_n_clusters = np.unique(self.opt_sel)

        for id_ in self.final_n_clusters:
            train = onehot_train.copy() 
            self.log_probs.append(np.log(self.rvs[id_].pdf(train)))
            mean = self.means[id_]
            std = self.stds[id_]

            # standardize
            train[:, self.con_idx] = np.clip((train - mean) / std, -.99, .99)[:, self.con_idx] 
            trains.append(train)

        return trains


    def inverse_transform(self, data, id_):
        data = data.copy()

        mean = self.means[id_]
        std = self.stds[id_]

        data_t = np.zeros([len(data), len(self.meta)])

        data[:, self.con_idx] = ((np.clip(data, -.99, .99)) * std + mean)[:, self.con_idx]
        
        st = 0
        for id_, info in enumerate(self.meta):
            if info['type'] == CONTINUOUS:
                u = data[:, st]
                data_t[:, id_] = u
                st += 1

            elif info['type'] == ORDINAL:
                current = data[:, st]
                current = (current + 1) / 2
                current = current * info['size']
                current = np.round(current).clip(0, info['size'] - 1)
                data_t[:, id_] = current
                st += 1

            else: # categorical
                current = data[:, st:st + info['size']]
                st += info['size']
                idx = np.argmax(current, axis=1)
                data_t[:, id_] = list(map(info['i2s'].__getitem__, idx))

        return data_t
