from collections import defaultdict

from symbols.data.data import load_option_partitions, load_ppddl_symbols
from symbols.domain.domain import Domain
from symbols.file_utils import make_dir, make_path
from symbols.logger.precondition_reader import PreconditionReader
from symbols.render.image import Image
from symbols.symbols.distribution_symbol import DistributionSymbol
import numpy as np
import matplotlib.pyplot as plt




def visualise_subpartitions(domain: Domain,
                            option,
                            partitions,
                            output_dir,
                            render,
                            verbose=False):
    if verbose:
        print("Creating output directory: " + output_dir)
    make_dir(output_dir)

    options = domain.action_space

    if verbose:
        print("Visualising option " + options.describe(option) + " with " + str(len(partitions)) + " partition(s)")

    for i in range(0, len(partitions)):

        if verbose:
            print("Processing partition " + str(i + 1) + "...")

        partition = partitions[i]

        # Write preconditions
        images = render(domain, partition.states)
        im = Image.merge(images)
        filename = make_path(output_dir, options.describe(option) + "-init" + str(i + 1) + ".bmp")
        im.save(filename)
        im.free()

        if verbose:
            print("\tPreconditions complete.")

        # Write probabilities
        filename = make_path(output_dir, options.describe(option) + "-transition" + str(i + 1) + ".txt")
        with open(filename, "w") as file:
            t_prob = partition.get_transition_probabilities()
            for item in t_prob:
                file.write("%s\n" % str(item))

        if verbose:
            print("\tTransition probabilities complete.")

        # Write images
        n_transition = partition.get_number_effects()
        for j in range(0, n_transition):
            target_states = partition.get_next_states(j)
            images = render(domain, target_states)
            im = Image.merge(images)
            filename = make_path(output_dir,
                                 options.describe(option) + "-effect" + str(i + 1) + "-" + str(j + 1) + ".bmp")
            im.save(filename)
            im.free()
            if verbose:
                print("\t\tImage " + str(j) + " complete.")

        if verbose:
            print("Images complete")
    if verbose:
        print("Partitions visualised.")


def _flatten(s):
    state = list()
    for j in s:
        if isinstance(j, int):
            state.append(j)
        else:
            for x in j:
                state.append(x)
    return np.array(state)


def debug_state(render, state):
    d = render(None, [state])
    images = list()
    for key in d:
        ims = d[key]
        for i in range(len(ims)):
            if i >= len(images):
                images.append([ims[i]])
            else:
                images[i].append(ims[i])
    final_ims = list()
    for temp in images:
        q = []
        for x in temp:
            if len(x.shape) == 3:
                image = Image.to_image(x, mode='RGB')
            else:
                image = Image.to_image(x)
            q.append(image)
        im = Image.combine(q)
        final_ims.append(im)
    plt.imshow(final_ims[0])


def visualise_preconditions(domain, input_dir, output_dir, render, preconditions, view='problem'):
    # make_dir(output_dir)
    option_count = defaultdict(int)
    for (option, object), svm in preconditions.items():

        reader = PreconditionReader(input_dir, option, view=view)
        samples = reader.get_samples()
        positive = list()

        q = [svm.probability(x.state) for i, x in enumerate(samples) if svm.probability(x.state) > 0.1]


        for sample in samples:

            # if sample.can_execute:
            #     debug_state(render, sample.state)

            state = sample.state
            # state += np.random.normal(0, 0.1, state.shape[0])


            p = svm.probability(state)
            # p = 0 if p < 0.1 else 1
            # print(p, sample.can_execute)
            if p > 0.7:
                positive.append(state)
        # print(len(positive))
        if len(positive) > 0:
            d = render(svm.mask, positive)
            images = list()
            for m, im in d.items():
                im = Image.merge(im)
                if len(im.shape) == 3:
                    image = Image.to_image(im, mode='RGB')
                else:
                    image = Image.to_image(im)
                images.append(image)

            image = Image.combine(images)

            mask = '+'.join(map(str, d.keys()))
            filename = '{}-init{}-{}.bmp'.format(domain.action_space.describe(option), option_count[option] + 1, mask)
            filename = make_path(output_dir, filename)
            Image.save(image, filename, mode='RGB')
        else:
            print("Fail ", option, option_count[option] + 1)

        option_count[option] += 1
        print('\n')


def _make_filename(directory, option_name, type, index, mask, extension):
    return make_path(directory, '{}-{}{}-{}.{}'.format(option_name, type, index, mask, extension))


