import os
import argparse

import embeddings_aux as aux

DIM = 300
MIN_COUNT = 5
WINDOW_SIZE = 15

X_WORDS = ['he', 'him', 'his', 'himself', 'man', 'men', 'boy', 'boys']
Y_WORDS = ['she', 'her', 'hers', 'herself', 'woman', 'women', 'girl', 'girls']

def glove(corpus, out_file, dim=DIM, min_count=MIN_COUNT, window_size=WINDOW_SIZE):
    glove_all(corpus, dim, min_count, window_size, [], [], out_orig=out_file)

def glove_mitigated(corpus, out_file, dim=DIM, min_count=MIN_COUNT, window_size=WINDOW_SIZE, X_words=X_WORDS, Y_words=Y_WORDS):
    glove_all(corpus, dim, min_count, window_size, X_words, Y_words, out_new=out_file)

def glove_both(corpus, out_orig, out_new, dim=DIM, min_count=MIN_COUNT, window_size=WINDOW_SIZE, X_words=X_WORDS, Y_words=Y_WORDS):
    glove_all(corpus, dim, min_count, window_size, X_words, Y_words, out_orig=out_orig, out_new=out_new)

def glove_all(corpus, dim, min_count, window_size, X_words, Y_words, out_orig=None, out_new=None):
    vocab_orig = aux.glove_vocab(corpus, min_count)

    w2i = {w: i for i, w in enumerate(vocab_orig)}

    counts = aux.glove_cooccur(corpus, window_size)
    if not (out_new is None):
        x_counts = aux.glove_cooccur(corpus, window_size, key_words=X_words, shape=counts.shape)
        y_counts = aux.glove_cooccur(corpus, window_size, key_words=Y_words, shape=counts.shape)
        x_or_y_counts = aux.glove_cooccur(corpus, window_size, key_words=X_words+Y_words, shape=counts.shape)

        cc, vocab_new = aux.corrected_counts(
            counts,
            x_counts,
            y_counts,
            x_or_y_counts,
            vocab_orig
        )

        vec_arr_new, _ = aux.glove_matrix(cc, vocab_new, dim)
        aux.save_vocab_matrix(cc, vocab_new)

        os.system(f'cp ./data/glove_temp/vectors.txt {out_new}_vecs.txt')
        os.system(f'cp ./data/glove_temp/glove_vocab.txt {out_new}_vocab.txt')

    if not (out_orig is None):
        vec_arr_orig, _ = aux.glove_matrix(counts, vocab_orig, dim)
        aux.save_vocab_matrix(counts, vocab_orig)

        os.system(f'cp ./data/glove_temp/vectors.txt {out_orig}_vecs.txt')
        os.system(f'cp ./data/glove_temp/glove_vocab.txt {out_orig}_vocab.txt')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("corpus", help="corpus file to train from")
    parser.add_argument("output", help="output filename root (no extension)")
    parser.add_argument("--dim", help="embedding dimension (default 300)")
    parser.add_argument("--min_count", help="minimum count of word to not ignore (default 5)")
    parser.add_argument("--window_size", help="window size for contexts (default 15)")
    parser.add_argument("--x_words", help="list of words for x class, separated by commas")
    parser.add_argument("--y_words", help="list of words for x class, separated by commas")

    if parser.dim:
        dim = parser.dim
    if parser.min_count:
        min_count = parser.min_count
    if parser.window_size:
        window_size = parser.window_size
    if parser.x_words:
        x_words = parser.x_words.split(',')
    if parser.y_words:
        y_words = parser.y_words.split(',')

    glove_mitigated(parser.corpus, parser.output, dim, min_count, window_size, x_words, y_words)
