import os, sys, os.path as osp, time
import re
import numpy as np
from tqdm import tqdm
from logzero import logger
import pandas as pd

import joblib
import scipy.sparse as sp
import scipy.sparse as smat

import multiprocessing as mp
from multiprocessing import Pool
import subprocess
import shlex
from multiprocessing.pool import ThreadPool

import nltk
from nltk.tokenize import word_tokenize
from nltk.stem.porter import PorterStemmer

from sklearn.feature_extraction.text import TfidfTransformer, TfidfVectorizer, CountVectorizer

import argparse
from collections import Counter
import json

def make_list(label):
    """
    get label list ranked by freq
    """
    label_name = np.array(list(set(label)))
    label_map = {label : i for i, label in enumerate(label_name)}
    n = len(label_name)
    freq = np.zeros(n)
    for i, ll in enumerate(label):
        freq[label_map[ll]] += 1
    sorted_id = np.argsort(freq)[::-1]
    label_list = label_name[sorted_id]
    return label_list

def get_labels(file_name):
    labels = []
    file = open(file_name, 'r')
    lines = file.readlines()
    file.close()
    for line in lines:
        label = line.strip()
        labels.append(label)
    return labels

def get_csv_examples(input_path):
    filename = osp.basename(input_path)
    data_dir = osp.dirname(input_path)
    logger.info(f"Data {filename} in {data_dir}")

    # passing na_filter=False can improve the performance of reading a large file.
    df = pd.read_csv(os.path.join(data_dir, filename), na_filter= False)
    texts, labels = [], []
    for (i, row) in tqdm(enumerate(df.values), desc='read from csv'):
        # row 0, 1, 2 ->  id, text, label list
        texts.append(row[1].replace("\n", " ").replace("\t", " ").replace("\r", " "))
        labels.append(row[2].split(','))
    return texts, labels

def convert_labels(labels, label_list):
    label_map = {label: i for i, label in enumerate(label_list)}
    label_ids = []
    indptr = []
    for ll in tqdm(labels):
        indptr.append(len(label_ids))
        label_id = [label_map[item] for item in ll]
        label_ids.extend(label_id)
    indptr.append(len(label_ids))
    val = np.ones(len(label_ids))
    return smat.csr_matrix((val, label_ids, indptr), shape=(len(indptr)-1, len(label_map)))

def write_data(filename, data):
    with open(filename, 'w') as fout:
        for d in data:
            fout.write(d + '\n')
    fout.close()

def load_data(text_path):
    with open(text_path) as fp:
        texts = fp.readlines()
    return texts

def create_raw_text_and_label(data_dir, data_type):
    data_path = osp.join(data_dir, f'{data_type}.csv')
    texts, labels = get_csv_examples(data_path)
    label_list_path = osp.join(args.data_dir, f'labels.txt')
    label_list = get_labels(label_list_path)
    label_npz = convert_labels(labels, label_list)
    smat.save_npz(osp.join(data_dir, f'y.{data_type}.npz'), label_npz)
    write_data(osp.join(data_dir, f'x.{data_type}.raw.txt'), texts)
    return texts, label_npz

def tokenize(sentence):
    tok_sent= [token.lower() for token in word_tokenize(sentence)
            if len(re.sub(r'[^\w]', '', token)) > 0]
    return ' '.join(tok_sent)

def stem(text):
    ps = PorterStemmer()
    stem_sent = [ps.stem(t) for t in text.split()]
    return ' '.join(stem_sent)

def get_top_features(ids, vals, names, top=None):
    top = len(ids) if top is None else min(top, len(ids))
    sorted_ids = np.argsort(vals)[::-1][:top]
    # print(vals[sorted_ids], ids[sorted_ids])
    # print([vocab[i] for i in ids[sorted_ids]])

    names = np.array(names)
    # print(len(vals), len(names))
    return ids[sorted_ids], names[sorted_ids]

def write_top_feature(args, feature, model, out_file):
    n, m = feature.shape
    if osp.exists(out_file):
        logger.info(f'{out_file} exists')
        # if args.overwrite:
        #     os.remove(out_file)
    with open(out_file, 'w') as fout:
        for i in tqdm(range(n)):
            ids = feature[i].indices
            vals = feature[i].data
            names = model.inverse_transform(feature[i])[0]
            top_ids, top_names = get_top_features(ids, vals, names, top=args.top)
            if i < 5:
                logger.info(f"{' '.join(top_names)}")
            fout.write(' '.join(top_names) + '\n')
    fout.close()

