import warnings

import numpy as np
from overrides import overrides
from sklearn.cluster import DBSCAN
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KernelDensity

from symbols.symbols.distribution_symbol import DistributionSymbol




class KernelDensityEstimator(DistributionSymbol):




    def __init__(self,
                 mask,
                 data,
                 copy=False,
                 masked=False):

        self._mc_samples = 100
        self._mask = np.array(mask)
        self._n_objects = len(mask)
        self._kdes = list()
        if not copy:

            # msk = [int(x) for x in self._mask]
            # if data.shape[1] != len(msk):
            #     masked_data = data[:, msk]
            # else:
            #     masked_data = data

            # if masked already, don't do this:
            if not masked:
                masked_data = data[:, mask]
            else:
                masked_data = data
            self._raw_data = masked_data  # for debug purposes only!
            n_splits = len(data) if 0 < len(data) < 3 else 3

            if n_splits < 3:
                warnings.warn("Very little data to do KDE")

            for col in range(masked_data.shape[1]):


                x = masked_data[:, col]
                x = np.array([sample.ravel() for sample in x])
                params = {'bandwidth': np.arange(0.001, 0.1, 0.001)}
                # params = {'bandwidth': np.arange(0.0001, 0.2, 0.001)}
                grid = GridSearchCV(KernelDensity(kernel='gaussian'), params, cv=n_splits)
                grid.fit(x)
                params = grid.best_params_
                # print(params)
                self._kdes.append(grid.best_estimator_)

        self.name = 'KDE'

    @staticmethod
    def _extract_remaining(a, b):
        """ 
          Return the elements of a that are not in b.
          Return two nd-arrays, the first listing these elements,
          the second listing their indices. 
        """
        new_vars = []
        new_indices = []
        
        for pos in range(0, len(a)):
            val = a[pos]
            if not (val in b):
                new_vars.append(val)
                new_indices.append(pos)
                
        return np.array(new_vars), np.array(new_indices)

    @overrides
    def integrate_out(self,
                      variable_list):
        """
        Given a distribution p(s) and a list of variables, return a new
        distribution equal to p with those variables marginalized out. 
        """
        (new_vars, new_indices) = KernelDensityEstimator._extract_remaining(self._mask, variable_list)
        if len(new_indices) > 1:
            #TODO: check
            # print("Must debug")
            pass
        new_samples = self.sample(self._mc_samples)[:, new_indices]
        return KernelDensityEstimator(mask=new_vars, data=new_samples, masked=True)

    @overrides
    def distribution_and(self,
                         *distribution_symbols):

        new_vars = self._mask

        for symbol in distribution_symbols:
            new_vars = np.concatenate([new_vars, symbol.mask])

        n_new_variables = len(new_vars)

        if n_new_variables > len(np.unique(new_vars)):
            raise TypeError("In KernelDensityEstimator::distribution_and: attempted to and distributions "
                            "with overlapping masks.")

        kde = KernelDensityEstimator(new_vars, None, copy=True)
        kde._kdes += self._kdes
        name = str(self)
        for symbol in distribution_symbols:
            kde._kdes += symbol._kdes
            name += ' AND ' + str(symbol)

        # new_sample = np.zeros([self._mc_samples, n_new_variables])
        # mc_sample1 = self.sample(self._mc_samples)
        #
        # new_sample[:, 0:len(self._mask)] = mc_sample1
        # idx = len(self._mask)
        #
        # name = str(self)

        # for symbol in distribution_symbols:
        #     mc_sample2 = symbol.sample(self._mc_samples)
        #     new_sample[:, idx: idx + len(symbol.flat_mask)] = mc_sample2
        #     idx += len(symbol.flat_mask)
        #     name += ' AND ' + str(symbol)

        # kde = KernelDensityEstimator(mask=new_vars, data=new_sample)
        kde.name = name
        return kde


    def get_symbols(self):

        # return the kdes and objects to which they refer
        return [(x, y) for (x, y) in zip(self.mask, self._kdes)]

    def __str__(self):
        return self.name

    @overrides
    def probability_in_set(self,
                           conditional_symbol,
                           allow_fill_in=True,
                           fill_in=None):

        if not set(self.mask).issuperset(set(conditional_symbol.mask)):
            return 0

        s_prob = 0.0
        keep_indices = []
        for i in range(0, len(self._mask)):
            if self._mask[i] in conditional_symbol.mask:
                keep_indices.append(i)

        # Bail if no overlap.
        if len(keep_indices) == 0:
            return 0

        #TODO:
        self._mc_samples = 1000

        # mc_samples = self.sample(self._mc_samples)
        mc_samples = self.sample(self._mc_samples)
        mc_samples = mc_samples[:, keep_indices]

        # TODO: NEW: average mc samples and feed to classifier!
        # mean of images:
        dat = np.array([np.hstack(x) for x in mc_samples])
        mean_point = np.mean(dat, axis=0)
        mean_image_prob = conditional_symbol.probability(mean_point, use_mask=False)
        probs = [conditional_symbol.probability(mc_samples[pos, :], use_mask=False) for pos in range(0, self._mc_samples)]
        return mean_image_prob, np.max(probs), np.min(probs), np.mean(probs), np.var(probs)

        # TODO: NEW: take max of classifier output!
        # m = 0
        # for pos in range(0, self._mc_samples):
        #     point = mc_samples[pos, :mc_samples]
        #     m = max(m, conditional_symbol.probability(point, use_mask=False))
        # return m

        # TODO: OLD: standard averaging over samples!
        # for pos in range(0, self._mc_samples):
        #     point = mc_samples[pos, :]
        #     s_prob = s_prob + conditional_symbol.probability(point, use_mask=False)



        # add_list = []
        # for m in conditional_symbol.mask:
        #     if m not in self._mask:
        #         add_list.append(m)
        #
        # total_mask = self._mask[keep_indices]
        #
        # if len(add_list) > 0:
        #
        #     if not allow_fill_in:
        #         # not allowed to fill in data (e.g. ames et al)
        #         return 0
        #     if fill_in is not None:
        #         uniform = np.array([list(fill_in[add_list])] * self._mc_samples)
        #     else:
        #         uniform = np.random.uniform(0.0, 1.0, size=[self._mc_samples, len(add_list)])
        #     mc_new = np.zeros([self._mc_samples, len(keep_indices) + len(add_list)])
        #     mc_new[:, 0:len(keep_indices)] = mc_samples
        #     mc_new[:, len(keep_indices):] = uniform
        #     mc_samples = mc_new
        #     total_mask = np.concatenate((total_mask, np.array(add_list)))
        #
        # for pos in range(0, self._mc_samples):
        #     point = mc_samples[pos, :]
        #     t_point = np.zeros([np.max(total_mask) + 1])
        #     t_point[total_mask] = point
        #     s_prob = s_prob + conditional_symbol.probability(t_point)

        # return s_prob / self._mc_samples

    @property
    def mask(self):
        return self._mask

    @overrides
    def sample(self,
               n_samples=100):
        samples = list()
        for i in range(n_samples):
            x = np.array([kde.sample().squeeze() for kde in self._kdes], dtype=object)
            samples.append(x)
        return np.array(samples)

    @overrides
    def kl_divergence(self,
                      distribution_symbol,
                      n_samples=100):
        x = self.sample(n_samples)
        log_p_x = self.score_samples(x)
        log_q_x = distribution_symbol.score_samples(x)
        return log_p_x.mean() - log_q_x.mean()


    def is_similar(self,
                   distribution_symbol,
                   threshold=500,
                   n_samples=100):

        # if ('43' in self.name and '77' in distribution_symbol.name) or ('77' in self.name and '43' in distribution_symbol.name):
        #     k=0


        if 9 in self.mask:
            threshold = 20
            return False

        kl = self.kl_divergence(distribution_symbol, n_samples)
        return kl < threshold


        x = self.sample(n_samples)
        x = np.array([np.hstack(i) for i in x])
        y = distribution_symbol.sample(n_samples)
        y = np.array([np.hstack(i) for i in y])
        if np.array_equal(self.mask, distribution_symbol.mask):

            data = np.concatenate((x, y))

            eps = 2 # 1
            if data.shape[1] < 20:
                # probably inventory - use smaller one
                data = np.rint(data)
                eps = 0.1
            # print(data.shape)
            db = DBSCAN(eps=eps).fit(data)
            # db = OPTICS(eps=neighbourhood_radius, min_samples=min_samples).fit(dat)
            labels = db.labels_
            return len(set(labels)) == 1
        else:
            raise ValueError("TODO")

    @overrides
    def score_samples(self, X):

        if X.shape[1] != len(self._kdes):
            raise ValueError

        scores = list()
        for i in range(X.shape[1]):

            x = X[:, i]
            x = np.array([np.hstack(i) for i in x])
            scores.append(self._kdes[i].score_samples(x))

        # for kde in self._kdes:

        return np.mean(scores, axis=0)

        return self._kde.score_samples(X[self._mask])

    @staticmethod
    def not_failed():
        kde = KernelDensityEstimator([])
        kde.id = 'notfailed'
        return kde

    @overrides
    def score(self, state, mask=None):
        if mask is None:
            mask = self._mask
        t = np.array(state)[mask].reshape(1, -1)
        return self.score_samples(t)[0]
