import numpy as np

from symbols.data.data import load_option_partitions, load_transition_data
from symbols.data.partitioned_option import PartitionedOption
from symbols.experimental.merge_map2 import MergeMap2
from symbols.experimental.quick_merge import QuickMerge
from symbols.file_utils import make_path, make_dir
import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KernelDensity, KNeighborsClassifier

from symbols.logger.transition_sample import TransitionSample
from symbols.render.image import Image
from symbols.render.render import visualise_partitions, visualise_subpartitions
from symbols.symbols.learned_lifted_operator import LearnedLiftedOperator

"""UnionFind.py

Union-find data structure. Based on Josiah Carlson's code,
http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/215912
with significant additional changes by D. Eppstein.
"""


class UnionFind:
    """Union-find data structure.

    Each unionFind instance X maintains a family of disjoint sets of
    hashable objects, supporting the following two methods:

    - X[item] returns a name for the set containing the given item.
      Each set is named by an arbitrarily-chosen one of its members; as
      long as the set remains unchanged it will keep the same name. If
      the item is not yet part of a set in X, a new singleton set is
      created for it.

    - X.union(item1, item2, ...) merges the sets containing each item
      into a single larger set.  If any item is not yet part of a set
      in X, it is added to X as one of the members of the merged set.
    """

    def __init__(self, objects):
        """Create a new empty union-find structure."""
        self.weights = {}
        self.parents = {}
        for object in objects:
            self.parents[object] = object
            self.weights[object] = 1

    def __getitem__(self, object):
        """Find and return the name of the set containing the object."""

        # check for previously unknown object
        if object not in self.parents:
            self.parents[object] = object
            self.weights[object] = 1
            return object

        # find path of objects leading to the root
        path = [object]
        root = self.parents[object]
        while root != path[-1]:
            path.append(root)
            root = self.parents[root]

        # compress the path and return
        for ancestor in path:
            self.parents[ancestor] = root
        return root

    def __iter__(self):
        """Iterate through all items ever found or unioned by this structure."""
        return iter(self.parents)

    def merge(self, *objects):
        """Find the sets containing the objects and merge them all."""
        roots = [self[x] for x in objects]
        heaviest = max([(self.weights[r], r) for r in roots])[1]
        for r in roots:
            if r != heaviest:
                self.weights[heaviest] += self.weights[r]
                self.parents[r] = heaviest


class KNNClassifier:
    def _to_prob(self, logit):
        odds = np.exp(logit)
        return odds / (1 + odds)

    def __init__(self, partitioned_symbols, index_to_label, norm=2):

        for i, partition in enumerate(partitioned_symbols):
            if i == 0:
                data = partition.states
                labels = np.zeros([len(data)])
            else:
                data = np.concatenate([data, partition.states])
                labels = np.hstack([labels, i * np.ones([len(partition.states)])])

        self._knn = KNeighborsClassifier(n_neighbors=len(partitioned_symbols), p=norm)
        self._knn.fit(data, labels)
        self._index_to_label = index_to_label

    def predict_multiple(self, state):
        eval = self._knn.predict_proba(np.array(state).reshape(1, -1))[0]
        probs = list()
        best = list()
        for i in range(len(eval)):
            if eval[i] > 0.01:
                probs.append(eval[i])
                best.append(self._index_to_label[i])
        return probs, best

    def predict(self, state):

        best = self._knn.predict(np.array(state).reshape(1, -1))[0]
        return 1, self._index_to_label[best]


