import torch.nn as nn


class Reshape(nn.Module):
    def __init__(self, *shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x.reshape([x.shape[0]] + list(self.shape))


def linear(num_classes: int = 2, num_features: int = None):
    return nn.Sequential(nn.Flatten(),
                         nn.Linear(num_features, num_classes, bias=False))


def conv_linear(num_classes: int = 2, kernel_size: int = 2, remove_last_layer: bool = False, num_features: int = None):
    if not remove_last_layer:
        linear = nn.Linear(num_features, num_classes, bias=False)
    layers = nn.Sequential(nn.Flatten(),
                           Reshape(1, -1),
                           nn.Conv1d(1, 1, kernel_size=kernel_size, bias=False, 
                                     padding_mode='circular', padding='same'),
                           nn.Flatten())
    if not remove_last_layer:
        layers.append(linear)
    return layers
