import cv2
import numpy as np
import torch
from PIL import Image
from diffusers import DDIMScheduler, StableDiffusionImg2ImgPipeline
from torchvision.transforms import transforms


class HuggingFaceDiffusionWrapper:
    def __init__(self):
        model_id = "CompVis/stable-diffusion-v1-4"
        device = "cuda"

        self.pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
        self.pipeline = self.pipeline.to(device)
        scheduler = DDIMScheduler.from_config(self.pipeline.scheduler.config)
        self.pipeline.scheduler = scheduler
        self.strength = 0.1
        self.guidance_scale = 0.5


    def __call__(self, images, t):
        prompt = "clean and detailed imagenet image"
        batch_prompts = [prompt for _ in range(images.shape[0])]
        normalized_images = (images / 2 + 0.5) * 255
        images = normalized_images.byte().cpu().permute(0, 2, 3, 1).numpy()

        batched_images = [Image.fromarray(images[i]) for i in range(images.shape[0])]

        denoised_images = self.pipeline(
            image=batched_images,
            strength=self.strength,
            guidance_scale=self.guidance_scale,
            num_inference_steps=t,
            prompt=batch_prompts,
        ).images
        transform = transforms.Compose([
            transforms.ToTensor()  # Converts PIL Image to tensor (values between 0 and 1)
        ])

        denoised_images = [transform(denoised_image).cuda() for denoised_image in denoised_images]
        denoised_images = torch.stack(denoised_images) * 2 - 1

        return denoised_images



