import argparse
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from diffusers.models import AutoencoderKL


def main(args):
    # Setup PyTorch:
    torch.manual_seed(args.seed)
    torch.set_grad_enabled(False)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # create and load model
    vae = AutoencoderKL.from_pretrained(f"stabilityai/{args.vae}").to(device)

    # load image
    img_path = args.image_path
    out_path = args.image_path.replace('.jpg', '_vae.jpg').replace('.jpeg', '_vae.jpeg').replace('.png', '_vae.png')
    input_size = args.image_size
    img = Image.open(img_path).convert("RGB")

    # preprocess
    size_org = img.size
    img = img.resize((input_size, input_size))
    img = np.array(img) / 255.
    x = 2.0 * img - 1.0 # x value is between [-1, 1]
    x = torch.tensor(x)
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)
    x_input = x.float().to("cuda")

    # inference
    with torch.no_grad():
        # Map input images to latent space + normalize latents:
        latent = vae.encode(x_input).latent_dist.sample().mul_(0.18215)
        # reconstruct:
        output = vae.decode(latent / 0.18215).sample # output value is between [-1, 1]

    # postprocess
    output = F.interpolate(output, size=[size_org[1], size_org[0]], mode='bilinear').permute(0, 2, 3, 1)[0]
    sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()

    # save        
    Image.fromarray(sample).save(out_path)
    print("Reconstructed image is saved to {}".format(out_path))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--image-path", type=str, default="assets/example.jpg")
    parser.add_argument("--vae", type=str, choices=["sdxl-vae", "sd-vae-ft-mse"], default="sd-vae-ft-mse")
    parser.add_argument("--image-size", type=int, choices=[256, 512, 1024], default=512)
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()
    main(args)