# taken from: hassanaskary blog post on STE
import torch
import torch.nn.functional as F


class Binarizer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0.0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)


class StraightThroughEstimatorBinary(torch.nn.Module):
    def __init__(self):
        super(StraightThroughEstimatorBinary, self).__init__()

    def forward(self, x):
        x = Binarizer.apply(x)
        return x


class DiscreteSample(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        from torch.distributions import Categorical

        m = Categorical(input)
        state = m.sample()
        state = F.one_hot(state, input.shape[1])
        return state.float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)


class StraightThroughEstimatorDiscreteSample(torch.nn.Module):
    def __init__(self):
        super(StraightThroughEstimatorDiscreteSample, self).__init__()

    def forward(self, x):
        x = DiscreteSample.apply(x)
        return x
