import torch
import torch.nn as nn
from util import log
import numpy as np
from modules import *


class NICELayer(nn.Module):
	def __init__(self, keep):
		super(NICELayer, self).__init__()
		self.keep = keep

	def forward(self, W, x, device):
		"""
		:param x: of size (batch_size, x_size)
		:return: output of size (batch_size, x_size)
		"""
		x_even, x_odd = x[:, 0::2], x[:, 1::2]
		if self.keep == 'even':
			y_even = x_even
			y_odd = x_odd + self.m(W=W, x=x_even, device=device)
			y = self._odd_even_concat(z_even=y_even, z_odd=y_odd)
		elif self.keep == 'odd':
			y_odd = x_odd
			y_even = x_even + self.m(W=W, x=x_odd, device=device)
			y = self._odd_even_concat(z_even=y_even, z_odd=y_odd)
		else:
			raise ValueError('Keep should be in {even, odd}')
		return y

	def inverse(self, W, y, device):
		y_even, y_odd = y[:, 0::2], y[:, 1::2]
		if self.keep == 'even':
			x_even = y_even
			x_odd = y_odd - self.m(W=W, x=y_even, device=device)
			x = self._odd_even_concat(z_even=x_even, z_odd=x_odd)
		elif self.keep == 'odd':
			x_odd = y_odd
			x_even = y_even - self.m(W=W, x=y_odd, device=device)
			x = self._odd_even_concat(z_even=x_even, z_odd=x_odd)
		else:
			raise ValueError('Keep should be in {even, odd}')
		return x

	def m(self, W, x, device):
		dropout = nn.Dropout(0.5)
		activation = nn.LeakyReLU(negative_slope=0.5)
		out = activation(torch.bmm(W, x.unsqueeze(-1))).squeeze(-1)
		out = dropout(out)
		return out

	def _consecutive_concat(self, z_1, z_2):
		return torch.cat((z_1, z_2), dim=-1)

	def _odd_even_concat(self, z_even, z_odd):
		z_size = z_odd.shape[1]

		temp = torch.stack((z_even, z_odd), dim=-1)		# of size (batch_size, z_size, 2)
		z = torch.cat([temp[:, i, :] for i in range(z_size)], dim=-1)		# of size (batch_size, 2*z_size)
		return z

	def silu(self, input):
		return input * torch.sigmoid(input)


