from collections import defaultdict
from datetime import datetime

import os
import numpy as np

from domain.inventory import Inventory
from domain.levels import Task
from experiment.train_pca import PCA_N
from pca.pca import PCA
from symbols.experimental.learn_effects import learn_effects
from symbols.experimental.learn_preconditions import learn_preconditions
from symbols.experimental.partition_options import partition_options
from symbols.file_utils import make_path
from symbols.render.image import Image
from symbols.render.render import visualise_partitions, visualise_preconditions, visualise_effects
import pickle

PCA_PATH = make_path('pca_models/full_pca.dat')
PCA_CLASS = PCA

# this  function is only for states drawn from a distribution. The states should be only those columns in the mask, and
# the ordering should correspond i.e. column1 is mask_1, column2 = mask2 etc

def render_masked(mask, states):
    images = defaultdict(list)
    pca = PCA_CLASS(PCA_N)
    pca.load(PCA_PATH)

    if mask is None:
        raise ValueError

    if len(mask) != states.shape[1]:
        raise ValueError

    for state in states:
        iter = zip(mask, state)
        for m, x in iter:
            if len(x) > 10:
                # image
                image = pca.unflatten(pca.uncompress_(x))
            else:
                # vector
                # print(x)
                x = np.rint(x)
                image = Image.to_array(Inventory.to_image(x))
            images[m].append(image)
    return images


if __name__ == '__main__':
    seeds = [31, 33, 76, 82, 92]
    directory = datetime.today().strftime('%Y%m%d')
    directory = '20190918'
    directory = '20191119'
    operator_dir = make_path(directory, 'learned_operators')
    task_id = 0
    task_dir = make_path(directory, task_id)
    partition_dir = make_path(task_dir, 'partitioned_options')

    env, _, _ = Task.generate(seeds[task_id])

    effects = learn_effects(env, partition_dir, view='agent', verbose='True')

    with open("eff", "wb") as file:
        pickle.dump(effects, file)
    with open("eff", "rb") as file:
        effects = pickle.load(file)

    visualise_effects(env, make_path(task_dir, 'vis_effects'), render_masked, effects, view='agent')
