from torch import nn
import torch


class MemoryUpdater(nn.Module):
	def update_memory(self, unique_node_ids, unique_messages, timestamps):
		pass


class SequenceMemoryUpdater(MemoryUpdater):
	def __init__(self, memory, message_dimension, memory_dimension, device):
		super(SequenceMemoryUpdater, self).__init__()
		self.memory = memory
		self.layer_norm = torch.nn.LayerNorm(memory_dimension)
		self.message_dimension = message_dimension
		self.device = device

	def update_memory(self, unique_node_ids, unique_messages, timestamps):
		if len(unique_node_ids) <= 0:
			return

		assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
																																										 "update memory to time in the past"

		memory = self.memory.get_memory(unique_node_ids)
		self.memory.last_update[unique_node_ids] = timestamps

		updated_memory = self.memory_updater(unique_messages, memory)

		self.memory.set_memory(unique_node_ids, updated_memory)

	def get_updated_memory(self, unique_node_ids, unique_messages, timestamps):
		if len(unique_node_ids) <= 0:
			return self.memory.memory.data.clone(), self.memory.last_update.data.clone()

		assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
																																										 "update memory to time in the past"

		updated_memory = self.memory.memory.data.clone()
		updated_memory[unique_node_ids] = self.memory_updater(unique_messages, updated_memory[unique_node_ids])

		updated_last_update = self.memory.last_update.data.clone()
		updated_last_update[unique_node_ids] = timestamps

		return updated_memory, updated_last_update


class GRUMemoryUpdater(SequenceMemoryUpdater):
	def __init__(self, memory, message_dimension, memory_dimension, device):
		super(GRUMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)

		self.memory_updater = nn.GRUCell(input_size=message_dimension,
																		 hidden_size=memory_dimension)


class RNNMemoryUpdater(SequenceMemoryUpdater):
	def __init__(self, memory, message_dimension, memory_dimension, device):
		super(RNNMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)

		self.memory_updater = nn.RNNCell(input_size=message_dimension,
																		 hidden_size=memory_dimension)


def get_memory_updater(module_type, memory, message_dimension, memory_dimension, device):
	if module_type == "gru":
		return GRUMemoryUpdater(memory, message_dimension, memory_dimension, device)
	elif module_type == "rnn":
		return RNNMemoryUpdater(memory, message_dimension, memory_dimension, device)
