#Load libraries
from minicons import cwe
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
import csv
import sklearn
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import pickle
import os
import numpy as np
import sys
from retro_funcs import mean_centering, format_vector, dim_removal, retrofit

'''Load Embeddings'''

bert = cwe.CWE("bert-base-uncased", "cuda:1")
LAYER = int(sys.argv[1]) # specify which layer: [0, 12]


sense_data = []
with open("/makesense_dir_path/data/sense_metadata.csv", "r") as f:
    reader = csv.reader(f)
    next(f)
    for line in reader:
        sense_data.append(line)

sense_dl = DataLoader(sense_data, batch_size = 100)

embedding_data = []
for batch in tqdm(sense_dl):
    _, _, sentence, word, position, sense = batch
    position = [int(item) for item in position]
    position_tuple = [(i, i+1) for i in position]
    representation = bert.extract_representation(list(zip(sentence, position_tuple)), layer = LAYER)
    embedding_data.extend(list(zip(position, word, sense, representation)))

torch.cuda.empty_cache()
    
#get entire embedding matrix
for_pca = torch.stack(list(zip(*embedding_data))[3])

#save original
os.chdir('/embedding_dir_path/bert/original/')
torch.save(for_pca, 'original_all_'+str(LAYER)+'.pt')


'''Perform Anisotropy Adjusted Retrofitted'''

#centering
van_cent = format_vector(mean_centering(for_pca), embedding_data)  ##centering

#low_anisotropy
van_pca = format_vector(dim_removal(van_cent, 1), embedding_data)  ##la  --> centered + pca
van_pca_emb = np.stack(list(zip(*van_pca))[3])

os.chdir('/embedding_dir_path/ms_embs/bert/low_anisotropy/')
torch.save(van_pca_emb, 'la_all_'+str(LAYER)+'.pt')
    
#sense_retrofitted
van_ser = format_vector(retrofit(embedding_data), embedding_data)  ##ser --> sense retrofitted
van_ser_emb = torch.stack(list(zip(*van_ser))[3])

os.chdir('/embedding_dir_path/ms_embs/bert/sense_retrofitted/')
torch.save(van_ser_emb, 'ser_all_'+str(LAYER)+'.pt')
    
#laser
van_pca_retro = format_vector(retrofit(van_pca), embedding_data)  ##laser --> centered + pca + sense retrofitted
van_retro_emb = torch.stack(list(zip(*van_pca_retro))[3])

os.chdir('/embedding_dir_path/ms_embs/bert/laser/')
torch.save(van_retro_emb, 'laser_all_'+str(LAYER)+'.pt')
    

