#!/usr/bin/env python
# encoding: utf-8

import argparse
from collections import Counter
import itertools
import json
import os, os.path as osp
from logzero import logger
import numpy as np
import pickle
import time
from tqdm import tqdm
import scipy as sp
import scipy.sparse as smat
from sklearn.preprocessing import normalize
import pandas as pd
import torch
import multiprocessing as mp
from multiprocessing import Pool
import subprocess
import shlex
from multiprocessing.pool import ThreadPool

from transformers import (
    WEIGHTS_NAME,
    BertConfig,
    BertForSequenceClassification,
    BertTokenizer,
    RobertaConfig,
    RobertaForSequenceClassification,
    RobertaTokenizer,
    XLMConfig,
    XLMForSequenceClassification,
    XLMTokenizer,
    XLNetConfig,
    XLNetForSequenceClassification,
    XLNetTokenizer,
    DistilBertConfig,
    DistilBertForSequenceClassification,
    DistilBertTokenizer,
    AlbertConfig,
    AlbertForSequenceClassification,
    AlbertTokenizer,
)
from sklearn.model_selection import KFold

# ALL_MODELS = sum(
#     (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig,)),
#     (),
# )

MODEL_CLASSES = {
    "bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
    "xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
    "xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
    "roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
    "distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer,),
    "albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
}

def get_labels(file_name):
    labels = []
    file = open(file_name, 'r')
    lines = file.readlines()
    file.close()
    for line in lines:
        label = line.strip()
        labels.append(label)
    return labels

def load_data(text_path):
    with open(text_path) as fp:
        texts = fp.readlines()
    return texts

def load_keyword(keyword_path, top=None):
    if args.keyword is None:
        return None
    keyword = load_data(keyword_path)
    for i, key in enumerate(keyword):
        ll = key.split()
        if top is None or top < 0: # use all keywords
            k = len(ll)
        else:
            k = min(top, len(ll))
        keyword[i] = ' '.join(ll[:k])
    return keyword

def get_csv_examples(input_path, size=-1):
    filename = osp.basename(input_path)
    data_dir = osp.dirname(input_path)
    logger.info(f"Data {filename} in {data_dir}")
    if size == -1:
        # passing na_filter=False can improve the performance of reading a large file.
        data_df = pd.read_csv(os.path.join(data_dir, filename), na_filter= False)

        return create_examples(data_df)
    else:
        data_df = pd.read_csv(os.path.join(data_dir, filename))
        return create_examples(data_df.sample(size))

def create_examples(df):
    """Creates examples for the training and dev sets."""
    texts, labels = [], []
    for (i, row) in enumerate(df.values):
        # row 0, 1, 2 ->  id, text, label list
        texts.append(row[1].strip())
        labels.append(row[2].split(','))
    return texts, labels

def convert_labels(labels, label_list):
    label_map = {label : i for i, label in enumerate(label_list)}
    label_ids = []
    for ll in tqdm(labels):
        label_id = [label_map[item] for item in ll]
        label_ids.append(sorted(label_id))
    return label_ids, len(label_list)

def convert_features(texts, tokenizer, max_seq_length=512, keyword=None):
    features = []
    # pool = ThreadPool(mp.cpu_count())
    # features = [pool.apply(tokenizer,
    #                        kwds={'text': text, 'padding':True, 'truncation': True, 'max_length': max_seq_length})
    #             for text in tqdm(texts)]
    if keyword is not None:
        assert len(keyword) == len(texts), 'keyword text mismatch'
    for (idx, text) in tqdm(enumerate(texts), desc='tokenize feature'):
        # input_ids, token_type_ids, attention_mask
        if keyword is None:
            encoding = tokenizer(text.strip(), padding='max_length', truncation=True, max_length=max_seq_length,
                                 return_token_type_ids=True)
        else:
            encoding = tokenizer(keyword[idx].strip(), text.strip(), padding='max_length', truncation=True, max_length=max_seq_length,
                             return_token_type_ids=True)
        if idx < 5:
            logger.info("*** Example ***")
            logger.info(f"ins id: {idx}")
            tokens = tokenizer.decode(encoding["input_ids"])
            logger.info(f"tokens: {tokens}")
            # logger.info(f"ids: {' '.join([str(x) for x in encoding['input_ids']])}")
            # logger.info(f"types: {' '.join([str(x) for x in encoding['token_type_ids']])}")
            # logger.info(f"attention mask: {' '.join([str(x) for x in encoding['attention_mask']])}")
        features.append(encoding)
    return features

def process_label(args, data_type='train'):
    data_path = osp.join(args.data_dir, f'{data_type}.csv')
    texts, labels = get_csv_examples(data_path)
    label_list_path = osp.join(args.data_dir, f'labels.txt')
    label_list = get_labels(label_list_path)
    labels, n_label = convert_labels(labels, label_list)
    np.save(osp.join(args.data_dir, f'label.{data_type}.full.npy'), [labels, n_label])

