import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm
from tqdm import tqdm
import os
import torch
import pickle
import scipy.cluster.hierarchy as shc
from sklearn.cluster import AgglomerativeClustering
from knn_funcs import sense_data_load
from sklearn.decomposition import PCA
import seaborn as sns
from sklearn.metrics.cluster import homogeneity_score
from sklearn.cluster import AgglomerativeClustering

from paths import auth1_path, auth2_path


def cluster_embs(filt, n):
    cluster = AgglomerativeClustering(n_clusters=n, affinity='cosine', linkage='complete')
    cluster.fit_predict([x[1].numpy() for x in filt])
    return cluster


from sklearn.decomposition import PCA

def emb_pca(filt):
    f_embs = [x[1].numpy() for x in filt]
    pca = PCA(n_components=2)
    pca.fit(f_embs)
    X = pca.transform(f_embs)
    return X

import matplotlib.pyplot as plt

def emb_plotter(X, labels_, x, y, x1, x2, y1, y2, type_):
    sns.set(style="white", color_codes=True)
    plt.figure(figsize=(x, y))  
    g = sns.scatterplot(x = X[:,0], y = X[:,1], hue=labels_.sense, 
                    palette = 'deep', s = 150, linewidth=1, edgecolor="black") 
    plt.xlim([x1, x2])
    plt.ylim([y1, y2])
    plt.title(type_)
    g.legend_.remove()
    return plt.show()


labels = sense_data_load()

words = list(set([x[2] for x in labels]))

hom_list = []

for keyword in tqdm(words):
    a_ = 0
    b_ = 0

    for layer in range(13):
        
        os.chdir(f'{auth2_path}/context_div/ms_embs/bert/')
        original = torch.load('original/original_all_'+str(layer)+'.pt')
        laser = torch.load('laser/laser_all_'+str(layer)+'.pt')
     
        combined_o = list(zip(labels, original))
        combined_l = list(zip(labels, laser))
    
        filt_o = [x for x in combined_o if x[0][2] == keyword]
        filt_o = sorted(filt_o, key = lambda x: x[0][0][3])
        filt_l = [x for x in combined_l if x[0][2] == keyword]
        filt_l = sorted(filt_l, key = lambda x: x[0][0][3])
        labels_ = [x[0] for x in filt_o]
        labels_ = pd.DataFrame(labels_, columns = ['dataset', 'position', 'word', 'sense'])
    
        if layer == 0:
            wf = len(filt_l)
            ws = len(list(set([x[0][3] for x in filt_l])))
            print("word: ", keyword)
            print("Total word occurrences: ", wf)
            print("Unique word senses: ", ws)
    
        if wf == 1 :
            break
        
        cluster_o = cluster_embs(filt_o, ws)
        cluster_l = cluster_embs(filt_l, ws)
    
        labels_['o_labels'] = cluster_o.labels_
        labels_['l_labels'] = cluster_l.labels_
    
    
        a = round(homogeneity_score(labels_.sense, labels_.o_labels),3)
        b = round(homogeneity_score(labels_.sense, labels_.l_labels),3)
        print("Layer :", layer)
        print("Sense homogeneity in")
        print("original embedding clusters:", a)
        print("laser embedding clusters:", b)
        
        hom_list.append((layer, ws, a, b))

torch.save(hom_list, f"{auth2_path}/makesense/analysis/heir_all.pt")
        
    
    