import warnings
from datetime import datetime

import pickle
import numpy as np
from domain.levels import Task
from experiment.effect_class import UnionFind, EffectClass
from symbols.data.data import load_operators
from symbols.file_utils import make_path, save, load
from symbols.pddl.description import Description
from symbols.pddl.operator import Operator, TypedOperator
from symbols.pddl.predicate import TypedPredicate
from symbols.pddl.typed_description import TypedDescription
from symbols.pddl.typed_pddl import TypedPDDL
from symbols.transfer.kb import KnowledgeBase
from symbols.transfer.transferable_learned_operator import TransferableLearnedOperator
from symbols.transfer.transferable_operator import TransferableOperator


def find_replacer(proposition, candidates, type_getter):
    """
    Find from the list of candidates which proposition was used to replace the given one
    """
    dist = np.inf
    closest = None
    for other in candidates:
        types = [type_getter(i) for i in proposition.mask]
        if np.array_equal(types, other[0].mask):
            d = proposition.kl_divergence(other[0])
            if d < dist:
                dist = d
                closest = other[1]
    return closest  # todo what if multiple???


def lift_pddl(task_id, pddl: Description, kb: KnowledgeBase):

    type_getter = lambda object_idx: kb.get_type(task_id, object_idx)
    transferable_operators = list()
    typed_pddl = TypedPDDL(task_id, 'task{}'.format(task_id), kb)

    # create typed vocabulary
    predicates = list()
    used_propositions = list()
    all_types = set()
    dropped = list()
    kept = list()
    for proposition in pddl.get_symbols():
        # get the type of each object represented by the proposition
        types = [type_getter(i) for i in proposition.mask]
        types_names = ['type{}'.format(i) for i in types]
        all_types.update(types_names)
        if np.array_equal(types, proposition.mask):
            # an original predicate. Keep in list
            predicate = TypedPredicate(proposition.name, *types_names)
            predicates.append(predicate)
            kb.add_predicate(predicate)
            typed_pddl.add_predicate(predicate)
            kept.append((proposition, predicate))
        else:
            print("Dropping {}".format(proposition.name))
            dropped.append(proposition)
    typed_pddl.set_types(all_types)

    replaced = dict()
    for proposition in dropped:
        replacement = find_replacer(proposition, kept, type_getter)
        replaced[proposition.name] = replacement

    for schema in pddl._schemata:

        valid = True
        operator = TypedOperator("{}-partition-{}".format(schema.option_name, schema.partition), schema.option,
                                 schema.partition)  # this operator is just to generate a PDDL description

        transfer = TransferableOperator(operator)  # this operator will contain information necessary for transfer

        # Trying this: if the operator has a proposition that is not a vocabulary predicate, ignore it because
        # it'll be duplicated later?

        # add preconditions

        for precondition in schema._preconditions:

            if len(schema._preconditions) > 1:
                warnings.warn("Schema has more than 1 precondition! Don't know if I accounted for that!")

            matches = list()
            for proposition in precondition:
                if isinstance(proposition, str):
                    continue  # ignore notfailed

                # TODO:  Trying this: if the operator has a proposition that is not a vocabulary predicate, ignore it because it'll be duplicated later?
                temp = [x for x in predicates if x.name == proposition.name]
                assert len(temp) < 2
                if len(temp) == 0:
                    temp = [replaced[proposition.name]]
                    # break

                predicate = temp[0]
                matches.append((predicate, proposition))

            if len(matches) != len(precondition) - 1:
                valid = False
                continue
            for predicate, proposition in matches:
                grounded = predicate(*proposition.mask)
                operator.add_precondition(grounded)
                transfer.add_precondition(proposition)

        if not valid:
            continue
        # add effects
        rules = schema.rules
        if len(rules) > 1:
            warnings.warn("Must fix probabilistic effects!")
            operator.failure_probability = rules[0].probability
            rules = rules[1:]

            # raise NotImplementedError("Haven't accounted for probabilistic effects here!")
        for rule in rules:
            for wrapper in rule.symbols:
                proposition = wrapper.symbol
                if isinstance(proposition, str):
                    continue  # TODO FIX!!!

                # Trying this: if the operator has a proposition that is not a vocabulary predicate, ignore it because
                # it'll be duplicated later?
                temp = [x for x in predicates if x.name == proposition.name]
                assert len(temp) < 2
                if len(temp) == 0:
                    temp = [replaced[proposition.name]]
                    # continue
                predicate = temp[0]
                if wrapper.sign < 0:
                    predicate = predicate.negate()
                grounded = predicate(*proposition.mask)
                operator.add_effect(grounded)
                transfer.add_effect(wrapper)

        typed_pddl.add_operator(operator)
        transferable_operators.append(transfer)
        kb.add_operator(transfer)


    return typed_pddl, replaced, transferable_operators


