# coding:utf-8
import logging
import time
import os
import math
import random
import pickle as pkl
from itertools import chain

import torch
from torch import nn, optim, autograd
import torch.nn.functional as F
import numpy as np

from cotk.metric import LanguageGenerationRecorder, NgramFwBwPerplexityMetric, FwBwBleuCorpusMetric,\
	MetricChain, SelfBleuCorpusMetric
from cotk_extend import FedMetric
from utils import Storage, cuda, BaseModel, SummaryHelper, get_mean, storage_to_list, \
	CheckpointManager, Tensor, zeros, ones, AdamW, RAdam, truncated_normal, LongTensor, reshape
from network import Network

class NAGAN(BaseModel):
	def __init__(self, param):
		args = param.args
		net = Network(param)
		self.genOptimizer = RAdam(net.get_parameters_by_name("gen"), lr=args.gen_lr, betas=(0.5, 0.9))
		self.disOptimizer = RAdam(net.get_parameters_by_name("dis"), lr=args.dis_lr, betas=(0.5, 0.9), weight_decay=0.05)

		optimizerList = {"genOptimizer": self.genOptimizer, "disOptimizer": self.disOptimizer}
		checkpoint_manager = CheckpointManager(args.name, args.model_dir, \
						args.checkpoint_steps, args.checkpoint_max_to_keep, "min")

		super().__init__(param, net, optimizerList, checkpoint_manager)

		self.create_summary()

	def create_summary(self):
		args = self.param.args
		self.summaryHelper = SummaryHelper("%s/%s_%s" % \
				(args.log_dir, args.name, time.strftime("%H%M%S", time.localtime())), \
				args)

		scalarlist = ["gen_loss", "dis_loss", "gp_loss", "fake_v", "real_v", "int_v", "FED", "fw-bleu", "bw-bleu", "fw-bw-bleu"]

		self.trainSummary = self.summaryHelper.addGroup(\
			scalar=scalarlist, \
			prefix="train")

		tensorlist = []
		textlist = []
		emblist = []
		for i in self.args.show_sample:
			textlist.append("show_str%d" % i)
		self.devSummary = self.summaryHelper.addGroup(\
			scalar=scalarlist,\
			tensor=tensorlist,\
			text=textlist,\
			embedding=emblist,\
			prefix="dev")
		self.testSummary = self.summaryHelper.addGroup(\
			scalar=scalarlist,\
			tensor=tensorlist,\
			text=textlist,\
			embedding=emblist,\
			prefix="test")

	def _preprocess_batch(self, data, device=None):
		incoming = Storage()
		incoming.data = data = Storage(data)
		data.batch_size = data.sent.shape[0]
		data.seqlen = data.sent.shape[1]
		data.sent = cuda(torch.LongTensor(data.sent), device=device)
		return incoming

	def train(self, batch_num):
		args = self.param.args
		dm = self.param.volatile.dm
		datakey = 'train'

		for _ in range(batch_num):
			self.now_batch += 1

			incoming = self.get_next_batch(dm, datakey)
			incoming.args = Storage()
			self.zero_grad()
			self.net.G_forward(incoming)
			GLoss = incoming.result.loss
			GResult = incoming.result
			GLoss.backward()
			nn.utils.clip_grad_norm_(self.net.parameters("gen"), args.grad_clip)
			self.genOptimizer.step()

			incoming = self.get_next_batch(dm, datakey)
			incoming.args = Storage()
			self.zero_grad()
			self.net.D_forward(incoming)
			DLoss = incoming.result.loss
			DResult = incoming.result
			DLoss.backward()
			nn.utils.clip_grad_norm_(self.net.parameters("dis"), args.grad_clip)
			self.disOptimizer.step()

			DResult.update(GResult)
			self.trainSummary(self.now_batch, storage_to_list(DResult))
			logging.info("batch %d : dis loss=%f, gen loss=%f", self.now_batch, \
				DLoss.detach().cpu().numpy(), GLoss.detach().cpu().numpy())

			self.net.emaHelper.update()

	def evaluate(self, key):
		args = self.param.args
		dm = self.param.volatile.dm

		self.net.emaHelper.load()

		metric = MetricChain()
		# only test a few samples for saving time
		metric.add_metric(FwBwBleuCorpusMetric(dm, dm.get_all_batch(key)["sent_allvocabs"], sample=200))
		metric.add_metric(FedMetric(dm, dm.get_all_batch("test")['sent_allvocabs'], gen_key="gen", sample=200))

		result_arr = []
		for incoming in self.get_batches(dm, key)[1]:
			with torch.no_grad():
				self.net.detail_forward(incoming)
			result_arr.append(incoming.result)
			data = incoming.data
			data.gen = convert_to_sentence(incoming.genIncoming.gen.logits, incoming.data.sent_length, dm)
			metric.forward(data)

		detail_arr = Storage()
		for i in args.show_sample:
			index = [i * args.batch_size + j for j in range(args.batch_size)]
			incoming = self.get_select_batch(dm, key, index)
			incoming.args = Storage()
			with torch.no_grad():
				self.net.detail_forward(incoming)
			detail_arr["show_str%d" % i] = incoming.result.show_str

		detail_arr.update(metric.close())
		detail_arr.update({key:get_mean(result_arr, key) for key in result_arr[0] if "detach" in dir(result_arr[0][key])})

		detail_arr['gen_lr'] = self.genOptimizer.param_groups[0]['lr']
		detail_arr['dis_lr'] = self.disOptimizer.param_groups[0]['lr']

		self.net.emaHelper.restore()

		return detail_arr

	def train_process(self):
		args = self.param.args
		dm = self.param.volatile.dm

		while self.now_epoch < args.epochs:
			self.now_epoch += 1
			self.updateOtherWeights()

			dm.restart('train', args.batch_size)
			self.net.train()
			self.train(args.batch_per_epoch)

			#self.net.eval() # use dropout at test time!
			devloss_detail = self.evaluate("dev")
			self.devSummary(self.now_batch, devloss_detail)
			logging.info("epoch %d, evaluate dev", self.now_epoch)

			# testloss_detail = self.evaluate("test")
			# self.testSummary(self.now_batch, testloss_detail)
			# logging.info("epoch %d, evaluate test", self.now_epoch)

			self.save_checkpoint(value=devloss_detail['FED'])

	def test(self, key, evaluate_samples):
		args = self.param.args
		dm = self.param.volatile.dm

		self.net.emaHelper.load()

		res = Storage()
		metric = MetricChain()
		# use standalone evaluator, disable the metrics
		#metric.add_metric(NgramFwBwPerplexityMetric(dm, dm.get_all_batch('test')["sent_allvocabs"], 5, gen_key="sent"))
		#metric.add_metric(FwBwBleuCorpusMetric(dm, dm.get_all_batch('test')["sent_allvocabs"], sample=5000))
		#metric.add_metric(SelfBleuCorpusMetric(dm, sample=5000))
		metric.add_metric(LanguageGenerationRecorder(dm, gen_key="gen"))

		output_samples = 0
		dm.restart(key, args.batch_size)
		while True:
			incoming = self.get_next_batch(dm, key)
			incoming.args = Storage()

			with torch.no_grad():
				self.net.detail_forward(incoming)
			data = incoming.data
			data.gen = convert_to_sentence(incoming.genIncoming.gen.logits, incoming.data.sent_length, dm)
			metric.forward(data)

			output_samples += data.gen.shape[0]
			if output_samples >= evaluate_samples:
				break
		res.update(metric.close())

		if not os.path.exists(args.out_dir):
			os.makedirs(args.out_dir)
		filename = args.out_dir + "/%s.txt" % (args.name)

		with open(filename, 'w') as f:
			for j in range(evaluate_samples):
				f.write(" ".join(res['gen'][j]) + "\n")
			f.flush()
		logging.info("result output to %s.", filename)

		self.net.emaHelper.restore()

	def test_process(self):
		logging.info("Test Start.")
		self.net.train() # use dropout at test time
		self.test("test", 5000)
		logging.info("Test Finish.")

	def interpolate(self):
		L = 13  # length of generated sentences
		n = 10  # numbers of interpolated sentences

		self.net.emaHelper.load()
		self.net.eval() # disable dropout

		for _ in range(10):
			z1 = Tensor(truncated_normal(1, L, self.args.z_size, threshold=2))
			z2 = Tensor(truncated_normal(1, L, self.args.z_size, threshold=2))
			for i in range(n+1):
				z = (z1 * i + z2 * (n - i)) / n
				sent, _ = self.net.generate_from_z(z)
				print(" ".join(sent[0]))

	def record(self):
		batch_size = 128
		L = 13  # length of generated sentences

		self.net.emaHelper.load()
		self.net.eval() # disable dropout

		z_list = []
		sent_list = []
		for sample in range(200):
			print(sample)
			z = Tensor(truncated_normal(batch_size, L, self.args.z_size, threshold=2))
			sent, _ = self.net.generate_from_z(z)

			z_list.extend(z.tolist())
			sent_list.extend(sent)

		with open("%s/record_%s.pkl" % (self.args.out_dir, self.args.name), "wb") as f:
			pkl.dump((z_list, sent_list), f)

	def walkgrad(self):
		source = "black"  # source word
		target = "white"  # target word

		self.net.emaHelper.load()
		self.net.eval()  # disable dropout

		print("loading")
		f = open("%s/record_%s.pkl" % (self.args.out_dir, self.args.name), 'rb')
		Z, sentl = pkl.load(f)
		print("loaded")

		def lcs(str_a, str_b):
			from difflib import SequenceMatcher
			s = SequenceMatcher(None, str_a, str_b)
			lcs = list(chain(*[str_a[block.a:(block.a + block.size)] for block in s.get_matching_blocks()]))
			return float(len(lcs)) / max(len(str_a), len(str_b))

		success = 0
		all_num = 0
		overlaps_list = []
		target_id = self.param.volatile.dm.frequent_vocab_list.index(target)

		for _, (z, s) in enumerate(zip(Z, sentl)):
			if source in s and target not in s:
				ori_z = Tensor(np.array(z)).unsqueeze(0)
				sent1, _ = self.net.generate_from_z(ori_z)
				print(" ".join(sent1[0]))
				assert source in sent1[0]

				batch_size = ori_z.shape[1]
				group = 4

				z1 = ori_z.clone().expand(batch_size, -1, -1).clone().detach()  # try replace each place in a batch
				z2 = z1.clone().detach()
				z3 = z1.clone().detach()
				z4 = z1.clone().detach()
				z1.requires_grad = True
				z2.requires_grad = True
				z3.requires_grad = True
				z4.requires_grad = True

				last_sent = ["" for __ in range(batch_size * group)]

				# try multiple optimizers
				optimizer1 = optim.Adam([z1], lr=1e-3)
				optimizer2 = optim.SGD([z2], lr=1e-2)
				optimizer3 = optim.SGD([z3], lr=1e-3, momentum=0.5)
				optimizer4 = optim.SGD([z4], lr=1e-2)
				ori_lr = np.array([1e-2 for _ in range(batch_size)])
				lr = np.array([1e-2 for _ in range(batch_size)])
				indices = LongTensor(np.arange(batch_size)).unsqueeze(-1).repeat(group, 1)
				ori_w = None

				res_sent = []
				res_ans = -1

				for step_num in range(100):
					with torch.enable_grad():
						z = torch.cat([z1, z2, z3, z4], dim=0)
						sent, logits = self.net.generate_from_z(z)
						if ori_w is None:
							ori_w = logits.max(dim=-1)[1].unsqueeze(-1) # batch * len

						now_sent = [" ".join(sent[j]) for j in range(batch_size * group)]

						for i in range(batch_size * group):
							if target in sent[i]:
								ans = lcs(sent1[0], sent[i])
								if ans > res_ans:
									res_sent = sent[i]
									res_ans = ans

						flag = np.array([last_sent[j] == now_sent[j] for j in range(batch_size*3, batch_size*4)], dtype=np.float)
						lr = lr * 2 * flag + ori_lr * (1-flag)  # if the sentence doesn't change, double the lr; otherwise, lr = 1e-2
						last_sent = now_sent

						loss_matrix = -F.log_softmax(logits, dim=-1)
						keep_loss = torch.gather(loss_matrix, 2, ori_w)[:, :, 0] # batch * len
						minus_keep_loss = torch.gather(keep_loss, 1, indices)[:, 0] # batch

						change_loss = loss_matrix[:, :, target_id] # batch * len
						change_loss = torch.gather(change_loss, 1, indices)[:, 0] # batch
						loss = (keep_loss.sum() - minus_keep_loss.sum()) * 0.1 + change_loss.sum()

					optimizer1.zero_grad()
					optimizer2.zero_grad()
					optimizer3.zero_grad()
					optimizer4.zero_grad()
					loss.backward()
					optimizer1.step()
					optimizer2.step()
					optimizer3.step()
					z4.data.add_(-z4.grad.data * Tensor(lr).unsqueeze(-1).unsqueeze(-1))

				all_num += 1

				now_success = 0
				now_overlaps = []


				if target in res_sent:
					op = lcs(sent1[0], res_sent)
					now_overlaps.append(op)
					joint_sent = " ".join(res_sent)
					print("success", op, joint_sent)
					now_success = 1
				else:
					print("failed")

				success += now_success
				if now_success:
					overlaps_list.append(max(now_overlaps))
				if success >= 100:
					break

		print("success rate", success / (all_num))
		print("overlaps", np.mean(overlaps_list))

	def adddiff(self):
		source = "black"  # source word
		target = "white"  # target word

		self.net.emaHelper.load()
		self.net.eval()  # disable dropout

		print("loading")
		f = open("%s/record_%s.pkl" % (self.args.out_dir, self.args.name), 'rb')
		Z, sent = pkl.load(f)
		print("loaded")

		source_mean = []
		target_mean = []

		for i, (z, s) in enumerate(zip(Z, sent)):
			if source in s and target not in s:
				source_mean.append(z)
			if target in s and source not in s:
				target_mean.append(z)

		print(len(source_mean), len(target_mean))
		source_mean = np.mean(source_mean, 0)
		target_mean = np.mean(target_mean, 0)
		diff = target_mean - source_mean

		success = 0
		all_num = 0
		overlaps_list = []

		for i, (z, s) in enumerate(zip(Z, sent)):
			if source in s and target not in s:
				ori_z = Tensor(np.array(z)).unsqueeze(0)
				sent1, _ = self.net.generate_from_z(ori_z)
				assert source in sent1[0]
				print(" ".join(sent1[0]))

				for j in range(50):
					ori_z = ori_z + Tensor(diff) * 0.2
					sent2, _ = self.net.generate_from_z(ori_z)
					if target in sent2[0]:
						break

				all_num += 1
				if target in sent2[0]:
					success += 1
					if success >= 100:
						break
					overlaps_list.append(lcs(sent1[0], sent2[0]))
					print(" ".join(sent2[0]))
				else:
					print("failed")

		print("success rate", success / all_num)
		print("precision", np.mean(overlaps_list))


def lcs(str_a, str_b):
	from difflib import SequenceMatcher
	s = SequenceMatcher(None, str_a, str_b)
	lcsstr = list(chain(*[str_a[block.a:(block.a + block.size)] for block in s.get_matching_blocks()]))
	return float(len(lcsstr)) / max(len(str_a), len(str_b))

def convert_to_sentence(logits, lengths, dm):
	logits = logits.clone()
	logits[:, :, dm.unk_id] -= 1e8
	logits[:, :, dm.go_id] -= 1e8
	logits[:, :, dm.eos_id] -= 1e8
	sent = logits.max(dim=-1)[1].detach().cpu().numpy()

	for i, sent_length in enumerate(lengths - 1):
		j = 0
		for k in range(sent_length):
			if j == 0 or sent[i][j-1] != sent[i][k]:
				sent[i][j] = sent[i][k]
				j += 1
		if j < sent.shape[1]:
			sent[i][j] = dm.eos_id
	return sent