import string, csv, os, re, urllib.request, glob
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from scipy import sparse
from scipy.io import savemat, loadmat
import pickle
from datasets import load_dataset


def write_csv(filename, X, Y=None, unique=False):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile, delimiter=',')
        if Y is not None:
            assert (len(X) == len(Y))
            writer.writerow(["Instance", "Label"])
            writer.writerows(map(lambda ind: [X[ind], Y[ind]], range(len(X))))
        else:
            writer.writerow(["Instance"])
            writer.writerows(map(lambda x: [x], X))
    return


def read_document_from(ds):
    X, Y = [], []
    for i in range(len(ds)):
        X.append(ds[i]['content'])
        Y.append(ds[i]['label'])
    return X, Y


## Make data folder
if (not os.path.isdir("data")):
    os.mkdir("data")

# Get DBpedia ontology dataset
dataset = load_dataset("dbpedia_14")
X_train, Y_train = read_document_from(dataset['train'])
X_test, Y_test = read_document_from(dataset['test'])


def contains_punctuation(w):
    return any(char in string.punctuation for char in w)


def contains_numeric(w):
    return any(char.isdigit() for char in w)


def pre_process(docs):
    docs = [re.findall(r'''[\w']+|[.,!?;-~{}`´_<=>:/@*()&'$%#"]''', docs[doc]) for doc in range(len(docs))]
    docs = [[w.lower() for w in docs[doc] if not contains_punctuation(w)] for doc in range(len(docs))]
    docs = [[w for w in docs[doc] if not contains_numeric(w)] for doc in range(len(docs))]
    docs = [[w for w in docs[doc] if len(w) > 1] for doc in range(len(docs))]
    docs = [" ".join(docs[doc]) for doc in range(len(docs))]
    return (docs)


X_train = pre_process(X_train)
X_test = pre_process(X_test)

## Filter vocabulary
all_documents = X_train + X_test

print('counting document frequency of words...')
cvectorizer = CountVectorizer(min_df=35, max_df=1.0, stop_words=None)
cvz = cvectorizer.fit_transform(all_documents).sign()
print(len(cvectorizer.vocabulary_))


def re_process(docs, vocab, thresh):
    docs = [str.split(docs[doc]) for doc in range(len(docs))]
    docs = [[w for w in docs[doc] if w in vocab] for doc in range(len(docs))]
    eliminated = [i for i in range(len(docs)) if len(docs[i]) < thresh]
    docs = [doc for doc in docs if len(doc) >= thresh]
    docs = [" ".join(docs[doc]) for doc in range(len(docs))]
    return (docs, eliminated)


X_train, eliminated_train = re_process(X_train, cvectorizer.vocabulary_, 4)
X_test, eliminated_test = re_process(X_test, cvectorizer.vocabulary_, 4)

Y_train = [Y_train[i] for i in range(len(Y_train)) if i not in eliminated_train]
Y_test = [Y_test[i] for i in range(len(Y_test)) if i not in eliminated_test]

print(len(X_train), len(Y_train))
print(len(X_test), len(Y_test))

Y = np.array(Y_train)
total_inds = []
for k in range(0, 14):
    label_inds = np.where(Y == k)[0]
    inds = list(np.random.choice(label_inds, 1000, replace=False))
    total_inds.extend(inds)

X_train_to_write, Y_train_to_write = [], []
for i in total_inds:
    X_train_to_write.append(X_train[i])
    Y_train_to_write.append(Y_train[i])

used_inds = set(total_inds)
X_unsupervised = []
for i in range(len(Y_train)):
    if i not in used_inds:
        X_unsupervised.append(X_train[i])

write_csv("data/data_dbpedia/train.csv", X_train_to_write, Y=Y_train_to_write)
write_csv("data/data_dbpedia/test.csv", X_test, Y=Y_test)
write_csv("data/data_dbpedia/unsupervised.csv", X_unsupervised)

# Get vocabulary
X_train = X_train_to_write
Y_train = Y_train_to_write

## Break X_unsupervised into X_unsupervised (main) + X_valid (validation) -- 90/10 split
X = X_unsupervised

inds = np.random.choice(len(X), int(0.1 * len(X)), replace=False)
X_valid = [X[i] for i in inds]

used_inds = set(inds.tolist())
X_unsupervised = [X[i] for i in range(len(X)) if i not in used_inds]

