import numpy as np
import torch

from models.r2d2_config import device

class Gumbel_Softmax:
	def __init__(self, temperature = 1.0, hard = True):
		self.temperature = temperature
		self.hard = hard
	def forward(self, m, train_mode):
		if(train_mode):
			# The self.hard toggle changes between gs-reparam and gs-st
			return torch.nn.functional.gumbel_softmax(m, tau = self.temperature, hard = self.hard)
		else:
			# Should always be hard, only discrete values
			return torch.nn.functional.gumbel_softmax(m, tau = self.temperature, hard = True)
