import torch
import torch.nn as nn
from lib.models import VanillaFlava
from lib.data.base import VLInputs 

class FlavaUpperBoundCL(nn.Module):
    def __init__(self,
        n_output_classes: int
    ):
        super().__init__()
        self.feature_extractor = VanillaFlava()
        self.classification_head = nn.Linear(self.feature_extractor.d_model, n_output_classes)
        
    def forward(self, inputs: VLInputs):
        hidden_states = self.feature_extractor(inputs)
        logits = self.classification_head(hidden_states[:, 0])
        return logits