from sklearn.cluster import DBSCAN
import numpy as np

from symbols.logger.transition_sample import TransitionSample


class PartitionedOption:
    def __init__(self,
                 option,
                 partition,
                 states,
                 mask,
                 rewards,
                 next_states,
                 epsilon_problem_space,
                 epsilon_agent_space,
                 object_ids=None):

        assert (states.shape[0] == next_states.shape[0])
        self._option = option
        self._partition = partition
        self._states = states
        self._mask = mask
        self._rewards = rewards
        self._next_states = next_states

        self._rules = [Rule(states, mask, rewards, next_states)]
        self._observations = None
        self._next_observations = None
        self._epsilon_problem_space = epsilon_problem_space
        self._epsilon_agent_space = epsilon_agent_space
        self._no_augment = set()
        self._object_ids = object_ids

    # these are the observations from the current task that we will use to ground
    def inject_observations(self,
                            observations,
                            next_observations,
                            append=False):
        if append:
            self._observations = np.concatenate((self._observations, observations))
            self._next_observations = np.concatenate((self._next_observations, next_observations))
        else:
            self._observations = observations
            self._next_observations = next_observations

    def add_no_augment(self, partition):
        self._no_augment.add(partition)

    def can_augment(self, idx):
        return idx not in self._no_augment

    @property
    def states(self):
        return self._states

    @property
    def next_states(self):
        return self._next_states

    @property
    def observations(self):
        return self._observations

    @property
    def next_observations(self):
        return self._next_observations

    @property
    def option(self):
        return self._option

    @property
    def partition(self):
        return self._partition

    @property
    def rules(self):
        return self._rules

    @property
    def object_ids(self):
        return self._object_ids

    def get_number_effects(self):
        return len(self._rules)

    def l2_dist(self, s):
        """
        Distance from s to the mean of the initiation set
        """
        M = np.mean(self.states, axis=0)
        dist = np.linalg.norm(s - M)
        return dist

    def __contains__(self, s):
        # M = np.mean(self.states, axis=0)
        # dist = np.linalg.norm(s - M)
        # return dist < 0.001
        #
        for x in range(0, self.states.shape[0]):
            dist = np.linalg.norm(s - self.states[x])
            if dist < 0.01:
                return True
        return False

    def init_sets_similar(self,
                          s_candidate,
                          masks=None):

        if masks is not None:
            cluster_current = self._num_clusters(self.states[masks[0]], self._epsilon_problem_space)
            cluster_other = self._num_clusters(s_candidate[masks[1]], self._epsilon_problem_space)
        else:
            cluster_current = self._num_clusters(self.states, self._epsilon_problem_space)
            cluster_other = self._num_clusters(s_candidate, self._epsilon_problem_space)

        if cluster_other * cluster_current == 0:
            return False

        if masks is not None:
            cluster_total = self._num_clusters(np.concatenate((self.states[masks[0]], s_candidate[masks[1]])),
                                               self._epsilon_problem_space)
        else:
            cluster_total = self._num_clusters(np.concatenate((self.states, s_candidate)), self._epsilon_problem_space)

        if cluster_total <= max(cluster_current, cluster_other):
            return True
        else:
            return False

    # check if observations init sets overlap
    def observation_init_sets_similar(self,
                                      obs_candidates):
        cluster_current = self._num_clusters(self.observations, self._epsilon_agent_space)
        cluster_other = self._num_clusters(obs_candidates, self._epsilon_agent_space)
        if cluster_other * cluster_current == 0:
            return False
        cluster_total = self._num_clusters(np.concatenate((self.observations, obs_candidates)),
                                           self._epsilon_agent_space)
        if cluster_total <= max(cluster_current, cluster_other):
            return True
        else:
            return False

    def _num_clusters(self, dat, eps):
        # flatten 2d state representation
        dat = np.array([np.concatenate(sample).ravel() for sample in dat])
        db = DBSCAN(eps=eps, min_samples=5).fit(np.array(dat))
        labels = db.labels_
        return len(set(labels)) - (1 if -1 in labels else 0)

    def get_transition_probabilities(self):
        return [rule.probability(self.states.shape[0]) for rule in self._rules]

    def get_transition_probability(self, effect_idx):
        return self._rules[effect_idx].probability(self.states.shape[0])

    def get_transition_rule(self, effect_idx):
        return self._rules[effect_idx]

    def get_next_states(self, effect_idx):
        return self._rules[effect_idx].termination_set

    def get_mask(self, effect_idx):
        return self._rules[effect_idx].mask

    @property
    def combined_mask(self):
        mask = list()
        for rule in self._rules:
            for v in rule.mask:
                mask.append(v)
        return mask

    # When merging, an outcome is created for each effect cluster (which could be distinct due of clustering or due
    # to a different mask) and assigned an outcome probability based on the fraction of the samples assigned to it.
    def merge(self, states, mask, rewards, next_states):

        self._rewards = np.append(self._rewards, rewards)
        self._states = np.concatenate((self._states, states))
        self._next_states = np.concatenate((self._next_states, next_states))
        self._rules.append(Rule(states, mask, rewards, next_states))

    def distance_to(self, s):
        min_dist = np.Inf
        for x in range(0, self.states.shape[0]):
            dist = np.linalg.norm(s - self.states[x])
            if dist < min_dist:
                min_dist = dist
            if min_dist == 0:
                break
        return min_dist

    def masked_distance_to(self, s):
        mask = self.combined_mask
        min_dist = np.Inf
        for x in range(0, self.states.shape[0]):
            dist = np.linalg.norm(s[mask] - self.states[x][mask])
            if dist < min_dist:
                min_dist = dist
            if min_dist == 0:
                break
        return min_dist

    def effect_distance_to(self, s, eff_no):
        rule = self._rules[eff_no]
        s_dat = s[rule.mask]
        target_dat = rule.termination_set[:, rule.mask]
        min_dist = np.Inf
        for x in range(0, target_dat.shape[0]):
            min_dist = min(min_dist, np.linalg.norm(s_dat - target_dat[x]))
        return min_dist

    def get_outcome_index(self,
                          s,
                          s_prime):
        msk = [x for x in range(0, len(s)) if s[x] != s_prime[x]]
        candidates = []

        for i in range(0, self.get_number_effects()):
            msk_i = self.get_mask(i)
            if set(msk_i) == set(msk):
                candidates.append(i)

        if len(candidates) == 0:
            return -1
        elif len(candidates) == 1:
            return candidates[0]
        else:
            min_pos = -1
            min_val = np.Inf

            for candidate in candidates:
                dist = self.effect_distance_to(s_prime, candidate)
                if dist < min_val:
                    min_pos = candidate
                    min_val = dist

            return min_pos

    def extract_prob_space(self, obs):
        arr = np.empty((1,), dtype=object)
        arr[0] = obs[1]
        return arr

    def subpartition(self, verbose=True):

        # Swap problem and agent space states and partition
        # TODO: hack. problem space is only xy orientation. So obs[1]
        samples = np.array([TransitionSample(self.extract_prob_space(self.observations[i]), self.states[i], self.option,
                                             0, self._rewards[i], self.extract_prob_space(self.next_observations[i]),
                                             self.next_states[i])
                            for i in
                            range(self.states.shape[0])])





        from symbols.experimental.partition_options import partition_option
        return partition_option(self.option, samples, verbose=verbose, subpartition=True)


class Rule:
    def __init__(self, initiation_set, mask, rewards, termination_set):
        self.initiation_set = initiation_set
        self.mask = mask
        self.rewards = rewards
        self.termination_set = termination_set

    def probability(self, total_num_states):
        return self.initiation_set.shape[0] / total_num_states
