from datetime import datetime

import pickle

from domain.levels import Task
from symbols.data.data import load_operators
from symbols.file_utils import make_path, load
from symbols.pddl.predicate import TypedPredicate
import numpy as np

class Problem:
    def __init__(self, problem_name, domain_name):
        self.name = problem_name
        self.domain = domain_name
        self.objects = dict()
        self.types = set()
        self.start_predicates = dict()
        self.start_propositions = list()
        self.goal_predicates = dict()
        self.goal_propositions = list()

    def add_object(self, id: int, object: str, type: str):
        self.objects[id] = (object, type)
        self.types.add(type)

    def add_start_predicate(self, predicate, object_id=None):
        if object_id is None:
            self.start_propositions.append(predicate)
        else:
            self.start_predicates[object_id] = predicate(object_id)

    def add_goal_predicate(self, predicate, object_id=None):
        if object_id is None:
            self.goal_propositions.append(predicate)
        else:
            self.goal_predicates[object_id] = predicate(object_id)

    def __str__(self):

        objects = ['{} - {}'.format(o, t) for _, (o, t) in self.objects.items()]
        objects = '\n\t\t\t\t'.join(objects)

        starts = list()
        for _, predicate in self.start_predicates.items():
            assert predicate.is_grounded()
            groundings = ' '.join([o for o in predicate.grounding])
            starts.append('({} {})'.format(predicate.name, groundings))
        for prop in self.start_propositions:
            starts.append('({})'.format(prop))
        for id, (object, _) in self.objects.items():
            starts.append('(= (id {}) {})'.format(object, id))


        start = '\n\t\t\t'.join(starts)

        goals = list()
        goals.append('({})'.format(self.start_propositions[0]))  # add notfailed to the goal condition
        for _, predicate in self.goal_predicates.items():
            assert predicate.is_grounded
            groundings = ' '.join([o for o in predicate.grounding])
            goals.append('({} {})'.format(predicate.name, groundings))
        for prop in self.goal_propositions:
            goals.append('({})'.format(prop))

        for _, predicate in self.start_predicates.items():
            assert predicate.is_grounded
            groundings = ' '.join([o for o in predicate.grounding])
            goals.append('(not ({} {}))'.format(predicate.name, groundings))

        if len(goals) > 1:
            goal = '(and {})'.format('\n\t\t\t\t'.join(goals))
        else:
            goal = '{}'.format('\n'.join(goals))

        definition = "(define (problem {})\n" \
                     "\t(:domain {})\n\n" \
                     "\t(:objects   {}\n" \
                     "\t)\n\n" \
                     "\t(:init  {}\n" \
                     "\t)\n\n" \
                     "\t(:goal {}\n" \
                     "\t)\n" \
                     ")".format(self.name, self.domain, objects, start, goal)
        return definition


"""
(define (problem open-chest)
   (:domain task1)

   (:objects hand - Hand
             red blue green - Block
   )
   (:init (BlockOnTable red)
          (BlockOnTable blue)
          (BlockOnTable green)
          (HandEmpty hand)
          (notfailed)
    )
   (:goal (and (BlockOnBlock red)
               (BlockOnBlock_BlockCovered green)
               (BlockOnTable_BlockCovered blue)
               ))
)
"""

def find_first(vocabulary, type):
    for predicate in vocabulary:
        if isinstance(predicate, TypedPredicate):
            if len(predicate.param_types) > 1:
                raise NotImplementedError("Haven't accounted for multiple params")
            if type in predicate.param_types:
                return predicate
    return None

def find_last(vocabulary, type):
    for predicate in reversed(vocabulary):
        if isinstance(predicate, TypedPredicate):
            if len(predicate.param_types) > 1:
                raise NotImplementedError("Haven't accounted for multiple params")
            if type in predicate.param_types:
                return predicate
    return None


if __name__ == '__main__':
    seeds = [31, 33, 76, 82, 92]
    directory = datetime.today().strftime('%Y%m%d')
    directory = '20190918'
    directory = '20191119'
    directory = '20191204_full'
    directory = '20191230'
    directory = '20200106'
    task_id = 0
    task_dir = make_path(directory, task_id)
    partition_dir = make_path(task_dir, 'partitioned_options')
    operator_dir = make_path(task_dir, 'learned_operators')
    pddl_dir = make_path(task_dir, 'pddl')

    env, doors, objects = Task.generate(seeds[task_id])

    pddl = load(make_path(pddl_dir, "linked_pddl.pkl"))

    # 0     1      2     3    4       5       6      7    8       9
    # objects are pov, door, door, door, door,  pickaxe, chest, ore, gold, inventory

    n_objects = len(doors) + len(
        objects) - 1 + 2  # all the doors and objects - craft table, + the agent and the inventory

    pddl_problem = Problem('open-chest', 'task{}'.format(task_id + 1))

    object_names = ['object-{}'.format(i) for i in range(n_objects)]
    object_names = ['agent', 'door1', 'door2', 'door3', 'door4', 'pickaxe', 'chest', 'ore', 'gold', 'inventory']
    for i, name in enumerate(object_names):
        pddl_problem.add_object(i, name, pddl.type_name(i))

    # about to make some assumptions here!
    pddl_problem.add_start_predicate(pddl.vocabulary[0])  # notfailed is always first
    for i, name in enumerate(object_names):

        type = pddl.type_name(i)
        parent = pddl.get_parent_type(type)
        if parent is not None:
            type = parent
        pred = find_first(pddl.vocabulary, type)  # assume that the first predicate referring to the type is in the start state
        pddl_problem.add_start_predicate(pred, name)  # add grounded predicate with name

    # get problem symbols
    # the start is the one with the smallest z-value
    means = load(make_path(pddl_dir, "psymbol_means.pkl"))
    idx = 0
    small = means[0][1]
    for i, mean in enumerate(means):
        if mean[1] < small:
            small = mean[1]
            idx = i
    for i, predicate in enumerate(pddl.vocabulary):
        if i > 0 and not isinstance(predicate, TypedPredicate) and predicate.name == 'psymbol_{}'.format(idx):
            pddl_problem.add_start_predicate(predicate)  # the first proposition that isn't notfailed is our start position!
            break


    # TODO come back to this: manually insert goal
    chest_idx = 6
    name = object_names[chest_idx]
    type = pddl.type_name(chest_idx)
    pred = find_last(pddl.vocabulary, type)  # assume that the last predicate referring to the type is in the goal state
    pddl_problem.add_goal_predicate(pred, name)  # add grounded predicate with name

    print(pddl_problem)


    with open(make_path(pddl_dir, 'problem.pddl'), 'w') as file:
        file.write(str(pddl_problem))


