import torch
import torch.nn as nn
import numpy as np

class ReLU(nn.Module):
    def __init__(self,beta=None):
        super(ReLU,self).__init__()
        if beta is None:
            self.relu=nn.ReLU()
        else:
            self.relu=nn.Softplus(beta=beta)
        self.beta=beta
        self.grad=None
        self.X=0

    def forward(self,x):
        self.X=x
        x=self.relu(x)
        return x

    def analyze(self,method,R):
        if self.beta is not None:
            act=torch.sigmoid(self.X*self.beta)
        else:
            act=(self.X>0).float()
        if method[0:3] == 'gbp' or (method[0:3] == 'our'):
            R=self.relu(R)*act
            if method=='gbp_all':
                self.grad=R
        else:
            R = R * act
            if method=='midgrad':
                self.grad=R

        #self.X=0
        return R