# PYTHONPATH=. srun -p mllm_safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=8 --time=30000 python src/tools/filters/filters/llama_guard_3_11b_vision.py
import requests
from dataclasses import dataclass

import torch
import transformers

from src.tools.filters.base import Filter


@dataclass
class Llama_Guard_3_11b_Vision_Filter(Filter):
    model_name_or_path: str = "/mnt/lustrenew/mllm_safety-shared/models/huggingface/meta-llama/Llama-Guard-3-11B-Vision"
    prompt: str = "Describe the image to a 3 year-old."
    
    def __post_init__(self):
        self.processor = transformers.AutoProcessor.from_pretrained(self.model_name_or_path)
        self.model = transformers.AutoModelForVision2Seq.from_pretrained(
            self.model_name_or_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        ).eval()
        self.conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": self.prompt},
                    {"type": "image"},
                ],
            }
        ]

    def apply(self, inputs: list[dict]) -> list[dict]:

        images = [[input["image"]] for input in inputs]
        conversations = [self.conversation] * len(images)

        input_prompt = self.processor.apply_chat_template(
            conversations,
            return_tensors="pt"
        )

        inputs = self.processor(
            text=input_prompt, 
            images=images, 
            return_tensors="pt"
        ).to(self.model.device)

        prompt_len = len(inputs['input_ids'][0])
        output = self.model.generate(
            **inputs,
            max_new_tokens=20,
            pad_token_id=0,
            do_sample=False,
        )

        generated_tokens = output[:, prompt_len:]

        generated_text = self.processor.batch_decode(generated_tokens, skip_special_tokens=True)

        return [{"retain": not "unsafe" in text} for text in generated_text]


if __name__ == "__main__":
    from PIL import Image

    filter = Llama_Guard_3_11b_Vision_Filter()

    images = [
        Image.open("data/moderation_v2/files/sex00.jpg"),
        Image.open("data/moderation_v2/files/sex01.jpg"),
        Image.open("data/moderation_v2/files/sex02.jpg"),
        Image.open("data/moderation_v2/files/sex03.jpg"),
        Image.open("data/moderation_v2/files/sex04.jpg"),
        Image.open("data/moderation_v2/files/sex05.jpg"),
        Image.open("data/moderation_v2/files/sex06.jpg"),
        Image.open("data/moderation_v2/files/sex07.jpg"),
        Image.open("data/moderation_v2/files/sex08.jpg"),
        Image.open("data/moderation_v2/files/sex09.jpg"),
        # 
        Image.open("data/moderation_v2/files/violence00.jpg"),
        Image.open("data/moderation_v2/files/violence01.jpg"),
        Image.open("data/moderation_v2/files/violence02.jpg"),
        Image.open("data/moderation_v2/files/violence03.jpg"),
        Image.open("data/moderation_v2/files/violence04.jpg"),
        Image.open("data/moderation_v2/files/violence05.jpg"),
        Image.open("data/moderation_v2/files/violence06.jpg"),
        Image.open("data/moderation_v2/files/violence07.jpg"),
        Image.open("data/moderation_v2/files/violence08.jpg"),
        Image.open("data/moderation_v2/files/violence09.jpg"),
        #
        Image.open("data/imagenet_animals_v2/files/bird.jpg")
    ]

    results = filter.apply([{"image": image} for image in images])
    print([result["retain"] for result in results])
    breakpoint()
