import string, csv, os, re, urllib.request
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from scipy import sparse
from scipy.io import savemat, loadmat
import pickle
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
def read_csv(filename):
    X = []
    Y = []
    with open(filename) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        for row in csv_reader:
            X += [row[2]]
            Y += [int(row[0])]
    return(X, Y)

def write_csv(filename, X, Y=None, unique=False):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile, delimiter=',')
        if Y is not None:
            assert(len(X) == len(Y))
            writer.writerow(["Instance", "Label"])
            writer.writerows(map(lambda ind: [X[ind], Y[ind]], range(len(X))))
        else:
            writer.writerow(["Instance"])
            writer.writerows(map(lambda x: [x], X))
    return

## Make data folder
if(not os.path.isdir("data")):
    os.mkdir("data")

urllib.request.urlretrieve("https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv", "data/temp_train.csv")
X_train, Y_train = read_csv("data/temp_train.csv")
os.remove("data/temp_train.csv")

urllib.request.urlretrieve("https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv", "data/temp_test.csv")
X_test, Y_test = read_csv("data/temp_test.csv")
os.remove("data/temp_test.csv")


def contains_punctuation(w):
    return any(char in string.punctuation for char in w)

def contains_numeric(w):
    return any(char.isdigit() for char in w)

def pre_process(docs):
    docs = [re.findall(r'''[\w']+|[.,!?;-~{}`´_<=>:/@*()&'$%#"]''', docs[doc]) for doc in range(len(docs))]
    docs = [[w.lower() for w in docs[doc] if not contains_punctuation(w)] for doc in range(len(docs))]
    docs = [[w for w in docs[doc] if not contains_numeric(w)] for doc in range(len(docs))]
    docs = [[w for w in docs[doc] if len(w)>1] for doc in range(len(docs))]
    docs = [" ".join(docs[doc]) for doc in range(len(docs))]
    return(docs)

X_train = pre_process(X_train)
X_test = pre_process(X_test)

## Filter vocabulary
all_documents = X_train + X_test

print('counting document frequency of words...')
cvectorizer = CountVectorizer(min_df=10, max_df=1.0, stop_words=None)
cvz = cvectorizer.fit_transform(all_documents).sign()
print(len(cvectorizer.vocabulary_))


def re_process(docs, vocab, thresh):
    docs = [str.split(docs[doc]) for doc in range(len(docs))]
    docs = [[w for w in docs[doc] if w in vocab] for doc in range(len(docs))]
    eliminated = [i for i in range(len(docs)) if len(docs[i])<thresh]
    docs = [doc for doc in docs if len(doc)>=thresh]
    docs = [" ".join(docs[doc]) for doc in range(len(docs))]
    return(docs, eliminated)

X_train, eliminated_train = re_process(X_train, cvectorizer.vocabulary_, 4)
X_test, eliminated_test = re_process(X_test, cvectorizer.vocabulary_, 4)

Y_train = [Y_train[i] for i in range(len(Y_train)) if i not in eliminated_train]
Y_test = [Y_test[i] for i in range(len(Y_test)) if i not in eliminated_test]


print(len(X_train), len(Y_train))
print(len(X_test), len(Y_test))

Y = np.array(Y_train)
total_inds = []
for k in range(1,5):
    label_inds = np.where(Y==k)[0]
    inds = list(np.random.choice(label_inds, 1000, replace=False))
    total_inds.extend(inds)

X_train_to_write, Y_train_to_write = [], []
for i in total_inds:
    X_train_to_write.append(X_train[i])
    Y_train_to_write.append(Y_train[i])

used_inds = set(total_inds)
X_unsupervised = []
for i in range(len(Y_train)):
    if i not in used_inds:
        X_unsupervised.append(X_train[i])
        
write_csv("data/data_agnews/train.csv", X_train_to_write, Y=Y_train_to_write)
write_csv("data/data_agnews/test.csv", X_test, Y=Y_test)
write_csv("data/data_agnews/unsupervised.csv", X_unsupervised)


# Get vocabulary
X_train = X_train_to_write
Y_train = Y_train_to_write

## Break X_unsupervised into X_unsupervised (main) + X_valid (validation) -- 90/10 split
X = X_unsupervised

inds = np.random.choice(len(X), int(0.1*len(X)), replace=False)
X_valid = [X[i] for i in inds]

used_inds = set(inds.tolist())
X_unsupervised = [X[i] for i in range(len(X)) if i not in used_inds]

