import pickle
from collections import defaultdict

import imagehash
import numpy as np
from sklearn.cluster import DBSCAN

from domain.inventory import Inventory
from experiment.train_pca import PCA_N
from pca.pca import PCA
from symbols.file_utils import make_path, load
from symbols.render.image import Image

PCA_PATH = make_path('pca_models/full_pca.dat')
PCA_CLASS = PCA


# this  function is only for states drawn from a distribution. The states should be only those columns in the mask, and
# the ordering should correspond i.e. column1 is mask_1, column2 = mask2 etc

def render_masked(mask, states):
    images = defaultdict(list)
    pca = PCA(PCA_N)
    pca.load(PCA_PATH)

    if mask is None:
        raise ValueError

    if len(mask) != states.shape[1]:
        raise ValueError

    for state in states:
        iter = zip(mask, state)
        for m, x in iter:
            if len(x) > 10:
                # image
                image = pca.unflatten(pca.uncompress_(x))
            else:
                # vector
                # print(x)
                x = np.rint(x)
                image = Image.to_array(Inventory.to_image(x))
            images[m].append(image)
    return images

def extract_symbol(render, distribution):
    d = render(distribution.mask, distribution.sample(100))
    images = list()
    for m, im in d.items():
        im = Image.merge(im)
        if len(im.shape) == 3:
            image = Image.to_image(im, mode='RGB')
        else:
            image = Image.to_image(im)
        images.append(image)
    im = Image.combine(images)
    return im

# def hash_distance(left_hash, right_hash):
#     """Compute the hamming distance between two hashes"""
#     if len(left_hash) != len(right_hash):
#         raise ValueError('Hamming distance requires two strings of equal length')
#     return sum(map(lambda x: 0 if x[0] == x[1] else 1, zip(left_hash, right_hash)))

def distance(a, b, hashfunc):
    a = extract_symbol(render_masked, a)
    b = extract_symbol(render_masked, b)
    hash1 = hashfunc(a)
    hash2 = hashfunc(b)
    return hash1 - hash2  # , hash_distance(hash1, hash2)


def cluster(a, b):
    n_samples = 1000
    x = a.sample(n_samples)
    x = np.array([np.hstack(i) for i in x])
    y = b.sample(n_samples)
    y = np.array([np.hstack(i) for i in y])

    if x.shape != y.shape or x.shape[1] == 2:
        return {0, 1}

    data = np.concatenate((x, y))


    eps = 4  # 0.5  # 1
    db = DBSCAN(eps=eps, min_samples=n_samples).fit(data)
    labels = db.labels_
    return set(labels)


if __name__ == '__main__':
    np.set_printoptions(formatter={'float': lambda x: "{0:0.2f}".format(x)}, linewidth=np.inf)
    (factors, option_symbols, symbol_list) = load('q.tmp')

    symbols = [36, 37, 38, 39, 40, 49]
    temp = list()

    hashfunc = imagehash.average_hash
    hashfunc = imagehash.phash
    # hashfunc = imagehash.dhash
    # hashfunc = lambda img: imagehash.whash(img, mode='db4')

    for i in range(0, len(symbols) - 1):
        for j in range(i + 1, len(symbols)):

            idx1 = symbols[i]
            idx2 = symbols[j]

            symbol_a = symbol_list[idx1 - 1]
            symbol_b = symbol_list[idx2 - 1]
            print('Distance: {}-{}: {}'.format(idx1, idx2, distance(symbol_a, symbol_b, hashfunc)))
            print('Cluster: {}-{}: {}'.format(idx1, idx2, cluster(symbol_a, symbol_b)))

    exit(0)

    for i in range(0, len(symbol_list) - 1):
        for j in range(i + 1, len(symbol_list)):

            symbol_a = symbol_list[i]
            symbol_b = symbol_list[j]
            # print('Distance: {}-{}: {}'.format(idx1, idx2, distance(symbol_a, symbol_b, hashfunc)))
            labels = cluster(symbol_a, symbol_b)
            if len(labels) == 1:
                print('Cluster: {}-{}'.format(i + 1, j + 1))


    # hashfunc = imagehash.phash
    # hashfunc = imagehash.dhash
    # hashfunc = imagehash.whash
    # hashfunc = lambda img: imagehash.whash(img, mode='db4')
    #
