import string, csv, os, re, urllib.request, glob
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from scipy import sparse
from scipy.io import savemat, loadmat
import pickle
import ssl

ssl._create_default_https_context = ssl._create_unverified_context


def read_csv(filename):
    X = []
    Y = []
    with open(filename) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        for row in csv_reader:
            X += [row[2]]
            Y += [int(row[0])]
    return (X, Y)


def write_csv(filename, X, Y=None, unique=False):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile, delimiter=',')
        if Y is not None:
            assert (len(X) == len(Y))
            writer.writerow(["Instance", "Label"])
            writer.writerows(map(lambda ind: [X[ind], Y[ind]], range(len(X))))
        else:
            writer.writerow(["Instance"])
            writer.writerows(map(lambda x: [x], X))
    return

def read_txt_from(folder_path):
    X = []
    for filename in glob.glob(folder_path+'/*.txt'):
        with open(filename) as f:
            doc = str(f.read()).strip('<br />')
            X.append(doc)
    return X


## Make data folder
if (not os.path.isdir("data")):
    os.mkdir("data")

'''
urllib.request.urlretrieve(
    "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv",
    "data/temp_train.csv")
X_train, Y_train = read_csv("data/temp_train.csv")
os.remove("data/temp_train.csv")

urllib.request.urlretrieve("https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv",
                           "data/temp_test.csv")
X_test, Y_test = read_csv("data/temp_test.csv")
os.remove("data/temp_test.csv")
'''
X_train_pos = read_txt_from('aclImdb/train/pos')
Y_train_pos = [1]*len(X_train_pos)
X_train_neg = read_txt_from('aclImdb/train/neg')
Y_train_neg = [0]*len(X_train_neg)
X_train, Y_train = X_train_pos + X_train_neg, Y_train_pos + Y_train_neg

X_test_pos = read_txt_from('aclImdb/test/pos')
Y_test_pos = [1]*len(X_train_pos)
X_test_neg = read_txt_from('aclImdb/test/neg')
Y_test_neg = [0]*len(X_train_neg)
X_test, Y_test = X_test_pos + X_test_neg, Y_test_pos + Y_test_neg

X_unlabeled = read_txt_from('aclImdb/train/unsup')


def contains_punctuation(w):
    return any(char in string.punctuation for char in w)


def contains_numeric(w):
    return any(char.isdigit() for char in w)


def pre_process(docs):
    docs = [re.findall(r'''[\w']+|[.,!?;-~{}`´_<=>:/@*()&'$%#"]''', docs[doc]) for doc in range(len(docs))]
    docs = [[w.lower() for w in docs[doc] if not contains_punctuation(w)] for doc in range(len(docs))]
    docs = [[w for w in docs[doc] if not contains_numeric(w)] for doc in range(len(docs))]
    docs = [[w for w in docs[doc] if len(w) > 1] for doc in range(len(docs))]
    docs = [" ".join(docs[doc]) for doc in range(len(docs))]
    return (docs)


X_train = pre_process(X_train)
X_test = pre_process(X_test)
X_unlabeled = pre_process(X_unlabeled)

## Filter vocabulary
all_documents = X_train + X_test + X_unlabeled

print('counting document frequency of words...')
cvectorizer = CountVectorizer(min_df=10, max_df=1.0, stop_words=None)
cvz = cvectorizer.fit_transform(all_documents).sign()
print(len(cvectorizer.vocabulary_))



def re_process(docs, vocab, thresh):
    docs = [str.split(docs[doc]) for doc in range(len(docs))]
    docs = [[w for w in docs[doc] if w in vocab] for doc in range(len(docs))]
    eliminated = [i for i in range(len(docs)) if len(docs[i]) < thresh]
    docs = [doc for doc in docs if len(doc) >= thresh]
    docs = [" ".join(docs[doc]) for doc in range(len(docs))]
    return (docs, eliminated)


X_train, eliminated_train = re_process(X_train, cvectorizer.vocabulary_, 4)
X_test, eliminated_test = re_process(X_test, cvectorizer.vocabulary_, 4)
X_unlabeled, eliminated_unlabeled = re_process(X_unlabeled, cvectorizer.vocabulary_, 4)

print(len(eliminated_unlabeled))

Y_train = [Y_train[i] for i in range(len(Y_train)) if i not in eliminated_train]
Y_test = [Y_test[i] for i in range(len(Y_test)) if i not in eliminated_test]

print(len(X_train), len(Y_train))
print(len(X_test), len(Y_test))

Y = np.array(Y_train)

# split train data into train and unsupervised
total_inds = []
for k in range(0, 2):
    label_inds = np.where(Y == k)[0]
    inds = list(np.random.choice(label_inds, 1000, replace=False))
    total_inds.extend(inds)

X_train_to_write, Y_train_to_write = [], []
for i in total_inds:
    X_train_to_write.append(X_train[i])
    Y_train_to_write.append(Y_train[i])

used_inds = set(total_inds)
X_unsupervised = []
for i in range(len(Y_train)):
    if i not in used_inds:
        X_unsupervised.append(X_train[i])

# split test data into train and unsupervised
total_inds = []
for k in range(0, 2):
    label_inds = np.where(Y == k)[0]
    inds = list(np.random.choice(label_inds, 3500, replace=False))
    total_inds.extend(inds)

X_test_to_write, Y_test_to_write = [], []
for i in total_inds:
    X_test_to_write.append(X_test[i])
    Y_test_to_write.append(Y_test[i])

