from nltk.corpus import wordnet as wn
from sklearn.metrics.pairwise import cosine_similarity
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import confusion_matrix
import torch
import numpy as np
import os
import pickle


class Similarities_manager:
    def __init__(self, dataset, wordnet_data_path, glove_data_path, glove_embeddings_data_file):
        self.dataset = dataset
        self.data_path = wordnet_data_path
        self.glove_data_path = glove_data_path
        self.glove_embeddings_data_file = glove_embeddings_data_file
        self.wordnet_matrix = None
        self.glove_matrix = None
        self.init_weight_matrix = None
        self.current_weight_matrix = None
        self.confusion_matrix = None
        self.confusion_matrix_weight_product = None
        self.wordnet_similarities = self.get_wordnet_similarities()
        self.glove_similarities = self.get_glove_similarities()

    def get_wordnet_similarities(self):
        """gets initial dataset names similarities"""

        if os.path.isfile(self.data_path):
            print('[INFO] Saved wordnet data found')
            data = pickle.load(open(self.data_path, 'rb'))
            similarity_wordnet = data['wordnet_sim']
            CSM_wordnet_nodiagonal = data['wordnet_nd']
            self.wordnet_matrix = similarity_wordnet

            return CSM_wordnet_nodiagonal
        else:

            print('[INFO] No wordnet data found, computing...')
            label_onehot, label_codes = self.dataset.label_order
            label_tensor = torch.Tensor(label_codes).squeeze()
            prod = torch.cartesian_prod(label_tensor, label_tensor).numpy().astype(int)
            similarity_wordnet = np.array(list(map(get_word_dist, prod))).reshape(len(label_codes), -1)

            self.wordnet_matrix = similarity_wordnet

            CSM_wordnet_nodiagonal = similarity_wordnet[~np.eye(similarity_wordnet.shape[0], dtype=bool)].reshape(
                similarity_wordnet.shape[0], -1)
            CSM_wordnet_nodiagonal = (CSM_wordnet_nodiagonal - np.min(CSM_wordnet_nodiagonal)) / (
                        np.max(CSM_wordnet_nodiagonal) - np.min(CSM_wordnet_nodiagonal))

            pickle.dump({'wordnet_sim': similarity_wordnet, 'wordnet_nd': CSM_wordnet_nodiagonal},
                        open(self.data_path, 'wb'))

            return CSM_wordnet_nodiagonal

    def get_glove_similarities(self):
        """gets initial dataset names similarities"""

        def load_glove_embeddings(glove_file):
            embeddings_indexes = {}
            with open(glove_file, 'r', encoding='utf-8') as f:
                for line in f:
                    values = line.split()
                    word = values[0]
                    coefs = np.asarray(values[1:], dtype='float32')
                    embeddings_indexes[word] = coefs
            return embeddings_indexes

        def get_synonyms(term):
            synonyms = set()
            for synset in wn.synsets(term):
                for lemma in synset.lemmas():
                    synonyms.add(lemma.name())
            return list(synonyms)

        def compute_embedding(name, embeddings_indexes):
            if name in embeddings_indexes:
                return embeddings_indexes[name]
            else:
                if '_' in name:
                    name_list = name.lower().split('_')
                    representation = np.array(
                        [embeddings_indexes.get(key) for key in name_list if embeddings_indexes.get(key) is not None])
                    if representation is not None:
                        final_representation = np.mean(representation, axis=0)
                        return final_representation
                    else:
                        return None
                else:
                    return None

        # Compute similarity matrix
        def compute_similarity_matrix(class_names, embeddings_indexes, embedding_length):
            vectors = []
            for name in class_names:
                embedding = compute_embedding(name, embeddings_indexes)
                if embedding is not None:
                    vectors.append(embedding)
                else:
                    synonyms = get_synonyms(name)
                    embeddings_from_synonyms = np.array(
                        [compute_embedding(synonym, embeddings_indexes) for synonym in synonyms if
                         compute_embedding(synonym, embeddings_indexes) is not None])
                    if len(embeddings_from_synonyms) == 0:
                        vectors.append(np.zeros(embedding_length))
                    else:
                        vectors.append(np.mean(embeddings_from_synonyms, axis=0))

            vectors = np.array(vectors)
            similarity_matrix = cosine_similarity(vectors)
            return similarity_matrix

        if os.path.isfile(self.glove_data_path):
            print('[INFO] Saved GLOVE data found')
            data = pickle.load(open(self.glove_data_path, 'rb'))
            similarity_glove = data['glove_sim']
            CSM_glove_nodiagonal = data['glove_nd']
            self.glove_matrix = similarity_glove

            return CSM_glove_nodiagonal
        else:

            print('[INFO] No GLOVE data found, computing...')
            label_onehot, label_codes = self.dataset.label_order
            label_array = torch.Tensor(label_codes).squeeze().numpy().astype(int)
            class_text_names = [wn.synset_from_pos_and_offset('n', label_array[i]).name().split('.')[0] for i in
                                range(len(label_array))]

            embedding_len = int(self.glove_embeddings_data_file.split("\\")[-1].split('.')[2].replace('d', ''))

            embeddings_index = load_glove_embeddings(self.glove_embeddings_data_file)
            # raw similarity matrix
            similarity_glove = compute_similarity_matrix(class_text_names, embeddings_index, embedding_len)

            self.glove_matrix = similarity_glove

            CSM_glove_nodiagonal = similarity_glove[~np.eye(similarity_glove.shape[0], dtype=bool)].reshape(
                similarity_glove.shape[0], -1)
            CSM_glove_nodiagonal = (CSM_glove_nodiagonal - np.min(CSM_glove_nodiagonal)) / (
                    np.max(CSM_glove_nodiagonal) - np.min(CSM_glove_nodiagonal))

            pickle.dump({'glove_sim': similarity_glove, 'glove_nd': CSM_glove_nodiagonal},
                        open(self.glove_data_path, 'wb'))

            return CSM_glove_nodiagonal

    def compute_min_max_weights_average(self, similarity_array, quantile):
        similarity_array_sorted = np.abs(np.sort(-similarity_array, axis=1))
        thresholds = np.quantile(similarity_array_sorted[:, 1:], quantile, axis=1)
        if quantile > 0.5:
            mask = similarity_array_sorted[:, 1:] >= thresholds[:, np.newaxis]
        else:
            mask = similarity_array_sorted[:, 1:] <= thresholds[:, np.newaxis]
        masked_values = similarity_array_sorted[:, 1:][mask]
        masked_values_mean = np.mean(masked_values)
        return masked_values_mean

    def get_semantic_sim(self, model):
        cos_sim_weights = self.get_cos_sim_weights(model)

        if self.init_weight_matrix is None:
            self.init_weight_matrix = cos_sim_weights
        else:
            self.current_weight_matrix = cos_sim_weights

        cos_sim_weights_nd_raw = cos_sim_weights[~np.eye(cos_sim_weights.shape[0], dtype=bool)].reshape(
            cos_sim_weights.shape[0], -1)

        cos_sim_weights_nd = (cos_sim_weights_nd_raw - np.min(cos_sim_weights_nd_raw)) / (
                    np.max(cos_sim_weights_nd_raw) - np.min(cos_sim_weights_nd_raw))

        cosine_value = np.sum(self.wordnet_similarities*cos_sim_weights_nd) / \
                       (np.sqrt(np.sum(self.wordnet_similarities * self.wordnet_similarities)) *
                        np.sqrt(np.sum(cos_sim_weights_nd * cos_sim_weights_nd)))

        mse_value = np.mean((self.wordnet_similarities - cos_sim_weights_nd) ** 2)
        mae_value = np.mean(np.abs(self.wordnet_similarities - cos_sim_weights_nd))

        structural_value, _ = ssim(cos_sim_weights_nd, self.wordnet_similarities, full=True, data_range=1)

        weight_sim_score = np.mean(cos_sim_weights_nd_raw)
        weight_min_sim_score = self.compute_min_max_weights_average(cos_sim_weights_nd_raw, 0.05)
        weight_max_sim_score = self.compute_min_max_weights_average(cos_sim_weights_nd_raw, 0.95)

        cosine_value_glove = np.sum(self.glove_similarities * cos_sim_weights_nd) / \
                             (np.sqrt(np.sum(self.glove_similarities * self.glove_similarities)) *
                              np.sqrt(np.sum(cos_sim_weights_nd * cos_sim_weights_nd)))

        mse_value_glove = np.mean((self.glove_similarities - cos_sim_weights_nd) ** 2)
        mae_value_glove = np.mean(np.abs(self.glove_similarities - cos_sim_weights_nd))
        structural_value_glove, _ = ssim(cos_sim_weights_nd, self.glove_similarities, full=True, data_range=1)

        return cosine_value.item(), structural_value.item(), weight_sim_score.item(), mse_value.item(), \
               mae_value.item(), weight_min_sim_score.item(), weight_max_sim_score.item(), cosine_value_glove.item(), \
               structural_value_glove.item(), mse_value_glove.item(), mae_value_glove.item()

    def get_dm_metric(self, model, y_true, y_pred, num_of_classes, CSM_type):
        if CSM_type == 'N':
            CSM = self.get_cos_sim_weights(model)
        elif CSM_type == 'G':
            CSM = self.glove_matrix
        else:
            CSM = self.wordnet_matrix

        SCSM = np.argsort(-CSM, axis=1)
        y_true = torch.argmax(torch.concatenate(y_true), dim=1)
        y_pred = torch.argmax(torch.concatenate(y_pred), dim=1)
        y_reshaped = np.reshape(y_true.detach().cpu().numpy(), (y_true.shape[0]))
        DM = (SCSM[y_reshaped] == np.expand_dims(y_pred.detach().cpu().numpy(), axis=1))
        DM = np.argmax(DM.astype(int), axis=1)
        DM = 1 - np.mean(DM / (num_of_classes - 1))
        return DM.item()

    def get_cos_sim_weights(self, model):
        """assumes the last layer is always named 'clf'"""
        clf = model.clf
        params = [x for x in clf.parameters()][0].detach().cpu().numpy()

        cos_sim_weights = cosine_similarity(params)
        return cos_sim_weights

    def confusion_mean(self, y_true, preds, num_classes):
        preds = torch.argmax(torch.concatenate(preds), dim=1)
        y_true = torch.argmax(torch.concatenate(y_true), dim=1)

        conf_matrix = confusion_matrix(y_true.detach().cpu().numpy(), preds.detach().cpu().numpy(), normalize='true',
                                       labels=list(range(num_classes)))
        self.confusion_matrix = conf_matrix
        return np.mean(np.diag(conf_matrix)).item()

    def get_confusion_weight_product(self):
        if self.current_weight_matrix is not None:
            product = self.current_weight_matrix * self.confusion_matrix
            self.confusion_matrix_weight_product = product
        else:
            self.confusion_matrix_weight_product = np.zeros_like(self.confusion_matrix)
        return self.confusion_matrix_weight_product

    def get_cf_wg_prod_sum(self):
        cf_wg_prod = self.get_confusion_weight_product()
        return np.sum(cf_wg_prod).item()

    def get_confusion_metric(self, model, CSM_type):
        if CSM_type == 'N':
            CSM = self.get_cos_sim_weights(model)
        elif CSM_type == 'G':
            CSM = self.glove_matrix
        else:
            CSM = self.wordnet_matrix

        confusion_nd = self.get_nondiagonal(self.confusion_matrix)
        CSM_nd = self.get_nondiagonal(CSM)
        out = np.sum(CSM_nd * confusion_nd) / \
               (np.sqrt(np.sum(CSM_nd * CSM_nd)) * np.sqrt(np.sum(confusion_nd * confusion_nd)))
        return out.item()

    def return_matrices(self):
        return self.wordnet_matrix, self.init_weight_matrix, self.current_weight_matrix,\
               self.glove_matrix, self.confusion_matrix_weight_product, self.confusion_matrix,

    def get_nondiagonal(self, matrix):
        return matrix[~np.eye(matrix.shape[0], dtype=bool)].reshape(matrix.shape[0], -1)

def get_word_dist(ids):
    id_1, id_2 = ids
    return wn.synset_from_pos_and_offset('n', id_1).path_similarity(wn.synset_from_pos_and_offset('n', id_2))
