import os, sys, os.path as osp, time
import re
import numpy as np
import torch
# from nltk.tokenize import word_tokenize
from sklearn.feature_extraction.text import TfidfTransformer, TfidfVectorizer, CountVectorizer
from tqdm import tqdm
from logzero import logger
from transformers import BertTokenizer

import joblib
import argparse
import scipy.sparse as sp
import scipy.sparse as smat
from collections import Counter

import multiprocessing as mp
from multiprocessing import Pool
import subprocess
import shlex
from multiprocessing.pool import ThreadPool

import nltk
from nltk.stem.porter import PorterStemmer


def make_label_list(label_file):
    with open(label_file) as fp:
        labels = np.array([[label for label in line.split()]
                             for line in tqdm(fp, desc='making label list', leave=False)])
    return labels

def tokenize(text):
    tokens = nltk.word_tokenize(text)
    stems = []
    for item in tokens:
        stems.append(PorterStemmer().stem(item))
    return ' '.join(stems)

def tok_batch(batch):
    res = []
    for i, c in tqdm(enumerate(batch)):
        res.append(tokenize(c))
    return res

tok_dic = []
def push_res(result):
    tok_dic.append(result)

def tok(corpus):
    res = [tokenize(c) for c in tqdm(corpus, desc=f'token')]
    return res
    # pool = ThreadPool(mp.cpu_count())
    #
    # #res = [pool.apply(tokenize, (c,)) for c in tqdm(corpus, desc=f'tokenize')]
    # #res = [r.get() for r in res]
    # n = len(corpus)
    # bs = 5000
    # res = []
    # st = time.time()
    # for sid in tqdm(range(0, n, bs), desc='tok'):
    #     eid = min(sid + bs, n)
    #     pool.apply_async(tok_batch, (corpus[sid:eid],), callback=push_res)
    # pool.close()
    # pool.join()
    # print(time.time()-st, len(tok_dic))

def load_data(text_path):
    with open(text_path) as fp:
        texts = fp.readlines()
    return texts

def write_data(filename, data):
    with open(filename, 'w') as fout:
        for d in data:
            fout.write(d.replace('\n', ' ') + '\n')
    fout.close()

def make_npz(label_map, labels):
    n = len(labels)
    m = len(label_map)
    row = np.arange(n)
    col = [label_map[l] for l in labels]
    val = np.ones(n)
    return smat.csr_matrix((val, (row, col)), shape=(n, m))

def make_label_npz(label_list_path, train_label_path, test_label_path):
    label_list = load_data(label_list_path)
    train_label = load_data(train_label_path)
    test_label = load_data(test_label_path)
    label_map = {label: i for i, label in enumerate(label_list)}
    train_label_npz, test_label_npz = make_npz(label_map, train_label), make_npz(label_map, test_label)
    smat.save_npz(osp.splitext(train_label_path)[0] + '.npz', train_label_npz)
    smat.save_npz(osp.splitext(test_label_path)[0] + '.npz', test_label_npz)

def main(args, data_path, vocab_size, **kwargs):
    train_text_path = osp.join(data_path, f'train.{args.feature}.txt')
    test_text_path = osp.join(data_path, f'test.{args.feature}.txt')
    train_tok_path = osp.join(data_path, f'train.{args.feature}.tok.txt')
    test_tok_path = osp.join(data_path, f'test.{args.feature}.tok.txt')
    train_feature_path = osp.join(data_path, f'train.tfidf.{args.feature}.{vocab_size}.npz')
    test_feature_path = osp.join(data_path, f'test.tfidf.{args.feature}.{vocab_size}.npz')
    tf_model_path = osp.join(data_path, f'tfidf.{args.feature}.model.{vocab_size}')

    label_list_path = osp.join(data_path, 'label_list.txt')
    head_label_list_path = osp.join(data_path, 'head_label_list.txt')
    hs_label_list_path = osp.join(data_path, 'hs_label_list.txt')
    train_label_path = osp.join(data_path, 'train.label.txt')
    test_label_path = osp.join(data_path, 'test.label.txt')
    train_head_label_path = osp.join(data_path, 'train.head_label.txt')
    test_head_label_path = osp.join(data_path, 'test.head_label.txt')
    train_hs_label_path = osp.join(data_path, 'train.hs_label.txt')
    test_hs_label_path = osp.join(data_path, 'test.hs_label.txt')

    if args.process_label:
        logger.info('process label')
        make_label_npz(label_list_path, train_label_path, test_label_path)
        make_label_npz(head_label_list_path, train_head_label_path, test_head_label_path)
        make_label_npz(hs_label_list_path, train_hs_label_path, test_hs_label_path)

    # tokenize
    st = time.time()
    if osp.exists(train_tok_path) and not args.overwrite:
        train_tok = load_data(train_tok_path)
    else:
        logger.info('parse train')
        train_corpus = load_data(train_text_path)
        train_tok = tok(train_corpus)
        write_data(train_tok_path, train_tok)
    if osp.exists(test_tok_path) and not args.overwrite:
        test_tok = load_data(test_tok_path)
    else:
        logger.info('parse test')
        test_corpus = load_data(test_text_path)
        test_tok = tok(test_corpus)
        write_data(test_tok_path, test_tok)
    print(f'total time {time.time() - st}')

    train_corpus = train_tok
    test_corpus = test_tok

    # do train
    model = None
    if not osp.exists(train_feature_path) or args.overwrite:
        logger.info('load train text')
        #train_corpus = load_text(train_text_path)

        logger.info('build model')
        st = time.time()
        # model = TfidfVectorizer(tokenizer=tokenize, stop_words='english',
        #                     max_features=vocab_size)
        model = TfidfVectorizer(max_features=vocab_size)
        logger.info('transform train feature')
        train_feature = model.fit_transform(train_corpus)
        logger.info(f'{time.time() - st}')
        joblib.dump(model, tf_model_path)
        smat.save_npz(train_feature_path, train_feature)

    # do test
    if not osp.exists(test_feature_path) or args.overwrite:
        logger.info('load test text')
        # test_corpus = load_text(test_text_path)
        model = joblib.load(tf_model_path) if model is None else model
        logger.info('transform test feature')
        test_feature = model.transform(test_corpus)
        smat.save_npz(test_feature_path, test_feature)

if __name__ == '__main__':
    args = argparse.ArgumentParser(description='Preprocess')
    args.add_argument('--data_path', required=True, type=str)
    args.add_argument('--feature', default='text', type=str)
    args.add_argument('--label_path', type=str)
    args.add_argument('--vocab_path', type=str)
    args.add_argument('--vocab_size', type=int, default=50000)
    args.add_argument('--overwrite', action='store_true')
    args.add_argument('--process_label', action='store_true')
    args = args.parse_args()

    print(vars(args))
    main(args, **vars(args))
    #main(args.text_path, args.tokenized_path, args.label_path, args.vocab_path, args.emb_path)




























