#!/usr/bin/env python
# encoding: utf-8

import argparse
from collections import Counter
import itertools
import json
import os, os.path as osp
from logzero import logger
import numpy as np
import pickle
import time
from tqdm import tqdm
import scipy as sp
import scipy.sparse as smat
from sklearn.preprocessing import normalize
import pandas as pd
import torch
import multiprocessing as mp
from multiprocessing import Pool
import subprocess
import shlex
from multiprocessing.pool import ThreadPool

from transformers import (
    WEIGHTS_NAME,
    BertConfig,
    BertForSequenceClassification,
    BertTokenizer,
    RobertaConfig,
    RobertaForSequenceClassification,
    RobertaTokenizer,
    XLMConfig,
    XLMForSequenceClassification,
    XLMTokenizer,
    XLNetConfig,
    XLNetForSequenceClassification,
    XLNetTokenizer,
    DistilBertConfig,
    DistilBertForSequenceClassification,
    DistilBertTokenizer,
    AlbertConfig,
    AlbertForSequenceClassification,
    AlbertTokenizer,
)

# ALL_MODELS = sum(
#     (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig,)),
#     (),
# )

MODEL_CLASSES = {
    "bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
    "xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
    "xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
    "roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
    "distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer,),
    "albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
}

def get_labels(file_name):
    labels = []
    file = open(file_name, 'r')
    lines = file.readlines()
    file.close()
    for line in lines:
        label = line.strip()
        labels.append(label)
    return labels

def load_data(text_path):
    with open(text_path) as fp:
        texts = fp.readlines()
    return texts

def load_keyword(keyword_path, top=None):
    if args.keyword is None:
        return None
    keyword = load_data(keyword_path)
    for i, key in enumerate(keyword):
        ll = key.split()
        if top is None or top < 0: # use all keywords
            k = len(ll)
        else:
            k = min(top, len(ll))
        keyword[i] = ' '.join(ll[:k])
    return keyword

def get_csv_examples(input_path, size=-1):
    filename = osp.basename(input_path)
    data_dir = osp.dirname(input_path)
    logger.info(f"Data {filename} in {data_dir}")
    if size == -1:
        # passing na_filter=False can improve the performance of reading a large file.
        data_df = pd.read_csv(os.path.join(data_dir, filename), na_filter= False)

        return create_examples(data_df)
    else:
        data_df = pd.read_csv(os.path.join(data_dir, filename))
        return create_examples(data_df.sample(size))

def create_examples(df):
    """Creates examples for the training and dev sets."""
    texts, labels = [], []
    for (i, row) in enumerate(df.values):
        # row 0, 1, 2 ->  id, text, label list
        texts.append(row[1].strip())
        labels.append(row[2].split(','))
    return texts, labels

def convert_labels(labels, label_list):
    label_map = {label : i for i, label in enumerate(label_list)}
    label_ids = []
    for ll in tqdm(labels):
        label_id = [label_map[item] for item in ll]
        label_ids.append(sorted(label_id))
    return label_ids, len(label_list)

def convert_features(texts, tokenizer, max_seq_length=512, keyword=None):
    features = []
    # pool = ThreadPool(mp.cpu_count())
    # features = [pool.apply(tokenizer,
    #                        kwds={'text': text, 'padding':True, 'truncation': True, 'max_length': max_seq_length})
    #             for text in tqdm(texts)]
    if keyword is not None:
        assert len(keyword) == len(texts), 'keyword text mismatch'
    for (idx, text) in tqdm(enumerate(texts), desc='tokenize feature'):
        # input_ids, token_type_ids, attention_mask
        if keyword is None:
            encoding = tokenizer(text.strip(), padding='max_length', truncation=True, max_length=max_seq_length,
                                 return_token_type_ids=True)
        else:
            encoding = tokenizer(keyword[idx].strip(), text.strip(), padding='max_length', truncation=True, max_length=max_seq_length,
                             return_token_type_ids=True)
        if idx < 5:
            logger.info("*** Example ***")
            logger.info(f"ins id: {idx}")
            tokens = tokenizer.decode(encoding["input_ids"])
            logger.info(f"tokens: {tokens}")
            # logger.info(f"ids: {' '.join([str(x) for x in encoding['input_ids']])}")
            # logger.info(f"types: {' '.join([str(x) for x in encoding['token_type_ids']])}")
            # logger.info(f"attention mask: {' '.join([str(x) for x in encoding['attention_mask']])}")
        features.append(encoding)
    return features

def process_label(args, data_type='train'):
    data_path = osp.join(args.data_dir, f'{data_type}.csv')
    texts, labels = get_csv_examples(data_path)
    label_list_path = osp.join(args.data_dir, f'labels.txt')
    label_list = get_labels(label_list_path)
    labels, n_label = convert_labels(labels, label_list)
    np.save(osp.join(args.data_dir, f'label.{data_type}.full.npy'), [labels, n_label])