def main(args, data_dir, **kwargs):
    train_stem_text_path = osp.join(data_dir, 'x.train.stem.txt')
    dev_stem_text_path = osp.join(data_dir, 'x.dev.stem.txt')
    train_tfidf_path = osp.join(data_dir, f'x.train.{args.feature_name}.{args.vocab_size}.npz')
    dev_tfidf_path = osp.join(data_dir, f'x.dev.{args.feature_name}.{args.vocab_size}.npz')
    tf_model_path = osp.join(data_dir, f'model.{args.feature_name}.{args.vocab_size}')
    train_key_path = osp.join(data_dir, f'x.train.key.{args.feature_name}.txt')
    dev_key_path = osp.join(data_dir, f'x.dev.key.{args.feature_name}.txt')

    # create raw text.txt
    if args.create_raw:
        logger.info('create raw text')
        create_raw_text_and_label(data_dir, 'train')
        create_raw_text_and_label(data_dir, 'dev')

    # tokenize for tfidf feature
    if args.tokenize:
        logger.info('tokenize text')
        train_raw_text = load_data(osp.join(data_dir, 'x.train.raw.txt'))
        train_tok_text = [tokenize(sent) for sent in tqdm(train_raw_text, desc='tokenize train')]
        write_data(osp.join(data_dir, 'x.train.tok.txt'), train_tok_text)
        dev_raw_text = load_data(osp.join(data_dir, 'x.dev.raw.txt'))
        dev_tok_text = [tokenize(sent) for sent in tqdm(dev_raw_text, desc='tokenize test')]
        write_data(osp.join(data_dir, 'x.dev.tok.txt'), dev_tok_text)
    if args.stem:
        logger.info('stem text')
        train_tok_text = load_data(osp.join(data_dir, 'x.train.tok.txt'))
        train_stem_text = [stem(sent) for sent in tqdm(train_tok_text, desc='stem train')]
        write_data(osp.join(data_dir, 'x.train.stem.txt'), train_stem_text)
        dev_tok_text = load_data(osp.join(data_dir, 'x.dev.tok.txt'))
        dev_stem_text = [stem(sent) for sent in tqdm(dev_tok_text, desc='stem test')]
        write_data(osp.join(data_dir, 'x.dev.stem.txt'), dev_stem_text)
    if args.tfidf:
        logger.info(f'tfidf feature name: {args.feature_name}')
        if not osp.exists(tf_model_path) or args.overwrite:
            train_stem_text = load_data(train_stem_text_path)
            st = time.time()
            if args.vocab_size != -1:
                model = TfidfVectorizer(max_features=args.vocab_size, stop_words='english')
            else:
                model = TfidfVectorizer(stop_words='english')
            logger.info('transform train feature')
            train_feature = model.fit_transform(train_stem_text)
            logger.info(f'{time.time() - st}')
            joblib.dump(model, tf_model_path)
            logger.info(f'len vocab: {len(model.vocabulary_)}')
            smat.save_npz(train_tfidf_path, train_feature)
            dev_stem_text = load_data(dev_stem_text_path)
            dev_feature = model.transform(dev_stem_text)
            smat.save_npz(dev_tfidf_path, dev_feature)

    if args.top_feature:
        logger.info(f'extract top feature from file: {args.feature_name}')
        model = joblib.load(tf_model_path)
        logger.info(f'number of vocab {len(model.vocabulary_)}')
        #train_stem_text = load_data(train_stem_text_path)
        train_tfidf = smat.load_npz(train_tfidf_path)
        write_top_feature(args, train_tfidf, model, train_key_path)
        dev_tfidf = smat.load_npz(dev_tfidf_path)
        write_top_feature(args, dev_tfidf, model, dev_key_path)
        #dev_stem_text = load_data(dev_stem_text_path)



if __name__ == '__main__':
    args = argparse.ArgumentParser(description='Preprocess')
    args.add_argument("--data_dir", default=None, type=str, required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    args.add_argument('--data_type', type=str, default=True)
    args.add_argument('--vocab_size', type=int, default=-1)
    args.add_argument('--create_raw', action='store_true')
    args.add_argument('--tokenize', action='store_true')
    args.add_argument('--data', type=str)
    args.add_argument('--stem', action='store_true')
    args.add_argument('--tfidf', action='store_true')
    args.add_argument('--top_feature', action='store_true')
    args.add_argument('--feature_name', default='tfidf', type=str)
    args.add_argument('--overwrite', action='store_true')
    args.add_argument('--top', type=int, default=None)
    args = args.parse_args()
    main(args, **vars(args))

