# coding:utf-8
import logging

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.nn.utils as U
from torch.nn.init import uniform_
from torch.autograd import grad

from utils import zeros, generateMask, Tensor, cuda, \
			BaseNetwork, Storage, gumbel_max, straight_max, \
			BaseModule, reshape, EMAHelper, sequence_pooling
from utils.transformer_learnable_embedding import TransformerDecoder, TransformerEncoder, positional_embeddings, positional_encodings


# pylint: disable=W0221
class Network(BaseNetwork):
	def __init__(self, param):
		super().__init__(param, ['gen', 'dis'])

		self.genNetwork_gen = GenNetwork(param)
		self.disNetwork_dis = DisNetwork(param)

		self.emaHelper = EMAHelper(param, list(self.get_parameters_by_name("gen", silent=True)))

	def createGenIncoming(self, incoming): # for generator
		incoming.genIncoming = genIncoming = Storage()
		data = genIncoming.data = incoming.data
		data.z = Tensor(np.random.randn(data.batch_size, data.seqlen - 1, self.args.z_size), device=data.sent)
		genIncoming.args = Storage()
		return genIncoming

	def createFakeIncoming(self, incoming, genIncoming): # for discriminator
		data = incoming.data
		incoming.fakeIncoming = fakeIncoming = Storage()
		fakeIncoming.args = Storage()
		fakeIncoming.data = fakeData = Storage()

		fakeData.seqlen = genIncoming.data.seqlen
		fakeData.batch_size = data.batch_size
		fakeData.sent_length = genIncoming.data.sent_length

		fakeData.onehot = genIncoming.gen.log_prob
		fakeData.embedding = torch.einsum("ijk,kl->ijl", genIncoming.gen.w_prob_gumbel, self.disNetwork_dis.embLayer.weight)
		return fakeIncoming

	def createInterpolatedIncoming(self, incoming, genIncoming): # for discriminator
		data = incoming.data
		incoming.interpolatedIncoming = intIncoming = Storage()
		intIncoming.args = Storage()
		intIncoming.data = intData = Storage()

		batch_size = intData.batch_size = data.batch_size
		intData.seqlen = genIncoming.data.seqlen
		intData.sent_length = genIncoming.data.sent_length

		p = zeros(batch_size).uniform_()

		real_onehot = zeros(data.batch_size, data.seqlen, self.param.volatile.dm.frequent_vocab_size).scatter_(-1, data.sent.unsqueeze(-1), 1)
		intData.onehot = real_onehot.detach() * p.unsqueeze(-1).unsqueeze(-1) + \
				genIncoming.gen.w_prob_gumbel.detach() * (1 - p).unsqueeze(-1).unsqueeze(-1)

		intData.onehot_e = intData.onehot.detach().requires_grad_()
		intData.embedding = torch.einsum("ijk,kl->ijl", intData.onehot_e, self.disNetwork_dis.embLayer.weight)
		return intIncoming

	def G_forward(self, incoming):
		genIncoming = self.createGenIncoming(incoming)
		genIncoming.args = Storage()
		genIncoming.args.non_differentiable_mode = self.args.gen_approx_mode  # add gumbel_noise for non-differentiable operations
		incoming.result = Storage()
		self.genNetwork_gen(genIncoming)

		fakeIncoming = self.createFakeIncoming(incoming, genIncoming)
		fakeIncoming.loss_mode = "gen"
		self.disNetwork_dis.forward_soft(fakeIncoming)

		incoming.result.loss = incoming.result.gen_loss = fakeIncoming.dis.loss

		if torch.isnan(incoming.result.loss).detach().cpu().numpy() > 0:
			logging.info("Nan detected")
			logging.info(incoming.result)
			raise FloatingPointError("Nan detected")

	def D_forward(self, incoming):
		genIncoming = self.createGenIncoming(incoming)
		genIncoming.args = Storage()
		genIncoming.args.non_differentiable_mode = self.args.dis_approx_mode  # add gumbel_noise for non-differentiable operations

		incoming.result = Storage()
		with torch.no_grad():
			self.genNetwork_gen(genIncoming)

		incoming.loss_mode = "real"
		self.disNetwork_dis.forward_sent(incoming)

		fakeIncoming = self.createFakeIncoming(incoming, genIncoming)
		fakeIncoming.loss_mode = "fake"
		self.disNetwork_dis.forward_soft(fakeIncoming)

		intIncoming = self.createInterpolatedIncoming(incoming, genIncoming)
		intIncoming.loss_mode = "fake"
		self.disNetwork_dis.forward_soft(intIncoming)
		gradients = grad(intIncoming.dis.predict.sum(), [intIncoming.data.onehot_e], retain_graph=True, create_graph=True, only_inputs=True)[0]
		mask = generateMask(intIncoming.data.seqlen, intIncoming.data.sent_length).transpose(0, 1)
		gradients_norm = ((gradients ** 2).sum(-1) * mask).sum(-1)
		incoming.result.gp_loss = gradients_norm.max()

		incoming.result.d_regularizer = incoming.dis.d_regularizer + fakeIncoming.dis.d_regularizer + intIncoming.dis.d_regularizer

		incoming.result.dis_loss = incoming.dis.loss + fakeIncoming.dis.loss + \
						incoming.result.gp_loss * self.args.gp + incoming.result.d_regularizer * self.args.dr
		incoming.result.real_v = incoming.dis.predict.mean()
		incoming.result.fake_v = fakeIncoming.dis.predict.mean()
		incoming.result.int_v = intIncoming.dis.predict.mean()

		incoming.result.loss = incoming.result.dis_loss

		if torch.isnan(incoming.result.loss).detach().cpu().numpy() > 0:
			logging.info("Nan detected")
			logging.info(incoming.result)
			raise FloatingPointError("Nan detected")

	def detail_forward(self, incoming):
		genIncoming = self.createGenIncoming(incoming)
		genIncoming.args = Storage()
		genIncoming.args.non_differentiable_mode = "max"  # do not add gumbel_noise for non-differentiable operations
		incoming.result = Storage()
		self.genNetwork_gen(genIncoming)

		incoming.loss_mode = "real"
		self.disNetwork_dis.forward_sent(incoming)

		fakeIncoming = self.createFakeIncoming(incoming, genIncoming)
		fakeIncoming.loss_mode = "fake"
		fakeIncoming.result = Storage()
		self.disNetwork_dis.forward_soft(fakeIncoming)

		#incoming.result.gen_loss = incoming.result.gen_loss
		incoming.result.dis_loss = (incoming.dis.loss + fakeIncoming.dis.loss).detach()
		incoming.result.fake_v = fakeIncoming.dis.predict.mean().detach()
		incoming.result.real_v = incoming.dis.predict.mean().detach()

		batch_size = incoming.data.batch_size
		dm = self.param.volatile.dm
		sent = [dm.convert_ids_to_sentence(ids) for ids in genIncoming.gen.w.detach().cpu().numpy()]

		incoming.result.show_str = show_str = []
		for i in range(batch_size):
			show_str.append("gen: " + sent[i])
		incoming.result.show_str = "\n".join(incoming.result.show_str)

	def generate_from_z(self, z):
		incoming = Storage()
		incoming.data = data = Storage()
		incoming.data.z = z
		incoming.data.seqlen = z.shape[1] + 1
		incoming.data.batch_size = z.shape[0]
		data.sent_length = np.array([data.seqlen] * data.batch_size)
		incoming.args = Storage()
		incoming.args.non_differentiable_mode = "max"  # do not add gumbel_noise for non-differentiable operations
		incoming.result = Storage()
		sent = []
		self.genNetwork_gen(incoming)
		data = incoming.data
		logits = incoming.gen.logits.clone()
		logits[:, :, self.param.volatile.dm.unk_id] -= 1e8  # disable unk
		word_without_unk = logits.max(dim=-1)[1]
		w = word_without_unk.detach().cpu().numpy()
		for i in range(w.shape[0]):
			sent.append(self.param.volatile.dm.convert_ids_to_tokens(w[i]))
		return sent, incoming.gen.logits

