import torch
from torch.nn import Linear
#from trainable_scattering.models.scatter import Scatter
from trainable_scattering.models.attention_scatter import Scatter
from torch.nn import Linear, BatchNorm1d
from trainable_scattering.models.rbf import RBF, gaussian

class attention_rbf(torch.nn.Module):
    def __init__(self, in_channels, out_channels, num_centres,trainable_laziness=False, **kwargs):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.trainable_laziness = trainable_laziness
        self.scatter = Scatter(in_channels, trainable_laziness)
        self.bn = BatchNorm1d(self.scatter.out_shape(), affine=False)
        self.rbf = RBF(self.scatter.out_shape(), num_centres, gaussian)
        self.lin1 = Linear(num_centres, out_channels)
    def set_centres(self, new):
        self.rbf.set_centres(new)
    def forward(self, data):
        x = self.scatter(data)
        x = self.bn(x)
        x = self.rbf(x)
        x = self.lin1(x)
        return x
