
import torch
from diffusers import DiffusionPipeline
from transformers import JanusForConditionalGeneration, JanusProcessor


class DiffusionModel(): 
    """ Diffusion models we use from diffusers. 
    """
    def __init__(self, model_id="stable-diffusion-v1-5/stable-diffusion-v1-5"): 
        self.model_id = model_id
        self.pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)

        ##### ignore nsfw #####
        def dummy(images, **kwargs): 
            return images, [False]
        
        if model_id.split("/")[-1].startswith("stable-diffusion"): 
            self.pipeline.safety_checker = dummy
    
    def __call__(self, prompt): 
        return self.pipeline(prompt).images[0]
    
    def to(self, device): 
        assert device in ["cuda", "cpu"], f"unknown device: {device}"
        self.pipeline.to(device)
        if device == "cpu": 
            torch.cuda.empty_cache()
        
        return self


class Janus_Pro(): 
    def __init__(self, model_id="deepseek-community/Janus-Pro-7B"): 
        self.model_id = model_id
        self.processor = JanusProcessor.from_pretrained(model_id)
        self.model = JanusForConditionalGeneration.from_pretrained(
            model_id, torch_dtype=torch.bfloat16, device_map="cpu"
        )

        # Set number of images to generate
        self.model.generation_config.num_return_sequences = 1
    
    def __call__(self, prompt): 
        messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]

        # Apply chat template
        prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
        inputs = self.processor(
            text=prompt,
            generation_mode="image",
            return_tensors="pt"
        ).to(self.model.device, dtype=torch.bfloat16)

        outputs = self.model.generate(
            **inputs,
            generation_mode="image",
            do_sample=True,
            use_cache=True
        )

        # Decode and save images
        decoded_image = self.model.decode_image_tokens(outputs)
        images = self.processor.postprocess(list(decoded_image.float()), return_tensors="PIL.Image.Image")

        return images["pixel_values"][0]
    
    def to(self, device): 
        assert device in ["cuda", "cpu"], f"unknown device: {device}"
        self.model.to(device)
        if device == "cpu": 
            torch.cuda.empty_cache()
        
        return self

