import torch
import torch.nn as nn


class OnsetLinearModel(nn.Module):
    def __init__(self, delay_dim, embedding_dim,
                 channel_dim, use_bias = False):
        super().__init__()

        self.delay_dim = delay_dim
        self.embedding_dim = embedding_dim
        self.channel_dim = channel_dim
        self.use_bias = use_bias
        
        self.delay_bias = nn.Parameter(torch.zeros(self.delay_dim, self.channel_dim))
            
        self.bias = None
        if use_bias:
            self.bias = nn.Parameter(torch.zeros(self.channel_dim))
        self.weights = nn.Parameter(torch.zeros((channel_dim,delay_dim * embedding_dim), requires_grad=False))
    
    def get_weights(self):
        return self.weights
    
    def forward(self, X):
        with torch.no_grad():
            X_onset_view = X.reshape(-1, 768, 40)
            is_onset = (torch.all(X_onset_view == 0, dim=1) == False).float()
        out = torch.matmul(is_onset, self.delay_bias)
        if self.use_bias:
            out = out + self.bias
        return out
    
    def numpy_forward(self, X):
        device = next(self.parameters()).device
        X = torch.tensor(X, device=device, dtype=torch.float32)
        out = self.forward(X)
        numpy_out = out.detach().cpu().numpy()
        return numpy_out
