import torch


class DRUnit(torch.nn.Module):
	"""
	DRU from "Learning to Communicate with Deep Multi-Agent Reinforcement Learning" by Foerster et al.
	https://proceedings.neurips.cc/paper/2016/file/c7635bfd99248a2cdef8249ef7bfbef4-Paper.pdf

	Code based on DRU implementation by Minqui
	https://github.com/minqi/learning-to-communicate-pytorch/blob/master/modules/dru.py
	Licensed under the Apache License, Version 2.0
	"""
	def __init__(self, sigma, comm_narrow=True, hard=True):
		super(DRUnit, self).__init__()
		self.sigma = sigma
		self.comm_narrow = comm_narrow
		self.hard = hard

	def regularize(self, m):	
		m_reg = m + torch.randn(m.size(), device=m.device) * self.sigma
		if self.comm_narrow:
			m_reg = torch.sigmoid(m_reg)
		else:
			m_reg = torch.softmax(m_reg, 0)
		return m_reg

	def discretize(self, m):
		if self.hard:
			if self.comm_narrow:
				return (m.gt(0.).float()).sign().float()  # 0 if negative, 1 if positive
			else:
				m_ = torch.zeros_like(m)
				if m.dim() == 1:
					_, idx = m.max(0)
					m_[idx] = 1.
				elif m.dim() == 2:
					_, idx = m.max(1)
					for b in range(idx.size(0)):
						m_[b, idx[b]] = 1.
				else:
					raise ValueError('Wrong message shape: {}'.format(m.size()))
				return m_
		else:
			scale = 2 * 20
			if self.comm_narrow:
				return torch.sigmoid((m.gt(0.5).float() - 0.5) * scale)
			else:
				return torch.softmax(m * scale, -1)

	def forward(self, m):
		if self.training:
			return self.regularize(m)
		else:
			return self.discretize(m)
