import torch


class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0.5).float()  # Hard thresholding

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output  # Pass through gradients as is
