from symbols.symbols.learned_operator import LearnedOperator


class LinkingFunction:
    def __init__(self):
        self.map = dict()
        self.visit = dict()

    def add_visit(self, start):
        if start not in self.visit:
            self.visit[start] = 0
        self.visit[start] += 1

    def add(self, start, end, val):
        if start not in self.map:
            self.map[start] = dict()

        if end not in self.map[start]:
            self.map[start][end] = 0

        self.map[start][end] += val

    def __str__(self):
        s = ''
        for x in self.map:
            for y in self.map[x]:
                prob = self.map[x][y] / self.visit[x]
                if prob > 0:
                    s += str(x) + ' -> ' + str(y) + ' w.p. ' + str(prob) + '\n'

        return s

    def emit(self):
        for x in self.map:
            ys = list()
            probs = list()
            for y in self.map[x]:
                prob = self.map[x][y] / self.visit[x]
                if prob > 0:
                    ys.append(y)
                    probs.append(prob)
            if len(ys) > 0:
                yield x, ys, probs


class LearnedLiftedOperator(LearnedOperator):

    def __init__(self, learned_operator):
        self._learned_symbol = learned_operator
        self._links = {i: LinkingFunction() for i, _ in enumerate(self.list_effects)}

    @property
    def option(self):
        return self._learned_symbol.option

    @property
    def partition(self):
        return self._learned_symbol.partition

    @property
    def precondition(self):
        return self._learned_symbol.precondition

    @property
    def list_probabilities(self):
        return self._learned_symbol.list_probabilities

    @property
    def list_effects(self):
        return self._learned_symbol.list_effects

    @property
    def list_rewards(self):
        return self._learned_symbol.list_rewards

    @property
    def links(self):
        return self._links

    def update_links(self, label, next_label, probs):

        # Updates the linking functions. The probs are the probabilities assigned to each possible effect outcome
        for i, prob in enumerate(probs):
            self._links[i].add(label, next_label, prob)
            self._links[i].add_visit(label)

    def find_alternative(self, state, merge_map):
        for i in range(len(self._links)):
            temp = self._links[i]
            for x, ys, probs in temp.emit():
                if merge_map.is_close(state, x):
                    return x
        return None


    def contains_init_label(self, label, merge_map=None):
        for i in range(len(self._links)):
            if self.in_equivalence_class(i, label, merge_map)[0] != []:
                return True
        return False

    def contains_init_labels(self, labels, state, merge_map=None):
        for label in labels:
            if self.contains_init_label(label, merge_map) or self.find_alternative(state, merge_map) is not None:
                return True
        return False

    def in_equivalence_class(self, eff_idx, label, merge_map=None):
        temp = self._links[eff_idx]
        for x, ys, probs in temp.emit():
            if x == label:
                return ys, probs

        # nothing found, check if we missed an overlap:
        if merge_map is not None:
            temp = self._links[eff_idx]
            for x, ys, probs in temp.emit():
                if merge_map.is_overlap(label, x):
                    return ys, probs

        return [], []

    def __str__(self):

        return "Option {}, partition {}\n#Effects: {}\nLinks:\n\n{}".format(
            self.option,
            self.partition,
            len(self.list_effects),
            '\n'.join([str(link) for _, link in self._links.items()])
        )