class Model(nn.Module):
	def __init__(self, task_gen, args):
		super(Model, self).__init__()
		self.args = args

		# Encoder
		log.info('Building encoder...')
		if args.encoder == 'c4_group_cnn':
			if 'cifar100' in args.transformation_method:
				self.encoder = C4_Encoder_group_cnn(in_channels=3)
			else:
				self.encoder = C4_Encoder_group_cnn(in_channels=1)
		elif args.encoder == 'mlp':
			self.encoder = Encoder_mlp(args)
		elif args.encoder == 'resnet':
			self.encoder = Encoder_ResNet()

		self.z_size = 128
		self.num_memory = args.num_memory
		self.num_memory_layer = args.num_memory_layer

		assert self.num_memory_layer >= 2, "Number of layers must be >=2."

		self.nice_layer_list = []
		for i in range(self.num_memory_layer):
			if i % 2 == 0:
				nice_layer = NICELayer(keep='even')
			else:
				nice_layer = NICELayer(keep='odd')
			self.nice_layer_list.append(nice_layer)
		self.nice_layer_list = nn.ModuleList(self.nice_layer_list)

		self.activation = nn.ReLU()
		phi = []
		for i in range(self.num_memory_layer-2):
			phi.append(nn.Sequential(nn.Linear(self.z_size, self.z_size // 2),
									 self.activation))
		self.phi = nn.ModuleList(phi)

		if self.args.use_memory == 1:
			self.memory_key_list = nn.Parameter(torch.randn(self.num_memory_layer, self.num_memory, self.z_size // 2, self.z_size // 2),
												requires_grad=True)
			self.memory_value_list = nn.Parameter(torch.randn(self.num_memory_layer, self.num_memory, self.z_size // 2, self.z_size // 2),
												  requires_grad=True)

		self.diff_coef = nn.Parameter(torch.zeros(self.z_size), requires_grad=True)

	def forward(self, x_seq, device):
		# Encode all images in sequence
		z_in_seq = []
		z_choices_seq = []
		for t in range(x_seq.shape[1]):
			if len(x_seq.shape) == 4:
				x_t = x_seq[:, t, :, :].unsqueeze(1)
			else:
				x_t = x_seq[:, t, :, :, :]
			z_t = self.encoder(x_t)
			if t >= 3:
				z_choices_seq.append(z_t)
			else:
				z_in_seq.append(z_t)
		z_seq = torch.stack(z_in_seq, dim=1)
		z_choices = torch.stack(z_choices_seq, dim=1)

		z_1 = z_seq[:, 0, :]
		z_2 = z_seq[:, 1, :]
		z_3 = z_seq[:, 2, :]

		# Meta-learn weights of neural net that maps z_1 to z_2
		W_list = []
		z_in = z_1
		for i in range(self.num_memory_layer):
			if i < self.num_memory_layer - 2:
				z_pseudotarget = self.phi[i](z_2)
			elif i == self.num_memory_layer - 2:
				if i % 2 == 0:
					z_pseudotarget = z_2[:, 1::2]
				else:
					z_pseudotarget = z_2[:, 0::2]
			else:
				if i % 2 == 1:
					z_pseudotarget = z_2[:, 0::2]
				else:
					z_pseudotarget = z_2[:, 1::2]

			if i % 2 == 0:
				z_in_query = z_in[:, 1::2]
			else:
				z_in_query = z_in[:, 0::2]

			W_query = self._meta_find_query(z_in=z_in_query, z_target=z_pseudotarget, device=device)	# of size (batch_size, z_size * 2, z_size + 1)

			if self.args.use_memory == 1:
				W = self._meta_attention(index=i, W=W_query, M_key=self.memory_key_list[i//2], M_value=self.memory_value_list[i//2],
										 device=device)
			else:
				W = W_query
			z_in = self._meta_forward(index=i, z=z_in, W=W, device=device)
			W_list.append(W)

		# Compute predicted image
		z_in = z_3
		for i in range(self.num_memory_layer):
			z_in = self._meta_forward(z=z_in, W=W_list[i], device=device, index=i)
		z_predicted = z_in

		z_predicted = torch.clip(z_predicted, min=-1e6, max=1e6)
		diff = torch.exp(self.diff_coef) * (z_choices - torch.stack([z_predicted] * 4, dim=1)) ** 2
		y_pred_linear = -torch.sum(diff, dim=[-1]) / self.z_size
		y_pred_linear = torch.clip(y_pred_linear, min=-1e6)
		y_pred = y_pred_linear.argmax(1)
		return y_pred_linear, y_pred

	def _meta_forward(self, index, z, W, device):
		nice_layer = self.nice_layer_list[index]
		out = nice_layer(W=W, x=z, device=device)
		return out

	def _meta_attention(self, index, W, M_key, M_value, device):
		"""
		:param W: of size (batch_size, A, B)
		:param M: of size (num_memory, A, B)
		:return: attention output of size (batch_size, A, B)
		"""
		batch_size, A, B = W.shape[0], W.shape[1], W.shape[2]

		W_flatten = W.reshape(batch_size, -1).unsqueeze(-1)		# of size (batch_size, A*B)

		M_key_flatten = M_key.reshape(M_key.shape[0], -1)		# of size (num_memory, A*B)
		M_key_batch = torch.stack([M_key_flatten] * batch_size, dim=0)		# of size (batch_size, num_memory, A*B)

		M_value_flatten = M_value.reshape(M_value.shape[0], -1)  # of size (num_memory, A*B)
		M_value_batch = torch.stack([M_value_flatten] * batch_size, dim=0)  # of size (batch_size, num_memory, A*B)

		attention_weight = torch.bmm(M_key_batch, W_flatten) / np.sqrt(A * B)	# of size (batch_size, num_memory, 1)
		attention_weight = torch.clip(attention_weight, min=-1e6, max=1e6)

		W_attention = torch.sum(attention_weight * M_value_batch, dim=1)		# of size (batch_size, A*B)
		W_attention = W_attention.reshape(batch_size, A, B)
		return W_attention

	def _meta_find_query(self, z_in, z_target, device):
		"""
		:param z_in: of size (batch_size, z_in_size)
		:param z_target: of size (batch_size, z_target_size)
		:return: weight of size (z_in_size + 1, z_target_size) (including bias)
		"""
		z_in = z_in.unsqueeze(-1)
		z_in_pseudoinverse = self._approx_pseudo_inverse(z_in)	# of size (batch_size, 1, z_in_size)

		weights = torch.bmm(z_target.unsqueeze(-1), z_in_pseudoinverse)		# of size (batch_size, z_target_size, z_in_size)
		return weights

	def _approx_pseudo_inverse(self, A, iterative_step=3):
		A_init = self.args.pseudoinverse_init * A
		A_pseudoinverse = A_init.transpose(1, 2)  # of size (batch_size, B, A)
		for i in range(iterative_step):
			A_pseudoinverse = 2 * A_pseudoinverse - torch.bmm(torch.bmm(A_pseudoinverse, A), A_pseudoinverse)
		return A_pseudoinverse


