import torch
import torch.nn as nn
from lib.models import VanillaFlava
from lib.models.incremental_classifier import IncrementalClassifier
from lib.data.base import VLInputs 

class FlavaLowerBoundCL(nn.Module):
    def __init__(self,
        n_output_classes: int
    ):
        super().__init__()
        self.feature_extractor = VanillaFlava()
        self.incremental_classifier = IncrementalClassifier(self.feature_extractor.d_model, n_output_classes)
        
    def forward(self, inputs: VLInputs):
        hidden_states = self.feature_extractor(inputs)
        logits = self.incremental_classifier(hidden_states[:, 0])
        return logits
         
    def adaptation(self, n_output_classes: int, **kwargs):
        self.incremental_classifier.adaptation(n_output_classes)