import torch.nn as nn
from transformers import FlavaModel, AutoProcessor
from lib.data.base import VLInputs

class FlavaProcessor:
    @staticmethod
    def get_processor():
        processor = AutoProcessor.from_pretrained("facebook/flava-full")
        return processor.image_processor, processor.tokenizer
    
class VanillaFlava(nn.Module):
    def __init__(self, frozen: bool = False):
        super().__init__()
        self.model = FlavaModel.from_pretrained("facebook/flava-full")
        for p in self.model.parameters():
            p.requires_grad = not frozen
        self.d_model = self.model.config.hidden_size

    def forward(self,
        batch: VLInputs,
        *args, **kwargs
    ):
        output = self.model(pixel_values=batch.pixel_values, input_ids=batch.input_ids, attention_mask=batch.attention_mask)
        multimodal_embeddings = output.multimodal_embeddings
        return multimodal_embeddings