class MergeMap:

    # check if parition 1 and 2 are actually the same
    def is_overlap(self, partition1, partition2):
        A = self.get_partition(partition1)
        B = self.get_partition(partition2)
        return A.init_sets_similar(B.states)

    def get_partition(self, id):
        id = self.reduce(id)
        for i, partition in self._flattened.items():
            i = self.reduce(i)
            if i == id:
                return partition
        return None

    def __init__(self, env, subpartitions, do_merge=False):
        self._do_merge = do_merge
        self._domain = env
        index_to_label = dict()
        self._flattened = dict()
        self._option_classifiers = dict()
        partitions = list()
        for option, val in subpartitions.items():
            t1 = list()
            t2 = dict()
            for partition, x in enumerate(val):
                for n, y in enumerate(x):
                    id = MergeMap.to_id(option, partition, n)
                    self._flattened[id] = y
                    partitions.append(y)
                    t1.append(y)
                    index_to_label[len(partitions) - 1] = id
                    t2[len(t1) - 1] = id
            if len(t1) > 0:
                self._option_classifiers[option] = KNNClassifier(t1, t2)
        self._global_classifier = KNNClassifier(partitions, index_to_label)

        self._merge_map = self._merge()

    def _merge(self):
        union_find = UnionFind(self._flattened.keys())
        i = 0
        for a, P in self._flattened.items():
            j = 0
            for b, Q in self._flattened.items():
                if j > i:
                    if self._do_merge and P.init_sets_similar(Q.states):
                        union_find.merge(a, b)
                j += 1
            i += 1

        return union_find

    def _make_kde(self, A):
        params = {'bandwidth': np.arange(0.01, 0.25, 0.01)}
        grid = GridSearchCV(KernelDensity(kernel='gaussian'), params)
        grid.fit(A.states)
        return grid.best_estimator_

    @staticmethod
    def _kl_divergence(A, B, n_samples=100):
        x = A.sample(n_samples)
        log_p_x = A.score_samples(x)
        log_q_x = B.score_samples(x)
        return log_p_x.mean() - log_q_x.mean()

    @staticmethod
    def _cantor_pairing(x: int, y: int):
        return int(0.5 * (x + y) * (x + y + 1) + y)

    @staticmethod
    def _invert_cantor_pairing(z: int):
        w = np.floor((np.sqrt(8 * z + 1) - 1) / 2)
        t = (w * w + w) / 2
        y = z - t
        x = w - y
        return int(x), int(y)

    def reduce(self, id):
        return self._merge_map[id]

    def most_likely(self, state, option=None):
        if option is None:
            return self._global_classifier.predict(state.reshape(1, -1))[1]

            # probs, labels = self._global_classifier.predict_multiple(state.reshape(1, -1))
            # return tuple(probs), tuple(set([self.reduce(x) for x in labels]))

        if option not in self._option_classifiers:
            return None
        return self._option_classifiers[option].predict(state.reshape(1, -1))[1]

    def score(self, id, agent_state, state, symbols):

        option, partition, subpartition = MergeMap.from_id(id)
        prob, beta = self.most_likely(state, option)
        if self.reduce(beta) == self.reduce(id):
            return prob
        return 0
        # mask = symbols[option][partition].precondition.mask
        # return 1 if self._partitions[option][partition].distance_to(agent_state) < 0.0001 else 0

    @staticmethod
    def to_id(option, partition, subpartition):
        return MergeMap._cantor_pairing(MergeMap._cantor_pairing(option, partition), subpartition)

    @staticmethod
    def from_id(id):
        t, subpartition = MergeMap._invert_cantor_pairing(id)
        option, partition = MergeMap._invert_cantor_pairing(t)
        return option, partition, subpartition

    def visualise(self, directory, env, subpartitions):
        make_dir(make_path(directory, 'problem_partitions'))
        used = set()
        for id, _ in self._flattened.items():
            t = self.reduce(id)
            t = id
            if True or t not in used:
                self._visualise_init_set(directory, env, id, t, subpartitions)
                used.add(t)

    def _visualise_init_set(self, dir, env, id, label, subpartitions):
        # pass
        o, p, s = MergeMap.from_id(id)
        partition = subpartitions[o][p][s]
        # init = kde.sample(100)
        init = partition.states
        images = env.render_states(init)
        im = Image.merge(images)
        filename = make_path(dir, 'problem_partitions/partition{}.png'.format(label))
        im.save(filename)
        im.free()


def _rescale(logit):
    odds = np.exp(logit)
    return odds / (1 + odds)


def _rule_probs(symbol, d):
    if len(symbol.list_effects) == 0:
        raise ValueError("WTF?!?!")

    log_likelihoods = [x.score(d) for x in symbol.list_effects]
    # Shove them through a softmax
    e_x = np.exp(log_likelihoods - np.max(log_likelihoods))
    return e_x / e_x.sum()


def refers_to(transition, symbols):
    d = transition.state
    d_prime = transition.next_state
    best_score = 0
    best = None
    for candidate in symbols:

        if candidate.option != transition.option:
            continue

        state_prob = candidate.precondition.probability(d)
        next_state_prob = 0
        if state_prob > 0.01:
            max_eff = 0
            for i, q in enumerate(candidate.list_probabilities):
                log_p = candidate.list_effects[i].score(d_prime)
                eff = _rescale(log_p)
                max_eff = max(max_eff, eff)
                next_state_prob += q * eff
            # total_prob = state_prob * next_state_prob
            total_score = state_prob * max_eff

            if total_score > 0.6 and total_score > best_score:
                # if total_prob > 0.95 and total_prob > best_prob:
                best_score = total_score
                best = candidate
    # if best is not None:
    #     print(best_prob)
    return best


def assign_data(symbols, data):
    symbol_dict = dict()
    for sample in data:

        target = refers_to(sample, symbols)
        if target is None:
            continue

        # import os
        # print('{}:{}'.format(target.option, target.partition))
        # f = 'tmp.png'
        # e = TreasureGameLocalVX(6)
        # images = e.render_states([sample.next_state], blend=True, background_alpha=1.0, foreground_alpha=0.5)
        # im = Image.merge(images)
        # im.save(f)
        # im.free()
        # os.startfile(f)

        if target not in symbol_dict:
            symbol_dict[target] = list()
        symbol_dict[target].append(sample)

    partitions = dict()
    for symbol in symbols:
        if symbol not in symbol_dict:
            continue

        option = symbol.option
        transitions = symbol_dict[symbol]

        s_data = np.array([sample.state for sample in transitions])
        s_prime_data = np.array([sample.next_state for sample in transitions])
        alt_s_data = np.array([sample.observation for sample in transitions])
        alt_s_prime_data = np.array([sample.next_observation for sample in transitions])
        r_data = np.array([sample.reward for sample in transitions])

        p = PartitionedOption(option, -1, s_data, [], r_data, s_prime_data)
        p.inject_observations(alt_s_data, alt_s_prime_data)

        if symbol not in partitions:
            partitions[symbol] = list()
        p._partition = len(partitions[symbol])
        partitions[symbol].append(p)

    return partitions