used_inds = set(total_inds)
for i in range(len(Y_test)):
    if i not in used_inds:
        X_unsupervised.append(X_test[i])


X_unsupervised = X_unsupervised + X_unlabeled

print('num unsup:', len(X_unsupervised))
print('num train:', len(X_train_to_write))
print('num test:', len(X_test_to_write))

write_csv("data/data_imdb/train.csv", X_train_to_write, Y=Y_train_to_write)
write_csv("data/data_imdb/test.csv", X_test_to_write, Y=Y_test_to_write)
write_csv("data/data_imdb/unsupervised.csv", X_unsupervised)

# Get vocabulary
X_train = X_train_to_write
Y_train = Y_train_to_write

X_test = X_test_to_write
Y_test = Y_test_to_write

## Break X_unsupervised into X_unsupervised (main) + X_valid (validation) -- 90/10 split
X = X_unsupervised

inds = np.random.choice(len(X), int(0.1 * len(X)), replace=False)
X_valid = [X[i] for i in inds]

used_inds = set(inds.tolist())
X_unsupervised = [X[i] for i in range(len(X)) if i not in used_inds]

sum_counts = cvz.sum(axis=0)
v_size = sum_counts.shape[1]
sum_counts_np = np.zeros(v_size, dtype=int)
for v in range(v_size):
    sum_counts_np[v] = sum_counts[0, v]
word2id = dict([(w, cvectorizer.vocabulary_.get(w)) for w in cvectorizer.vocabulary_])
id2word = dict([(cvectorizer.vocabulary_.get(w), w) for w in cvectorizer.vocabulary_])

indexed_train = [[word2id[word] for word in doc.split()] for doc in X_train]
indexed_test = [[word2id[word] for word in doc.split()] for doc in X_test]
indexed_unsup = [[word2id[word] for word in doc.split()] for doc in X_unsupervised]
indexed_valid = [[word2id[word] for word in doc.split()] for doc in X_valid]


def create_list_words(in_docs):
    return [x for y in in_docs for x in y]


words_train = create_list_words(indexed_train)
words_test = create_list_words(indexed_test)
words_unsup = create_list_words(indexed_unsup)
words_valid = create_list_words(indexed_valid)


def create_doc_indices(in_docs):
    aux = [[j for i in range(len(doc))] for j, doc in enumerate(in_docs)]
    return [int(x) for y in aux for x in y]


doc_indices_train = create_doc_indices(indexed_train)
doc_indices_test = create_doc_indices(indexed_test)
doc_indices_unsup = create_doc_indices(indexed_unsup)
doc_indices_valid = create_doc_indices(indexed_valid)

# Number of documents in each set
n_docs_train = len(indexed_train)
n_docs_test = len(indexed_test)
n_docs_unsup = len(indexed_unsup)
n_docs_valid = len(indexed_valid)


def create_bow(doc_indices, words, n_docs, vocab_size):
    return sparse.coo_matrix(([1] * len(doc_indices), (doc_indices, words)), shape=(n_docs, vocab_size)).tocsr()


bow_train = create_bow(doc_indices_train, words_train, n_docs_train, len(cvectorizer.vocabulary_))
bow_test = create_bow(doc_indices_test, words_test, n_docs_test, len(cvectorizer.vocabulary_))
bow_unsup = create_bow(doc_indices_unsup, words_unsup, n_docs_unsup, len(cvectorizer.vocabulary_))
bow_valid = create_bow(doc_indices_valid, words_valid, n_docs_valid, len(cvectorizer.vocabulary_))

# Split bow intro token/value pairs
print('splitting bow into token/value pairs and saving to disk...')
data_path = "data/data_imdb/"


def split_bow(bow_in, n_docs):
    indices = [[w for w in bow_in[doc, :].indices] for doc in range(n_docs)]
    counts = [[c for c in bow_in[doc, :].data] for doc in range(n_docs)]
    return indices, counts


bow_train_tokens, bow_train_counts = split_bow(bow_train, n_docs_train)
bow_test_tokens, bow_test_counts = split_bow(bow_test, n_docs_test)
bow_unsup_tokens, bow_unsup_counts = split_bow(bow_unsup, n_docs_unsup)
bow_valid_tokens, bow_valid_counts = split_bow(bow_valid, n_docs_valid)

savemat(os.path.join(data_path, 'bow_train_tokens'), {'tokens': bow_train_tokens}, do_compression=True)
savemat(os.path.join(data_path, 'bow_train_counts'), {'counts': bow_train_counts}, do_compression=True)
savemat(os.path.join(data_path, 'bow_test_tokens'), {'tokens': bow_test_tokens}, do_compression=True)
savemat(os.path.join(data_path, 'bow_test_counts'), {'counts': bow_test_counts}, do_compression=True)
savemat(os.path.join(data_path, 'bow_unsup_tokens'), {'tokens': bow_unsup_tokens}, do_compression=True)
savemat(os.path.join(data_path, 'bow_unsup_counts'), {'counts': bow_unsup_counts}, do_compression=True)
savemat(os.path.join(data_path, 'bow_valid_tokens'), {'tokens': bow_valid_tokens}, do_compression=True)
savemat(os.path.join(data_path, 'bow_valid_counts'), {'counts': bow_valid_counts}, do_compression=True)

vocab = [id2word[cc] for cc in range(v_size)]
with open(os.path.join(data_path, 'vocab.pkl'), 'wb') as f:
    pickle.dump(vocab, f)

with open(os.path.join(data_path, 'id2word.pkl'), 'wb') as f:
    pickle.dump(id2word, f)