import copy
import numpy as np
from tqdm import tqdm
from typing import Tuple
from collections import OrderedDict
from typing import NamedTuple, Dict, List
from dataset import FeatIndex
from sklearn.mixture import BayesianGaussianMixture

from dataset import Adult


class TransInfo(NamedTuple):
    """ Numerical feature transformation """
    transform: BayesianGaussianMixture
    valid: np.ndarray
    output_dim: int


class DataTransformer():
    """ Model numerical columns with a BayesianGMM """

    STD_MULTIPLIER = 4

    def __init__(self, max_clusters=10, weight_threshold=0.005):
        self._max_clusters = max_clusters
        self._weight_threshold = weight_threshold

    def _fit_continuous(self, X: np.ndarray) -> TransInfo:
        """
        Fit numerical data with Gaussian Mixture Model
        X is 1-D numpy array represents one numerical feature
        """

        gm = BayesianGaussianMixture(
            n_components=self._max_clusters,
            weight_concentration_prior_type='dirichlet_process',
            max_iter=5000,
            weight_concentration_prior=0.001,
            n_init=1,
        )
        gm.fit(X.reshape(-1, 1))

        valid_component_indicator = gm.weights_ > self._weight_threshold
        num_components = sum(valid_component_indicator)

        return TransInfo(gm, valid_component_indicator, num_components + 1)

    def _transform_continuous(self, X: np.ndarray, trans_info: TransInfo):
        """
        Transform numerical data with Gaussian Mixture Model
        X is 1-D numpy array represents one numerical feature
        Returned value is a vector, the first position is normalized value, the rest are selected component in one-hot
        """

        data = X.reshape(-1, 1)
        gm = trans_info.transform

        means = gm.means_.reshape((1, self._max_clusters))
        stds = np.sqrt(gm.covariances_).reshape((1, self._max_clusters))
        normalized_val = (data - means) / (self.STD_MULTIPLIER * stds)
        normalized_values = normalized_val[:, trans_info.valid]
        component_probs = gm.predict_proba(data)
        component_probs = component_probs[:, trans_info.valid]

        selected_component = np.zeros(len(data), dtype=np.int64)
        for i in range(len(data)):
            component_prob_t = component_probs[i] + 1e-6
            component_prob_t = component_prob_t / component_prob_t.sum()
            selected_component[i] = np.random.choice(
                np.arange(trans_info.valid.sum()),
                p=component_prob_t,
            )

        aranged = np.arange(len(data))
        normalized = normalized_values[aranged, selected_component].reshape([-1, 1])
        normalized = np.clip(normalized, -.99, .99)
        normalized = normalized[:, 0]
        rows = [normalized, selected_component]

        return np.stack(rows, axis=1)

    def _reverse_continuous(self, X: np.ndarray, trans_info: TransInfo):
        selected_component = np.argmax(X[:, 1:], axis=1).astype(np.int64)
        normalized = np.clip(X[:, 0], -1., 1.)

        means = trans_info.transform.means_.reshape([-1])
        stds = np.sqrt(trans_info.transform.covariances_).reshape([-1])
        mean_t = means[trans_info.valid][selected_component]
        std_t = stds[trans_info.valid][selected_component]
        reversed_data = normalized * self.STD_MULTIPLIER * std_t + mean_t

        return reversed_data.reshape(-1, 1)

    def fit(self, X: np.ndarray, feat_idx: FeatIndex):
        self._ori_feat_idx = feat_idx
        self._trans_dict = OrderedDict()
        # for idx in tqdm(feat_idx.num_idx, desc="Fitting numerical features to Gaussian Mixture Model"):
        for feat_name in feat_idx.num_feat:
            idx = feat_idx.feat2idx[feat_name]
            trans_info = self._fit_continuous(X[:, idx])
            self._trans_dict.update({feat_name: trans_info})

        return

    def transform(self, X: np.ndarray):
        """ Data after transformation: categorical first then encoded numerical features """

        output_list = []
        num_feat_dim = OrderedDict()
        for feat_name, trans_info in self._trans_dict.items():
            idx = self._ori_feat_idx.feat2idx[feat_name]
            transformed = self._transform_continuous(X[:, idx], trans_info)

            output = np.zeros((len(transformed), trans_info.output_dim))
            output[:, 0] = transformed[:, 0]
            index = transformed[:, 1].astype(int)
            output[np.arange(index.size), index + 1] = 1.

            output_list.append(output)
            num_feat_dim[feat_name] = trans_info.output_dim

        X = np.delete(X, self._ori_feat_idx.num_idx, 1)
        output_list.insert(0, X)
        X = np.concatenate(output_list, axis=1).astype(np.float64)

        # get updated feature index
        feat2idx = copy.deepcopy(self._ori_feat_idx.feat2idx)
        for feat in self._ori_feat_idx.cat_feat:
            feat2idx[feat] = [idx - len(self._ori_feat_idx.num_idx) for idx in feat2idx[feat]]
        curr_idx = len(self._ori_feat_idx.cat_idx)
        for feat in self._ori_feat_idx.num_feat:
            feat2idx[feat] = [idx for idx in range(curr_idx, curr_idx + num_feat_dim[feat])]
            curr_idx += num_feat_dim[feat]

        feat2idx = OrderedDict(sorted(feat2idx.items(), key=lambda x: x[1][0]))  # ordered by index

        cat_idx, num_idx, sen_idx = [], [], []
        for feat in feat2idx.keys():
            if feat in self._ori_feat_idx.cat_feat:
                cat_idx.extend(feat2idx[feat])
            if feat in self._ori_feat_idx.num_feat:
                num_idx.extend(feat2idx[feat])
            if feat in self._ori_feat_idx.sen_feat:
                sen_idx.extend(feat2idx[feat])

        self._update_feat_idx = FeatIndex(
            self._ori_feat_idx.cat_feat,
            self._ori_feat_idx.num_feat,
            self._ori_feat_idx.sen_feat,
            feat2idx, cat_idx, num_idx, sen_idx
        )

        return X

    def reverse_transform(self, X: np.ndarray):
        reversed_list = []  # reverse for numerical feature
        for feat_name, trans_info in self._trans_dict.items():
            idx = self._update_feat_idx.feat2idx[feat_name]
            reversed = self._reverse_continuous(X[:, idx], trans_info)
            reversed_list.append(reversed)

        # categorical feature
        for feat_name in self._update_feat_idx.cat_feat:
            idx: List = self._update_feat_idx.feat2idx[feat_name]
            feat = X[:, idx]
            selected = np.argmax(feat, axis=1).astype(np.int64)
            feat = np.zeros_like(feat)
            for i, val in enumerate(selected):
                feat[i, val] = 1
            X[:, idx] = feat

        X = np.delete(X, self._update_feat_idx.num_idx, 1)
        reversed_list.append(X)

        return np.concatenate(reversed_list, axis=1)

    @property
    def feat_idx(self) -> FeatIndex:
        """ Updated FeatIndex for transformed data """
        return copy.deepcopy(self._update_feat_idx)

    @property
    def trans_dict(self) -> Dict[str, TransInfo]:
        """ Dictionary mapping numerical feature name to TransInfo """
        return copy.deepcopy(self._trans_dict)


if __name__ == "__main__":
    dataset = Adult()
    train_X, train_y = dataset.train_data(scale="none")
    feat_idx = dataset.feat_idx

    data_trans = DataTransformer()

    data_trans.fit(train_X, feat_idx)
    transformed_train_X = data_trans.transform(train_X)
    update_feat_idx = data_trans.feat_idx
    reversed_train_X = data_trans.reverse_transform(transformed_train_X)

    for ori, rev in zip(train_X, reversed_train_X):
        if not np.array_equal(ori, rev):
            print(ori)
            print(rev)
            print(np.equal(ori, rev))
            print("-----------------")

    assert np.array_equal(train_X, reversed_train_X)