if __name__ == '__main__':


    kb = KnowledgeBase()

    seeds = [31, 33, 76, 82, 92]
    directory = datetime.today().strftime('%Y%m%d')
    directory = '20190918'
    directory = '20191119'
    directory = '20191204_full'
    directory = '20191230'
    directory = '20200106'
    task_id = 0
    task_dir = make_path(directory, task_id)
    task_dir = '../data'
    partition_dir = make_path(task_dir, 'partitioned_options')
    operator_dir = make_path(task_dir, 'learned_operators')
    pddl_dir = make_path(task_dir, 'pddl')

    env, doors, objects = Task.generate(seeds[task_id])

    operators = load_operators(operator_dir)
    #             0     1      2     3    4       5       6      7    8       9
    # objects are pov, door, door, door, door,  pickaxe, chest, ore, gold, inventory

    n_objects = len(doors) + len(
        objects) - 1 + 2  # all the doors and objects - craft table, the agent and the inventory

    uf = UnionFind(n_objects)
    # compute merge
    effect_classes = EffectClass(env.action_space.n, n_objects)  # map from objects to effect list
    for operator in operators:
        for prob, eff in zip(operator.list_probabilities, operator.list_effects):
            for object, kde in zip(eff.mask, eff._kdes):

                # not sure what I did here!!!

                # effect_classes.add(operator.option, object, prob, eff._kdes)
                effect_classes.add(operator.option, object, prob, [kde])

    types = dict()
    for i in range(n_objects - 1):
        for j in range(i + 1, n_objects - 1):
            if i == 0:
                continue

            #
            # if i * j != 3:
            #     continue

            if uf.get(i) == uf.get(j):
                print("{} and {} are the same".format(i, j))
                continue

            match, options = effect_classes.is_same(i, j)
            if match:
                uf.union(i, j)
                print("{} and {} are the same".format(i, j))
            elif len(options) > 0:
                print("{} and {} matches on {}".format(i, j, options))

    type_loaded = set()

    for i in range(n_objects):
        type = uf.get(i)
        if type in type_loaded:
            effects = None
        else:
            type_loaded.add(type)

        effects = effect_classes.get(i)
        kb.set_type(task_id, i, type, effects)

    rules = load(make_path(pddl_dir, "pddl_rules.pkl"))
    save(uf, make_path(pddl_dir, "classes.pkl"))

    typed_pddl, replaced, transferable_operators = lift_pddl(task_id, rules, kb)
    typed_pddl.probabilistic = True
    print(typed_pddl)

    save(replaced, make_path(pddl_dir, "replaced.dat"))
    save(typed_pddl, make_path(pddl_dir, "lifted_pddl.pkl"))
    save(typed_pddl, make_path(pddl_dir, "unlinked.pddl"), binary=False)
    save(transferable_operators,  make_path(pddl_dir, "transferable.pkl"))
    save(kb,  make_path(pddl_dir, "knowledge.kb"))
    print(kb)
