import numpy as np
import random
from collections import Counter, OrderedDict
from itertools import chain
from nltk.tokenize import WordPunctTokenizer
import pickle as pkl
import sys
import random
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_gan as tfgan

from cotk.metric import MetricBase
from cotk.dataloader import LanguageProcessing, GeneralVocab, FieldContext

class PredefinedLanguageGeneration(LanguageProcessing):
	def __init__(self, file_id, *,
			max_sent_length=None, \
			convert_to_lower_letter=None, \
			set_name=None
			):

		word_list = list(map(lambda x: x.strip(), open(file_id + "/word_list.txt", encoding="utf-8").readlines()))

		vocab = GeneralVocab.from_predefined(["<pad>", "<go>", "<eos>", "<unk>"] + word_list, len(word_list) + 4, \
			OrderedDict([("pad", "<pad>"), ("go", "<go>"), ("eos", "<eos>"), ("unk", "<unk>")]))

		if set_name is None:
			fields = OrderedDict([("sent", "SentenceDefault")])
		else:
			fields = {key: OrderedDict([("sent", "SentenceDefault")]) for key in set_name}


		with FieldContext.set_parameters(vocab=vocab,\
				tokenizer="nltk", \
				vocab_from={
					"train": "train",
					"training": "train",
					"dev": "test",
					"development": "test",
					"valid": "test",
					"validation": "test",
					"test": "test",
					"evaluation": "test",
					"fake": "train"
				},
				max_sent_length=max_sent_length,
				convert_to_lower_letter=convert_to_lower_letter):
			super().__init__(file_id, fields)
		self.set_default_field("train", "sent")

class FedMetric(MetricBase):

	cache = {}
	def __init__(self, dataloader, \
			reference_test_list, \
			gen_key="gen", \
			sample=1000, \
			seed=1234, \
			cpu_count=None):
		super().__init__("Fed", 1)
		self.dataloader = dataloader
		self.reference_test_list = reference_test_list
		self.gen_key = gen_key
		self.sample = sample
		self.seed = seed
		self.refs = []
		self.hyps = []

	def forward(self, data):
		gen = data[self.gen_key]

		if not isinstance(gen, (np.ndarray, list)):
			raise TypeError("Unknown type for gen.")

		for gen_sen in gen:
			self.hyps.append(list(self.dataloader.trim_in_ids(gen_sen)))

	def get_text(self, sents):
		text = []
		for s in sents:
			text.append(" ".join(self.dataloader.convert_ids_to_tokens(s)))
		return text

	def close(self):
		res = super().close()
		if not self.hyps:
			raise RuntimeError("The metric has not been forwarded data correctly.")

		for resp_sen in self.reference_test_list:
			self.refs.append(list(self.dataloader.trim_in_ids(resp_sen[1:])))

		sample_hyps = self.sample if self.sample < len(self.hyps) else len(self.hyps)
		sample_refs = self.sample if self.sample < len(self.refs) else len(self.refs)

		if sample_hyps <= 1:
			raise RuntimeError('`sample_hyps` should be more than 1, \
				whose value is `{}`'.format(sample_hyps))
		if sample_refs <= 1:
			raise RuntimeError('`sample_refs` should be more than 1, \
				whose value is `{}`'.format(sample_refs))

		rng_state = random.getstate()
		random.seed(self.seed)
		random.shuffle(self.hyps)
		random.shuffle(self.refs)
		random.setstate(rng_state)

		# Modified based on https://github.com/deepmind/deepmind-research/blob/master/scratchgan/eval_metrics.py

		textA = self.get_text(self.refs[:sample_refs])
		textB = self.get_text(self.hyps[:sample_hyps])

		if "embed" in FedMetric.cache:
			embed = FedMetric.cache['embed']
			session = FedMetric.cache['session']
		else:
			embed = FedMetric.cache['embed'] = hub.Module("https://tfhub.dev/google/universal-sentence-encoder-large/3")
			config = tf.ConfigProto()
			config.intra_op_parallelism_threads = 16
			config.inter_op_parallelism_threads = 16
			config.gpu_options.allow_growth = True
			session = FedMetric.cache['session'] = tf.Session(config=config)
			session.run(tf.global_variables_initializer())
			session.run(tf.tables_initializer())

		real_embed = embed(textA)
		generated_embed = embed(textB)
		distance = tfgan.eval.frechet_classifier_distance_from_activations(
			real_embed, generated_embed)

		distance_np = session.run(distance)

		res = {"FED": distance_np}
		return res
