import torch
import torch.nn as nn
from util import log
import numpy as np
from modules import *

ablation_dir = 'out/ONiceNet/ablation/'


class NICELayer(nn.Module):
	def __init__(self, keep):
		super(NICELayer, self).__init__()
		self.keep = keep

	def forward(self, W, x, device):
		"""
		:param x: of size (batch_size, x_size)
		:return: output of size (batch_size, x_size)
		"""
		x_even, x_odd = x[:, 0::2], x[:, 1::2]
		if self.keep == 'even':
			y_even = x_even
			y_odd = x_odd + self.m(W=W, x=x_even, device=device)
			y = self._odd_even_concat(z_even=y_even, z_odd=y_odd)
			# y = y.flip(dims=[-1])
			# y = self._consecutive_concat(z_1=y_even, z_2=y_odd)
		elif self.keep == 'odd':
			y_odd = x_odd
			y_even = x_even + self.m(W=W, x=x_odd, device=device)
			y = self._odd_even_concat(z_even=y_even, z_odd=y_odd)
			# y = y.flip(dims=[-1])
			# y = self._consecutive_concat(z_1=y_even, z_2=y_odd)
		else:
			raise ValueError('Keep should be in {even, odd}')
		return y

	def inverse(self, W, y, device):
		y_even, y_odd = y[:, 0::2], y[:, 1::2]
		if self.keep == 'even':
			x_even = y_even
			x_odd = y_odd - self.m(W=W, x=y_even, device=device)
			x = self._odd_even_concat(z_even=x_even, z_odd=x_odd)
		elif self.keep == 'odd':
			x_odd = y_odd
			x_even = y_even - self.m(W=W, x=y_odd, device=device)
			x = self._odd_even_concat(z_even=x_even, z_odd=x_odd)
		else:
			raise ValueError('Keep should be in {even, odd}')
		return x

	def m(self, W, x, device):
		dropout = nn.Dropout(0.5)
		activation = nn.LeakyReLU(negative_slope=0.5)
		out = activation(torch.bmm(W, x.unsqueeze(-1))).squeeze(-1)
		out = dropout(out)
		return out

	def _consecutive_concat(self, z_1, z_2):
		return torch.cat((z_1, z_2), dim=-1)

	def _odd_even_concat(self, z_even, z_odd):
		z_size = z_odd.shape[1]

		temp = torch.stack((z_even, z_odd), dim=-1)		# of size (batch_size, z_size, 2)
		z = torch.cat([temp[:, i, :] for i in range(z_size)], dim=-1)		# of size (batch_size, 2*z_size)
		return z

	def silu(self, input):
		return input * torch.sigmoid(input)


