import torch
from torch.autograd import Function


class PseudoStep(Function):
    """
    Behaves like a step function in the forward pass and passes on the gradient in the backward pass.

    *Warning: This function is meant to be used after a tanh activation function.*
    """

    @staticmethod
    def forward(ctx, tensor):
        """
        Non-differentiable forward pass that behaves like a step function with -1 for tensor < 0 and 1 for tensor >= 0

        :param ctx: The context (can be used to store tensors)
        :param tensor: The input
        :return: -1 for tensor < 0 and 1 for tensor >= 0
        """
        # remember the input
        ctx.tensor = tensor

        # discretize the inputs
        discretized = 2 * torch.greater_equal(tensor, 0).float() - 1
        discretized.requires_grad = True
        return discretized

    @staticmethod
    def backward(ctx, grad_output):
        """
        Passes on the gradient.

        :param ctx: The context
        :param grad_output: The gradient w.r.t. the output of this module
        :return: The gradient w.r.t. the input of this module
        """
        return grad_output
