import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import LayerNorm2d

class DynamicTanh(nn.Module):
    def __init__(self, normalized_shape, elementwise_affine=True, channels_last=True):
        super().__init__()
        self.normalized_shape = normalized_shape
        self.elementwise_affine = elementwise_affine
        self.channels_last = channels_last
        
        self.alpha = nn.Parameter(torch.ones(1))
        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(normalized_shape))
            self.bias = nn.Parameter(torch.zeros(normalized_shape))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

    def forward(self, x):
        x.mul_(self.alpha)
        x = torch.tanh(x)
        if not self.elementwise_affine:
            return x
        elif self.channels_last:
            return self.weight * x + self.bias
        else:
            return self.weight[:, None, None] * x + self.bias[:, None, None]
        
    def extra_repr(self):
        return f"normalized_shape={self.normalized_shape}, elementwise_affine={self.elementwise_affine}, channels_last={self.channels_last}"

    @classmethod
    def convert_ln_to_at(cls, module):
        module_output = module
        if isinstance(module, torch.nn.LayerNorm):
            module_output = DynamicTanh(
                module.normalized_shape,
                module.elementwise_affine,
                not isinstance(module, LayerNorm2d)
            )
        for name, child in module.named_children():
            module_output.add_module(
                name, cls.convert_ln_to_at(child)
            )
        del module
        return module_output
    