import torch
import torch.nn as nn
import torch.nn.functional as F

class ANL(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,  
        scale_noise=0.1,
        base_activation=nn.Tanh,
        grid_range=[-1, 1],
    ):
        super(ANL, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0]
        self.register_buffer("grid", grid)  

        self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = nn.Parameter(
            torch.Tensor(out_features, grid_size + spline_order)
        )

        self.scale_noise = scale_noise
        self.base_activation = base_activation()

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_normal_(self.base_weight)
        nn.init.uniform_(self.spline_weight, -self.scale_noise, self.scale_noise)
        
    def b_splines(self, x):
        assert x.dim() == 2 and x.size(1) == self.in_features
        x = x.unsqueeze(-1)  
        grid = self.grid  
        bases = ((x >= grid[:-1]) & (x < grid[1:])).float()  

        for k in range(1, self.spline_order + 1):
            denom1 = grid[k:-1] - grid[:-k - 1]
            denom2 = grid[k + 1:] - grid[1:-k]
            denom1[denom1 == 0] = 1 
            denom2[denom2 == 0] = 1  

            term1 = ((x - grid[:-k - 1]) / denom1) * bases[:, :, :-1]
            term2 = ((grid[k + 1:] - x) / denom2) * bases[:, :, 1:]

            bases = term1 + term2  

        return bases

    def forward(self, x):
        original_shape = x.shape
        x = x.view(-1, self.in_features)  

        base_output = self.base_activation(F.linear(self.base_activation(x), self.base_weight))  

        bases = self.b_splines(x)  
        spline_input = bases.sum(dim=1)  

        spline_input = spline_input.to(dtype=self.spline_weight.dtype)
        
        spline_output = F.linear(spline_input, self.spline_weight)  

        output = base_output + spline_output
        output = output.view(*original_shape[:-1], self.out_features)
        return output