class Model(nn.Module):
	def __init__(self, task_gen, args):
		super(Model, self).__init__()
		self.args = args

		# Encoder
		log.info('Building encoder...')
		if args.encoder == 'c4_group_cnn':
			if 'cifar100' in args.transformation_method:
				self.encoder = C4_Encoder_group_cnn(in_channels=3)
			else:
				self.encoder = C4_Encoder_group_cnn(in_channels=1)
		elif args.encoder == 'mlp':
			self.encoder = Encoder_mlp(args)
		elif args.encoder == 'resnet':
			self.encoder = Encoder_ResNet()

		self.z_size = 128
		self.num_memory = args.num_memory
		self.num_memory_layer = args.num_memory_layer

		assert self.num_memory_layer >= 2, "Number of layers must be >=2."

		self.nice_layer_list = []
		for i in range(self.num_memory_layer):
			if i % 2 == 0:
				nice_layer = NICELayer(keep='even')
			else:
				nice_layer = NICELayer(keep='odd')
			self.nice_layer_list.append(nice_layer)
		self.nice_layer_list = nn.ModuleList(self.nice_layer_list)

		self.activation = nn.ReLU()
		phi = []
		for i in range(self.num_memory_layer-2):
			phi.append(nn.Sequential(nn.Linear(self.z_size, self.z_size // 2),
									 self.activation))
		self.phi = nn.ModuleList(phi)

		if self.args.use_memory == 1:
			self.memory_value_list = nn.Parameter(torch.randn(self.num_memory_layer, self.num_memory, self.z_size // 2, self.z_size // 2),
												  requires_grad=True)

		self.key_base_matrix = nn.Parameter(torch.randn(self.num_memory_layer, self.z_size // 2, self.z_size // 2), requires_grad=True)
		self.value_base_matrix = nn.Parameter(torch.randn(self.num_memory_layer, self.z_size // 2, self.z_size // 2), requires_grad=True)

		self.a_orthogonal_value = nn.Parameter(torch.ones(1), requires_grad=True)
		self.a_orthogonal_key = nn.Parameter(torch.ones(1), requires_grad=True)

		self.diff_coef = nn.Parameter(torch.zeros(self.z_size), requires_grad=True)

		self.lstm_attention = nn.LSTM(input_size=self.z_size, hidden_size=self.num_memory, batch_first=True)

	def forward(self, x_seq, device):
		batch_size = x_seq.shape[0]

		# Encode all images in sequence
		z_in_seq = []
		z_choices_seq = []
		for t in range(x_seq.shape[1]):
			if len(x_seq.shape) == 4:
				x_t = x_seq[:, t, :, :].unsqueeze(1)
			else:
				x_t = x_seq[:, t, :, :, :]
			z_t = self.encoder(x_t)
			# z_t = self.stereo_transform(z_t, device=device)
			if t >= 3:
				z_choices_seq.append(z_t)
			else:
				z_in_seq.append(z_t)
		z_seq = torch.stack(z_in_seq, dim=1)
		z_choices = torch.stack(z_choices_seq, dim=1)

		z_1 = z_seq[:, 0, :]
		z_2 = z_seq[:, 1, :]
		z_3 = z_seq[:, 2, :]

		# Meta-learn weights of neural net that maps z_1 to z_2
		W_list = []
		z_in = z_1

		h_lstm = torch.zeros(1, batch_size, self.num_memory).cuda()
		c_lstm = torch.zeros(1, batch_size, self.num_memory).cuda()
		for i in range(self.num_memory_layer):
			if i < self.num_memory_layer - 2:
				z_pseudotarget = self.phi[i](z_2)
			elif i == self.num_memory_layer - 2:
				if i % 2 == 0:
					z_pseudotarget = z_2[:, 1::2]
				else:
					z_pseudotarget = z_2[:, 0::2]
			else:
				if i % 2 == 1:
					z_pseudotarget = z_2[:, 0::2]
				else:
					z_pseudotarget = z_2[:, 1::2]

			if i % 2 == 0:
				z_in_query = z_in[:, 1::2]
			else:
				z_in_query = z_in[:, 0::2]

			W, (h_lstm, c_lstm) = self._meta_attention(input=z_in_query, output=z_pseudotarget, M_value=self.memory_value_list[i],
													   h_lstm=h_lstm, c_lstm=c_lstm)

			z_in = self._meta_forward(index=i, z=z_in, W=W, device=device)
			W_list.append(W)

		# Compute predicted image
		z_in = z_3
		for i in range(self.num_memory_layer):
			z_in = self._meta_forward(z=z_in, W=W_list[i], device=device, index=i)
		z_predicted = z_in

		z_predicted = torch.clip(z_predicted, min=-1e6, max=1e6)
		diff = torch.exp(self.diff_coef) * (z_choices - torch.stack([z_predicted] * 4, dim=1)) ** 2
		y_pred_linear = -torch.sum(diff, dim=[-1]) / self.z_size
		y_pred_linear = torch.clip(y_pred_linear, min=-1e6)
		y_pred = y_pred_linear.argmax(1)
		return y_pred_linear, y_pred

	def _similarity(self, x_seq, W):
		W = W.reshape(W.shape[0], -1)
		# norm_W = torch.norm(W, dim=1).unsqueeze(-1)
		# W_normalized = W / norm_W
		# similarity_matrix = torch.mm(W_normalized, W_normalized.transpose(0, 1))

		res = (W.unsqueeze(1) - W)**2
		similarity_matrix = torch.sum(res, dim=-1) / W.shape[1]
		if similarity_matrix.shape[0] < 32:
			np.save(f'{ablation_dir}similarity_matrix.npy', similarity_matrix.cpu().detach().numpy())
			for i in range(x_seq.shape[0]):
				self._plot(x_seq[i], fig_name=str(i))

	def _get_orthogonal_matrix(self, index, v, device, key=False):
		"""
		:param v: of size (_, v_size, 1)
		:return: orthogonal tensor of size (_, v_size, v_size)
		"""
		batch_size, v_size = v.shape[0], v.shape[1]

		norm_square = torch.bmm(v.transpose(1, 2), v)
		# I = torch.stack([torch.diag(torch.ones(v_size))] * batch_size, dim=0).to(device)
		if not key:
			I = torch.stack([self.value_base_matrix[index]] * batch_size, dim=0)
			out = I - torch.exp(self.a_orthogonal_key) * torch.bmm(v, v.transpose(1, 2))
		else:
			I = torch.stack([self.key_base_matrix[index]] * batch_size, dim=0)
			out = I - torch.exp(self.a_orthogonal_value) * torch.bmm(v, v.transpose(1, 2))
		return out

	def _meta_inverse(self, index, z, W, device):
		nice_layer = self.nice_layer_list[index]
		out = nice_layer.inverse(W=W, y=z, device=device)
		return out

	def _meta_forward(self, index, z, W, device):
		batch_size = z.shape[0]

		nice_layer = self.nice_layer_list[index]
		out = nice_layer(W=W, x=z, device=device)
		return out

	def _meta_attention(self, input, output, M_value, h_lstm, c_lstm):
		"""
		:param W: of size (batch_size, A, B)
		:param M: of size (num_memory, A, B)
		:return: attention output of size (batch_size, A, B)
		"""
		batch_size, A, B = input.shape[0], M_value.shape[1], M_value.shape[2]

		M_value_flatten = M_value.reshape(M_value.shape[0], -1)  # of size (num_memory, A*B)
		M_value_batch = torch.stack([M_value_flatten] * batch_size, dim=0)  # of size (batch_size, num_memory, A*B)

		x_lstm = torch.cat((input, output), dim=-1).unsqueeze(1)
		attention_weight, (h_lstm, c_lstm) = self.lstm_attention(x_lstm, (h_lstm, c_lstm))
		attention_weight = attention_weight.reshape(batch_size, self.num_memory, 1)

		W_attention = torch.sum(attention_weight * M_value_batch, dim=1)		# of size (batch_size, A*B)
		if self.args.key_value_type == 'vector':
			W_attention = W_attention.reshape(batch_size, A, 1)
		else:
			W_attention = W_attention.reshape(batch_size, A, B)
		return W_attention, (h_lstm, c_lstm)

	def _meta_find_weights(self, z_in, z_target, device):
		"""
		:param z_in: of size (batch_size, z_in_size)
		:param z_target: of size (batch_size, z_target_size)
		:return: weight of size (z_in_size + 1, z_target_size) (including bias)
		"""
		batch_size = z_in.shape[0]

		# ones = torch.ones(batch_size, 1).to(device)
		# z_in = torch.cat((z_in, ones), dim=-1).unsqueeze(-1)	# of size (batch_size, z_in_size+1, 1)
		z_in = z_in.unsqueeze(-1)
		z_in_pseudoinverse = self._approx_pseudo_inverse(z_in)	# of size (batch_size, 1, z_in_size+1)
		weights = torch.bmm(z_target.unsqueeze(-1), z_in_pseudoinverse)		# of size (batch_size, z_target_size, z_in_size+1)
		return weights

	def _approx_pseudo_inverse(self, A, iterative_step=3):
		A_init = 1e-4 * A
		A_pseudoinverse = A_init.transpose(1, 2)  # of size (batch_size, B, A)
		for i in range(iterative_step):
			A_pseudoinverse = 2 * A_pseudoinverse - torch.bmm(torch.bmm(A_pseudoinverse, A), A_pseudoinverse)
			# A_pseudoinverse = torch.clip(A_pseudoinverse, min=-1e3, max=1e3)
		return A_pseudoinverse

	def _plot(self, tensor, fig_name):
		import matplotlib.pyplot as plt

		tensor = tensor.cpu()
		num_img = len(tensor)
		fig, ax_list = plt.subplots(1, num_img)

		for i in range(num_img):
			ax_list[i].imshow(np.array(tensor[i]), cmap='gray')
			ax_list[i].set_xticks([])
			ax_list[i].set_yticks([])
		plt.savefig(f'{ablation_dir}{fig_name}.png')