class GenNetwork(BaseModule):
	def __init__(self, param):
		super().__init__()
		self.args = args = param.args
		self.param = param

		self.initLayer = nn.Linear(args.z_size, args.tf_size)
		self.natLayer = NATLayer(param)

	def forward(self, incoming):
		gen = incoming.gen = Storage()
		batch_size = incoming.data.batch_size
		seqlen = incoming.data.seqlen - 1
		sent_length = incoming.data.sent_length - 1
		z = self.initLayer(incoming.data.z)

		gen.logits = self.natLayer.forward(z, batch_size, seqlen, sent_length)
		gen.log_prob = gen.logits.softmax(dim=-1)

		front_pad_go = zeros(batch_size, 1, self.param.volatile.dm.frequent_vocab_size)
		front_pad_go[:, 0, self.param.volatile.dm.go_id] = 1
		gen.log_prob = torch.cat([front_pad_go, gen.log_prob], dim=1)

		if incoming.args.non_differentiable_mode == "max":
			gen.w_prob_gumbel = straight_max(gen.logits)[0]
		elif incoming.args.non_differentiable_mode == "soft":
			gen.w_prob_gumbel = gen.logitsh
		else:  # add gumbel noise and use straight-through estimator
			gen.w_prob_gumbel = gumbel_max(gen.logits, tau=self.args.tau)[0]

		gen.w_prob_gumbel = torch.cat([front_pad_go, gen.w_prob_gumbel], dim=1)
		gen.w = gen.w_prob_gumbel.max(dim=-1)[1]

