import re
import json
from collections import Counter

import numpy as np
from scipy import sparse

from common import file_handling as fh
from common.vocab import extract_vocab_params, convert_to_ngrams


def load_data(partition_file):

    partition = fh.read_json(partition_file)

    train_path = partition['train_file']
    dev_path = partition['dev_file']
    test_path = partition['test_file']

    train_indices = set(partition['train_indices'])
    if 'dev_indices' in partition:
        if partition['dev_indices'] is None:
            dev_indices = []
        else:
            dev_indices = partition['dev_indices']
    else:
        dev_indices = []
    if 'test_indices' in partition:
        if partition['test_indices'] is None:
            test_indices = []
        else:
            test_indices = set(partition['test_indices'])
    else:
        test_indices = []

    train_docs = []
    dev_docs = []
    test_docs = []
    with open(train_path) as f:
        for i, line in enumerate(f):
            line = json.loads(line)
            line['_i'] = 'tr_' + str(i)
            if i in train_indices:
                train_docs.append(line)
            if dev_path == train_path and i in dev_indices:
                dev_docs.append(line)
            if test_path == train_path and i in test_indices:
                test_docs.append(line)

    if dev_path != train_path:
        dev_docs = load_subset(dev_path, dev_indices, 'dev')

    if test_path != train_path:
        test_docs = load_subset(test_path, test_indices, 'test')

    return train_docs, dev_docs, test_docs


def load_side_data(partition_file, side_data_file, side_data_field='data', side_data_col_names_field=None):

    partition = fh.read_json(partition_file)

    train_indices = partition['train_indices']
    dev_indices = partition['dev_indices']
    test_indices = partition['test_indices']

    data = np.load(side_data_file)
    side_data = data[side_data_field]
    n_rows, n_cols = side_data.shape
    if side_data_col_names_field is None:
        col_names = ['_s_' + str(k) for k in range(n_cols)]
    else:
        col_names = data[side_data_col_names_field]

    side_train = side_data[np.array(train_indices, dtype=int), :]
    if dev_indices is not None:
        side_dev = side_data[np.array(dev_indices, dtype=int), :]
    else:
        side_dev = None
    if test_indices is not None:
        side_test = side_data[np.array(test_indices, dtype=int), :]
    else:
        side_test = None

    return side_train, side_dev, side_test, col_names


def load_subset(path, indices=None, subset='na'):
    docs = []
    with open(path) as f:
        for i, line in enumerate(f):
            line = json.loads(line)
            line['_i'] = subset + '_' + str(i)
            if indices is None or i in indices:
                docs.append(line)
    return docs


