from symbols.pddl.operator import TypedOperator
from symbols.pddl.predicate import Predicate, TypedPredicate
import numpy as np

class TransferableOperator:

    def __init__(self, pddl_operator: TypedOperator):
        self.operator = pddl_operator
        self.preconditions = [Predicate('notfailed')]
        self.effects = list()

    def add_precondition(self, proposition):
        self.preconditions.append(proposition)

    def add_effect(self, wrapped_proposition):
        self.effects.append(wrapped_proposition)

    @property
    def name(self):
        return self.operator.name

    @property
    def option(self):
        return self.operator.option

    @property
    def types(self):
        types = list()
        groundings = set()
        for predicate in self.operator.preconditions + self.operator.effects:
            if isinstance(predicate, TypedPredicate):
                for param in predicate.param_types:
                    for object in predicate.grounding:
                        if object not in groundings:
                            types.append(int(param[4:]))
                        groundings.add(object)
        return types

    def transition_proba(self, state, next_state, types, mask):

        # calculate precondition mask
        pre_mask = []
        i = 0
        groundings_to_idx = dict()
        for predicate in self.operator.preconditions:
            if isinstance(predicate, TypedPredicate):
                if len(predicate.grounding) > 1:
                    raise ValueError
                for param in predicate.param_types:
                    type = int(param[4:])
                    if types[i] != type:
                        raise ValueError
                    pre_mask.append(mask[i])
                groundings_to_idx[predicate.grounding[0]] = i
                i += 1

        # calculate effect mask
        eff_mask = []
        neg_mask = []
        for predicate in self.operator.effects:
            if isinstance(predicate, TypedPredicate):

                if len(predicate.grounding) > 1:
                    raise ValueError

                for grounding in predicate.grounding:

                    m = mask[groundings_to_idx[grounding]]
                    if predicate.sign > 0:
                        eff_mask.append(m)
                    else:
                        neg_mask.append(m)

        pre_prob = self._proba([x for x in self.preconditions if not isinstance(x, Predicate)], state, pre_mask)
        eff_prob = self._proba([x.symbol for x in self.effects if x.sign > 0], next_state, eff_mask)
        # neg_prob = self._proba([x.symbol for x in self.effects if x.sign < 0], next_state, eff_mask)
        # pre_prob = 0
        # eff_prob = 0
        return min(pre_prob, eff_prob)

    def _proba(self, kdes, state, mask):
        total_prob = 1
        if len(kdes) != len(mask):
            raise ValueError

        for m, eff in zip(mask, kdes):
            # try:

            if len(eff._kdes) > 1:
                raise ValueError
            temp = eff._kdes[0]

            logit = np.exp(temp.score_samples(state[m].reshape(1, -1))[0])
            # rescale
            prob = logit / (1 + logit)
            # except:
            #     prob = 0
            total_prob *= prob
        return total_prob
