import warnings

from symbols.pddl.schema import Schema





class LiftedSchema:
    def __init__(self, base_schema):
        self._base_schema = base_schema
        self._id = 0
        self.action_descriptor = None
        self._links = None

    @property
    def option(self):
        return self._base_schema.option

    @property
    def partition(self):
        return self._base_schema.partition

    def __str__(self):

        self._base_schema.action_descriptor = self.action_descriptor

        comment = 'Action ' + self._base_schema.option_name + '-partition-' + str(self._base_schema.partition)
        name = self._base_schema.option_name + '_' + str(self._id)
        precondition = self._base_schema._precondition_to_str()

        effect = self._effect_to_str()

        action = '\t(:action {}\n' \
                 '\t :parameters()\n' \
                 '\t :precondition {}\n' \
                 '\t :effect {}'.format(name, precondition, effect)

        template = '\t;{}\n' \
                   '{}\n' \
                   '\t)'.format(comment, action)
        return template + '\n'


    def set_id(self, id):
        self._id = id


    def _effect_to_str(self):
        s = ''
        count = 0
        for rule_line in self._links:
            temp = self._links[rule_line]
            for (x, y, probs) in temp.emit():

                if len(y) > 1:
                    warnings.warn("Haven't yet coded probabilistic partition effects")

                else:

                    effect = self._base_effect_to_str(self._base_schema._effects[rule_line])

                    # insert assign into rule

                    if len(y) == 1:
                        y = int(y[0])
                    else:
                        y = y[0]

                    effect = effect[:-1] + ' (assign (partition) ' + str(y) + '))'
                    if count > 0:
                        s += '\t\t\t\t  '
                    s += '(when {} {})\n'.format(
                        '(= (partition) ' + str(int(x)) + ')',
                        effect
                    )
                    count += 1

        if count > 1:
            s = '(and ' + s + '\t\t\t )'
        return s


    def _base_effect_to_str(self, rule):

        temp = 'increase'
        r = rule.reward
        if r < 0:
            temp = 'decrease'
            r = -r
        reward_str = temp + ' (reward) ' + '{0:.2f}'.format(r)

        return Schema.conjunction_to_str(rule._symbols + [reward_str])

    def set_links(self, links):
        self._links = links
