# PYTHONPATH=. srun -p mllm_safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=8 --time=30000 python src/tools/filters/filters/qwen_vl.py
from dataclasses import dataclass

import torch
import transformers

import qwen_vl_utils

from src.tools.filters.base import Filter


@dataclass
class Qwen_VL_Filter(Filter):
    model_name_or_path: str = "/mnt/lustrenew/mllm_safety-shared/models/huggingface/Qwen/Qwen2-VL-7B"
    prompt: str = "Can you tell what animal is shown partially in the image?"
    # the first option will be retained
    options: list = " No, Yes"
    threshold: float = 0.5

    def __post_init__(self):
        processor_kwargs = {"min_pixels": 32*28*28, "max_pixels": 128*28*28}
        self.processor = transformers.AutoProcessor.from_pretrained(
            self.model_name_or_path,
            padding_side="right",
            **processor_kwargs,
        )
        self.model = transformers.AutoModelForVision2Seq.from_pretrained(
            self.model_name_or_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        self.options = self.options.split(",")
        self.first_token_of_options = [
            self.processor.tokenizer.encode(option, add_special_tokens=False)[0]
            for option in self.options
        ]


    def apply(self, inputs: list[dict]) -> list[dict]:
        # pass
        image_prefix = "<|vision_start|><|image_pad|><|vision_end|>"
        prompts = [
            (image_prefix + self.prompt).format(**input.get("mapping", {}))
            for input in inputs
        ]
        images = [
            [qwen_vl_utils.fetch_image({"image": input["image"]})]
            for input in inputs
        ]

        batch = self.processor(
            text=prompts,
            images=images,
            return_tensors="pt",
            padding=True
        ).to(self.model.device)

        with torch.no_grad():
            outputs = self.model.forward(**batch)

        lens = (batch["input_ids"] != self.processor.tokenizer.pad_token_id).sum(-1)
        pooled_logits = outputs["logits"][
            torch.arange(lens.size(0), device=self.model.device), lens-1]
        pooled_logits = pooled_logits[:, self.first_token_of_options]
        pooled_probs = pooled_logits.softmax(-1)

        retain_probs = pooled_probs[:, 0].cpu().tolist()

        return [{"retain": retain_prob > self.threshold, "retain_prob": retain_prob} for retain_prob in retain_probs]



if __name__ == "__main__":
    from PIL import Image
    import requests

    filter = Qwen_VL_Filter()

    url1 = "https://images.nubilefilms.com/videos/whats_yours_is_mine_with_chanel_camryn_freya_parker/samples/cover960.jpg"
    image1 = Image.open(requests.get(url1, stream=True).raw).convert("RGB")

    # Sample messages for batch inference
    url2 = "https://aerospaceamerica.aiaa.org/wp-content/uploads/2023/06/0723_Aero_Starship-1200x675.jpg"
    image2 = Image.open(requests.get(url2, stream=True).raw)
    
    print(filter.apply([{"image": image1}, {"image": image2}]))
