import itertools
from collections import defaultdict
import imagehash

import numpy as np
from sklearn.cluster import DBSCAN

from domain.inventory import Inventory
from experiment.learn_operators import PCA_PATH
from pca.base_pca import PCA_N
from pca.pca import PCA
from symbols.render.image import Image


def render(effects, n_samples=100):
    images = defaultdict(list)
    pca = PCA(PCA_N)
    pca.load(PCA_PATH)

    for i, effect in enumerate(effects):
        states = effect.sample(n_samples)
        for x in states:
            if len(x) > 10:
                # image
                image = pca.unflatten(pca.uncompress_(x))
            else:
                # vector
                # print(x)
                x = np.rint(x)
                image = Image.to_array(Inventory.to_image(x))
            images[i].append(image)
    return images


class Effect:
    def __init__(self, effects, probs=1.0):
        # simplifying code because I know it's deterministic!
        # if not isinstance(effects, list):
        effects = [effects]
        if not isinstance(probs, list):
            probs = [probs]

        self.effects = list()
        for p, eff in zip(probs, effects):
            self.effects.append((p, eff))
        self.n = len(self.effects)

    def __iter__(self):
        return self.effects

    def is_equal(self, other):
        if self.n != other.n:
            return False
        for x in itertools.permutations(self.effects):
            for y in itertools.permutations(other.effects):
                if self._is_equal(x, y):
                    return True
        return False

    def _is_equal(self, x, y):
        for p, e in x:
            for q, f in y:
                if abs(p - q) > 0.1:
                    return False
                if not self._is_similar(e, f):
                    return False
        return True


    def visualise(self, effect):

        d = render(effect)
        images = list()
        for m, im in d.items():
            im = Image.merge(im)
            if len(im.shape) == 3:
                image = Image.to_image(im, mode='RGB')
            else:
                image = Image.to_image(im)
            images.append(image)
        im = Image.combine(images)
        im.show()


    def extract_image(self, effect):
        d = render(effect)
        images = list()
        for m, im in d.items():
            im = Image.merge(im)
            if len(im.shape) == 3:
                image = Image.to_image(im, mode='RGB')
            else:
                image = Image.to_image(im)
            images.append(image)
        im = Image.combine(images)
        return im

    def hash_distance(self, left_hash, right_hash):
        """Compute the hamming distance between two hashes"""
        if len(left_hash) != len(right_hash):
            raise ValueError('Hamming distance requires two strings of equal length')
        return sum(map(lambda x: 0 if x[0] == x[1] else 1, zip(left_hash, right_hash)))

    def _is_similar(self, A, B):
        n_samples = 100


        # a_im = self.extract_image(A)
        # b_im = self.extract_image(B)
        # a_im.show()
        # b_im.show()
        for a, b in zip(A, B):
            x = a.sample(n_samples)
            log_p_x = a.score_samples(x)
            log_q_x = b.score_samples(x)
            kl = log_p_x.mean() - log_q_x.mean()
            # print(kl)
            if kl > 10000:
                return False
        return True

        # self.visualise(A)
        # self.visualise(B)
        #
        # a_im.save("a.png")
        # b_im.save("b.png")
        #

        #
        # hashfunc = imagehash.average_hash
        #
        # # hashfunc = imagehash.phash
        # # hashfunc = imagehash.dhash
        # # hashfunc = imagehash.whash
        # # hashfunc = lambda img: imagehash.whash(img, mode='db4')
        # #
        #
        # hash1 = hashfunc(a_im)
        # hash2 = hashfunc(b_im)
        #
        # x = hash1 - hash2
        # return x < 30
        # # y = self.hash_distance(hash1, hash2)

        for a, b in zip(A, B):
            x = a.sample(n_samples)
            x = np.array([np.hstack(i) for i in x])
            y = b.sample(n_samples)
            y = np.array([np.hstack(i) for i in y])
            if x.shape != y.shape:
                return False
            data = np.concatenate((x, y))
            eps = 5  # 0.5  # 1
            db = DBSCAN(eps=eps).fit(data)
            labels = db.labels_
            return len(set(labels)) == 1


class EffectClass:
    # records, for each, object, the effect of every option on that object. Two objects are equal if these effects line up
    def __init__(self, n_options, n_objects):
        self.n_options = n_options
        self.map = dict()
        for i in range(n_objects):
            self.map[i] = list()
            for j in range(n_options):
                self.map[i].append(None)

    def add(self, option: int, object: int, probs, effects):
        if self.map[object][option] is None:
            self.map[object][option] = list()
        self.map[object][option].append(Effect(effects, probs))

    def get(self, object):
        effects = [[]] * len(self.map[object])
        for option, e in enumerate(self.map[object]):
            effects[option] = e
        return effects

    def is_same(self, object1, object2):

        A = self.map[object1]
        B = self.map[object2]
        equal = True
        if len(A) != len(B):
            equal = False
        if len([a for a in A if a is not None]) != len([b for b in B if b is not None]):
            equal = False

        matches = set()

        for option in range(self.n_options):
            p = self.map[object1][option]
            q = self.map[object2][option]
            if p is None and q is None:
                continue
            if p is None and q is not None or q is None and p is not None:
                equal = False
                continue

            # return True
            # self.visualise(p)
            # self.visualise(q)
            match = False
            for x in itertools.permutations(p):
                if match:
                    break
                for y in itertools.permutations(q):
                    if self.is_perm_equal(x, y):
                        match = True
                        break
            if not match:
                equal = False
            else:
                matches.add(option)
        return equal, matches

    def visualise(self, p):
        for e in p:
            effect = e.effects[0][1]
            d = render(effect)
            images = list()
            for m, im in d.items():
                im = Image.merge(im)
                if len(im.shape) == 3:
                    image = Image.to_image(im, mode='RGB')
                else:
                    image = Image.to_image(im)
                images.append(image)
            im = Image.combine(images)
            im.show()

    def is_perm_equal(self, x, y):

        n = min(len(x), len(y))
        for i in range(n):
            if not x[i].is_equal(y[i]):
                return False
        return True

    def __str__(self):
        return str(self.map)


class UnionFind:
    class Node:

        def __init__(self, label):
            self.label = label
            self.parent = self
            self.rank = 0

        def __str__(self):
            return self.label

    def __init__(self, n_objects):

        self.nodes = [UnionFind.Node(x) for x in range(n_objects)]
        for node in self.nodes:
            node.parent = node
            node.rank = 0

    def union(self, x, y):
        x = self.nodes[x]
        y = self.nodes[y]
        xRoot = self._find(x)
        yRoot = self._find(y)
        if xRoot.rank > yRoot.rank:
            yRoot.parent = xRoot
        elif xRoot.rank < yRoot.rank:
            xRoot.parent = yRoot
        elif xRoot != yRoot:  # Unless x and y are already in same set, merge them
            yRoot.parent = xRoot
            xRoot.rank = xRoot.rank + 1

    def _find(self, x):
        if x.parent == x:
            return x
        else:
            x.parent = self._find(x.parent)
            return x.parent

    def get(self, object):
        x = self.nodes[object]
        return self._find(x).label
