import sys
import pickle
import numpy as np
import random

from hmmlearn.hmm import MultinomialHMM
from cotk_extend import PredefinedLanguageGeneration

def fithmm():
	import numpy as np
	dm = PredefinedLanguageGeneration("./mscoco_data")

	X = []
	lengths = []

	for i, data in enumerate(dm.get_batches("train", batch_size=1)):
		# remove <go> and <eos>
		# reduce vocabulary size to 500
		X.append(data['sent'][0][1:-1] % 500)
		lengths.append(data['sent'][0].shape[0] - 2)
		if i >= 5000:
			break

	hmm = MultinomialHMM(100, n_iter=100, verbose=True)
	X = np.expand_dims(np.concatenate(X, axis=0), axis=-1)
	print("fit start")
	hmm.fit(X, lengths)

	with open("./synthetic_data/oracle.pkl", "wb") as file:
		pickle.dump(hmm, file)

	with open("word_list.txt", "w") as f:
		for i in range(500):
			f.write(str(i) + "\n")

	def generate_sentences(filename, num):
		with open(filename, "w") as f:
			for _ in range(num):
				sent = hmm.sample(20)[0][:, 0].tolist()
				f.write(" ".join([str(x) for x in sent]) + "\n")

	generate_sentences("train.txt", 50000)
	generate_sentences("dev.txt", 5000)
	generate_sentences("test.txt", 5000)

def evaluate():
	with open("./synthetic_data/oracle.pkl", "rb") as file:
		hmm = pickle.load(file)

	score = []
	filename = "./output/%s.txt" % sys.argv[1]
	with open(filename, 'r') as f:
		for line in f:
			if not line:
				break
			line = line.replace('<unk>', str(random.randint(0, 499)))
			line = line.replace('<go>', str(random.randint(0, 499)))
			line = line.replace('<pad>', str(random.randint(0, 499)))
			try:
				sent = [int(x) for x in line.split()]
			except:
				print("error in generated sentences:", line)
			while len(sent) < 20:
				sent.append(random.randint(0, 499))
			sent = sent[:20]
			sent = np.expand_dims(np.array(sent), axis=-1)

			score.append(hmm.score(sent))

	assert len(score) == 5000
	print(-np.mean(score) / 20)

evaluate()