class NATLayer(nn.Module):
	def __init__(self, param):
		super().__init__()
		self.args = args = param.args
		self.param = param

		self.out = nn.Linear(args.tf_size, param.volatile.dm.frequent_vocab_size)

		self.decoder = TransformerDecoder(args.tf_size, args.tf_size, args.tf_hidden_size, args.n_heads, args.n_dislayers, \
			args.droprate, args.input_droprate, windows=args.windows, attend_mode="full", \
			vocabulary_attention=args.va, relative_clip=args.rc)

		init_emb = positional_encodings(args.max_sent_length, args.tf_size // 2)[0]
		self.first_emb = nn.Parameter(torch.Tensor(init_emb.detach().cpu().numpy()))
		self.reverse_emb = nn.Parameter(torch.Tensor(init_emb.detach().cpu().numpy()))

	def forward(self, src, batch_size, seqlen, sent_length):
		src_pos = src + positional_embeddings(seqlen, self.first_emb, self.reverse_emb, sent_length)
		srcs = [src for i in range(self.args.n_layers)]  # attend to z
		hidden = self.decoder(src_pos, srcs, sent_length, sent_length, vocab_attention_layer=self.out)[-1]
		logits = self.out(hidden)
		return logits

class DisNetwork(BaseModule):
	def __init__(self, param):
		super().__init__()
		self.args = args = param.args
		self.param = param

		self.embLayer = nn.Embedding(param.volatile.dm.frequent_vocab_size, args.embedding_size)
		self.embLayer.weight.data = torch.Tensor(self.param.volatile.wordvec)
		self.outputLayer = OutputLayer(param)

	def forward_sent(self, incoming):
		incoming.data.embedding_input = self.embLayer(incoming.data.sent)
		self.outputLayer.forward_sent(incoming)

	def forward_soft(self, incoming):
		incoming.data.embedding_input = incoming.data.embedding
		self.outputLayer.forward_soft(incoming)

class OutputLayer(nn.Module):
	def __init__(self, param):
		super().__init__()
		self.args = args = param.args
		self.param = param

		self.inputLayer = nn.Linear(args.embedding_size, args.tf_size)

		self.module = TransformerEncoder(args.tf_size, args.tf_hidden_size, args.n_heads, args.n_layers,
			args.droprate, args.input_droprate, relative_clip=args.rc)

		init_emb = positional_encodings(args.max_sent_length, args.tf_size // 2)[0]
		self.first_emb = nn.Parameter(torch.Tensor(init_emb.detach().cpu().numpy()))
		self.reverse_emb = nn.Parameter(torch.Tensor(init_emb.detach().cpu().numpy()))

		self.fc = nn.Linear(args.tf_size, 1)
		self.drop = nn.Dropout(args.droprate)

	def forward(self, incoming):
		seqlen = incoming.data.seqlen
		length = incoming.data.sent_length
		batch_size = incoming.data.batch_size
		args = self.args

		embedding = incoming.data.embedding_input
		x = self.inputLayer(embedding)   # batch * seq * tf_size
		x = x + positional_embeddings(seqlen, self.first_emb, self.reverse_emb, length)
		z = self.module(x, length=length)[-1]

		logits = self.fc(sequence_pooling(z, length, "max"))[:, 0]

		incoming.dis.predict = logits
		if incoming.loss_mode == "real":
			incoming.dis.activate = - F.logsigmoid(logits)
		elif incoming.loss_mode == "fake":
			incoming.dis.activate = - F.logsigmoid(-logits)
		elif incoming.loss_mode == "gen":
			incoming.dis.activate = - F.logsigmoid(logits)

		# an idea from https://arxiv.org/pdf/1909.13188.pdf
		incoming.dis.d_regularizer = (incoming.dis.predict ** 2).mean()
		incoming.dis.loss = incoming.dis.activate.mean()

	def forward_sent(self, incoming):
		incoming.dis = dis = Storage()
		self.forward(incoming)

	def forward_soft(self, incoming):
		incoming.dis = dis = Storage()
		self.forward(incoming)
