import torch
from torch.nn import Linear, BatchNorm1d
from trainable_scattering.models.rbf import RBF, gaussian


class SCNet(torch.nn.Module):
    def __init__(self, in_channels, intermediate_channels, num_layers, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.innerlayers = []
        self.lin1 = Linear(in_channels, intermediate_channels)
        for i in range(intermediate_channels):
            self.innerlayers.append(Linear(intermediate_channels,intermediate_channels))
        self.lin2 = Linear(intermediate_channels, out_channels)
        self.act = torch.nn.LeakyReLU()

    def forward(self, data):
        x = data.x
        x = self.act(x)
        x = self.lin1(x)
        x = self.act(x)
        for currentlayer in range(self.num_layers):
            x = self.innerlayers[currentlayer](x)
            x = self.act(x)
        x = self.lin2(x)
        return x


class LinearRegression(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.lin1 = Linear(in_channels, out_channels)

    def forward(self, data):
        x = data.x
        x = self.lin1(x)
        return x


class SVM(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.lin1 = Linear(in_channels, out_channels)

    def forward(self, data):
        x = data.x
        x = self.lin1(x)
        return x


class RBF_SVM(torch.nn.Module):
    def __init__(self, in_channels, out_channels, num_centres):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.bn = BatchNorm1d(in_channels, affine=False)
        self.rbf = RBF(in_channels, num_centres, gaussian)
        self.lin1 = Linear(num_centres, out_channels)

    def set_centres(self, new):
        self.rbf.set_centres(new)

    def forward(self, data):
        x = data.x
        x = self.bn(x)
        #x = x / 10
        x = self.rbf(x)
        #print(x)
        x = self.lin1(x)
        return x