sum_counts = cvz.sum(axis=0)
v_size = sum_counts.shape[1]
sum_counts_np = np.zeros(v_size, dtype=int)
for v in range(v_size):
    sum_counts_np[v] = sum_counts[0, v]
word2id = dict([(w, cvectorizer.vocabulary_.get(w)) for w in cvectorizer.vocabulary_])
id2word = dict([(cvectorizer.vocabulary_.get(w), w) for w in cvectorizer.vocabulary_])

indexed_train = [[word2id[word] for word in doc.split()] for doc in X_train]
indexed_test = [[word2id[word] for word in doc.split()] for doc in X_test]
indexed_unsup = [[word2id[word] for word in doc.split()] for doc in X_unsupervised]
indexed_valid = [[word2id[word] for word in doc.split()] for doc in X_valid]


def create_list_words(in_docs):
    return [x for y in in_docs for x in y]


words_train = create_list_words(indexed_train)
words_test = create_list_words(indexed_test)
words_unsup = create_list_words(indexed_unsup)
words_valid = create_list_words(indexed_valid)


def create_doc_indices(in_docs):
    aux = [[j for i in range(len(doc))] for j, doc in enumerate(in_docs)]
    return [int(x) for y in aux for x in y]


doc_indices_train = create_doc_indices(indexed_train)
doc_indices_test = create_doc_indices(indexed_test)
doc_indices_unsup = create_doc_indices(indexed_unsup)
doc_indices_valid = create_doc_indices(indexed_valid)

# Number of documents in each set
n_docs_train = len(indexed_train)
n_docs_test = len(indexed_test)
n_docs_unsup = len(indexed_unsup)
n_docs_valid = len(indexed_valid)


def create_bow(doc_indices, words, n_docs, vocab_size):
    return sparse.coo_matrix(([1] * len(doc_indices), (doc_indices, words)), shape=(n_docs, vocab_size)).tocsr()


bow_train = create_bow(doc_indices_train, words_train, n_docs_train, len(cvectorizer.vocabulary_))
bow_test = create_bow(doc_indices_test, words_test, n_docs_test, len(cvectorizer.vocabulary_))
bow_unsup = create_bow(doc_indices_unsup, words_unsup, n_docs_unsup, len(cvectorizer.vocabulary_))
bow_valid = create_bow(doc_indices_valid, words_valid, n_docs_valid, len(cvectorizer.vocabulary_))

# Split bow intro token/value pairs
print('splitting bow into token/value pairs and saving to disk...')
data_path = "data/data_dbpedia"


def split_bow(bow_in, n_docs):
    indices = [[w for w in bow_in[doc, :].indices] for doc in range(n_docs)]
    counts = [[c for c in bow_in[doc, :].data] for doc in range(n_docs)]
    return indices, counts


bow_train_tokens, bow_train_counts = split_bow(bow_train, n_docs_train)
bow_test_tokens, bow_test_counts = split_bow(bow_test, n_docs_test)
bow_unsup_tokens, bow_unsup_counts = split_bow(bow_unsup, n_docs_unsup)
bow_valid_tokens, bow_valid_counts = split_bow(bow_valid, n_docs_valid)

savemat(os.path.join(data_path, 'bow_train_tokens'), {'tokens': bow_train_tokens}, do_compression=True)
savemat(os.path.join(data_path, 'bow_train_counts'), {'counts': bow_train_counts}, do_compression=True)
savemat(os.path.join(data_path, 'bow_test_tokens'), {'tokens': bow_test_tokens}, do_compression=True)
savemat(os.path.join(data_path, 'bow_test_counts'), {'counts': bow_test_counts}, do_compression=True)
savemat(os.path.join(data_path, 'bow_unsup_tokens'), {'tokens': bow_unsup_tokens}, do_compression=True)
savemat(os.path.join(data_path, 'bow_unsup_counts'), {'counts': bow_unsup_counts}, do_compression=True)
savemat(os.path.join(data_path, 'bow_valid_tokens'), {'tokens': bow_valid_tokens}, do_compression=True)
savemat(os.path.join(data_path, 'bow_valid_counts'), {'counts': bow_valid_counts}, do_compression=True)

vocab = [id2word[cc] for cc in range(v_size)]
with open(os.path.join(data_path, 'vocab.pkl'), 'wb') as f:
    pickle.dump(vocab, f)

with open(os.path.join(data_path, 'id2word.pkl'), 'wb') as f:
    pickle.dump(id2word, f)
