import numpy as np
import os, pickle
import scipy.io

## BOW data helpers
def _fetch(path, name):
    if name == 'train':
        token_file = os.path.join(path, 'bow_train_tokens')
        count_file = os.path.join(path, 'bow_train_counts')
    elif name == 'test':
        token_file = os.path.join(path, 'bow_test_tokens')
        count_file = os.path.join(path, 'bow_test_counts')
    elif name == 'unsup':
        token_file = os.path.join(path, 'bow_unsup_tokens')
        count_file = os.path.join(path, 'bow_unsup_counts')
    elif name == 'valid':
        token_file = os.path.join(path, 'bow_valid_tokens')
        count_file = os.path.join(path, 'bow_valid_counts')
    tokens = scipy.io.loadmat(token_file)['tokens'].squeeze()
    counts = scipy.io.loadmat(count_file)['counts'].squeeze()
    return {'tokens': tokens, 'counts': counts}

def get_data(path):
    with open(os.path.join(path, 'vocab.pkl'), 'rb') as f:
        vocab = pickle.load(f)

    train = _fetch(path, 'train')
    test = _fetch(path, 'test')
    unsup = _fetch(path, 'unsup')
    valid = _fetch(path, 'valid')
    return vocab, train, test, unsup, valid

def sparse_matrix_format(tokens, counts, vocab_size):
    data, row_ind, col_ind = [],[],[]
    for row, (col, count) in enumerate(zip(tokens, counts)):
        ccol = col.squeeze().tolist()
        ccount = count.squeeze().tolist()
        if(isinstance(ccount, list)):
            data.extend(ccount)
            col_ind.extend(ccol)
            row_ind.extend([row]*len(ccol))
        else:
            data.append(ccount)
            col_ind.append(ccol)
            row_ind.append(row)

    return(scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=(len(tokens), vocab_size)))