import torch
from test_medclip import MedCLIPVisionModelViT, MedCLIPModel
from constants import *
from MedCLIPProcessor import MedCLIPProcessor
from concurrent.futures import ThreadPoolExecutor
from medclip import PromptClassifier



if __name__ == '__main__':

    model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
    state_dict = torch.load(VIT_Pretrain)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    processor = MedCLIPProcessor()
    clf = PromptClassifier(model, ensemble=True)
    clf.cuda()

    # prepare input image
    from PIL import Image

    image = Image.open('/mnt/nvme_share/wuwl/project/CARZero-main/Dataset/MIMIC-CXR-JPG/2.0.0/files/p10/p10046166/s50051329/427446c1-881f5cce-85191ce1-91a58ba9-0a57d3f5.jpg')
    inputs = processor(images=image, return_tensors="pt")

    # prepare input prompt texts
    from medclip.prompts import generate_chexpert_class_prompts, process_class_prompts

    cls_prompts = process_class_prompts(generate_chexpert_class_prompts(n=10))
    inputs['prompt_inputs'] = cls_prompts

    # make classification
    output = clf(**inputs)
    print(output)