# HOO BOY...
from collections import defaultdict
from datetime import datetime

import numpy as np

from domain.inventory import Inventory
from domain.levels import Task
from experiment.collect_compressed import get_initial_states
from experiment.train_pca import PCA_N
from generate_forward_model import build_forward_model, generate_symbol_vocabulary
from pca.pca import PCA
from symbols.data.data import load_operators
from symbols.file_utils import make_path, make_dir, save, load
from symbols.pddl.typed_description import TypedDescription
from symbols.render.image import Image
import pickle

from symbols.render.render import visualise_ppddl_symbols

PCA_PATH = make_path('pca_models/full_pca.dat')
PCA_CLASS = PCA


# this  function is only for states drawn from a distribution. The states should be only those columns in the mask, and
# the ordering should correspond i.e. column1 is mask_1, column2 = mask2 etc

def render_masked(mask, states):
    images = defaultdict(list)
    pca = PCA_CLASS(PCA_N)
    pca.load(PCA_PATH)

    if mask is None:
        raise ValueError

    if len(mask) != states.shape[1]:
        raise ValueError

    for state in states:
        iter = zip(mask, state)
        for m, x in iter:
            if len(x) > 10:
                # image
                image = pca.unflatten(pca.uncompress_(x))
            else:
                # vector
                # print(x)
                x = np.rint(x)
                image = Image.to_array(Inventory.to_image(x))
            images[m].append(image)
    return images


if __name__ == '__main__':
    seeds = [31, 33, 76, 82, 92]
    directory = datetime.today().strftime('%Y%m%d')
    directory = '20191227'
    # directory = '20191204_full'
    directory = '20191230'
    directory = '20200106'
    task_id = 0
    task_dir = make_path(directory, task_id)
    partition_dir = make_path(task_dir, 'partitioned_options')
    operator_dir = make_path(task_dir, 'learned_operators')
    pddl_dir = make_path(task_dir, 'pddl')

    env, doors, objects = Task.generate(seeds[task_id])

    operators = load_operators(operator_dir)
    n_objects = len(doors) + len(
        objects) - 1 + 2  # all the doors and objects - craft table, the agent and the inventory

    pca = PCA_CLASS(PCA_N)
    pca.load(PCA_PATH)

    # initial_states = get_initial_states(task_id, '20191230_raw', pca, 10)
    # init_obs = [x[1] for x in initial_states]

    # generate vocabulary
    propositions_dir = make_path(task_dir, 'propositions')
    (factors, _), (option_symbols, symbol_list) = generate_symbol_vocabulary(env, propositions_dir,
                                                                             operators, n_objects, init_obs)
    #
    # # temp sav
    # save((factors, option_symbols, symbol_list), 'q.tmp')
    (factors, option_symbols, symbol_list) = load('q.tmp')


    # # # visualise vocabulary
    output_dir = make_path(task_dir, 'vis_symbols')
    visualise_ppddl_symbols(propositions_dir, output_dir, env, render_masked)

    # build pddl model
    ppddl_symbol_dir = make_path(task_dir, 'ppddl_symbols')
    rules = build_forward_model(env, task_dir, ppddl_symbol_dir, operators, factors, option_symbols, symbol_list)
    make_dir(pddl_dir)
    with open(make_path(pddl_dir, "pddl_rules.pkl"), 'wb') as file:
        pickle.dump(rules, file)

    with open(make_path(pddl_dir, "pddl_rules.pkl"), 'rb') as file:
        rules = pickle.load(file)

    print(rules)

    with open(make_path(pddl_dir, 'propositional.pddl'), 'w') as file:
        file.write(str(rules))