def visualise_partitions(domain: Domain,
                         input_dir,
                         output_dir,
                         render,
                         verbose=False):
    if verbose:
        print("Creating output directory: " + output_dir)
    make_dir(output_dir)
    options = domain.action_space
    if verbose:
        print("Loading partitions...")

    option_partitions = load_option_partitions(options, input_dir)

    if verbose:
        print("Partitions loaded.")

    for option in options:

        partitions = option_partitions[option]

        if verbose:
            print("Visualising option " + options.describe(option) + " with " + str(len(partitions)) + " partition(s)")

        for i in range(0, len(partitions)):

            if verbose:
                print("Processing partition " + str(i + 1) + "...")

            partition = partitions[i]

            # Write preconditions
            # for state in partition.states:
            #     state = _flatten(state)
            #     domain.describe_state(state, list(range(len(state))))

            # TODO old
            # images = render(domain, partition.states)
            # im = Image.merge(images)
            # filename = make_path(output_dir, options.describe(option) + "-init" + str(i + 1) + ".bmp")
            # Image.save(im, filename)

            # TODO new
            d = render(domain, partition.states)
            images = list()
            for m, im in d.items():
                im = Image.merge(im)
                if len(im.shape) == 3:
                    image = Image.to_image(im, mode='RGB')
                else:
                    image = Image.to_image(im)
                images.append(image)
            im = Image.combine(images)

            filename = _make_filename(output_dir, options.describe(option), 'init', i + 1, partition.combined_mask,
                                      'bmp')
            Image.save(im, filename, mode='RGB')
            # im.free()

            if verbose:
                print("\tPreconditions complete.")
                print("*******************************************************")

            # Write probabilities
            filename = _make_filename(output_dir, options.describe(option), 'transition', i + 1,
                                      partition.combined_mask, 'txt')

            with open(filename, "w") as file:
                t_prob = partition.get_transition_probabilities()
                for item in t_prob:
                    file.write("%s\n" % str(item))

            if verbose:
                print("\tTransition probabilities complete.")

            # Write images
            n_transition = partition.get_number_effects()
            for j in range(0, n_transition):
                target_states = partition.get_next_states(j)

                # for state in target_states:
                #     state = _flatten(state)
                #     domain.describe_state(state, list(range(len(state))))
                # print(' ## ')

                # TODO old
                # images = render(domain, target_states)
                # im = Image.merge(images)
                # filename = make_path(output_dir,
                #                      options.describe(option) + "-effect" + str(i + 1) + "-" + str(j + 1) + ".bmp")
                # Image.save(im, filename)
                # im.free()

                # TODO new
                d = render(domain, target_states)
                images = list()
                for m, im in d.items():
                    im = Image.merge(im)
                    if len(im.shape) == 3:
                        image = Image.to_image(im, mode='RGB')
                    else:
                        image = Image.to_image(im)
                    images.append(image)
                im = Image.combine(images)
                filename = _make_filename(output_dir, options.describe(option), 'effect', i + 1,
                                          partition.combined_mask, 'bmp')

                Image.save(im, filename, mode='RGB')

                if verbose:
                    print("\t\tImage " + str(j) + " complete.")

            if verbose:
                print("Images complete")
                print("*******************************************************")

    if verbose:
        print("Partitions visualised.")


def render_distribution(env, distribution: DistributionSymbol, width, height, filename, view='problem', ndims=-1):
    if ndims < 0:
        ndims = env.observation_space.shape[1] if view == 'problem' else env.agent_space.shape[1]

    states = np.random.rand(100, ndims)
    states[:, distribution.mask] = distribution.sample(100)
    images = env.render_states(states, view=view, background_alpha=1.0 / states.shape[0], foreground_alpha=0.5)
    im = Image.merge(images)
    im.save(filename)
    im.free()
    return images


def debug(env, dir, A, render, samples):
    partitions = load_option_partitions(env.action_space, dir)
    # data = np.array([sample for sample in partitions[0][0].rules[0].initiation_set])
    # data = data[:, [0]]

    np.set_printoptions(formatter={'float': lambda x: "{0:0.2f}".format(x)})
    a = np.mean(samples, axis=0)
    print(a)
    print("______________________________________")
    print(np.var(samples, axis=0))
    print("______________________________________")
    print(np.min(samples, axis=0))
    print("______________________________________")
    print(np.max(samples, axis=0))
    print("______________________________________")
    # b = np.mean(np.array([x for x in data[:, 0]]), axis=0)
    # print(b)
    # print(b - a)

    # d = render([0], data)
    # images = list()
    # for m, im in d.items():
    #     im = Image.merge(im)
    #     if len(im.shape) == 3:
    #         image = Image.to_image(im, mode='RGB')
    #     else:
    #         image = Image.to_image(im)
    #     images.append(image)
    # B = Image.combine(images)
    # buffer1 = np.asarray(A)
    # buffer2 = np.asarray(B)
    # # Subtract image2 from image1
    # buffer3 = buffer1 - buffer2
    #
    # from PIL import Image as I
    # differenceImage = I.fromarray(buffer3)


    A.show()
    # B.show()
    # differenceImage.show()


