from .model_init import init_subject_model
from PIL import Image
import torch
import open_clip


@torch.no_grad()
def get_text_embedding(model_dict, text):
    tokenizer = model_dict["tokenizer"]
    model = model_dict["model"]
    inputs = tokenizer(text)
    # print(inputs)
    # text_inputs = torch.tensor([inputs], device="cuda")
    text_inputs = inputs.to("cuda")
    text_features = model.encode_text(text_inputs)
    return text_features


@torch.no_grad()
def get_image_embedding(model_dict, image):
    preprocess = model_dict["preprocess"]
    model = model_dict["model"]
    inputs = preprocess(image).unsqueeze(0).to("cuda")  # Add batch dimension
    image_features = model.encode_image(inputs)
    return image_features


# @torch.no_grad()
# def get_text_embedding(model_dict, text):
#     inputs = model_dict["tokenizer"](text, padding=True, return_tensors="pt").to(model_dict["model"].device)
#     text_features = model_dict["model"].get_text_features(**inputs)
#     return text_features

# @torch.no_grad()
# def get_image_embedding(model_dict, image):
#     inputs = model_dict["processor"](images=image, return_tensors="pt").to(model_dict["model"].device)
#     image_features = model_dict["model"].get_image_features(**inputs)
#     return image_features
