import torch
import math


class RePU(torch.nn.Module):
    def __init__(self, temperature=1, power=2):
        super(RePU, self).__init__()
        self.temperature = torch.tensor(temperature)
        self.exponent = torch.tensor(power)

    def forward(self, x):
        return torch.pow(self.temperature * x, self.exponent)


class Swish(torch.nn.Module):
    def __init__(self, temperature=1, power=1):
        super(Swish, self).__init__()
        self.temperature = torch.tensor(temperature)
        self.exponent = torch.tensor(power)

    def forward(self, x):
        return torch.pow(x * torch.sigmoid(x / self.temperature), self.exponent)


class GumbelFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return torch.exp(-torch.exp(-input))

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the inputs: here input and weights
        """
        input, = ctx.saved_tensors
        grad_input = torch.exp(-torch.exp(-input) - input) * grad_output
        return grad_input


class GumbelLU(torch.nn.Module):
    def __init__(self, temperature=1, power=1):
        super().__init__()
        self.temperature = torch.tensor(temperature)
        self.exponent = torch.tensor(power)

    def forward(self, x):
        gumbel = GumbelFunction.apply(x / self.temperature)
        return torch.pow(x * gumbel, self.exponent)


class GeLU(torch.nn.Module):
    def __init__(self, temperature=1, power=1):
        super().__init__()
        self.temperature = torch.tensor(temperature)
        self.exponent = torch.tensor(power)

    def forward(self, x):
        return torch.pow(
            x / torch.tensor(2.0) * (torch.tensor(1.0) + torch.erf(x / self.temperature / torch.sqrt(torch.tensor(2)))),
            self.exponent)


class GudermanLU(torch.nn.Module):
    def __init__(self, temperature=1, power=1):
        super().__init__()
        self.temperature = torch.tensor(temperature)
        self.exponent = torch.tensor(power)

    def forward(self, x):
        return torch.pow(x * (1 / 2 + 2 / math.pi * torch.arctan(torch.tanh(x / self.temperature))), self.exponent)


class AlgebraicLU(torch.nn.Module):
    def __init__(self, temperature=1, power=1):
        super().__init__()
        self.temperature = torch.tensor(temperature)
        self.exponent = torch.tensor(power)

    def forward(self, x):
        return torch.pow(1 / 2 * (1 + (x / self.temperature) / torch.sqrt(torch.pow((x / self.temperature), 2) + 1)),
                         self.exponent)


if __name__ == '__main__':
    import numpy as np

    for i in np.linspace(10, -1000000, 10):
        print('-' * 30)
        x = torch.tensor(i, requires_grad=True)

        output = GumbelFunction.apply(x)
        output.backward()
        dx = x.grad
        print(x)
        print(output)
        print(dx)

