import numpy as np
import editdistance
import torch
from scipy.stats import spearmanr
from scipy.spatial.distance import cosine
from joblib import Parallel, delayed


class EvaluationMetrics:
    def __init__(self, meaning, representation, vocab_size):
        super(EvaluationMetrics, self).__init__()
        self.meaning = meaning
        self.representation = representation
        self.vocab_size = vocab_size

    def calculate_topographic_similarity(self):
        distance_representation = edit_dist(self.representation)
        distance_meaning = cosine_dist(self.meaning.astype(float))
        corr = spearmanr(distance_representation, distance_meaning).correlation
        return corr


def edit_dist(_list):
    n = len(_list)

    def calculate_edit_distance(el1, el2):
        str_el1 = np.array2string(el1)
        str_el2 = np.array2string(el2)
        return editdistance.eval(str_el1, str_el2) / max(len(str_el1), len(str_el2))

    result = Parallel(n_jobs=-1)(delayed(calculate_edit_distance)(_list[i], _list[j]) for i in range(n - 1) for j in range(i + 1, n))
    return result


def cosine_dist(_list):
    n = len(_list)

    def calculate_cosine_distance(el1, el2):
        return cosine(el1, el2)

    result = Parallel(n_jobs=-1)(delayed(calculate_cosine_distance)(_list[i], _list[j]) for i in range(n - 1) for j in range(i + 1, n))
    return result


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dictionary_size = 100
    data_size = 18432
    word_length = 10
    label_link = 'data/labels.npy'
    true_label = np.load(label_link)
    true_label = np.delete(true_label, [0, 3], axis=1)
    message_link = '...'
    messages = np.genfromtxt(message_link, delimiter=',', dtype=int)
    one_hot_messages = np.eye(dictionary_size)[messages]
    one_hot_messages = one_hot_messages.reshape(data_size, word_length, dictionary_size)
    print(messages)
    eval_message = EvaluationMetrics(meaning=true_label, representation=one_hot_messages, vocab_size=100)
    topSim = eval_message.calculate_topographic_similarity()

