import multiprocessing
import warnings
from functools import partial

from sklearn.cluster import DBSCAN
from sklearn.cluster import OPTICS

from symbols.data.data import load_transition_data, write_partition_data, load_transition_data_pandas, \
    load_full_transition_data_pandas
from symbols.data.partitioned_option import PartitionedOption
from symbols.domain.domain import Domain

import numpy as np
import pandas as pd
from symbols.experimental.pca import PCA
from symbols.file_utils import make_dir, make_path
from symbols.logger.transition_sample import TransitionSample

import matplotlib.pyplot as plt
from random import sample

PCA_DIMS = 35


def run_parallel(functions):
    """
    Run the list of function in parallel and return the results in a list
    :param functions: the functions to execute
    :return: a list of results
    """
    n_procs = len(functions)
    pool = multiprocessing.Pool(processes=n_procs)
    processes = [pool.apply_async(functions[i]) for i in range(n_procs)]
    return [p.get() for p in processes]


def debug_visualise(states, mask):
    pca = PCA(PCA_DIMS)
    pca.load("D:\\PycharmProjects\\Marlo\\full_pca.dat")
    fig = plt.figure(figsize=(6, 6))
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
    for i, state in enumerate(states):

        if i > 35:
            break

        x = state[0:PCA_DIMS]
        image = pca.unflatten(pca.uncompress_(x))
        ax = fig.add_subplot(5, 8, i + 1, xticks=[], yticks=[])
        ax.imshow(image, cmap=plt.cm.bone, interpolation='nearest')
    plt.show()

    # plt.savefig("temp/object_{}_{}.png".format(episode, object_id))
    # plt.show()


def partition_options(env: Domain, data_dir, output_dir, view='problem', verbose=True, parallel=False, **kwargs):
    make_dir(output_dir)
    if parallel:

        if verbose:
            print("Loading all transition data..")
        n_episodes = kwargs['n_episodes']
        all_data = load_full_transition_data_pandas(env, data_dir,  view=view, verbose=verbose)
        if verbose:
            print("Data loaded..")
        n_jobs = min(multiprocessing.cpu_count(), env.action_space.n)
        splits = np.array_split(list(range(env.action_space.n)), n_jobs)
        functions = [partial(partition_options_parallel, all_data, output_dir, splits[i], view, verbose, **kwargs)
                     for i in range(n_jobs)]
        run_parallel(functions)
    else:
        for i, option in enumerate(env.action_space):

            if 'n_episodes' in kwargs:
                n_episodes = kwargs['n_episodes']
                transition_data = load_transition_data_pandas(option, data_dir, n_episodes, view, verbose)
            else:
                transition_data = load_transition_data(option, data_dir, view=view, verbose=verbose)
            partitions = partition_option(option, transition_data, verbose=verbose, **kwargs)

            # write to file
            # name = env._describe_action(option)
            name = str(option)
            write_partition_data(name, partitions, output_dir, verbose=verbose)


def partition_options_parallel(all_transition_data, output_dir, options, view, verbose, **kwargs):
    for option in options:
        transition_data = [x for x in all_transition_data if x.option == option]
        partitions = partition_option(option, transition_data, verbose=verbose, **kwargs)
        # write to file
        name = str(option)
        write_partition_data(name, partitions, output_dir, verbose=verbose)


