import torch.nn as nn

class OperateFunc(nn.Module):
    def __init__(self, input_size, output_size = 1, layer_num=2, embed_size=32):
        super(OperateFunc, self).__init__()
        self.layer_num = layer_num
        self.embed_size = embed_size
        self.input_size = input_size
        self.network = [nn.Linear(input_size, embed_size), nn.BatchNorm1d(embed_size), nn.GELU()]
        for i in range(self.layer_num):
            if i == self.layer_num - 1:
                self.network.append(nn.Linear(embed_size, output_size))
            else:
                self.network.append(nn.Linear(embed_size, embed_size))
                self.network.append(nn.BatchNorm1d(embed_size))
                self.network.append(nn.GELU())
        self.network = nn.Sequential(*self.network)

    def forward(self, x):
        batch_size, in_len, resolution, input_size = x.shape
        x = x.reshape(batch_size * in_len * resolution, input_size)
        assert input_size == self.input_size, ('The input size does not agree!')
        y = self.network(x)
        return y.reshape(batch_size, in_len, resolution, -1)
    
