from collections import defaultdict
from datetime import datetime

import os
import numpy as np

from domain.inventory import Inventory
from domain.levels import Task
from pca.base_pca import PCA_N
from pca.pca import PCA
from symbols.experimental.learn_preconditions import learn_preconditions
from symbols.experimental.partition_options import partition_options
from symbols.file_utils import make_path
from symbols.render.image import Image
from symbols.render.render import visualise_partitions, visualise_preconditions
import pickle

PCA_PATH = make_path('./pca_modlels/full_pca.dat')

def render(mask, states):
    images = defaultdict(list)
    pca = PCA(PCA_N)
    pca.load(PCA_PATH)

    if mask is None:
        mask = [0, 9]

    for state in states:

        if mask is None:
            # print(state[9])
            state = state[[0]]
            iter = zip([0], state)
        else:
            state = state[mask]
            iter = zip(mask, state)
        for m, x in iter:

            if len(x) > 10:
                # image
                image = pca.unflatten(pca.uncompress_(x))
            else:
                # vector
                # print(x)
                image = Image.to_array(Inventory.to_image(x))
            images[m].append(image)
    return images


if __name__ == '__main__':

    seeds = [31, 33, 76, 82, 92]
    directory = datetime.today().strftime('%Y%m%d')
    directory = '20191204_full'
    directory = '20200106'
    operator_dir = make_path(directory, 'learned_operators')
    task_id = 0
    task_dir = make_path(directory, task_id)
    partition_dir = make_path(task_dir, 'partitioned_options')

    env, _, _ = Task.generate(seeds[task_id])

    preconditions = learn_preconditions(env, task_dir, partition_dir, view='agent', verbose='True', render=render)
    #
    exit(0)

    with open("temp", "wb") as file:
        pickle.dump(preconditions, file)
    with open("temp", "rb") as file:
        preconditions = pickle.load(file)

    visualise_preconditions(env, task_dir, make_path(task_dir, 'vis_preconds'), render, preconditions, view='agent')