def partition_option(option,
                     samples,
                     neighbourhood_agent_space=1,  # 5.95, # 8,  # 0.4 / 14,
                     neighbourhood_problem_space=0.5,  # 0.8 / 14,
                     subpartition=False,
                     verbose=False,
                     **kwargs):
    """
    Partition the given option based on its transition data to obey the subgoal property
    :param option: the option
    :param samples: the transition sample data
    :param neighbourhood_agent_space: the max distance between two data points to be considered in the same
    neighbourhood for the purpose of partitioning
    :param neighbourhood_radius_merge: the max distance between two data points to be considered in the same
    neighbourhood for the purpose of merging
    :param verbose: whether to print information to screen
    :return:
    """

    epsilon = 4
    if option in [6, 4, 3]:
        epsilon = 0.99
    elif option == 8:
        epsilon = 2
    if subpartition:
        epsilon = 0.5

    masks = {tuple(np.sort(sample.mask)) for sample in samples}  # Get a list of all the masks from the dataset
    objects = set([sample.object_id for sample in samples])

    if verbose:
        print("Data loaded.")

    data_list = []
    for object in objects:
        # for mask in masks:

        # if len(mask) == 0:
        #     continue

        if verbose:
            print('Processing object: ' + str(object))

        object_filtered_samples = [sample for sample in samples if sample.object_id == object]
        # filtered_samples = [sample for sample in samples if tuple(sample.mask) == mask]

        masks = {tuple(np.sort(sample.mask)) for sample in object_filtered_samples}
        for mask in masks:

            if len(mask) == 0:
                continue

            # if len(set(masks)) > 1:
            #     raise ValueError
            # mask = masks[0]

            filtered_samples = [sample for sample in object_filtered_samples if tuple(sample.mask) == mask]

            if verbose:
                print("Clustering samples...")

            # TODO THIS
            labels = _partition_samples(filtered_samples, epsilon, minimal=subpartition, **kwargs)

            states = np.array([sample.state for sample in filtered_samples])
            rewards = np.array([sample.reward for sample in filtered_samples])
            next_states = np.array([sample.next_state for sample in filtered_samples])

            alt_states = np.array([sample.observation for sample in filtered_samples])
            alt_next_states = np.array([sample.next_observation for sample in filtered_samples])

            object_ids = np.array([sample.object_id for sample in filtered_samples])

            if verbose:
                n_clust = (len(set(labels)) - (1 if -1 in labels else 0))
                print(str(n_clust) + " cluster(s) found")

            # The data is now partitioned into distinct effect distributions, but may be over-partitioned because
            # distinct effects may occur from the same start state partition.

            for label in set(labels):

                if label == -1:
                    continue

                # Get the data belonging to the current cluster
                indices = [i for i in range(0, len(labels)) if labels[i] == label]
                s_data = states[indices]
                r_data = rewards[indices]
                s_prime_data = next_states[indices]

                alt_s_data = alt_states[indices]
                alt_s_prime_data = alt_next_states[indices]

                object_id_data = object_ids[indices]

                new_dat = [s_data, r_data, s_prime_data, mask, alt_s_data, alt_s_prime_data, object_id_data]
                dat_total = len(data_list)

                # For each pair of partitions, determine whether their start states samples overlap substantially by
                # clustering the combined start state samples from each partition, and determine whether each resulting
                # cluster contains data from both partitions. If so, the common data is merged into a single partition

                for pos in range(0, dat_total):
                    old_data = data_list[pos]

                    (label_old, label_new) = _try_intersect(old_data, new_dat, neighbourhood_agent_space)

                    if len(set(label_old)) > 1:
                        old_zeros = _select_transition_data(old_data, label_old, 0.0)
                        old_ones = _select_transition_data(old_data, label_old, 1.0)

                        data_list[pos] = old_zeros
                        data_list.append(old_ones)

                    if len(set(label_new)) > 1:
                        new_zeros = _select_transition_data(new_dat, label_new, 0.0)
                        new_ones = _select_transition_data(new_dat, label_new, 1.0)
                        data_list.append(new_ones)
                        new_dat = new_zeros

                data_list.append(new_dat)

    # Here, we should have sets of input-output pairs (with masks) that are maximally split.
    # So we just need to merge.
    if verbose:
        print("Total options before merging: " + str(len(data_list)))

    # Check merge
    partitioned_options = []

    for dat in data_list:
        merged = False
        s_data = dat[0]
        r_data = dat[1]
        s_prime_data = dat[2]
        mask = dat[3]
        alt_s_data = dat[4]
        alt_s_prime_data = dat[5]
        object_ids_data = dat[6]

        # When merging, an outcome is created for each effect cluster (which could be distinct due of clustering or due
        # to a different mask) and assigned an outcome probability based on the fraction of the samples assigned to it.
        no_augment = list()
        for idx, partitioned_option in enumerate(partitioned_options):

            # TODO: NEW: check if both the init sets and observation init sets overlap!
            if partitioned_option.init_sets_similar(dat[0]):

                if partitioned_option.observation_init_sets_similar(dat[4]):
                    # a bonafide stochastic transition
                    warnings.warn("Stochastic transition found!")

                    partitioned_option.observation_init_sets_similar(dat[4])

                    partitioned_option.merge(s_data, mask, r_data, s_prime_data)
                    partitioned_option.inject_observations(alt_s_data, alt_s_prime_data, append=True)
                    merged = True
                    break

                else:
                    # actually a deterministic transition! Do not augment negative samples when partitioning
                    no_augment.append(idx)
                    pass

        if not merged and len(s_prime_data) >= 3:  # todo fix
            pos = len(partitioned_options)
            partitioned_option = PartitionedOption(option, pos, s_data, mask, r_data, s_prime_data,
                                                   neighbourhood_agent_space,  # this way around
                                                   neighbourhood_problem_space, object_ids_data)
            partitioned_option.inject_observations(alt_s_data, alt_s_prime_data)
            partitioned_options.append(partitioned_option)

            for idx in no_augment:
                partitioned_options[pos].add_no_augment(idx)
                partitioned_options[idx].add_no_augment(pos)

    if verbose:
        print(str(len(partitioned_options)) + ' partitioned options found')

    return partitioned_options