def tokenize_data(args, data_type):
    if args.keyword:
        cached_data_file = os.path.join(args.data_dir,
                                        f'{data_type}.{args.model_name}.{args.max_seq_length}.{args.feature_name}')
    else:
        cached_data_file = os.path.join(args.data_dir,
                                        f'{data_type}.{args.model_name}.{args.max_seq_length}')

    if osp.exists(cached_data_file) and not args.overwrite:
        logger.info(f"{data_type} exists, skip")
        return

    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)
    # path and load data
    data_path = osp.join(args.data_dir, f'{data_type}.csv')
    texts, _ = get_csv_examples(data_path)

    if args.keyword:
        keyword_path = osp.join(args.data_dir, f'x.{data_type}.key.{args.feature_name}.txt')
        keyword = load_keyword(keyword_path, top=args.top_keyword)
    else:
        keyword = None
    features = convert_features(texts, tokenizer, max_seq_length=args.max_seq_length, keyword=keyword)
    # save trn features
    logger.info(f'save data to {cached_data_file}')
    torch.save(features, cached_data_file)

def tokenize_keyword(args, data_type):
    cached_data_file = os.path.join(args.data_dir,
                                    f'{data_type}.{args.feature_name}.{args.max_seq_length}')

    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)
    # path and load data
    keyword_path = osp.join(args.data_dir, f'x.{data_type}.key.{args.feature_name}.txt')
    keyword = load_keyword(keyword_path, top=args.top_keyword)
    # reuse convert_features function, replace text with keyword
    features = convert_features(keyword, tokenizer, max_seq_length=args.max_seq_length, keyword=None)
    # save trn features
    logger.info(f'save data to {cached_data_file}')
    torch.save(features, cached_data_file)



def main(args):
    args.model_type = args.model_type.lower()
    logger.info(f"model type: {args.model_type}, model name: {args.model_name}, tokenizer_name: {args.tokenizer_name}")

    # feature
    if args.keyword:
        feature_file = os.path.join(args.data_dir,
                                        f'train.{args.model_name}.{args.max_seq_length}.{args.feature_name}')
    else:
        feature_file = os.path.join(args.data_dir,
                                        f'train.{args.model_name}.{args.max_seq_length}')

    features = torch.load(feature_file)
    n = len(features)
    idx = np.arange(n)
    kf = KFold(n_splits=args.fold, shuffle=True, random_state=2)
    label_path = osp.join(args.data_dir, f'label.train.full.npy')
    labels, n_label = np.load(label_path, allow_pickle=True)
    labels = np.array(labels)
    os.makedirs(args.output_dir, exist_ok=True)
    for i, (train_idx, dev_idx) in enumerate(kf.split(idx)):
        f_train_i = [features[j] for j in train_idx]
        f_dev_i = [features[j] for j in dev_idx]
        l_train_i = labels[train_idx]
        l_dev_i = labels[dev_idx]
        if i == 0:
            logger.info(f'split train size: {len(train_idx)}, dev size: {len(dev_idx)}')
        f_train_i_path = osp.join(args.output_dir, f'train.{args.model_name}.{args.max_seq_length}')
        f_dev_i_path = osp.join(args.output_dir, f'dev.{args.model_name}.{args.max_seq_length}')
        l_train_i_path = osp.join(args.output_dir, f'label.train.full.split{i}-{args.fold}.npy')
        l_dev_i_path = osp.join(args.output_dir, f'label.dev.full.split{i}-{args.fold}.npy')
        if args.keyword:
            f_train_i_path += f'.{args.feature_name}'
            f_dev_i_path += f'.{args.feature_name}'
        f_train_i_path += f'.split{i}-{args.fold}'
        f_dev_i_path += f'.split{i}-{args.fold}'
        logger.info(f"save split {i}-{args.fold}")
        torch.save(f_train_i, f_train_i_path)
        torch.save(f_dev_i, f_dev_i_path)
        np.save(l_train_i_path, [l_train_i, n_label])
        np.save(l_dev_i_path, [l_dev_i, n_label])


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", default=None, type=str, required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--model_type", default=None, type=str, required=True,
                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument("--model_name", default=None, type=str, required=True)
    parser.add_argument("--output_dir", required=True, type=str)

    parser.add_argument("--tokenizer_name", default=None, type=str,
                        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument("--cache_dir", default="", type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length", default=128, type=int,
                        help="The maximum total input sequence length after tokenization. Sequences longer "
                             "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument('--feature_name', default='tfidf', type=str)

    # keyword
    parser.add_argument('--keyword', action='store_true')
    parser.add_argument('--top_keyword', type=int, default=None)
    parser.add_argument('--overwrite', action='store_true')
    parser.add_argument('--keyword_only', action='store_true')
    parser.add_argument('--shuffle', action='store_true')

    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--fold", type=int, default=5,
                        help="Set this flag if you are using an uncased model.")

    args = parser.parse_args()
    if args.tokenizer_name is None:
        args.tokenizer_name = args.model_name
    print(args)
    main(args)





