def visualise_effects(domain, output_dir, render, effects, view='problem', ndims=-1):
    make_dir(output_dir)
    for (option, object), (probabilities, rewards, outcomes) in effects.items():
        for k, distribution in enumerate(outcomes):
            d = render(distribution.mask, distribution.sample(100))
            images = list()
            for m, im in d.items():
                im = Image.merge(im)
                if len(im.shape) == 3:
                    image = Image.to_image(im, mode='RGB')
                else:
                    image = Image.to_image(im)
                images.append(image)
            im = Image.combine(images)

            mask = '+'.join(map(str, d.keys()))
            filename = '{}-eff{}.{}-{}.bmp'.format(domain.action_space.describe(option), object, k + 1, mask)
            filename = make_path(output_dir, filename)
            Image.save(im, filename, mode='RGB')


def _draw_samples(observation_space, symbol, n_variables, n_samples):
    # TODO FIX: not working :(
    # out_samp = np.fromfunction(lambda row, _, observation_space=observation_space: observation_space.sample(), (n_samples, n_variables), dtype=float)

    out_samp = np.empty([n_samples, n_variables])
    for i in range(out_samp.shape[0]):

        s = observation_space.sample()
        state = list()
        for j in s:

            if isinstance(j, int):
                state.append(j)
            else:
                for x in j:
                    state.append(x)
        out_samp[i, :] = np.array(state)

    samp = symbol.sample(n_samples)
    if len(symbol.flat_mask) == 1:
        samp = np.reshape(samp, (n_samples, 1))
    out_samp[:, symbol.flat_mask] = samp
    return out_samp


def visualise_ppddl_symbols(input_directory,
                            output_directory,
                            domain,
                            render,
                            verbose=True):
    """
    Visualise the PPDDL symbols
    :param input_directory: the directory containing the data
    :param output_directory: the directory to save the images to
    :param domain:  the domain
    :param verbose: whether to print information to screen
    """

    symbols = load_ppddl_symbols(input_directory, verbose=verbose)
    make_dir(output_directory)
    for i, symbol in enumerate(symbols):

        if verbose:
            print("Visualising symbol " + str(i + 1) + " of " + str(len(symbols)))
        d = render(symbol.mask, symbol.sample(100))
        images = list()
        for m, im in d.items():
            im = Image.merge(im)
            if len(im.shape) == 3:
                image = Image.to_image(im, mode='RGB')
            else:
                image = Image.to_image(im)
            images.append(image)
        im = Image.combine(images)

        # if symbol.name == 'symbol_15' or symbol.name == 'symbol_2':
        #     debug(domain, make_path(output_directory, '../partitioned_options'), im, render, symbol.sample(100))


        mask = '+'.join(map(str, d.keys()))
        filename = '{}-{}.bmp'.format(symbol.name, mask)
        filename = make_path(output_directory, filename)
        Image.save(im, filename, mode='RGB')


def describe_ppddl_symbols(input_directory,
                           output_directory,
                           domain: Domain,
                           verbose=True):
    """
    Visualise the PPDDL symbols
    :param input_directory: the directory containing the data
    :param output_directory: the directory to save the images to
    :param domain:  the domain
    :param verbose: whether to print information to screen
    """

    symbols = load_ppddl_symbols(input_directory, verbose=verbose)
    make_dir(output_directory)
    for i, symbol in enumerate(symbols):

        if verbose:
            print("Visualising symbol " + str(i) + " of " + str(len(symbols) - 1))

        n_vars = domain.state_length
        samples = _draw_samples(domain.observation_space, symbol, n_vars, 1)
        for sample in samples:
            domain.describe_state(sample, symbol.flat_mask)
        print('***************************\n')
        # images = domain.render_states(samples, view='problem', background_alpha=1.0 / samples.shape[0],
        #                               foreground_alpha=0.5)
        # im = Image.merge(images)
        # filename = make_path(output_directory, str(i + 1) + "-vis.bmp")
        # im.save(filename)
        # im.free()