def merge(curr_partitions, old_partitions):
    seen_symbols = list()
    for symbol in old_partitions:
        for partition in old_partitions[symbol]:
            symbol.partition = len(curr_partitions[symbol.option])
            curr_partitions[symbol.option].append(partition)
            seen_symbols.append(symbol)
    return curr_partitions, seen_symbols


def find_new_partition(option, symbols):
    idx = -1
    for symbol in symbols:
        if symbol.option == option and symbol.partition > idx:
            idx = symbol.partition
    return idx + 1


def merge_relabel(prev_symbols, curr_symbols):
    syms = list(curr_symbols)
    for symbol in prev_symbols:
        idx = find_new_partition(symbol.option, syms)
        symbol.partition = idx
        syms.append(symbol)
    return syms


def debug_state(env, state, filename='debug.png'):
    import os
    images = env.render_states([state])
    im = Image.merge(images)
    im.save(filename)
    im.free()
    os.startfile(filename)


def _extract_transitions(partitions, view):
    samples = list()
    for partition in partitions:
        for i in range(len(partition.states)):
            s = TransitionSample(partition.states[i], partition.observations[i], partition.option,
                                 partition._rewards[i], partition.next_states[i], partition.next_observations[i],
                                 view=view)
            samples.append(s)
    return samples


def learn_links(directory, env, prev_symbols, symbols, n_samples):
    # TODO
    # prev_symbols = []

    print("In " + directory)

    # Partition in problem space
    partitions = load_option_partitions(env.action_space, make_path(directory, 'partitioned_options'))
    used_symbols = list()
    for option in env.action_space:
        transition_data = load_transition_data(option, directory, view='agent', verbose=False)
        old_partitioned_data = assign_data(prev_symbols, transition_data)  # assign data to previous rules
        partitions, seen_symbols = merge(partitions, old_partitioned_data)
        used_symbols += seen_symbols

    lifted_symbols = symbols + used_symbols
    lifted_symbols[:] = map(lambda x: LearnedLiftedOperator(x), lifted_symbols)

    merge_map = QuickMerge(env)

    for option in env.action_space:

        transition_data = load_transition_data(option, directory, view='problem', verbose=False)
        #transition_data = all_transition_data[option]

        if len(transition_data) == 0:
            # If this happens, then the option does not exist in the domain
            continue

        samples = transition_data[:]
        for x, sample in enumerate(samples):
            state = sample.state
            #debug_state(env, state)
            next_state = sample.next_state
            alt_state = sample.observation
            alt_next_state = sample.next_observation

            # alpha = merge_map.most_likely(state, sample.option)
            alpha = merge_map.most_likely(state)

            if alpha is None:
                continue
            # o, partition, _ = MergeMap.from_id(alpha)
            # assert o == option

            #debug_state(env, next_state)
            beta = merge_map.most_likely(next_state)

            if beta is None:
                # we just had no partition data for this option :(
                continue

           # print(str(alpha) + " -> " + str(beta))

            for symbol in lifted_symbols:
                if option == symbol.option and len(symbol.list_effects) > 0:
                    probs = _rule_probs(symbol, alt_next_state)  # Pr(alt_next_state | Theta)
                    symbol.update_links(merge_map.reduce(alpha), merge_map.reduce(beta), probs)


            # filtered_symbols = list(filter(lambda x: x.option == option and x.partition == partition, lifted_symbols))
            #
            # #
            # # for symbol in filtered_symbols:
            # #     probs = _rule_probs(symbol, alt_next_state)  # Pr(alt_next_state | Theta)
            # #     symbol.update_links(merge_map.reduce(alpha), merge_map.reduce(beta), probs)
            #
            # if len(filtered_symbols) == 0:
            #     # we did not have enough data to learn this partition. So we just try to do the best we can
            #     continue
            #
            # assert len(filtered_symbols) <= 1
            # symbol = filtered_symbols[0]
            # # if len(symbol.list_effects) == 0:
            # #     continue
            #
            # if len(symbol.list_effects) == 0:
            #     continue
            #
            # probs = _rule_probs(symbol, alt_next_state)  # Pr(alt_next_state | Theta)
            #
            # symbol.update_links(merge_map.reduce(alpha), merge_map.reduce(beta), probs)

    # for x in lifted_symbols:
    #     print(x)

    return lifted_symbols, merge_map, used_symbols
