# PYTHONPATH=. srun -p mllm_safety --quotatype=reserved --gres=gpu:1 --cpus-per-task=8 --time=30000 python src/tools/filters/filters/shieldgemma_2_4b_it.py
import requests
from dataclasses import dataclass

import torch
import transformers

from src.tools.filters.base import Filter


@dataclass
class SheildGemma_2_4b_it_Filter(Filter):
    model_name_or_path: str = "/mnt/lustrenew/mllm_safety-shared/models/huggingface/google/shieldgemma-2-4b-it"

    def __post_init__(self):
        self.processor = transformers.AutoProcessor.from_pretrained(self.model_name_or_path)
        self.model = transformers.ShieldGemma2ForImageClassification.from_pretrained(
            self.model_name_or_path,
            torch_dtype=torch.bfloat16,
        ).eval()

    def apply(self, inputs: list[dict]) -> list[dict]:

        images = [input["image"] for input in inputs] # TODO?

        model_inputs = self.processor(images=images, return_tensors="pt")

        with torch.inference_mode():
            scores = self.model(**model_inputs)
        
        probs = scores.probabilities.reshape(-1, 3, 2)
        retain_flags = (probs > 0.5)[..., -1].all(-1).tolist()
        
        return [{"retain": retain_flag} for retain_flag in retain_flags]


if __name__ == "__main__":

    from PIL import Image

    filter = SheildGemma_2_4b_it_Filter()

    images = [
        # Image.open("data/moderation_v2/files/sex00.jpg"),
        # Image.open("data/moderation_v2/files/sex01.jpg"),
        # Image.open("data/moderation_v2/files/sex02.jpg"),
        # Image.open("data/moderation_v2/files/sex03.jpg"),
        # Image.open("data/moderation_v2/files/sex04.jpg"),
        # Image.open("data/moderation_v2/files/sex05.jpg"),
        # Image.open("data/moderation_v2/files/sex06.jpg"),
        # Image.open("data/moderation_v2/files/sex07.jpg"),
        # Image.open("data/moderation_v2/files/sex08.jpg"),
        # Image.open("data/moderation_v2/files/sex09.jpg"),
        # # 
        Image.open("data/moderation_v2/files/violence00.jpg"),
        Image.open("data/moderation_v2/files/violence01.jpg"),
        Image.open("data/moderation_v2/files/violence02.jpg"),
        Image.open("data/moderation_v2/files/violence03.jpg"),
        Image.open("data/moderation_v2/files/violence04.jpg"),
        Image.open("data/moderation_v2/files/violence05.jpg"),
        Image.open("data/moderation_v2/files/violence06.jpg"),
        Image.open("data/moderation_v2/files/violence07.jpg"),
        Image.open("data/moderation_v2/files/violence08.jpg"),
        Image.open("data/moderation_v2/files/violence09.jpg"),
        # #
        # Image.open("data/imagenet_animals_v2/files/bird.jpg")
    ]

    results = filter.apply([{"image": image} for image in images])
    print([result["retain"] for result in results])
    breakpoint()