from torch import nn


class MessageFunction(nn.Module):
	"""
	Module which computes the message for a given interaction.
	"""

	def compute_message(self, raw_messages):
		return None


class MLPMessageFunction(MessageFunction):
	def __init__(self, raw_message_dimension, message_dimension):
		super(MLPMessageFunction, self).__init__()

		self.mlp = self.layers = nn.Sequential(
			nn.Linear(raw_message_dimension, raw_message_dimension // 2),
			nn.ReLU(),
			nn.Linear(raw_message_dimension // 2, message_dimension),
		)

	def compute_message(self, raw_messages):
		messages = self.mlp(raw_messages)

		return messages


class IdentityMessageFunction(MessageFunction):

	def compute_message(self, raw_messages):

		return raw_messages


def get_message_function(module_type, raw_message_dimension, message_dimension):
	if module_type == "mlp":
		return MLPMessageFunction(raw_message_dimension, message_dimension)
	elif module_type == "identity":
		return IdentityMessageFunction()
