import pandas as pd 
from gensim.models.word2vec import Word2Vec
import pickle
import numpy as np
from helpers import timer 
import argparse
'''
Self train word2vec model on unsupervised dataset
and save resulting embedding matrix in data/word2vec_matrix_trained.npy
'''
@timer
def train_word2vec_unsupervised(dataset="agnews",embedding_size=300,window=5,workers=1,seed = 1):
    #load data
    datapath = "data/data_"+dataset
    data= pd.read_csv(datapath+"/unsupervised.csv")
    data=data.Instance.values
    data=[i.split(" ") for i in data]
    print('Training word2vec model...')
    model = Word2Vec(sentences=data, vector_size=embedding_size, window=window, workers=workers,seed = seed)
    word_vectors=model.wv
    print('model train complete')


    with open(datapath+'/id2word.pkl', 'rb') as fname:
        id2word = pickle.load(fname)

    word2id=dict([(id2word[k],k) for k in id2word])

    with open(datapath+'/vocab.pkl', 'rb') as fname:
        my_vocab = pickle.load(fname)

    vocab_size=len(my_vocab)

    count_missing=0
    print('Generating word2vec matrix for vocab...')
    word2vec_matrix = np.zeros((vocab_size,embedding_size))
    for w in word2id:
        if w in word_vectors:
            index=word2id[w]
            word2vec_matrix[index,:] = word_vectors[w]
        else:
            count_missing+=1
 

    print("word embedding matrix has shape: ",word2vec_matrix.shape)
    print("total missing words:",count_missing,"  percentage missing: ", (count_missing/vocab_size)*100,"%")

    fname = datapath+'/word2vec_matrix_trained_embsize'+str(embedding_size)+'.npy'
    print(f'file saved at {fname}')
    with open(fname,'wb') as f:
        np.save(f,word2vec_matrix)

    print("Trained embedding saved...")
    return word2vec_matrix

if __name__=='__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument(
    #     "--embedding_size",
    #     type=int,
    #     default=300,
    #     help="runtime one or df",
    # )
    # args = parser.parse_args()
    # print(args)
    train_word2vec_unsupervised(dataset="agnews",embedding_size=200,window=5,workers=3) 