import torch
import torch.nn as nn

models = {
    'tiny': None,
    'small': None,
    'medium': None,
    'large': None,
    'full': None
}


class Block(nn.Module):
    def __init__(self, in_channels, out_channels, width, stride):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, (1, width), (1, stride), (0, width//2)),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
            nn.MaxPool2d((1, 2)),
            nn.Dropout(0.25)
        )

    def forward(self, x):
        out = self.conv(x)
        return out


class CREPE(nn.Module):
    def __init__(self, model_capacity):
        super(CREPE, self).__init__()
        if models[model_capacity] is None:
            self.capacity_multiplier = {
                'tiny': 4, 'small': 8, 'medium': 16, 'large': 24, 'full': 32
            }[model_capacity]
        else:
            self.capacity_multiplier = models[model_capacity]
        # self.layers = 6
        self.filters = [n * self.capacity_multiplier for n in [32, 4, 4, 4, 8, 16]]
        self.widths = [512, 64, 64, 64, 64, 64]
        self.strides = [4, 1, 1, 1, 1, 1]

        in_channels = 1
        self.layers = nn.ModuleList()
        for out_channels, width, stride in zip(self.filters, self.widths, self.strides):
            self.layers.append(Block(in_channels, out_channels, width, stride))
            in_channels = out_channels
        self.fc = nn.Sequential(
            nn.Linear(in_channels * 4, 360),
            nn.Sigmoid(),
        )

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.fc(x.transpose(1, 2).flatten(-2))
        return x


if __name__ == '__main__':
    x = torch.ones((1, 1, 36, 1024))
    crepe = CREPE('full')
    y = crepe(x)