def encode_documents_as_bow(documents, vocab, config, idf=None, confounders=None, confound_matrix=None, side_data=None, truncate_feda=10):
    ngram_level, _, _, transform, lower, digits, require_alpha = extract_vocab_params(config)

    dataset_reader = config["dataset_reader"]
    tokens_field_name = dataset_reader['tokens_field_name']
    weight_field_name = dataset_reader['weight_field_name']
    feda = dataset_reader['feda']

    vocab_size = len(vocab)
    vocab_index = dict(zip(vocab, range(vocab_size)))

    n_docs = len(documents)

    ids = []
    orig_indices = []

    n_features = vocab_size
    if confounders is not None:
        confounders_index = dict(zip(confounders, range(len(confounders))))
        n_confounders, n_labels = confound_matrix.shape
        #n_features += n_labels
        confounder_values = np.zeros((n_docs, n_confounders))
    else:
        n_confounders = 0
        confounders_index = {}
        confounder_values = None

    if side_data is not None:
        n_side_items, n_side_features = side_data.shape
        assert n_side_items == n_docs
        side_data_index = dict(zip(['_s_' + str(k) for k in range(n_side_features)] ,range(n_side_features)))
        n_features += n_side_features
    else:
        side_data_index = {}

    counts = sparse.lil_matrix((n_docs, n_features))
    weights = np.ones(n_docs)

    for i, doc in enumerate(documents):
        if 'id' in doc:
            ids.append(doc['id'])
        else:
            ids.append(doc['_i'])
        orig_indices.append(doc['_i'])
        text = doc[tokens_field_name]
        token_counts = Counter()
        for sentence in text:
            try:
                assert(type(sentence) == list)
            except AssertionError as e:
                print("Input tokens should be a list of lists, not a list of strings!")
                raise e
            if lower:
                sentence = [token.lower() for token in sentence]
            if digits:
                sentence = [re.sub(r'\d', '#', token) for token in sentence]
            sentence = [token if re.match(r'[a-zA-Z0-9#$!?%"]+', token) is not None else '_' for token in sentence]
            token_counts.update(sentence)
            if feda is not None:
                # create duplicate features a la frustratingly easy domain adaptation
                feda_value = doc[feda]
                decorated = [token + '__' + str(feda_value)[:truncate_feda] for token in sentence if re.match(r'[a-zA-Z0-9#$!?%&"]+', token) is not None]
                token_counts.update(decorated)
            for n in range(2, ngram_level+1):
                ngrams = convert_to_ngrams(sentence, n, require_alpha)
                token_counts.update(ngrams)
                if feda is not None:
                    feda_value = doc[feda]
                    decorated = [ngram + '__' + str(feda_value)[:truncate_feda] for ngram in ngrams if re.match(r'[a-zA-Z0-9#$!?%&"]+', ngram) is not None]
                    token_counts.update(decorated)
        if transform == 'binarize':
            index_count_pairs = {vocab_index[term]: 1 for term, count in token_counts.items() if term in vocab_index}
        else:
            index_count_pairs = {vocab_index[term]: count for term, count in token_counts.items() if term in vocab_index}

        if len(index_count_pairs) > 0:
            indices, item_counts = zip(*index_count_pairs.items())
            counts[i, indices] = item_counts

        # add label probabilities associated with confounder as the final columns in feature matrix
        if confounders is not None:
            confounder_val = doc[config['dataset_reader']['confounder']]
            confounder_values[i, confounders_index[confounder_val]] = 1.
            #counts[i, vocab_size:vocab_size + n_confounders] = confound_matrix[confounders_index[confounder_val], :]

        if side_data is not None:
            counts[i, vocab_size + n_confounders:] = side_data[i, :]

        if weight_field_name is not None:
            if weight_field_name in doc:
                weights[i] = doc[weight_field_name]

    if transform == 'tfidf':
        if idf is None:
            print("Computing idf")
            idf = float(n_docs) / (np.array(counts.sum(0)).reshape((n_features, )))
        counts = counts.multiply(idf)

    counts = counts.tocsr()
    return ids, orig_indices, counts, idf, weights, confounder_values


def encode_documents_as_seq(documents, vocab, config):
    ngram_level, _, _, transform, lower, digits, require_alpha = extract_vocab_params(config)
    assert ngram_level == 1

    dataset_reader = config["dataset_reader"]
    tokens_field_name = dataset_reader['tokens_field_name']

    vocab_size = len(vocab)
    vocab_index = dict(zip(vocab, range(vocab_size)))

    n_docs = len(documents)

    ids = []
    orig_indices = []
    token_indices = []

    for i, doc in enumerate(documents):
        if 'id' in doc:
            ids.append(doc['id'])
        else:
            ids.append(doc['_i'])
        orig_indices.append(doc['_i'])
        text = doc[tokens_field_name]
        doc_indices = []
        for sentence in text:
            if lower:
                sentence = [token.lower() for token in sentence]
            if digits:
                sentence = [re.sub(r'\d', '#', token) for token in sentence]
            ngrams = convert_to_ngrams(sentence, ngram_level, require_alpha)
            # add sent end token
            ngrams.append('__SEP__')
            sent_indices = [vocab_index[t] if t in vocab_index else vocab_index['__UNK__'] for t in ngrams]
            doc_indices.extend(sent_indices)
        token_indices.append(doc_indices)

    return ids, orig_indices, token_indices