import torch


class HeavisideST(torch.autograd.Function):
    """
    Heaviside activation function with straight through estimator
    """

    @staticmethod
    def forward(ctx, input):
        return torch.ceil(input).clamp(min=0, max=1)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return grad_input


class BernoulliST(torch.autograd.Function):
    """
    Heaviside activation function with straight through estimator
    """

    @staticmethod
    def forward(ctx, input):
        return torch.bernoulli(input)

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return grad_input


class BernoulliPolicyHead(torch.nn.Module):
    def __init__(self, activation=torch.nn.Tanh):
        super().__init__()
        self.activation = activation
        self.distribution = BernoulliST

    def initialize(self, input_size, action_size):
        self.loc_layer = torch.nn.Sequential(
            torch.nn.Linear(input_size, action_size), self.activation()
        )

    def forward(self, inputs):
        loc = self.loc_layer(inputs)
        return self.distribution.apply(loc)


class SecondBernoulliPolicyHead(torch.nn.Module):
    def __init__(
        self,
        loc_activation=torch.nn.Tanh,
        loc_fn=None,
        scale_activation=torch.nn.Softplus,
        scale_min=1e-4,
        scale_max=1,
        scale_fn=None,
        distribution=torch.distributions.normal.Normal,
    ):
        super().__init__()
        self.loc_activation = loc_activation
        self.loc_fn = loc_fn
        self.scale_activation = scale_activation
        self.scale_min = scale_min
        self.scale_max = scale_max
        self.scale_fn = scale_fn
        self.distribution = distribution

    def forward(self, inputs):
        loc = self.loc_layer(inputs)
        scale = self.scale_layer(inputs)
        scale = torch.clamp(scale, self.scale_min, self.scale_max)
        sample = self.distribution(loc, scale)
        sample = BernoulliST.apply(sample)
