from diffusers import AutoencoderKL
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import *
from PIL import Image 
import torchvision.transforms as T
pretrained_model_name_or_path = 'CompVis/stable-diffusion-v1-4'
revision = None
image_size = 512

img_path = 'DATASETS/concat_laion_11k_64/49.png'
Resize = T.Resize(size=(image_size,image_size))
Normalize = T.Normalize([0.5], [0.5])
img_batch = Normalize(Resize(T.ToTensor()(Image.open(img_path).convert('RGB')).unsqueeze(dim=0)))
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
ckpt_path = None#'vae_checkpoint/laion_11k/vae.pth'
if ckpt_path is not None:
    vae_state_dict = torch.load(ckpt_path) #'results/epoch_1/unet.pth'))
    vae.load_state_dict(vae_state_dict)
    print("load checkpoint from {} successfully".format(ckpt_path))
latents = vae.encode(img_batch).latent_dist.sample()
latents = latents * vae.config.scaling_factor
latents = 1 / vae.config.scaling_factor * latents
image = vae.decode(latents, return_dict=False)[0].detach()
#image = (image / 2 + 0.5).clamp(0, 1)
image_processor = VaeImageProcessor(vae_scale_factor=vae.config.scaling_factor)
do_denormalize = [True] * 1
image = image_processor.postprocess(image, output_type='pil', do_denormalize=do_denormalize)
print("image:",image)
pil_image = image[0]
pil_image.save("vae_reconstructed_img_49.png")
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
#pil_image = T.ToPILImage()(image.squeeze())
#print("image:{}".format(image.shape))
#pil_image.save("vae_reconstructed_img.png")
