import warnings
import numpy as np
from sklearn.cluster import DBSCAN

from symbols.pddl.predicate import Predicate
from symbols.utils import samples2np


class Link:
    def __init__(self, start, end):

        # end is already partitioned, but start may be stochastic. So check too
        self.end = end
        self.starts = self._partition(start)

    def __iter__(self):
        for start in self.starts:
            yield samples2np(start), None if self.end is None else samples2np(self.end)

    def _partition(self, X):
        states = samples2np(X)
        epsilon = 0.5
        min_samples = max(3, min(10, len(states) // 10))  # at least 3, at most 10

        min_samples = 1

        db = DBSCAN(eps=epsilon, min_samples=min_samples).fit(states)
        labels = db.labels_
        if all(elem == -1 for elem in labels) and len(labels) > min_samples:
            warnings.warn("All datapoints classified as noise!")
            labels = [0] * len(labels)
        clusters = list()
        for label in set(labels):
            if label == -1:
                continue
            # Get the data belonging to the current cluster
            indices = [i for i in range(0, len(labels)) if labels[i] == label]
            clusters.append(X[indices])
        return clusters


class ProblemSymbols:
    def __init__(self):
        self._symbols = list()

    def __len__(self):
        return len(self._symbols)

    def add(self, data):
        idx = self._check_similar(data)
        if idx > -1:
            return idx
        self._symbols.append(data)
        return len(self._symbols) - 1

    def _check_similar(self, data):
        mean = np.mean(data, axis=0)
        for i, d in enumerate(self._symbols):
            m = np.mean(d, axis=0)
            if np.linalg.norm(mean - m, np.inf) < 0.55:
                return i
        return -1

    def means(self):
        return np.array([np.mean(data, axis=0) for data in self._symbols])


class OperatorData:
    def __init__(self, partitioned_option, pddl_operators, raw_operator, full_mask=None):
        self.partitioned_option = partitioned_option
        self.subpartitions = partitioned_option.subpartition()
        self.schemata = pddl_operators
        self.raw_operator = raw_operator

        if len(self.subpartitions) == 0:
            #  Then we need to ground the precondition only
            states = np.array([partitioned_option.extract_prob_space(partitioned_option.observations[i]) for i in
                               range(partitioned_option.states.shape[0])])
            self.links = [Link(states, None)]
        else:
            self.links = [Link(subpartition.states, subpartition.next_states) for subpartition in self.subpartitions]

            # if len(self.subpartitions) > 1:
            #     self.subpartitions = self.subpartitions[0:1]
            # raise NotImplementedError
        self._full_mask = full_mask

    @property
    def full_mask(self):
        """
        The standard mask plus any extra preconditions!!
        """
        if self._full_mask is not None:
            return self._full_mask

        mask = set(self.partitioned_option._mask)
        mask.update(self.raw_operator.precondition.mask)
        return sorted(tuple(mask))

    @property
    def option(self):
        return self.partitioned_option.option

    @property
    def partition(self):
        return self.partitioned_option.partition

    @property
    def n_subpartitions(self):

        if len(self.subpartitions) == 0:
            return 1  # Then we need to ground the precondition only
        return len(self.subpartitions)

    #  init set in problem space
    def observations(self, idx=0):
        return samples2np(self.subpartitions[idx].states)

    # terminal set in problem space
    def next_observations(self, idx=0):
        # problem space is just xy position (first 4)
        return samples2np(self.subpartitions[idx].next_states)

    def add_problem_symbols(self, pddl, precondition_idx, effect_idx):
        print("Adding p_symbol{} and p_symbol{}".format(precondition_idx, effect_idx))
        for operator in self.schemata:
            precondition = Predicate('psymbol_{}'.format(precondition_idx))
            if effect_idx != -1:
                effect = Predicate('psymbol_{}'.format(effect_idx))
                operator.link(precondition, effect)
            else:
                operator.link(precondition, None)
                # operator.add_effect(effect)
                # operator.add_effect(precondition.negate())

            # propositionalise objects to avoid ambiguity
            mask = self.full_mask
            instantiated = False
            # easier - just add object ID fluent to precondition!
            ambigious = [pddl.is_ambiguous(m) for m in mask]  # make PDDL look nicer aesthetically!
            operator.add_object_to_precondition(mask, ambigious)

            # for i, m in enumerate(mask):
            #
            #     if not pddl.is_ambiguous(m):
            #         # object is its own type, so can ignore!
            #         continue
            #     else:
            #         new_type = 'type{}{}'.format(pddl.object_type(m), chr(ord('a') + m - 1))
            #         pddl.add_grounded_type(m, 'type{}'.format(pddl.object_type(m)), new_type)
            #         operator.instantiate_object(i, new_type)
            #         instantiated = True
            # if not instantiated:
            #     operator.instantiate_object(-1, None)

    def observation_mask(self, idx):

        if len(self.subpartitions) == 0:
            return []

        masks = set()
        for obs, next_obs in zip(self.observations(idx), self.next_observations(idx)):
            mask = np.array([j for j in range(0, len(obs)) if not np.array_equal(obs[j], next_obs[j])])
            for m in mask:
                masks.add(m)
        return np.sort(list(masks))
