import os
from diffusers import AutoencoderKLTemporalDecoder, AutoencoderKL
from diffusers import UNetSpatioTemporalConditionModel
from diffusers import EulerDiscreteScheduler
from diffusers.image_processor import VaeImageProcessor
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from PIL import Image
import torch
import numpy as np
from einops import rearrange
from torchmetrics import PeakSignalNoiseRatio
device = "cuda"
variant = "fp16"
psnr = PeakSignalNoiseRatio(reduction="none").to(device)

# noise_scheduler = EulerDiscreteScheduler.from_config(base_model_id, subfolder="scheduler", variant=variant)
# image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_model_id, subfolder="image_encoder", variant=variant)

base_model_id = "stabilityai/stable-video-diffusion-img2vid"
svd_vae = AutoencoderKLTemporalDecoder.from_pretrained(base_model_id, subfolder="vae", variant=variant).to(device)

base_model_id = "stabilityai/stable-diffusion-2-1-base"
sd_vae = AutoencoderKL.from_pretrained(base_model_id, subfolder="vae", variant=variant).to(device)

# unet = UNetSpatioTemporalConditionModel.from_pretrained(base_model_id, subfolder="unet", variant=variant).to(device)
# feature_extractor = CLIPImageProcessor.from_pretrained(base_model_id, subfolder="feature_extractor", variant=variant)
svd_needs_upcasting = svd_vae.dtype == torch.float16 and svd_vae.config.force_upcast
print(svd_vae.dtype, svd_vae.config.force_upcast)
if svd_needs_upcasting:
    svd_vae = svd_vae.to(dtype=torch.float32)
    print("upcasting svd_vae to float32")

sd_needs_upcasting = sd_vae.dtype == torch.float16 and sd_vae.config.force_upcast
if sd_needs_upcasting:
    sd_vae = sd_vae.to(dtype=torch.float32)
    print("upcasting sd_vae to float32")

svd_process = VaeImageProcessor()
sd_process = VaeImageProcessor()
def load_image(img_path, bg_color, rescale=True, return_type="np"):
    # not using cv2 as may load in uint16 format
    # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
    # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
    # pil always returns uint8
    img_wh = (384, 384)
    img = np.array(Image.open(img_path).resize(img_wh))
    img = img.astype(np.float32) / 255.0  # [0, 1]
    if img.shape[-1] == 4:
        alpha = img[..., 3:4]
        img = img[..., :3] * alpha + bg_color * (1 - alpha)
    # if rescale:
    #     img = img * 2.0 - 1.0  # to -1 ~ 1
    if return_type == "np":
        pass
    elif return_type == "pt":
        img = torch.from_numpy(img)
    else:
        raise NotImplementedError

    return img

models = os.listdir('data/3d-data-example/00/')

models = [os.path.join('data/3d-data-example/00/', model) for model in models]

# compare the two vae models psnr metric
sd_psnr_list, svd_psnr_list = [], []
for model in models:
    img_list = []
    for i in range(0, 15, 1):
        name = f"render_{i:04d}.png"
        img_path = os.path.join(model, name)
        img = load_image(img_path, bg_color=[1, 1, 1], rescale=True, return_type="pt")
        img_list.append(img)
    img = torch.stack(img_list, dim=0).to(device)
    img = rearrange(img, "t h w c -> t c h w")
    svd_image = svd_process.preprocess(img).to(svd_vae.dtype).to(svd_vae.device)
    sd_image = sd_process.preprocess(img).to(sd_vae.dtype).to(sd_vae.device)
    # svd_image = (img - 0.5) / 0.5
    # sd_image = (img - 0.5) / 0.5
    # svd_image = 2 * img - 1
    # sd_image = 2 * img - 1
    svd_image = svd_image.to(svd_vae.dtype).to(svd_vae.device)
    sd_image = sd_image.to(sd_vae.dtype).to(sd_vae.device)
    with torch.no_grad():

        sd_image_latents = sd_vae.encode(sd_image).latent_dist.mode()
        svd_image_latents = svd_vae.encode(svd_image).latent_dist.mode()
        sd_decoded_image = sd_vae.decode(sd_image_latents).sample
        svd_decoded_image = svd_vae.decode(svd_image_latents, num_frames=len(img_list)).sample
        sd_decoded_image = sd_process.postprocess(sd_decoded_image, output_type="pt")
        svd_decoded_image = svd_process.postprocess(svd_decoded_image, output_type="pt")
        # print(img.device, sd_decoded_image.device, svd_decoded_image.device)
        sd_psnr = psnr(sd_decoded_image, img)
        svd_psnr = psnr(svd_decoded_image, img)
        sd_psnr_list.append(sd_psnr)
        svd_psnr_list.append(svd_psnr)
        print(f"sd_psnr: {sd_psnr}, svd_psnr: {svd_psnr}")

print(f"sd_psnr: {sum(sd_psnr_list) / len(sd_psnr_list)}, svd_psnr: {sum(svd_psnr_list) / len(svd_psnr_list)}")

