from typing import List

from symbols.pddl.predicate import Predicate, TypedPredicate, NamedParameterPredicate


#  operator description. Does not support probabilistic effects!
class Operator:
    def __init__(self, name, option, partition):
        self.name = name.replace(' ', '-')  # NB ensure no spaces in name
        self.preconditions = [Predicate('notfailed')]
        self.effects = list()
        self._option = option
        self._partition = partition
        self._linking = list()  # map from problem space symbol to problem space symbol

    @property
    def partition(self):
        return self._partition

    @property
    def option(self):
        return self._option

    # todo: LOOK AT
    def is_duplicate(self, other):
        ps = sorted([p.name for p in self.preconditions])
        o_ps = sorted([p.name for p in other.preconditions])

        if ps != o_ps:
            return False

        es = sorted([p.name for p in self.effects])
        o_es = sorted([p.name for p in other.effects])

        return es == o_es

    def add_precondition(self, predicate: Predicate):
        self.preconditions.append(predicate)

    def add_effect(self, predicate: Predicate):
        self.effects.append(predicate)

    def link(self, precondition: Predicate, effect: Predicate):
        self._linking.append((precondition, effect))

    def __str(self, precondition: Predicate, effect: Predicate):  # problem_space preconditions and effects
        pre = ''
        if len(self.preconditions) > 0:
            pre += 'and '
        pre += ' '.join(['({})'.format(p) for p in self.preconditions])

        if precondition is not None:
            pre += ' ({})'.format(precondition)

        eff = ''
        if len(self.effects) > 0:
            eff += 'and '
        eff += ' '.join(['({})'.format(e) for e in self.effects])

        if effect is not None:
            eff += ' ({})'.format(effect)
            eff += ' ({})'.format(precondition.negate())

        return '\t(:action {}\n' \
               '\t :parameters ()\n' \
               '\t :precondition ({})\n' \
               '\t :effect ({})\n' \
               '\t)'.format(self.name, pre, eff)

    def __str__(self):

        if len(self._linking) == 0:
            return self.__str(None, None)
        return '\n\n'.join([self.__str(p, e) for p, e in self._linking])


# typed operator description. Does not support probabilistic effects!
class TypedOperator(Operator):
    def __init__(self, name, option, partition):
        super().__init__(name, option, partition)
        self.id = -1
        self.child_type = dict()
        self._child_type = list()  # list of old types and the new ones!
        # not doing probablistic effects, but will do probabilistic failures!
        self.failure_probability = 0
        self.probabilistic = False
        self._object_preconditions = list()
        self._ambiguous = list()

    def add_object_to_precondition(self, mask, ambiguous):
        self._object_preconditions.append(mask)
        self._ambiguous.append(ambiguous)

    def _create_named_predicate(self, predicate: TypedPredicate, variable_names):
        if not predicate.is_grounded():
            return predicate
        names = list()
        for object in predicate.grounding:
            names.append(variable_names[object])
        return NamedParameterPredicate(predicate, names)

    def instantiate_object(self, param_idx, new_type):

        raise ValueError("Not using this. Use add_object_to_precondition instead! ")

        if param_idx == -1:
            self._child_type.append((None, None))
            return

        # change the parameter at the index to be the new type
        predicate = self.preconditions[param_idx + 1]  # plus one because first is always notfailed
        groundings = predicate.grounding
        if len(groundings) != 1:
            raise NotImplementedError("Haven't tackled predicates with multiple groundings")
        # predicate.param_types[0] = new_type
        self._child_type.append((predicate.param_types[0], new_type))
        self.child_type[predicate.param_types[0]] = new_type

    def __type(self, type, old_type, new_type):
        if old_type is None:
            if new_type is not None:
                raise ValueError
            return type
        if type == old_type:
            return new_type
        return type

    # def __str(self, linking_precondition: Predicate, linking_effect: Predicate, old_type: str, new_type: str,
    #           variant=None):
    def __str(self, linking_precondition: Predicate, linking_effect: Predicate, object_ids: List[int], ambiguous,
              variant=None):
        object_to_var_name = dict()
        params = list()
        for predicate in self.preconditions:
            if not predicate.is_grounded():
                continue
            for object, type in predicate.ground_types():
                if object not in object_to_var_name:
                    object_to_var_name[object] = chr(ord('w') + len(object_to_var_name))
                variable_name = object_to_var_name[object]
                params.append((variable_name, type))

        params_str = ' '.join(
            ['?{} - {}'.format(var, type) for (var, type) in
             params])
        pre = ''
        if len(self.preconditions) + len(self._object_preconditions) > 0:
            pre += 'and '
        pre += ' '.join(
            ['({})'.format(self._create_named_predicate(p, object_to_var_name)) for p in self.preconditions])

        if object_ids is not None and len(object_ids) > 0:
            for (variable_name, _), object_id, a in zip(params, object_ids, ambiguous):
                if a:  # only fi ambiguous. Makes look nicer!
                    pre += ' (= (id ?{}) {})'.format(variable_name, object_id)

        if linking_precondition is not None:
            pre += ' ({})'.format(linking_precondition)

        eff = ''
        if len(self.effects) > 0:
            eff += 'and '
        eff += ' '.join(['({})'.format(self._create_named_predicate(e, object_to_var_name)) for e in self.effects])

        if linking_effect is not None:
            eff += ' ({})'.format(linking_effect)
            eff += ' ({})'.format(linking_precondition.negate())

        if self.probabilistic and self.failure_probability > 0:
            prob = round(self.failure_probability, 2)
            eff = 'probabilistic {} {} {} ({})'.format(prob, '(not (notfailed))', 1 - prob, eff)

        name = '{}-{}'.format(self.name, self.id)
        if variant is not None:
            name += variant

        return '\t(:action {}\n' \
               '\t :parameters ({})\n' \
               '\t :precondition ({})\n' \
               '\t :effect ({})\n' \
               '\t)'.format(name, params_str, pre, eff)

    def __str__(self):

        # if len(self._linking) != len(self._child_type):
        #     raise ValueError
        if len(self._linking) != len(self._object_preconditions):
            raise ValueError

        if len(self._linking) == 0:
            return self.__str(None, None, None, None)
        # return '\n\n'.join(
        #     [self.__str(precondition, effect, old_type, new_type, variant=chr(ord('a') + i)) for
        #      i, ((precondition, effect), (old_type, new_type))
        #      in enumerate(zip(self._linking, self._child_type))])
        return '\n\n'.join(
            [self.__str(precondition, effect, object_ids, ambiguous, variant=chr(ord('a') + i)) for
             i, ((precondition, effect), object_ids, ambiguous)
             in enumerate(zip(self._linking, self._object_preconditions, self._ambiguous))])

    def extract_symbol_names(self):
        """
        Get the names of all symbols mentioned in this operator (except notfailed!)
        """
        names = set()
        for pred in self.preconditions + self.effects:
            if isinstance(pred, TypedPredicate):
                names.add(pred.name)
        return names