sum_counts = cvz.sum(axis=0)
v_size = sum_counts.shape[1]
sum_counts_np = np.zeros(v_size, dtype=int)
for v in range(v_size):
    sum_counts_np[v] = sum_counts[0,v]
word2id = dict([(w, cvectorizer.vocabulary_.get(w)) for w in cvectorizer.vocabulary_])
id2word = dict([(cvectorizer.vocabulary_.get(w), w) for w in cvectorizer.vocabulary_])


indexed_train = [[word2id[word] for word in doc.split()] for doc in X_train]
indexed_test = [[word2id[word] for word in doc.split()] for doc in X_test]
indexed_unsup = [[word2id[word] for word in doc.split()] for doc in X_unsupervised]
indexed_valid = [[word2id[word] for word in doc.split()] for doc in X_valid]

def create_list_words(in_docs):
    return [x for y in in_docs for x in y]

words_train = create_list_words(indexed_train)
words_test = create_list_words(indexed_test)
words_unsup = create_list_words(indexed_unsup)
words_valid = create_list_words(indexed_valid)



def create_doc_indices(in_docs):
    aux = [[j for i in range(len(doc))] for j, doc in enumerate(in_docs)]
    return [int(x) for y in aux for x in y]

doc_indices_train = create_doc_indices(indexed_train)
doc_indices_test = create_doc_indices(indexed_test)
doc_indices_unsup = create_doc_indices(indexed_unsup)
doc_indices_valid = create_doc_indices(indexed_valid)


# Number of documents in each set
n_docs_train = len(indexed_train)
n_docs_test = len(indexed_test)
n_docs_unsup = len(indexed_unsup)
n_docs_valid = len(indexed_valid)

def create_bow(doc_indices, words, n_docs, vocab_size):
    return sparse.coo_matrix(([1]*len(doc_indices),(doc_indices, words)), shape=(n_docs, vocab_size)).tocsr()

bow_train = create_bow(doc_indices_train, words_train, n_docs_train, len(cvectorizer.vocabulary_))
bow_test = create_bow(doc_indices_test, words_test, n_docs_test, len(cvectorizer.vocabulary_))
bow_unsup = create_bow(doc_indices_unsup, words_unsup, n_docs_unsup, len(cvectorizer.vocabulary_))
bow_valid = create_bow(doc_indices_valid, words_valid, n_docs_valid, len(cvectorizer.vocabulary_))

# Split bow intro token/value pairs
print('splitting bow into token/value pairs and saving to disk...')
data_path = "data/data_agnews/"


def split_bow(bow_in, n_docs):
    indices = [[w for w in bow_in[doc,:].indices] for doc in range(n_docs)]
    counts = [[c for c in bow_in[doc,:].data] for doc in range(n_docs)]
    return indices, counts

bow_train_tokens, bow_train_counts = split_bow(bow_train, n_docs_train)
bow_test_tokens, bow_test_counts = split_bow(bow_test, n_docs_test)
bow_unsup_tokens, bow_unsup_counts = split_bow(bow_unsup, n_docs_unsup)
bow_valid_tokens, bow_valid_counts = split_bow(bow_valid, n_docs_valid)

savemat(os.path.join(data_path, 'bow_train_tokens'), {'tokens': bow_train_tokens}, do_compression=True)
savemat(os.path.join(data_path, 'bow_train_counts'), {'counts': bow_train_counts}, do_compression=True)
savemat(os.path.join(data_path, 'bow_test_tokens'), {'tokens': bow_test_tokens}, do_compression=True)
savemat(os.path.join(data_path, 'bow_test_counts'), {'counts': bow_test_counts}, do_compression=True)
savemat(os.path.join(data_path, 'bow_unsup_tokens'), {'tokens': bow_unsup_tokens}, do_compression=True)
savemat(os.path.join(data_path, 'bow_unsup_counts'), {'counts': bow_unsup_counts}, do_compression=True)
savemat(os.path.join(data_path, 'bow_valid_tokens'), {'tokens': bow_valid_tokens}, do_compression=True)
savemat(os.path.join(data_path, 'bow_valid_counts'), {'counts': bow_valid_counts}, do_compression=True)

vocab = [id2word[cc] for cc in range(v_size)]
with open(os.path.join(data_path, 'vocab.pkl'), 'wb') as f:
    pickle.dump(vocab, f)
    
with open(os.path.join(data_path, 'id2word.pkl'), 'wb') as f:
    pickle.dump(id2word, f)