import pickle
import warnings
from collections import defaultdict
from datetime import datetime

from sklearn.cluster import DBSCAN

from domain.levels import Task
from symbols.data.data import load_option_partitions, load_operators
from symbols.experimental.partition_options import _flatten
from symbols.file_utils import make_path, make_dir, load, save
import numpy as np

import matplotlib.pyplot as plt
from matplotlib import colors

from symbols.link.link import OperatorData, ProblemSymbols
from symbols.logger.transition_sample import TransitionSample
from symbols.pddl.operator import TypedOperator
from symbols.pddl.predicate import Predicate, TypedPredicate
from symbols.utils import samples2np
import copy


def show(mean):
    data = np.zeros(shape=(41, 11))
    data[0, :] = 1
    data[-1, :] = 1
    data[:, 0] = 1
    data[:, -1] = 1

    data[-7 - 1, :] = 1
    data[-15 - 1, :] = 1
    data[-23 - 1, :] = 1
    data[-31 - 1, :] = 1

    row, col = 39 - int(mean[1]), int(10 - mean[0])
    data[row, col] = 0.5

    cmap = colors.ListedColormap(['Blue', 'red'])
    plt.figure(figsize=(6, 6))
    plt.pcolor(data[::-1])
    # plt.xlim(12, -1)  # decreasing time
    plt.show()


def visualise_symbols(dir, symbols):
    make_dir(dir)
    for i, d in enumerate(symbols):
        mean = np.mean(d, axis=0)
        data = np.zeros(shape=(41, 11))

        data[0, :] = 1
        data[-1, :] = 1
        data[:, 0] = 1
        data[:, -1] = 1

        data[-7 - 1, :] = 1
        data[-15 - 1, :] = 1
        data[-23 - 1, :] = 1
        data[-31 - 1, :] = 1

        row, col = 39 - int(mean[1]), int(10 - mean[0])
        data[row, col] = 0.5

        cmap = colors.ListedColormap(['Blue', 'red'])
        plt.figure(figsize=(6, 6))
        plt.pcolor(data[::-1])
        # plt.xlim(12, -1)  # decreasing time
        plt.savefig(make_path(dir, '{}-{}.png'.format(i, mean)))


def _get_types(operators):
    types = set()
    for operator in operators:
        for predicate in operator.preconditions:
            if isinstance(predicate, TypedPredicate):
                for x in predicate.param_types:
                    types.add(x)
    return types


def extract_symbol_names(schema):
    names = set()
    for precondition in schema._preconditions:
        for kde in precondition:
            if isinstance(kde, str):
                continue
            names.add(kde.name)
    for rule in schema.rules:
        for temp in rule.symbols:
            symbol = temp.symbol
            if not isinstance(symbol, str):
                names.add(symbol.name)
    return names


def replace(names, replaced):
    new_names = set()
    for name in names:
        if name in replaced:
            new_names.add(replaced[name])
        else:
            new_names.add(name)
    return new_names


def _find_closest_match(pddl, map, partitioned_option, replaced):
    # rules = load(make_path(pddl_dir, "pddl_rules.pkl"))
    # old_schemata = [schema for schema in rules if
    #                 schema.option == partitioned_option.option and schema.partition == partitioned_option.partition]
    #
    candidates = [x for (option, partition), x in map.items() if
                  partitioned_option.option == option and partitioned_option.partition != partition]
    # matches = list()
    # for schema in old_schemata:
    #     names = extract_symbol_names(schema)
    #     names = replace(names, replaced)
    #
    #     for candidate in candidates:
    #         for operator in candidate:
    #             other = operator.extract_symbol_names()
    #             if other == names:
    #                 matches.append(candidate)
    #
    # if len(matches) != 0:
    #     return matches


    types = {'type{}'.format(pddl.object_type(x)) for x in partitioned_option._mask}
    matches = list()
    for candidate in candidates:
        if types == _get_types(candidate):
            matches += candidate

    return matches


def extract_raw_operator(option, partition, operators):
    ops = [x for x in operators if x.option == option and x.partition == partition]
    if len(ops) != 1:
        raise ValueError
    return ops[0]


if __name__ == '__main__':

    seeds = [31, 33, 76, 82, 92]
    directory = datetime.today().strftime('%Y%m%d')
    directory = '20190918'
    directory = '20191119'
    directory = '20191204_full'
    directory = '20200106'
    task_id = 0
    task_dir = make_path(directory, task_id)
    partition_dir = make_path(task_dir, 'partitioned_options')
    operator_dir = make_path(task_dir, 'learned_operators')
    pddl_dir = make_path(task_dir, 'pddl')

    env, _, _ = Task.generate(seeds[task_id])

    partitions = load_option_partitions(env.action_space, partition_dir)
    operators = load_operators(operator_dir)

    pddl = load(make_path(pddl_dir, "lifted_pddl.pkl"))

    replaced = load(make_path(pddl_dir, "replaced.dat"))

    idx_to_schema = defaultdict(list)

    for operator in pddl.operators():
        idx_to_schema[(operator.option, operator.partition)].append(operator)

    operator_data = list()

    for option, ps in partitions.items():
        for partitioned_option in ps:
            if partitioned_option.option != option:
                raise ValueError

            schemata = idx_to_schema[partitioned_option.option, partitioned_option.partition]
            if len(schemata) == 0:
                schemata = _find_closest_match(pddl, idx_to_schema, partitioned_option, replaced)

            if len(schemata) == 0:
                print("XXXXXAAAHA {} {}".format(option, partitioned_option.partition))

            data = OperatorData(partitioned_option, schemata,
                                extract_raw_operator(option, partitioned_option.partition, operators))
            operator_data.append(data)

    save(operator_data, make_path(pddl_dir, "operator_data.dat"))

    temp = copy.deepcopy(operator_data)
    kb = load(make_path(pddl_dir, "knowledge.kb"))
    for od in temp:
        kb.add_operator_data(task_id, od)
    save(kb, make_path(pddl_dir, "knowledge.kb"))

    problem_symbols = ProblemSymbols()

    for operator in operator_data:

        for i in range(operator.n_subpartitions):

            # if len(operator.observation_mask(i)) > 0:

            # TODO will probably break for stochastic transitions

            for link in operator.links:

                for start, end in link:
                    precondition = problem_symbols.add(start)
                    if end is None:
                        operator.add_problem_symbols(pddl, precondition, -1)
                    else:
                        effect = problem_symbols.add(end)
                        operator.add_problem_symbols(pddl, precondition, effect)
                    print(operator.option, operator.partition)

                    # precondition = problem_symbols.add(operator.observations(i))
                    # effect = problem_symbols.add(operator.next_observations(i))
                    # operator.add_problem_symbols(precondition, effect)
                    # print(operator.option, operator.partition)

    for i in range(len(problem_symbols)):
        pddl.add_predicate(Predicate('psymbol_{}'.format(i)))

    pddl.probabilistic = False
    with open(make_path(pddl_dir, 'linked.pddl'), 'w') as file:
        file.write(str(pddl))

    pddl.probabilistic = True
    with open(make_path(pddl_dir, 'probabilistic_linked.ppddl'), 'w') as file:
        file.write(str(pddl))

    with open(make_path(pddl_dir, "linked_pddl.pkl"), 'wb') as file:
        pickle.dump(pddl, file)

    print(pddl)

    # visualise_symbols(make_path(task_dir, 'vis_p_symbols'), problem_symbols._symbols)
    save(problem_symbols.means(), make_path(pddl_dir, "psymbol_means.pkl"))
