#!/usr/bin/env python
# encoding: utf-8

import argparse
from collections import Counter
import itertools
import json
import os, os.path as osp
from logzero import logger
import numpy as np
import pickle
import time
from tqdm import tqdm
import scipy as sp
import scipy.sparse as smat
from sklearn.preprocessing import normalize
import pandas as pd
import torch
import multiprocessing as mp
from multiprocessing import Pool
import subprocess
import shlex
from multiprocessing.pool import ThreadPool

from transformers import (
    WEIGHTS_NAME,
    BertConfig,
    BertForSequenceClassification,
    BertTokenizer,
    RobertaConfig,
    RobertaForSequenceClassification,
    RobertaTokenizer,
    XLMConfig,
    XLMForSequenceClassification,
    XLMTokenizer,
    XLNetConfig,
    XLNetForSequenceClassification,
    XLNetTokenizer,
    DistilBertConfig,
    DistilBertForSequenceClassification,
    DistilBertTokenizer,
    AlbertConfig,
    AlbertForSequenceClassification,
    AlbertTokenizer,
)

# ALL_MODELS = sum(
#     (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig,)),
#     (),
# )

MODEL_CLASSES = {
    "bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
    "xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
    "xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
    "roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
    "distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer,),
    "albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
}

def get_labels(file_name):
    labels = []
    file = open(file_name, 'r')
    lines = file.readlines()
    file.close()
    for line in lines:
        label = line.strip()
        labels.append(label)
    return labels

def make_mtx(labels, n_label):
    row, col = [], []
    for i,l in enumerate(labels):
        for ll in l:
            row.append(i)
            col.append(ll)
    data = np.ones(len(row))
    mtx = smat.csr_matrix((data, (row, col)), shape=(len(labels), n_label))
    return mtx

def make_dic(target, top_label=None):
    """
    transformed labels are sparse matrix, make it a dictionary
    if top_label is used, labels should sorted
    """
    if top_label is not None:
        target = target[:, :top_label]
    n, m = target.shape
    d = []
    indptr, indices = target.indptr, target.indices
    for lo, hi in zip(indptr[:-1], indptr[1:]):
        d.append(indices[lo:hi])
    return d, m

def load_data(text_path):
    with open(text_path) as fp:
        texts = fp.readlines()
    return texts

def load_keyword(keyword_path, top=None):
    if args.keyword is None:
        return None
    keyword = load_data(keyword_path)
    for i, key in enumerate(keyword):
        ll = key.split()
        if top is None or top < 0: # use all keywords
            k = len(ll)
        else:
            k = min(top, len(ll))
        keyword[i] = ' '.join(ll[:k])
    return keyword

def convert_labels(label_mtx, label_map):
    labels, n_label = make_dic(label_mtx)
    #label_map = {label : i for i, label in enumerate(label_list)}
    label_ids = []
    for ll in tqdm(labels):
        label_id = [label_map[item] for item in ll]
        label_ids.append(sorted(label_id))
    return label_ids, n_label

def convert_features(texts, tokenizer, max_seq_length=512, keyword=None):
    features = []
    # pool = ThreadPool(mp.cpu_count())
    # features = [pool.apply(tokenizer,
    #                        kwds={'text': text, 'padding':True, 'truncation': True, 'max_length': max_seq_length})
    #             for text in tqdm(texts)]
    if keyword is not None:
        assert len(keyword) == len(texts), 'keyword text mismatch'
    for (idx, text) in tqdm(enumerate(texts), desc='tokenize feature'):
        # input_ids, token_type_ids, attention_mask
        if keyword is None:
            encoding = tokenizer(text.strip(), padding='max_length', truncation=True, max_length=max_seq_length,
                                 return_token_type_ids=True)
        else:
            encoding = tokenizer(keyword[idx].strip(), text.strip(), padding='max_length', truncation=True, max_length=max_seq_length,
                             return_token_type_ids=True)
        if idx < 5:
            logger.info("*** Example ***")
            logger.info(f"ins id: {idx}")
            tokens = tokenizer.decode(encoding["input_ids"])
            logger.info(f"tokens: {tokens}")
            # logger.info(f"ids: {' '.join([str(x) for x in encoding['input_ids']])}")
            # logger.info(f"types: {' '.join([str(x) for x in encoding['token_type_ids']])}")
            # logger.info(f"attention mask: {' '.join([str(x) for x in encoding['attention_mask']])}")
        features.append(encoding)
    return features

def process_label(args):
    """
    NOTE: methods like APLC requires the label id to be sorted by frequency
    i.e. label 0 has largest number of training instance.
    We change the label id according to their frequency, but this should not affect model performance
    """
    train_y_path = osp.join(args.data_dir, f'Y.trn.npz')
    dev_y_path = osp.join(args.data_dir, f'Y.tst.npz')
    train_label_mtx = smat.load_npz(train_y_path)
    test_label_mtx = smat.load_npz(dev_y_path)

    train_l_freq = np.sum(train_label_mtx, 0)
    train_l_freq = np.squeeze(np.asarray(train_l_freq))
    orig2new = np.argsort(train_l_freq)[::-1]

    train_labels, n_label = convert_labels(train_label_mtx, orig2new)
    test_labels, _ = convert_labels(test_label_mtx, orig2new)
    np.save(osp.join(args.data_dir, f'label.train.full.npy'), [train_labels, n_label])
    np.save(osp.join(args.data_dir, f'label.dev.full.npy'), [test_labels, n_label])

    smat.save_npz(osp.join(args.data_dir, f'y.train.npz'), make_mtx(train_labels, n_label))
    smat.save_npz(osp.join(args.data_dir, f'y.dev.npz'), make_mtx(test_labels, n_label))

