import torch
from torch.autograd import Function


class FlipGradientFunction(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x

    @staticmethod
    def backward(ctx, grad_output):
        return -grad_output * ctx.alpha, None


class FlipGradientBuilder(object):
    def __init__(self):
        self.num_calls = 0

    def __call__(self, x, alpha):
        flip_gradient = FlipGradientFunction.apply
        y = flip_gradient(x, alpha)
        self.num_calls += 1
        return y


flip_gradient = FlipGradientBuilder()
