import torch
import torch.nn as nn
import numpy as np

class AllChnsEmbedder(nn.Module):
    """
    Split into 105 time steps, 2 freq steps using 64 channels
    The resulting patches have size 12x20
    """
    def __init__(self):

        super(AllChnsEmbedder, self).__init__()
        self.linear = torch.nn.Linear(64*12*20, 16)
       
    def forward(self, x):
        #extract patches
        x_ = torch.zeros(x.shape[0], 105, 2, x.shape[1], x.shape[2]//105, x.shape[3]//2).to(x.device)
        for time in range(105):
            for freq in range(2):
                x_[:, time, freq, :, :, :] = x[:, :, time*12:(time+1)*12, freq*20:(freq+1)*20]
        #apply linear layer
        x_ = x_.reshape(x.shape[0], 105, 2, 64*12*20)
        x = self.linear(x_)
        x = x.reshape(x.shape[0], 210, 16)
        return x


