import torch
from torch import nn

from collections import defaultdict
from copy import deepcopy


class Memory(nn.Module):

	def __init__(self, n_nodes, memory_dimension, input_dimension, message_dimension=None,
							 device="cpu", combination_method='sum'):
		super(Memory, self).__init__()
		self.n_nodes = n_nodes
		self.memory_dimension = memory_dimension
		self.input_dimension = input_dimension
		self.message_dimension = message_dimension
		self.device = device

		self.combination_method = combination_method

		self.__init_memory__()

	def __init_memory__(self):
		"""
		Initializes the memory to all zeros. It should be called at the start of each epoch.
		"""
		# Treat memory as parameter so that it is saved and loaded together with the model
		self.memory = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device),
															 requires_grad=False)
		self.last_update = nn.Parameter(torch.zeros(self.n_nodes).to(self.device),
																		requires_grad=False)

		self.messages = defaultdict(list)

	def store_raw_messages(self, nodes, node_id_to_messages):
		for node in nodes:
			self.messages[node].extend(node_id_to_messages[node])

	def get_memory(self, node_idxs):
		return self.memory[node_idxs, :]

	def set_memory(self, node_idxs, values):
		self.memory[node_idxs, :] = values

	def get_last_update(self, node_idxs):
		return self.last_update[node_idxs]

	def backup_memory(self):
		messages_clone = {}
		for k, v in self.messages.items():
			messages_clone[k] = [(x[0].clone(), x[1].clone()) for x in v]

		return self.memory.data.clone(), self.last_update.data.clone(), messages_clone

	def restore_memory(self, memory_backup):
		self.memory.data, self.last_update.data = memory_backup[0].clone(), memory_backup[1].clone()

		self.messages = defaultdict(list)
		for k, v in memory_backup[2].items():
			self.messages[k] = [(x[0].clone(), x[1].clone()) for x in v]

	def detach_memory(self):
		self.memory.detach_()

		# Detach all stored messages
		for k, v in self.messages.items():
			new_node_messages = []
			for message in v:
				new_node_messages.append((message[0].detach(), message[1]))

			self.messages[k] = new_node_messages

	def clear_messages(self, nodes):
		for node in nodes:
			self.messages[node] = []
