import torch
from torch.nn import Linear


class SCNet(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.lin1 = Linear(in_channels, 128)
        self.lin2 = Linear(128, out_channels)
        self.act = torch.nn.LeakyReLU()

    def forward(self, data):
        x = data.x
        x = self.act(x)
        x = self.lin1(x)
        x = self.act(x)
        x = self.lin2(x)
        return x


class LinearRegression(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.lin1 = Linear(in_channels, out_channels)

    def forward(self, data):
        x = data.x
        x = self.lin1(x)
        return x
