import copy
import json
import os
import pickle
import random
import urllib.request
from collections import defaultdict

import numpy as np

from domain.inventory import Inventory
from domain.levels import Task
from experiment.collect import collect_data
from experiment.collect_compressed import extract_data, get_initial_states
from experiment.effect_class import UnionFind, EffectClass
from experiment.generate_forward_model import generate_symbol_vocabulary, build_forward_model
from experiment.generate_problem_pddl import Problem, find_first, find_last
from experiment.link_problem_data import _find_closest_match, extract_raw_operator, visualise_symbols
from experiment.merge_similar import lift_pddl
from pca.base_pca import PCA_N
from pca.pca import PCA
from symbols.data.data import load_operators, load_option_partitions
from symbols.experimental.learn_effects import learn_effects
from symbols.experimental.learn_preconditions import learn_preconditions
from symbols.experimental.partition_options import partition_options
from symbols.file_utils import make_path, make_dir, load, save
from symbols.link.link import OperatorData, ProblemSymbols
from symbols.render.image import Image
from symbols.render.render import visualise_partitions, visualise_effects, visualise_preconditions, \
    visualise_ppddl_symbols
from symbols.symbols.learned_operator import LearnedOperator
from symbols.transfer.kb import KnowledgeBase
from symbols.pddl.predicate import Predicate, TypedPredicate

PCA_PATH = os.path.abspath('pca_models/full_pca.dat')


def render_masked(mask, states):
    images = defaultdict(list)
    pca = PCA(PCA_N)
    pca.load(PCA_PATH)

    if mask is None:
        raise ValueError

    if len(mask) != states.shape[1]:
        raise ValueError

    for state in states:
        iter = zip(mask, state)
        for m, x in iter:
            if len(x) > 10:
                # image
                image = pca.unflatten(pca.uncompress_(x))
            else:
                # vector
                # print(x)
                x = np.rint(x)
                image = Image.to_array(Inventory.to_image(x))
            images[m].append(image)
    return images


def render_preconditions(mask, states):
    images = defaultdict(list)
    pca = PCA(PCA_N)
    pca.load(PCA_PATH)

    if mask is None:
        mask = [0, 9]

    for state in states:

        if mask is None:
            # print(state[9])
            state = state[[0]]
            iter = zip([0], state)
        else:
            state = state[mask]
            iter = zip(mask, state)
        for m, x in iter:

            if len(x) > 10:
                # image
                image = pca.unflatten(pca.uncompress_(x))
            else:
                # vector
                # print(x)
                image = Image.to_array(Inventory.to_image(x))
            images[m].append(image)
    return images


def render_effects(mask, states):
    images = defaultdict(list)
    pca = PCA(PCA_N)
    pca.load(PCA_PATH)

    if mask is None:
        raise ValueError

    if len(mask) != states.shape[1]:
        raise ValueError

    for state in states:
        iter = zip(mask, state)
        for m, x in iter:
            if len(x) > 10:
                # image
                image = pca.unflatten(pca.uncompress_(x))
            else:
                # vector
                # print(x)
                x = np.rint(x)
                image = Image.to_array(Inventory.to_image(x))
            images[m].append(image)
    return images


def render_partition(_, states):
    images = defaultdict(list)
    pca = PCA(PCA_N)
    pca.load(PCA_PATH)
    mask = [0, 9]
    for state in states:
        state = state[mask]
        iter = zip(mask, state)
        for m, x in iter:
            if len(x) > 10:
                # image
                image = pca.unflatten(pca.uncompress_(x))
            else:
                # vector
                image = Image.to_array(Inventory.to_image(x))
            images[m].append(image)
    return images


