import json
import os
import torch
import pickle
from sklearn.cluster import KMeans
import itertools
from numpy import dot
from numpy.linalg import norm

def most_common(lst):
    if lst == []:
        return []
    return max(set(lst), key=lst.count)

class TagBuilder:
    def __init__(self):
        # if processed_tag.pkl exists load it
        if os.path.isfile('processed_tag.pkl'):
            with open('processed_tag.pkl', 'rb') as f:
                ids, embeddings = pickle.load(f)
                self.ids, self.embeddings = ids, embeddings
        else:
            self.ids, self.embeddings = self.preprocess_json()

    def find_match(self, glove, tag, meta_embeddings):
        tag = tag.lower()
        if ' ' in tag:
            tag = tag.split()
        else:
            if '_' in tag:
                tag = tag.split('_')
            elif '-' in tag:
                tag = tag.split('-')
            else:
                tag = [tag]
        
        tag_embedding = sum([glove[t] for t in tag])
        # find similarity
        score = {}
        for n, v in meta_embeddings.items():
            a, b = tag_embedding, v
            score[n] = dot(a, b)/(norm(a)*norm(b))
        max_pick = max(score, key=score.get)
        return score, max_pick

    def preprocess_json(self):
        import torchtext

        glove = torchtext.vocab.GloVe()
        with open('./tags.json', 'r') as f:
            tags = json.load(f)
        print('Sanitising tags')
        for k, v in tags.items():
            #TODO REMOVE: JUST FOR AMAZON:
            tags[k] = v
            if isinstance(v, str):
                v = [v]
                tags[k] = v

        unique_tags = list(itertools.chain.from_iterable(list(tags.values())))
        unique_tags = set(unique_tags)
        # Meta classes for Instruments
        meta_classes = [
            'electric guitar',
            'acoustic guitar',
            'percussion',
            'live sound & stage',
            'studio recording equipment',
            'microphones & cable',
            'amplifiers & effects'
        ]
        # Meta Classes for Amazon
#         meta_classes = [
#             'camera photo and lighting',
#             'audio and video',
#             'bags cases and covers',
#             'batteries and chargers',
#             'peripherals keyboards and mice',
#             'storage and networking',
#         ]
        meta_embeddings = {i: sum([glove[t] for t in i.split()]) for i in meta_classes}
        # Meta Classes for Flickr
#         meta_classes = [
#             'people', 'buildings', 'places', 'plants', 'animals', 'vehicles', 'scenery']
#         meta_embeddings = {i: glove[i] for i in meta_classes}
        tag_scores = []
        tag_max_pick = {}

        print('Finding closest match')
        for tag in unique_tags:
            score, max_pick = self.find_match(glove, tag, meta_embeddings)
            print(tag, score)
            tag_scores.append(score)
            tag_max_pick[tag] = max_pick
        
        print('Generating string labels')
        labels = {}
        for k, v in tags.items():
            if len(v) == 0:
                print("\n\n\t\t\tSHIT\n\n")
            picks = [tag_max_pick[i] for i in v]
            labels[k] = most_common(picks)

        print('Generating data on classes')
        for m in meta_classes:
            n = list(labels.values()).count(m)
            print(m, n, n / len(labels))

        import pdb; pdb.set_trace()

        with open("labels.csv", "w") as file:
            for k, v in labels.items():
                try:
                    row = k + ',' + v + '\n'
                except:
                    breakpoint()
                file.write(row)

    def old_preprocess_json(self):
        import torchtext

        with open('./tags.json', 'r') as f:
            tags = json.load(f)
        tmp = []
        for k, v in tags.items():
            if type(v) == list:
                for vi in v:
                    if len(vi.split(' ')) > 1:
                        tmp.append(vi)
            if type(v) == str:
                if len(v.split(' ')) > 1:
                    tmp.append(vi)

        glove = torchtext.vocab.GloVe()
        ids = []
        embeddings = []
        for k, v in tags.items():
            if type(v) == list:
                if v == []:
                    raise Exception("Empty tag")
                else:
                    embds = [glove[vi] for vi in v]
                    embeddings.append(sum(embds))
                ids.append(k)
            if type(v) == str:
                embeddings.append(glove[v])
                ids.append(k)
        embeddings = torch.stack(embeddings)
        embeddings = embeddings.numpy()
        with open('processed_tag.pkl', 'wb') as f:
            pickle.dump((ids, embeddings), f)
        return ids, embeddings

    def run_clustering(self, n_classes):
        # run k means using sklearn
        alg = KMeans(n_clusters=n_classes)
        kmeans = alg.fit(self.embeddings)
        labels = kmeans.labels_
        score = alg.score(self.embeddings)
        return labels, score

tag_builder = TagBuilder()

classes = [i for i in range(5,80)]
with Pool(processes=32) as pool:
    results = pool.map(tag_builder.run_clustering, classes)

# The found optimal number of classes
# labels, score = tag_builder.run_clustering(78)
# ids = tag_builder.ids

