from collections import defaultdict

from domain.inventory import Inventory
from experiment.link_problem_data import OperatorData
from pca.base_pca import PCA_N
from pca.pca import PCA
from symbols.data.data import load_transition_data
from symbols.file_utils import load, make_path, save
import numpy as np

from symbols.render.image import Image


PCA_PATH = make_path('pca_models/full_pca.dat')

def render(effects, n_samples=100):

    images = defaultdict(list)
    pca = PCA(PCA_N)
    pca.load(PCA_PATH)



    for i, effect in enumerate(effects):

        if isinstance(effects, list):
            for x in effects:
                if len(x) > 10:
                    # image
                    image = pca.unflatten(pca.uncompress_(x))
                else:
                    # vector
                    # print(x)
                    x = np.rint(x)
                    image = Image.to_array(Inventory.to_image(x))

                if len(x) < 10:
                    image = Image.to_image(image, mode='RGB')
                else:
                    image = Image.to_image(image)
                im = Image.combine([image])
                im.show()

        else:
            effect.visualise(effect.effects[0][1])
    #     states = effect.effects[0][1][0].sample(n_samples)
    #     for x in states:
    #         if len(x) > 10:
    #             # image
    #             image = pca.unflatten(pca.uncompress_(x))
    #         else:
    #             # vector
    #             # print(x)
    #             x = np.rint(x)
    #             image = Image.to_array(Inventory.to_image(x))
    #         images[i].append(image)
    # return images

if __name__ == '__main__':
    kb = load('20200106/0/pddl/knowledge.kb')
    data = list()
    for option in range(9):
        transition_data = load_transition_data(option, '20200106/0/', view='agent')
        # transition_data = load_transition_data(option, '20200109/1/0_5/', view='agent')
        data += list(transition_data)

    types = kb.infer_types(data, render=render)

    for object, type in types.items():
        kb.set_type(1, object, type)

    save(kb, 'tempkb')

    kb = load('tempkb')

    # print(types)


    #
    #

    assigned_data = defaultdict(list)
    unassigned_data = list()

    for sample in data:
        probs = kb.operator_proba(1, sample.state, sample.option, sample.next_state)
        operator = max(probs, key=probs.get)
        prob = probs[operator]

        if prob > 0.2:
            assigned_data[operator].append(sample)
        else:
            unassigned_data.append(sample)


        print(np.max(list(probs.values())))

    operator_data = list()
    for operator, data in assigned_data.items():

        partitioned_option = make_partitioned_option(operator, data)
        full_mask = todo
        data = OperatorData(partitioned_option, None, None, [operator], None, full_mask=full_mask)
        operator_data.append(data)



    for option, ps in partitions.items():
        for partitioned_option in ps:
            if partitioned_option.option != option:
                raise ValueError

            schemata = idx_to_schema[partitioned_option.option, partitioned_option.partition]
            if len(schemata) == 0:
                schemata = _find_closest_match(pddl, idx_to_schema, partitioned_option, replaced)

            if len(schemata) == 0:
                print("XXXXXAAAHA {} {}".format(option, partitioned_option.partition))

            data = OperatorData(partitioned_option, None, None, schemata,
                                extract_raw_operator(option, partitioned_option.partition, operators))
            operator_data.append(data)

    problem_symbols = ProblemSymbols()
