from turtle import forward
import torch
import torch.nn as nn

from transformers import HubertForSequenceClassification


class HubertClassifier(HubertForSequenceClassification):
    # def __init__(self, model_name, **kwargs):
    #     super().__init__(model_name, **kwargs)

    def __call__(self, x):
        logits = super().__call__(x).logits
        # feat = self.wav2vec2(x, sampling_rate=16000, return_tensors="pt")
        # logits = self.hubert(**feat).logits
        return logits