if __name__ == '__main__':

    TASK_ID = 0
    RAW = True
    seeds = [31, 33, 76, 82, 92]
    n_episodes = 15
    random.seed(seeds[TASK_ID])
    np.random.seed(seeds[TASK_ID])
    env, _, _ = Task.generate(seeds[TASK_ID])

    raw_dir_name = '../data/raw_data'
    output_dir = '../data'

    # 1. Collect raw image data
    # collect_data(PCA_PATH, seeds[TASK_ID], raw_dir_name, n_episodes, 300, verbose=True, raw=True)

    # 2. Compress to PCA
    pca = PCA(PCA_N)
    pca.load(PCA_PATH)
    extract_data(TASK_ID, raw_dir_name, pca, output_dir, n_episodes, verbose=True)

    # 3. Partition options
    partition_dir = make_path(output_dir, 'partitioned_options')
    partition_options(env, output_dir, partition_dir, view='agent', )
    visualise_partitions(env, partition_dir, make_path(output_dir, 'vis_partitioned'), render_partition, verbose=True)

    # 4. Learn preconditions
    operator_dir = make_path(output_dir, 'learned_operators')
    preconditions = learn_preconditions(env, output_dir, partition_dir, view='agent', verbose='True', render=None)
    effects = learn_effects(env, partition_dir, view='agent', verbose='True')
    make_dir(operator_dir, clean=False)
    operators = list()
    for (option, partition), precondition in preconditions.items():
        if (option, partition) in effects:
            probs, rewards, eff = effects[(option, partition)]
            operators.append(LearnedOperator(option, partition, precondition, probs, rewards, eff))
    for operator in operators:
        filename = make_path(operator_dir, 'operator-{}-{}.dat'.format(operator.option, operator.partition))
        with open(filename, "wb") as file:
            pickle.dump(operator, file)
    visualise_effects(env, make_path(output_dir, 'vis_effects'), render_effects, effects, view='agent')
    visualise_preconditions(env, output_dir, make_path(output_dir, 'vis_preconds'), render_preconditions, preconditions,
                            view='agent')


    # #
    # # # 5. Generate propositional PDDL
    pddl_dir = make_path(output_dir, 'pddl')
    env, doors, objects = Task.generate(seeds[TASK_ID])

    operators = load_operators(operator_dir)
    # all the doors and objects - craft table, the agent and the inventory
    n_objects = len(doors) + len(objects) - 1 + 2
    #
    pca = PCA(PCA_N)
    pca.load(PCA_PATH)

    initial_states = get_initial_states(TASK_ID, raw_dir_name, pca, 10)
    init_obs = [x[1] for x in initial_states]
    #
    # generate vocabulary
    propositions_dir = make_path(output_dir, 'propositions')
    (factors, _), (option_symbols, symbol_list) = generate_symbol_vocabulary(env, propositions_dir,
                                                                             operators, n_objects, init_obs)

    visualise_ppddl_symbols(propositions_dir, make_path(output_dir, 'vis_symbols'), env, render_masked)

    # build pddl model
    ppddl_symbol_dir = make_path(output_dir, 'ppddl_symbols')
    rules = build_forward_model(env, output_dir, ppddl_symbol_dir, operators, factors, option_symbols, symbol_list)
    make_dir(pddl_dir)
    with open(make_path(pddl_dir, "pddl_rules.pkl"), 'wb') as file:
        pickle.dump(rules, file)
    with open(make_path(pddl_dir, "pddl_rules.pkl"), 'rb') as file:
        rules = pickle.load(file)
    with open(make_path(pddl_dir, 'propositional.pddl'), 'w') as file:
        file.write(str(rules))

    # 6. Merge similar object types
    kb = KnowledgeBase()
    env, doors, objects = Task.generate(seeds[TASK_ID])

    operators = load_operators(operator_dir)
    #             0     1      2     3    4       5       6      7    8       9
    # objects are pov, door, door, door, door,  pickaxe, chest, ore, gold, inventory
    uf = UnionFind(n_objects)
    # compute merge
    effect_classes = EffectClass(env.action_space.n, n_objects)  # map from objects to effect list
    for operator in operators:
        for prob, eff in zip(operator.list_probabilities, operator.list_effects):
            for object, kde in zip(eff.mask, eff._kdes):
                # not sure what I did here!!!
                # effect_classes.add(operator.option, object, prob, eff._kdes)
                effect_classes.add(operator.option, object, prob, [kde])

    types = dict()
    for i in range(n_objects - 1):
        for j in range(i + 1, n_objects - 1):
            if i == 0:
                continue
            if uf.get(i) == uf.get(j):
                print("{} and {} are the same".format(i, j))
                continue
            match, options = effect_classes.is_same(i, j)
            if match:
                uf.union(i, j)
                print("{} and {} are the same".format(i, j))
            elif len(options) > 0:
                print("{} and {} matches on {}".format(i, j, options))
    type_loaded = set()
    for i in range(n_objects):
        type = uf.get(i)
        if type in type_loaded:
            effects = None
        else:
            type_loaded.add(type)

        effects = effect_classes.get(i)
        kb.set_type(TASK_ID, i, type, effects)

    rules = load(make_path(pddl_dir, "pddl_rules.pkl"))
    save(uf, make_path(pddl_dir, "classes.pkl"))

    typed_pddl, replaced, transferable_operators = lift_pddl(TASK_ID, rules, kb)
    typed_pddl.probabilistic = False
    save(typed_pddl, make_path(pddl_dir, "unlinked_deterministic.pddl"), binary=False)
    typed_pddl.probabilistic = True
    print(typed_pddl)

    save(replaced, make_path(pddl_dir, "replaced.dat"))
    save(typed_pddl, make_path(pddl_dir, "lifted_pddl.pkl"))
    save(typed_pddl, make_path(pddl_dir, "unlinked.pddl"), binary=False)
    save(transferable_operators, make_path(pddl_dir, "transferable.pkl"))
    save(kb, make_path(pddl_dir, "knowledge.kb"))
    print(kb)

    # 7. Link
    env, _, _ = Task.generate(seeds[TASK_ID])

    partitions = load_option_partitions(env.action_space, partition_dir)
    operators = load_operators(operator_dir)

    pddl = load(make_path(pddl_dir, "lifted_pddl.pkl"))

    replaced = load(make_path(pddl_dir, "replaced.dat"))

    idx_to_schema = defaultdict(list)

    for operator in pddl.operators():
        idx_to_schema[(operator.option, operator.partition)].append(operator)

    operator_data = list()

    for option, ps in partitions.items():
        for partitioned_option in ps:
            if partitioned_option.option != option:
                raise ValueError

            schemata = idx_to_schema[partitioned_option.option, partitioned_option.partition]
            if len(schemata) == 0:
                schemata = _find_closest_match(pddl, idx_to_schema, partitioned_option, replaced)

            if len(schemata) == 0:
                print("XXXXXAAAHA {} {}".format(option, partitioned_option.partition))

            data = OperatorData(partitioned_option, schemata,
                                extract_raw_operator(option, partitioned_option.partition, operators))
            operator_data.append(data)

    save(operator_data, make_path(pddl_dir, "operator_data.dat"))

    temp = copy.deepcopy(operator_data)
    kb = load(make_path(pddl_dir, "knowledge.kb"))
    for od in temp:
        kb.add_operator_data(TASK_ID, od)
    save(kb, make_path(pddl_dir, "knowledge.kb"))

    problem_symbols = ProblemSymbols()

    for operator in operator_data:

        for i in range(operator.n_subpartitions):

            # if len(operator.observation_mask(i)) > 0:

            # TODO will probably break for stochastic transitions

            for link in operator.links:

                for start, end in link:
                    precondition = problem_symbols.add(start)
                    if end is None:
                        operator.add_problem_symbols(pddl, precondition, -1)
                    else:
                        effect = problem_symbols.add(end)
                        operator.add_problem_symbols(pddl, precondition, effect)
                    print(operator.option, operator.partition)

                    # precondition = problem_symbols.add(operator.observations(i))
                    # effect = problem_symbols.add(operator.next_observations(i))
                    # operator.add_problem_symbols(precondition, effect)
                    # print(operator.option, operator.partition)

    for i in range(len(problem_symbols)):
        pddl.add_predicate(Predicate('psymbol_{}'.format(i)))

    pddl.probabilistic = False
    with open(make_path(pddl_dir, 'linked.pddl'), 'w') as file:
        file.write(str(pddl))

    pddl.probabilistic = True
    with open(make_path(pddl_dir, 'probabilistic_linked.ppddl'), 'w') as file:
        file.write(str(pddl))

    with open(make_path(pddl_dir, "linked_pddl.pkl"), 'wb') as file:
        pickle.dump(pddl, file)

    print(pddl)

    visualise_symbols(make_path(output_dir, 'vis_p_symbols'), problem_symbols._symbols)
    save(problem_symbols.means(), make_path(pddl_dir, "psymbol_means.pkl"))

    print("AND DONE!")

    # 8. Generate problem PDDL
    pddl_dir = make_path(output_dir, 'pddl')
    pddl = load(make_path(pddl_dir, "linked_pddl.pkl"))

    pddl_problem = Problem('open-chest', 'task{}'.format(TASK_ID))

    object_names = ['object-{}'.format(i) for i in range(n_objects)]
    object_names = ['agent', 'door1', 'door2', 'door3', 'door4', 'pickaxe', 'chest', 'ore', 'gold', 'inventory']
    for i, name in enumerate(object_names):
        pddl_problem.add_object(i, name, pddl.type_name(i))

    # about to make some assumptions here!
    pddl_problem.add_start_predicate(pddl.vocabulary[0])  # notfailed is always first
    for i, name in enumerate(object_names):

        type = pddl.type_name(i)
        parent = pddl.get_parent_type(type)
        if parent is not None:
            type = parent
        pred = find_first(pddl.vocabulary,
                          type)  # assume that the first predicate referring to the type is in the start state
        pddl_problem.add_start_predicate(pred, name)  # add grounded predicate with name

    # get problem symbols
    # the start is the one with the smallest z-value
    means = load(make_path(pddl_dir, "psymbol_means.pkl"))
    idx = 0
    small = means[0][1]
    for i, mean in enumerate(means):
        if mean[1] < small:
            small = mean[1]
            idx = i
    for i, predicate in enumerate(pddl.vocabulary):
        if i > 0 and not isinstance(predicate, TypedPredicate) and predicate.name == 'psymbol_{}'.format(idx):
            pddl_problem.add_start_predicate(
                predicate)  # the first proposition that isn't notfailed is our start position!
            break

    # TODO come back to this: manually insert goal
    chest_idx = 6
    name = object_names[chest_idx]
    type = pddl.type_name(chest_idx)
    pred = find_last(pddl.vocabulary, type)  # assume that the last predicate referring to the type is in the goal state
    pddl_problem.add_goal_predicate(pred, name)  # add grounded predicate with name

    print(pddl_problem)

    with open(make_path(pddl_dir, 'problem.pddl'), 'w') as file:
        file.write(str(pddl_problem))

    with open(make_path(pddl_dir, 'linked.pddl'), 'r') as file:
        domain = file.read()

    with open(make_path(pddl_dir, 'problem.pddl'), 'r') as file:
        problem = file.read()

    data = {'domain': domain,
            'problem': problem}

    params = json.dumps(data).encode('utf8')
    request = urllib.request.Request('http://solver.planning.domains/solve', headers={'Content-Type': 'application/json'},
                                 data=params)


    # proxy = urllib.request.ProxyHandler({'http': r'http://ds\a0029938:Password250891@proxyad.wits.ac.za:80'})
    # auth = urllib.request.HTTPBasicAuthHandler()
    # opener = urllib.request.build_opener(proxy, auth, urllib.request.HTTPHandler)
    # urllib.request.install_opener(opener)

    resp = json.loads(urllib.request.urlopen(request).read())

    result = resp['result']

    print(result['output'])

    if result['parse_status'] == 'ok':

        print('\n'.join([act['name'] for act in resp['result']['plan']]))

        print()

        print('\n\n'.join([act['action'] for act in resp['result']['plan']]))
