#Load libraries
import torch
from tqdm import tqdm
import sklearn
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import numpy as np

'''Retrofitting functions'''

#centering
def mean_centering(embs):
    centered = torch.mean(embs, axis = 0)
    emb_c = embs - centered
    return emb_c

#format vector
def format_vector(embs, vanilla):
    vector_list = embs
    id_ = [x[0] for x in vanilla]
    word_ = [x[1] for x in vanilla]
    sense_ = [x[2] for x in vanilla]
    
    return list(zip(id_, word_, sense_, vector_list))

#dimension removal
def dim_removal(embs, d):
    embeds = torch.stack(list(zip(*embs))[3])
    mu = np.mean(embeds.numpy(), axis=0)
    pca = sklearn.decomposition.PCA()
    pca.fit(embeds)
    Xhat = np.dot(pca.transform(embeds)[:,d:768], pca.components_[d:768,:])
    Xhat += mu
    embeds_ = Xhat
    return embeds_

#retrofitting function    
def retrofit(embs):
    vector_list = []

    for i in tqdm(range(len(embs))):
        sense = embs[i][2]
        neighbours = np.concatenate(([x[3] for x in embs if x[2] == sense]), axis=0)
        n_i = len(neighbours)
        q_i_cap = np.array(embs[i][3])
        n_i_q_i_ = n_i * q_i_cap
        n_i_q_i_cap = np.add(n_i_q_i_, np.sum(neighbours, axis = 0))
        
        numerator_i = n_i_q_i_cap
        denominator_i = 2*n_i
        q_i_new = numerator_i/denominator_i
            
        vector_list.append(q_i_new)
        
    vector_list = torch.tensor(vector_list)
            
    return vector_list
