import torch
import torch.nn as nn
import torch.nn.functional as F




import torch
import torch.nn as nn
# import matplotlib.pyplot as plt

class SReLU3nt(nn.Module):
    def __init__(self,  trainable=False):
        """
        参数:
            a (float): 控制过渡区间宽度的参数
            trainable (bool): 是否使a成为可训练参数
        """
        super().__init__()
        super(SReLU3nt, self).__init__()
        #self.lr=nn.Parameter(torch.tensor(lr))
        
    def forward(self, x):
        a = 1
        condition1 = (x <= -a)
        condition2 = (x > -a) & (x < a)
        condition3 = (x >= a)
        
        part1 = torch.zeros_like(x)
        part2 = (x/2.0 + 3.0*x**2/(8.0*a) + 3.0*a/16.0 - x**4/(16.0*a**3))
        part3 = x
        
        output = torch.where(condition1, part1, 
                           torch.where(condition2, part2, part3))
        return output

    # def extra_repr(self):
    #     return f'a={self.a.item():.3f}'
