# PYTHONPATH=. srun -p mllm_safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=8 --time=30000 python src/tools/filters/filters/qwen_vl_instruct.py
from dataclasses import dataclass

import torch
import transformers

import qwen_vl_utils

from src.tools.filters.base import Filter


@dataclass
class Qwen_VL_Instruct_Filter(Filter):
    model_name_or_path: str = "/mnt/lustrenew/mllm_safety-shared/models/huggingface/Qwen/Qwen2.5-VL-7B-Instruct"
    prompt: str = "Can you tell what animal is shown partially in the image?"
    # the first option will be retained
    options: list = "No,Yes" # TODO: space?
    threshold: float = 0.5

    def __post_init__(self):
        processor_kwargs = {"min_pixels": 32*28*28, "max_pixels": 128*28*28}
        self.processor = transformers.AutoProcessor.from_pretrained(
            self.model_name_or_path,
            padding_side="right",
            **processor_kwargs,
        )
        self.model = transformers.AutoModelForVision2Seq.from_pretrained(
            self.model_name_or_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        self.options = self.options.split(",")
        self.first_token_of_options = [
            self.processor.tokenizer.encode(option, add_special_tokens=False)[0]
            for option in self.options
        ]


    def apply(self, inputs: list[dict]) -> list[dict]:
        messages_list = [
            [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": input["image"]},
                        {"type": "text", "text": self.prompt.format(**input.get("mapping", {}))},
                    ],
                }
            ]
            for input in inputs
        ]
        texts = [
            self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            for messages in messages_list
        ]
        images, _ = qwen_vl_utils.process_vision_info(messages_list)

        batch = self.processor(
            text=texts,
            images=images,
            return_tensors="pt",
            padding=True
        ).to(self.model.device)

        with torch.no_grad():
            outputs = self.model.forward(**batch)

        lens = (batch["input_ids"] != self.processor.tokenizer.pad_token_id).sum(-1)
        pooled_logits = outputs["logits"][
            torch.arange(lens.size(0), device=self.model.device), lens-1]
        pooled_logits = pooled_logits[:, self.first_token_of_options]
        pooled_probs = pooled_logits.softmax(-1)

        retain_probs = pooled_probs[:, 0].cpu().tolist()

        return [{"retain": retain_prob > self.threshold, "retain_prob": retain_prob} for retain_prob in retain_probs]



if __name__ == "__main__":
    from PIL import Image
    import requests

    filter = Qwen_VL_Instruct_Filter()

    url1 = "https://images.nubilefilms.com/videos/whats_yours_is_mine_with_chanel_camryn_freya_parker/samples/cover960.jpg"
    image1 = Image.open(requests.get(url1, stream=True).raw).convert("RGB")

    # Sample messages for batch inference
    # url2 = "https://aerospaceamerica.aiaa.org/wp-content/uploads/2023/06/0723_Aero_Starship-1200x675.jpg"
    # image2 = Image.open(requests.get(url2, stream=True).raw)

    image2 = Image.open("data/animals/files/bird.jpg")
    
    print(filter.apply([{"image": image1}, {"image": image2}]))
