import os
import sys
from PIL import Image
import torch

class VLM2VecQwen7B:
    def __init__(self, image_root_dir):
        
        from vlm2vec.model import MMEBModel
        from vlm2vec.arguments import ModelArguments
        from vlm2vec.model_utils import load_processor, QWEN2_VL, vlm_image_tokens
        
        model_args = ModelArguments(
            model_name='Qwen/Qwen2-VL-7B-Instruct',
            checkpoint_path='TIGER-Lab/VLM2Vec-Qwen2VL-7B',
            pooling='last',
            normalize=True,
            model_backbone='qwen2_vl',
            lora=True
        )

        self.processor = load_processor(model_args)
        self.model = MMEBModel.load(model_args)
        self.model = self.model.to('cuda', dtype=torch.bfloat16)
        self.model.eval()
        
        self.image_token = vlm_image_tokens[QWEN2_VL]
        
        self.image_root_dir = image_root_dir
        
    def _to_text(self, content: list[dict]):
        text = ""
        for cnt in content:
            if cnt["type"] == "text":
                text += cnt["text"]
            elif cnt["type"] == "image":
                text += f" {self.image_token} "
            else:
                raise NotImplementedError()
            
        return text.strip()
    
    def _to_image(self, content: list[dict]):
        images = []
        for cnt in content:
            if cnt["type"] == "image":
                images.append(Image.open(os.path.join(self.image_root_dir, cnt["location"])))
                
        return images
    
    def __call__(self, content: list[dict]):
        inputs = self.processor(
            text=self._to_text(content),
            images=self._to_image(content),
            return_tensors="pt"
        )
        inputs = {key: value.to('cuda') for key, value in inputs.items()}
        inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
        inputs['image_grid_thw'] = inputs['image_grid_thw'].unsqueeze(0)
        qry_output = self.model(qry=inputs)["qry_reps"]
        
        return qry_output[0].data.cpu()
    
MODELS = {
    "vlm2vec-qwen7b": VLM2VecQwen7B,
}

def get_embeddings(model, contents: list[list[dict]]):
    embeddings = []
    for content in contents:
        embeddings.append(model(content))
    
    return torch.stack(embeddings, dim=0)