from .constants import *
from transformers import CLIPFeatureExtractor, CLIPProcessor, AutoTokenizer
from medclip.dataset import  MedCLIPFeatureExtractor



class MedCLIPProcessor(CLIPProcessor):
    feature_extractor_class = "CLIPFeatureExtractor"
    tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
    def __init__(self):
        feature_extractor = MedCLIPFeatureExtractor()
        tokenizer = AutoTokenizer.from_pretrained(BERT_TYPE)
        tokenizer.model_max_length = 77
        super().__init__(feature_extractor, tokenizer)