def truncate_texts(texts, max_len):
    tt = max(max_len * 10, 4000)
    logger.info(f'truncate size to {tt}')
    for i in range(len(texts)):
        texts[i] = texts[i][:tt]
    return texts

def tokenize_data(args, data_type):
    if args.keyword:
        cached_data_file = os.path.join(args.data_dir,
                        f'{data_type}.{args.model_name}.{args.max_seq_length}.{args.feature_name}_{args.top_keyword}')
    else:
        cached_data_file = os.path.join(args.data_dir,
                                        f'{data_type}.{args.model_name}.{args.max_seq_length}')

    if osp.exists(cached_data_file) and not args.overwrite:
        logger.info(f"{data_type} exists, skip")
        return

    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)
    # path and load data
    data_path = osp.join(args.data_dir, f'{data_type}_raw_texts.txt')
    texts = load_data(data_path)
    # truncate text
    texts = truncate_texts(texts, args.max_seq_length)

    if args.keyword:
        keyword_path = osp.join(args.data_dir, f'x.{data_type}.key.{args.feature_name}.txt')
        keyword = load_keyword(keyword_path, top=args.top_keyword)
    else:
        keyword = None
    features = convert_features(texts, tokenizer, max_seq_length=args.max_seq_length, keyword=keyword)
    # save trn features
    logger.info(f'save data to {cached_data_file}')
    torch.save(features, cached_data_file)

def tokenize_keyword(args, data_type):
    cached_data_file = os.path.join(args.data_dir,
                                    f'{data_type}.{args.model_name}.{args.max_seq_length}.keyword_only.{args.feature_name}')

    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)
    # path and load data
    keyword_path = osp.join(args.data_dir, f'x.{data_type}.key.{args.feature_name}.txt')
    keyword = load_keyword(keyword_path, top=args.top_keyword)
    # reuse convert_features function, replace text with keyword
    features = convert_features(keyword, tokenizer, max_seq_length=args.max_seq_length, keyword=None)
    # save trn features
    logger.info(f'save data to {cached_data_file}')
    torch.save(features, cached_data_file)

def tokenize_shuffle_data(args, data_type):

    cached_data_file = os.path.join(args.data_dir,
                                        f'{data_type}.{args.model_name}.{args.max_seq_length}.shuffle')

    if osp.exists(cached_data_file) and not args.overwrite:
        logger.info(f"{data_type} exists, skip")
        return


    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)
    # path and load data
    data_path = osp.join(args.data_dir, f'{data_type}.csv')
    #keyword_path = osp.join(args.data_dir, f'x.{data_type}.key.{args.feature_name}.txt')
    texts, _ = get_csv_examples(data_path)
    #keyword = load_keyword(keyword_path, top=args.top_keyword)
    keyword = None

    # shuffle
    for i, text in enumerate(texts):
        text = np.array(text.split(' ')[:args.max_seq_length])
        l = len(text)
        ids = np.arange(l)
        np.random.shuffle(ids)
        if i < 5:
            logger.info(f"text: {' '.join(text)}")
            logger.info(f"shuffled: {' '.join(text[ids])}")
        texts[i] = ' '.join(text[ids])

    features = convert_features(texts, tokenizer, max_seq_length=args.max_seq_length, keyword=keyword)
    # save trn features
    logger.info(f'save data to {cached_data_file}')
    torch.save(features, cached_data_file)

def main(args):
    args.model_type = args.model_type.lower()
    logger.info(f"model type: {args.model_type}, model name: {args.model_name}, tokenizer_name: {args.tokenizer_name}")

    if args.do_label:
        process_label(args)

    if not args.do_feature:
        exit(0)
    #if args.do_train: # hide train dev logic
    for mode in ['train', 'dev']:
        if args.keyword_only:
            logger.info(f'keyword only with top {args.top_keyword}')
            tokenize_keyword(args, data_type=mode)
        else:
            logger.info(f'process doc, keyword: {args.keyword}, top: {args.top_keyword}')
            if args.shuffle:
                tokenize_shuffle_data(args, data_type=mode)
            else:
                tokenize_data(args, data_type=mode)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", default=None, type=str, required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--model_type", default=None, type=str, required=True,
                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument("--model_name", default=None, type=str, required=True)
    parser.add_argument("--tokenizer_name", default=None, type=str,
                        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument("--cache_dir", default="", type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length", default=128, type=int,
                        help="The maximum total input sequence length after tokenization. Sequences longer "
                             "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument('--feature_name', default='tfidf', type=str)
    parser.add_argument("--do_feature", action='store_true')
    parser.add_argument("--do_label", action='store_true')
    #parser.add_argument("--do_eval", action='store_true')

    # keyword
    parser.add_argument('--keyword', action='store_true')
    parser.add_argument('--top_keyword', type=int, default=None)
    parser.add_argument('--overwrite', action='store_true')
    parser.add_argument('--keyword_only', action='store_true')
    parser.add_argument('--shuffle', action='store_true')

    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")
    args = parser.parse_args()
    if args.tokenizer_name is None:
        args.tokenizer_name = args.model_name
    print(args)
    main(args)





























