# PYTHONPATH=. srun -p p-cpu-new --quotatype=reserved --cpus-per-task=8 --time=30000 python src/tools/filters/filters/openai_moderation_api.py
SHLAB_PROXY="http://closeai-proxy.pjlab.org.cn:23128"
import os; os.environ["http_proxy"] = SHLAB_PROXY; os.environ["https_proxy"] = SHLAB_PROXY
import io
import time
import base64
from dataclasses import dataclass

import openai

from src.tools.filters.base import Filter


@dataclass
class OpenAI_Moderation_Filter(Filter):
    waiting_time: float = 15

    def __post_init__(self):
        self.client = openai.OpenAI()


    def apply(self, inputs: list[dict]) -> list[dict]:

        images = [input["image"] for input in inputs]

        responses = []

        for image in images:
            # Convert PIL image to base64
            buffered = io.BytesIO()
            image.save(buffered, format="JPEG")  # Use "PNG" if your image is PNG
            base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")

            while True:
                try:
                    response = self.client.moderations.create(
                        model="omni-moderation-latest",
                        input=[
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/jpeg;base64,{base64_image}",
                                }
                            },
                        ],
                    )
                    break
                except (openai.RateLimitError, openai.APIConnectionError):
                    time.sleep(self.waiting_time)
                    print("waiting for quota")
                    pass

            responses.append(response)

        return [{"retain": not response.results[0].flagged, "meta": response} for response in responses]
            


if __name__ == "__main__":

    from PIL import Image
    import requests

    filter = OpenAI_Moderation_Filter()

    images = [
        Image.open("data/moderation_v2/files/sex00.jpg"),
        Image.open("data/moderation_v2/files/sex01.jpg"),
        Image.open("data/moderation_v2/files/sex02.jpg"),
        Image.open("data/moderation_v2/files/sex03.jpg"),
        Image.open("data/moderation_v2/files/sex04.jpg"),
        Image.open("data/moderation_v2/files/sex05.jpg"),
        Image.open("data/moderation_v2/files/sex06.jpg"),
        Image.open("data/moderation_v2/files/sex07.jpg"),
        Image.open("data/moderation_v2/files/sex08.jpg"),
        Image.open("data/moderation_v2/files/sex09.jpg"),
        # 
        Image.open("data/moderation_v2/files/violence00.jpg"),
        Image.open("data/moderation_v2/files/violence01.jpg"),
        Image.open("data/moderation_v2/files/violence02.jpg"),
        Image.open("data/moderation_v2/files/violence03.jpg"),
        Image.open("data/moderation_v2/files/violence04.jpg"),
        Image.open("data/moderation_v2/files/violence05.jpg"),
        Image.open("data/moderation_v2/files/violence06.jpg"),
        Image.open("data/moderation_v2/files/violence07.jpg"),
        Image.open("data/moderation_v2/files/violence08.jpg"),
        Image.open("data/moderation_v2/files/violence09.jpg"),
        #
        Image.open("data/imagenet_animals_v2/files/bird.jpg")
    ]

    results = filter.apply([{"image": image} for image in images])
    print([result["retain"] for result in results])
    breakpoint()