def _flatten(state):
    return np.concatenate(state).ravel()


def _partition_samples(samples, epsilon, minimal=False, **kwargs):
    sprime = np.array([_flatten(sample.next_state[sample.mask]) for sample in samples])
    # TODO Warning: unsure how to order objects when we move to multitask! e.g. if we have A,B and B,A, then we'll think it's different clusters. We must order!!
    # Visualise states for debug purposes
    for _ in range(5):
        if 'debug_render' in kwargs:
            kwargs['debug_render'](sprime[np.random.choice(sprime.shape[0], min(40, sprime.shape[0]))])

    # Another crappy hyperparameter where something can go wrong
    if minimal:
        min_samples = 1
    else:
        min_samples = max(3, min(10, len(sprime) // 10))  # at least 3, at most 10

    db = DBSCAN(eps=epsilon, min_samples=min_samples).fit(sprime)
    labels = db.labels_
    if all(elem == -1 for elem in labels) and len(labels) > min_samples:
        warnings.warn("All datapoints classified as noise!")
        return [0] * len(labels)
    return labels


def _try_intersect(od, nd, neighbourhood_radius):
    old_data = od[0]
    new_data = nd[0]

    # flatten 2d state representation
    old_data = np.array([np.concatenate(sample).ravel() for sample in old_data])
    new_data = np.array([np.concatenate(sample).ravel() for sample in new_data])

    old_data_n = len(old_data)
    new_data_n = len(new_data)
    total_data = np.concatenate((old_data, new_data))

    old_ret = np.zeros([old_data_n])
    new_ret = np.zeros([new_data_n])

    db = DBSCAN(eps=neighbourhood_radius, min_samples=5).fit(np.array(total_data))
    labels = db.labels_

    labels_old = labels[0:old_data_n]
    labels_new = labels[old_data_n:(old_data_n + new_data_n)]

    label_set_old = set(labels_old)
    if -1 in label_set_old:
        label_set_old.remove(-1)
    label_set_new = set(labels_new)
    if -1 in label_set_new:
        label_set_new.remove(-1)

    labels_both = label_set_old.intersection(label_set_new)

    for x in range(0, old_data_n):
        old_ret[x] = (labels_old[x] in labels_both)
    for x in range(0, new_data_n):
        new_ret[x] = (labels_new[x] in labels_both)

    # Handle "noise" - count as intersected if the whole group has been subsumed.
    if (-1 in labels_old) and (label_set_old.issubset(label_set_new)):
        for x in range(old_data_n):
            if labels_old[x] == -1:
                old_ret[x] = 1.0

    if (-1 in labels_new) and (label_set_new.issubset(label_set_old)):
        for x in range(new_data_n):
            if labels_new[x] == -1:
                new_ret[x] = 1.0

    return old_ret, new_ret


def _select_transition_data(trans_data, labels, val):
    s_data = trans_data[0]
    r_data = trans_data[1]
    s_prime_data = trans_data[2]
    mask = trans_data[3]
    alt_s_data = trans_data[4]
    alt_s_prime_data = trans_data[5]
    object_ids_data = trans_data[6]
    return [s_data[labels == val], r_data[labels == val], s_prime_data[labels == val], mask,
            alt_s_data[labels == val], alt_s_prime_data[labels == val], object_ids_data[labels == val]]


def _subsample(data, n_samples):
    return data[np.random.choice(data.shape[0], n_samples, replace=True)]
