from collections import defaultdict
from datetime import datetime

import os
import numpy as np

from domain.inventory import Inventory
from domain.levels import Task
from pca.base_pca import PCA_N
from pca.pca import PCA
from pca.sparse_pca import SparsePCA
from symbols.experimental.partition_options import partition_options
from symbols.file_utils import make_path
from symbols.render.image import Image
from symbols.render.render import visualise_partitions
import matplotlib.pyplot as plt

PCA_PATH = make_path('pca_models/full_pca.dat')
PCA_CLASS = PCA

def render(_, states):
    images = list()
    pca = PCA_CLASS(PCA_N)

    pca.load(PCA_PATH)
    for state in states:
        x = state[0][0:PCA_N]
        image = pca.unflatten(pca.uncompress_(x))
        images.append(image)
    return images

def render2(_, states):
    pca = PCA_CLASS(PCA_N)
    pca.load(PCA_PATH)
    n_objects = states.shape[1]
    objects = [[] for _ in range(n_objects)]
    for state in states:

        for i in range(n_objects):
            x = state[i]
            if x.shape[0] == PCA_N:
                # image
                image = pca.unflatten(pca.uncompress_(x))
                objects[i].append(image)
            else:
                #inventory
                objects[i].append(x)
    return objects




def render_state(mask, states):
    images = list()
    pca = PCA_CLASS(PCA_N)
    pca.load(PCA_PATH)
    for state in states:
        mask = np.arange(PCA_N)
        x = state[mask]
        image = pca.unflatten(pca.uncompress_(x))
        images.append(image)
    return images


def render_partition(_, states):
    images = defaultdict(list)
    pca = PCA_CLASS(PCA_N)
    pca.load(PCA_PATH)
    mask = [0, 9]
    for state in states:
        state = state[mask]
        iter = zip(mask, state)
        for m, x in iter:
            if len(x) > 10:
                # image
                image = pca.unflatten(pca.uncompress_(x))
            else:
                # vector
                image = Image.to_array(Inventory.to_image(x))
            images[m].append(image)
    return images


def render_debug_states(states):
    pca = PCA_CLASS(PCA_N)
    pca.load(PCA_PATH)
    fig = plt.figure(figsize=(6, 6))
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
    for i, state in enumerate(states):

        if i > 35:
            break
        x = state[0:PCA_N]
        image = pca.unflatten(pca.uncompress_(x))
        ax = fig.add_subplot(5, 8, i + 1, xticks=[], yticks=[])
        ax.imshow(image, cmap=plt.cm.bone, interpolation='nearest')
    plt.show()

if __name__ == '__main__':

    seeds = [31, 33, 76, 82, 92]
    directory = datetime.today().strftime('%Y%m%d')
    # directory = '20191119'
    # directory = '20191203_sparse'
    # directory = '20191204_full'
    # directory = '20191230'
    directory = '20200106'

    task_id = 0
    task_dir = make_path(directory, task_id)
    env, _, _ = Task.generate(seeds[task_id])
    partition_dir = make_path(task_dir, 'partitioned_options')
    partition_options(env, task_dir, partition_dir, view='agent', )  # debug_render=render_debug_states)
    # visualise_partitions(env, partition_dir, make_path(task_dir, 'vis_partitioned'), render_partition, verbose=True)
