import torch
import torch.nn as nn
class DynamicTanh(nn.Module):
    def __init__(self, in_features, n_params=3):
        super(DynamicTanh, self).__init__()
        self.n_params = n_params
        if n_params == 3:
            self.alpha = nn.Parameter(torch.ones(in_features))
            self.beta = nn.Parameter(torch.ones(in_features))
            self.gamma = nn.Parameter(torch.zeros(in_features))
        elif n_params == 2:
            self.alpha = nn.Parameter(torch.ones(in_features))
            self.beta = nn.Parameter(torch.ones(in_features))
            self.gamma = 0
        else:
            self.alpha = nn.Parameter(torch.ones(in_features))
            self.beta = 1.0
            self.gamma = 0
    
    def forward(self, x):
        return self.alpha * torch.tanh(self.beta * x + self.gamma)