import numpy as np

with open('vocab.txt') as f:
    words = f.readlines()

words_list = []
for string in words:
    if string.strip('\n').isalpha():
        words_list.append(string.strip('\n'))

words_list = words_list[1000:]

word2ind = {}
for j,word in enumerate(words_list):
    word2ind[word] = j

N_hidden = 100
N_word = 200
N_sent = 1000000
T = 32 

ini_vec = np.random.rand(N_hidden) 
ini_vec = ini_vec/np.sum(ini_vec)

trans_mat = np.random.rand(N_hidden, N_hidden)
trans_mat = trans_mat / trans_mat.sum(axis=1)[:,None]

trans_mat1 = np.zeros((N_hidden, N_hidden))
for block_ind in range(N_hidden // 20):
    trans_mat_block = np.random.rand(20, 20)
    trans_mat_block = trans_mat_block/trans_mat_block.sum(axis=1)[:,None]
    trans_mat1[block_ind * 20: block_ind *20 +20, block_ind * 20: block_ind *20 +20] = trans_mat_block

trans_mat = 0.05* trans_mat + 0.95 * trans_mat1

gen_mat = np.zeros((N_hidden, N_word))
for block_ind in range(N_hidden // 10):
    trans_mat_block = np.random.rand(10, 20)
    trans_mat_block = trans_mat_block/trans_mat_block.sum(axis=1)[:,None]
    gen_mat[block_ind * 10: block_ind *10 +10, block_ind * 20: block_ind *20 +20] = trans_mat_block

np.save('ini_vec_block_cor.npy', ini_vec)
np.save('trans_mat_block_cor.npy', trans_mat)
np.save('gen_mat_block_cor.npy', gen_mat)

h_lst = np.arange(N_hidden)
w_lst = np.arange(N_word)

sent_list = []
N_sent = 1000000

for i in range(N_sent):
    sent = []
    h = np.random.choice(h_lst, p=ini_vec)
    for t in range(T):
        h = np.random.choice(h_lst, p=trans_mat[h])
        sent.append(words_list[np.random.choice(w_lst, p=gen_mat[h])])
    sent_list.append(" ".join(sent))

down_sent_list = []
for i in range(N_sent):
    sent = []
    h = np.random.choice(h_lst, p=ini_vec)
    for t in range(T):
        h = np.random.choice(h_lst, p=trans_mat[h])
        sent.append(words_list[np.random.choice(w_lst, p=gen_mat[h])])
    down_sent_list.append(" ".join(sent))
    if i % 1000 == 1:
        print(i)

import csv
with open('toy_hmm_train.csv', 'w', newline='') as csvfile:
    spamwriter = csv.writer(csvfile, delimiter=',')
    spamwriter.writerow(['text'] + ['label'])
    for ii in range(90000):
        print(ii)
        spamwriter.writerow([down_sent_list[ii], np.argmax(calc_viterbi(down_sent_list[ii].split(), 25) * calc_viterbi_post(down_sent_list[ii].split(), 25))])                                                          
with open('toy_hmm_eval.csv', 'w', newline='') as csvfile:
    spamwriter = csv.writer(csvfile, delimiter=',')
    spamwriter.writerow(['text'] + ['label'])
    for ii in range(90000,100000):
        print(ii)
        spamwriter.writerow([down_sent_list[ii], np.argmax(calc_viterbi(down_sent_list[ii].split(), 25) * calc_viterbi_post(down_sent_list[ii].split(), 25))])                                                          

with open('hmm_train.txt' ,'w') as f:
    for sent in sentences[:900000]:
        f.write(sent + '\n')
with open('hmm_eval.txt' ,'w') as f:
    for sent in sent_list[900000:]:
        f.write(sent + '\n')

def calc_viterbi(sent, ind):
    u = ini_vec
    for i in range(ind):
        u = trans_mat.T.dot(u) * gen_mat[:,word2ind[sent[i]]]
    return u

def calc_viterbi_post(sent, ind):
    u = gen_mat[:,word2ind[sent[T-1]]].dot(trans_mat.T)
    for i in range(T-2, ind-1, -1):
        u = u.dot(np.diag(gen_mat[:,word2ind[sent[i]]])).dot(trans_mat.T)
    return u.dot(np.eye(N_hidden))