import torch
from torch.nn import Linear
from trainable_scattering.models.attention_scatter import Scatter
#from scatter import Scatter


class TSNetAttention(torch.nn.Module):
    def __init__(self, in_channels, out_channels, trainable_laziness=False, **kwargs):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.trainable_laziness = trainable_laziness
        self.scatter = Scatter(in_channels, trainable_laziness)
        self.lin1 = Linear(self.scatter.out_shape(), 64)
        self.lin2 = Linear(64, out_channels)
        self.act = torch.nn.LeakyReLU()

    def forward(self, data):
        x = self.scatter(data)
        x = self.act(x)
        x = self.lin1(x)
        x = self.act(x)
        x = self.lin2(x)
        return x
