from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.svm import SVC

import numpy as np

from symbols.file_utils import make_dir, make_path
from symbols.render.image import Image
from symbols.symbols.svc import SupportVectorClassifier


class UnionFind:
    """Union-find data structure.

    Each unionFind instance X maintains a family of disjoint sets of
    hashable objects, supporting the following two methods:

    - X[item] returns a name for the set containing the given item.
      Each set is named by an arbitrarily-chosen one of its members; as
      long as the set remains unchanged it will keep the same name. If
      the item is not yet part of a set in X, a new singleton set is
      created for it.

    - X.union(item1, item2, ...) merges the sets containing each item
      into a single larger set.  If any item is not yet part of a set
      in X, it is added to X as one of the members of the merged set.
    """

    def __init__(self, objects):
        """Create a new empty union-find structure."""
        self.weights = {}
        self.parents = {}
        for object in objects:
            self.parents[object] = object
            self.weights[object] = 1

    def __getitem__(self, object):
        """Find and return the name of the set containing the object."""

        # check for previously unknown object
        if object not in self.parents:
            self.parents[object] = object
            self.weights[object] = 1
            return object

        # find path of objects leading to the root
        path = [object]
        root = self.parents[object]
        while root != path[-1]:
            path.append(root)
            root = self.parents[root]

        # compress the path and return
        for ancestor in path:
            self.parents[ancestor] = root
        return root

    def __iter__(self):
        """Iterate through all items ever found or unioned by this structure."""
        return iter(self.parents)

    def merge(self, *objects):
        """Find the sets containing the objects and merge them all."""
        roots = [self[x] for x in objects]
        heaviest = max([(self.weights[r], r) for r in roots])[1]
        for r in roots:
            if r != heaviest:
                self.weights[heaviest] += self.weights[r]
                self.parents[r] = heaviest


class MergeMap2:
    # check if parition 1 and 2 are actually the same
    def is_overlap(self, partition1, partition2):
        A = self.get_partition(partition1)
        B = self.get_partition(partition2)
        return A.init_sets_similar(B.states)

    def get_partition(self, id):
        id = self.reduce(id)
        for i, partition in self._flattened.items():
            i = self.reduce(i)
            if i == id:
                return partition
        return None

    def __init__(self, env, subpartitions, do_merge=False):
        self._do_merge = do_merge
        self._domain = env
        self._flattened = dict()
        for option, val in subpartitions.items():
            for partition, x in enumerate(val):
                for n, y in enumerate(x):
                    id = MergeMap2.to_id(option, partition, n)
                    self._flattened[id] = y

        self._merge_map = self._merge(self._flattened)
        self._data = dict()
        N = 0
        for id, partition in self._flattened.items():
            i = self.reduce(id)
            if i in self._data:
                states = np.concatenate([self._data[i], partition.states])
                self._data[i] = states
            else:
                self._data[i] = partition.states
            N += len(partition.states)

        X = np.empty(shape=(N, env.observation_space.shape[1]))
        Y = np.empty(shape=N)

        i = 0
        for label, states in self._data.items():
            X[i:i + len(states), :] = states
            Y[i:i + len(states)] = label
            i += len(states)

        self._classifier = SupportVectorClassifier(np.arange(0, env.observation_space.shape[1]), X, Y,
                                                   probabilistic=False)._classifier

    def _merge(self, flattened):
        union_find = UnionFind(flattened.keys())
        i = 0
        for a, P in flattened.items():
            j = 0
            for b, Q in flattened.items():
                if j > i:
                    if self._do_merge and set(P.combined_mask) == set(Q.combined_mask) and \
                            P.init_sets_similar(Q.states, (P.combined_mask, Q.combined_mask)):
                        union_find.merge(a, b)
                j += 1
            i += 1

        return union_find

    def is_close(self, state, label):
        partition = self._flattened[self.reduce(label)]
        return partition.masked_distance_to(state) < 0.01

    @staticmethod
    def _cantor_pairing(x: int, y: int):
        return int(0.5 * (x + y) * (x + y + 1) + y)

    @staticmethod
    def _invert_cantor_pairing(z: int):
        w = np.floor((np.sqrt(8 * z + 1) - 1) / 2)
        t = (w * w + w) / 2
        y = z - t
        x = w - y
        return int(x), int(y)

    def reduce(self, id):
        return self._merge_map[id]

    def most_likely(self, state, _=None):
        return self._classifier.predict(state.reshape(1, -1))[0]

    @staticmethod
    def to_id(option, partition, subpartition):
        return MergeMap2._cantor_pairing(MergeMap2._cantor_pairing(option, partition), subpartition)

    @staticmethod
    def from_id(id):
        t, subpartition = MergeMap2._invert_cantor_pairing(id)
        option, partition = MergeMap2._invert_cantor_pairing(t)
        return option, partition, subpartition

    def visualise(self, directory, env, subpartitions):
        make_dir(make_path(directory, 'problem_partitions'))
        used = set()
        for id, _ in self._flattened.items():
            t = self.reduce(id)
            # t = id
            if t not in used:
                self._visualise_init_set(directory, env, id, t, subpartitions)
                used.add(t)

    def _visualise_init_set(self, dir, env, id, label, subpartitions):
        # pass
        o, p, s = MergeMap2.from_id(id)
        partition = subpartitions[o][p][s]
        # init = kde.sample(100)
        init = partition.states
        images = env.render_states(init)
        im = Image.merge(images)
        filename = make_path(dir, 'problem_partitions/partition{}.png'.format(label))
        im.save(filename)
        im.free()