def tokenize_data(args, data_type):
    if args.keyword:
        cached_data_file = os.path.join(args.data_dir,
                        f'{data_type}.{args.model_name}.{args.max_seq_length}.{args.feature_name}_{args.top_keyword}')
    else:
        cached_data_file = os.path.join(args.data_dir,
                                        f'{data_type}.{args.model_name}.{args.max_seq_length}')

    if osp.exists(cached_data_file) and not args.overwrite:
        logger.info(f"{data_type} exists, skip")
        return

    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)
    # path and load data
    data_path = osp.join(args.data_dir, f'{data_type}.csv')
    texts, _ = get_csv_examples(data_path)

    if args.keyword:
        keyword_path = osp.join(args.data_dir, f'x.{data_type}.key.{args.feature_name}.txt')
        keyword = load_keyword(keyword_path, top=args.top_keyword)
    else:
        keyword = None
    features = convert_features(texts, tokenizer, max_seq_length=args.max_seq_length, keyword=keyword)
    # save trn features
    logger.info(f'save data to {cached_data_file}')
    torch.save(features, cached_data_file)

def tokenize_keyword(args, data_type):
    cached_data_file = os.path.join(args.data_dir,
                                    f'{data_type}.{args.model_name}.{args.max_seq_length}.keyword_only.{args.feature_name}')

    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)
    # path and load data
    keyword_path = osp.join(args.data_dir, f'x.{data_type}.key.{args.feature_name}.txt')
    keyword = load_keyword(keyword_path, top=args.top_keyword)
    # reuse convert_features function, replace text with keyword
    features = convert_features(keyword, tokenizer, max_seq_length=args.max_seq_length, keyword=None)
    # save trn features
    logger.info(f'save data to {cached_data_file}')
    torch.save(features, cached_data_file)

def tokenize_shuffle_data(args, data_type):

    cached_data_file = os.path.join(args.data_dir,
                                        f'{data_type}.{args.model_name}.{args.max_seq_length}.shuffle')

    if osp.exists(cached_data_file) and not args.overwrite:
        logger.info(f"{data_type} exists, skip")
        return


    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)
    # path and load data
    data_path = osp.join(args.data_dir, f'{data_type}.csv')
    #keyword_path = osp.join(args.data_dir, f'x.{data_type}.key.{args.feature_name}.txt')
    texts, _ = get_csv_examples(data_path)
    #keyword = load_keyword(keyword_path, top=args.top_keyword)
    keyword = None

    # shuffle
    for i, text in enumerate(texts):
        text = np.array(text.split(' ')[:args.max_seq_length])
        l = len(text)
        ids = np.arange(l)
        np.random.shuffle(ids)
        if i < 5:
            logger.info(f"text: {' '.join(text)}")
            logger.info(f"shuffled: {' '.join(text[ids])}")
        texts[i] = ' '.join(text[ids])

    features = convert_features(texts, tokenizer, max_seq_length=args.max_seq_length, keyword=keyword)
    # save trn features
    logger.info(f'save data to {cached_data_file}')
    torch.save(features, cached_data_file)

def main(args):
    args.model_type = args.model_type.lower()
    logger.info(f"model type: {args.model_type}, model name: {args.model_name}, tokenizer_name: {args.tokenizer_name}")

    if args.do_label:
        process_label(args, data_type='train')
        process_label(args, data_type='dev')

    #if args.do_train: # hide train dev logic
    for mode in ['train', 'dev']:
        if args.keyword_only:
            logger.info(f'keyword only with top {args.top_keyword}')
            tokenize_keyword(args, data_type=mode)
        else:
            logger.info(f'process doc, keyword: {args.keyword}, top: {args.top_keyword}')
            if args.shuffle:
                tokenize_shuffle_data(args, data_type=mode)
            else:
                tokenize_data(args, data_type=mode)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", default=None, type=str, required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--model_type", default=None, type=str, required=True,
                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument("--model_name", default=None, type=str, required=True)
    parser.add_argument("--tokenizer_name", default=None, type=str,
                        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument("--cache_dir", default="", type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length", default=128, type=int,
                        help="The maximum total input sequence length after tokenization. Sequences longer "
                             "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument('--feature_name', default='tfidf', type=str)
    parser.add_argument("--do_train", action='store_true')
    parser.add_argument("--do_label", action='store_true')
    parser.add_argument("--do_eval", action='store_true')

    # keyword
    parser.add_argument('--keyword', action='store_true')
    parser.add_argument('--top_keyword', type=int, default=None)
    parser.add_argument('--overwrite', action='store_true')
    parser.add_argument('--keyword_only', action='store_true')
    parser.add_argument('--shuffle', action='store_true')

    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")
    args = parser.parse_args()
    if args.tokenizer_name is None:
        args.tokenizer_name = args.model_name
    print(args)
    main(args)





























