import os

import numpy as np
import scipy
import scipy.sparse

GLOVE_DIR = './GloVe-Bias/build'

CREC = np.dtype([('row', np.int32),
                 ('col', np.int32),
                 ('val', np.float64)])

def corrected_counts(counts, x_counts, y_counts, x_or_y_counts, i2w):
    counts = (counts + counts.T) / 2
    x_counts = (x_counts + x_counts.T) / 2
    y_counts = (y_counts + y_counts.T) / 2
    x_or_y_counts = (x_or_y_counts + x_or_y_counts.T) / 2

    xP = np.array(x_counts.sum(axis=0)).squeeze()
    yP = np.array(y_counts.sum(axis=0)).squeeze()

    keep_inds = (xP * yP) > 0

    vocab = [i2w[i] for i in np.arange(keep_inds.shape[0])[keep_inds]]

    counts = counts[keep_inds, :]
    counts = counts[:, keep_inds]
    x_counts = x_counts[keep_inds, :]
    x_counts = x_counts[:, keep_inds]
    y_counts = y_counts[keep_inds, :]
    y_counts = y_counts[:, keep_inds]
    x_or_y_counts = x_or_y_counts[keep_inds, :]
    x_or_y_counts = x_or_y_counts[:, keep_inds]

    xP = xP[keep_inds]
    yP = yP[keep_inds]

    xy_counts = x_counts + y_counts - x_or_y_counts


    cc = x_or_y_counts.copy().tocoo() #corrected counts

    cc.data[:] = (
          np.array(counts[cc.row, cc.col])
        / cc.data
        * (
            (  np.array(x_counts[cc.row, cc.col])
             / xP[cc.row]
             + np.array(y_counts[cc.row, cc.col])
             / yP[cc.row]
            )
          * (xP[cc.row] + yP[cc.row])/2
          - np.array(xy_counts[cc.row, cc.col])
          )
        )


    cc = cc.tocsc()
    cc = (cc.T + cc) / 2

    return cc, vocab

def write_cooccur(matrix, fname):
    # Write the cooccurrence (counts) array in the format GloVe expects.
    # This is  a stream of bytes representing tuples of (int, int, double)
    # corresponding to (row, col, value).  On the server this was tested
    # on, this was 4 + 4 + 8 bytes, for a total of 16, in little endian.
    # It's possible rows and columns may be switched, but that shouldn't
    # actually matter.

    matrix = matrix.tocoo()
    row = matrix.row.astype(np.int32)
    col = matrix.col.astype(np.int32)
    val = matrix.data.astype(np.float64)

    crec_m = np.zeros(row.shape, dtype=CREC)
    crec_m['row'] = row
    crec_m['col'] = col
    crec_m['val'] = val

    crec_m.tofile(fname)

def load_glove_cooccur(fname, shape=None):
    crec_m = np.fromfile(fname, dtype=CREC)

    return scipy.sparse.coo_matrix((crec_m['val'], (crec_m['row'], crec_m['col'])), shape=shape).tocsc()

def load_glove_vocab(fname):
    vocab = []
    with open(fname) as f:
        for line in f:
            vocab.append(line[:line.find(' ')])
    return vocab

def glove_vocab(corpus_file, min_word_count, verbose=2):
    vocab_file   = 'data/glove_temp/glove_vocab.txt'

    os.system(f'{GLOVE_DIR}/vocab_count -min-count {min_word_count} -verbose {verbose} < {corpus_file} > {vocab_file}')

    return load_glove_vocab(vocab_file)

def glove_cooccur(corpus_file, window_size, key_words=None, verbose=2, shape=None):
    vocab_file   = 'data/glove_temp/glove_vocab.txt'
    cooccur_file = 'data/glove_temp/cooccurrence.bin'

    if key_words is None:
        os.system(f'{GLOVE_DIR}/cooccur -memory 4.0 -vocab-file {vocab_file} -verbose {verbose} -window-size {window_size} < {corpus_file} > {cooccur_file}')
    else:
        if len(",".join(key_words)) >= 1000:
            raise ValueError("key_words array too long.")
        os.system(f'{GLOVE_DIR}/cooccur -memory 4.0 -vocab-file {vocab_file} -verbose {verbose} -window-size {window_size} -key-words {",".join(key_words)} < {corpus_file} > {cooccur_file}')

    return load_glove_cooccur(cooccur_file, shape=shape)

def save_vocab_matrix(matrix, vocab):
    vocab_file   = 'data/glove_temp/glove_vocab.txt'

    s = np.array(matrix.sum(axis=0)).squeeze()


    with open(vocab_file, 'w') as f:
        for v, c in zip(vocab, s):
            f.write(f"{v} {c}\n")

def glove_matrix(matrix, vocab, vec_dim, verbose=2, seed=0, output_dir='data/glove_temp'):
    # Runs GloVe on a matrix of cooccurrence counts (or probabilities, sure).

    vocab_file   = f'{output_dir}/glove_vocab.txt'
    cooccur_file = f'{output_dir}/cooccurrence.bin'
    shuffle_file = f'{output_dir}/cooccurrence.shuf.bin'

    vec_file     = f'{output_dir}/vectors'

    save_vocab_matrix(matrix, vocab)

    write_cooccur(matrix, cooccur_file)
    os.system(f'{GLOVE_DIR}/shuffle -memory 4.0 -verbose {verbose} < {cooccur_file} > {shuffle_file}')
    os.system(f'{GLOVE_DIR}/glove -save-file {vec_file} -threads 8 -input-file {shuffle_file} -x-max 10 -iter 15 -vector-size {vec_dim} -binary 2 -vocab-file {vocab_file} -verbose {verbose} -seed {seed}')

    return load_glove(f'{vec_file}.txt')

def load_glove(data_file):

    num_lines = 0
    num_dims = 0
    with open(data_file) as f:
        num_dims = f.readline().count(' ')
        f.seek(0)
        for line in f:
            num_lines += 1

    vocab = []
    vec_arr = np.zeros((num_lines, num_dims))
    with open(data_file) as f:
        i = 0
        for line in f:
            spl_line = line.split(' ')
            vocab.append(spl_line[0])
            vec_arr[i,:] = [float(x) for x in spl_line[1:]]
            i += 1

    return vec_arr, vocab
