import random
import numpy as np
import os
import sys
import shutil


if __name__ == '__main__':
	from cotk_extend import PredefinedLanguageGeneration, FedMetric
	from cotk.metric import MetricChain, FwBwBleuCorpusMetric, SelfBleuCorpusMetric, NgramFwBwPerplexityMetric

	import utils
	# utils.debug()

	datapath = "./snli_data"
	dm = PredefinedLanguageGeneration(datapath)

	filename = "./output/%s.txt" % sys.argv[1]

	generated_text = []
	with open(filename, 'r', encoding='utf-8') as g:
		for data in g.readlines():
			generated_text.append(dm.convert_tokens_to_ids(data.strip().split()))

	generated_text = generated_text[:5000]
	assert len(generated_text) == 5000

	metric = MetricChain()
	metric.add_metric(FedMetric(dm, dm.get_all_batch("test")['sent_allvocabs'], gen_key="sent_allvocabs", seed=1234, sample=5000))
	metric.add_metric(FwBwBleuCorpusMetric(dm, dm.get_all_batch("test")['sent_allvocabs'], ngram=5, gen_key="sent_allvocabs", seed=1234, sample=5000))
	#metric.add_metric(SelfBleuCorpusMetric(dm, ngram=5, gen_key="sent_allvocabs", seed=1234, sample=5000))
	metric.add_metric(NgramFwBwPerplexityMetric(dm, dm.get_all_batch("test")['sent_allvocabs'], ngram=4, gen_key="sent_allvocabs"))

	for sent in generated_text:
		metric.forward({"sent_allvocabs": [sent]})

	res = metric.close()

	with open("./output/%s.res" % sys.argv[1], 'w', encoding='utf-8') as f:
		f.write(str(res))
	print(res)
