import pickle
from collections import defaultdict
from datetime import datetime

import numpy as np

from domain.inventory import Inventory
from domain.levels import Task
from experiment.train_pca import PCA_N
from pca.pca import PCA
from symbols.experimental.learn_effects import learn_effects
from symbols.experimental.learn_preconditions import learn_preconditions
from symbols.file_utils import make_path, make_dir, save, load
from symbols.render.image import Image
from symbols.render.render import visualise_effects, visualise_preconditions
from symbols.symbols.learned_operator import LearnedOperator

PCA_PATH = make_path('pca_models/full_pca.dat')
PCA_CLASS = PCA

# this  function is only for states drawn from a distribution. The states should be only those columns in the mask, and
# the ordering should correspond i.e. column1 is mask_1, column2 = mask2 etc

def render_effects(mask, states):
    images = defaultdict(list)
    pca = PCA_CLASS(PCA_N)
    pca.load(PCA_PATH)

    if mask is None:
        raise ValueError

    if len(mask) != states.shape[1]:
        raise ValueError

    for state in states:
        iter = zip(mask, state)
        for m, x in iter:
            if len(x) > 10:
                # image
                image = pca.unflatten(pca.uncompress_(x))
            else:
                # vector
                # print(x)
                x = np.rint(x)
                image = Image.to_array(Inventory.to_image(x))
            images[m].append(image)
    return images


def render_preconditions(mask, states):
    images = defaultdict(list)
    pca = PCA_CLASS(PCA_N)
    pca.load(PCA_PATH)

    if mask is None:
        mask = [0, 9]

    for state in states:

        if mask is None:
            # print(state[9])
            state = state[[0]]
            iter = zip([0], state)
        else:
            state = state[mask]
            iter = zip(mask, state)
        for m, x in iter:

            if len(x) > 10:
                # image
                image = pca.unflatten(pca.uncompress_(x))
            else:
                # vector
                # print(x)
                image = Image.to_array(Inventory.to_image(x))
            images[m].append(image)
    return images
