import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import AutoencoderKL
from ..utils.util import zero_module

class VAEAdatper(nn.Module):
    def __init__(self, model_path, weight_dtype):
        super().__init__()
        self.vae = AutoencoderKL.from_pretrained(model_path).to(
            'cuda', dtype=weight_dtype
        )
        self.zeros_in = zero_module(
            nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1, padding=0, stride=1)
        )

        self.zeros_out = zero_module(
            nn.Conv2d(in_channels=4, out_channels=4, kernel_size=1, padding=0, stride=1)
        )
    
    def forward(self, img, img_c):
        # print(img.shape, img_c.shape)
        latents = img + self.zeros_in(img_c)
        latents = self.vae.encode(latents).latent_dist.sample()
        latents = latents * 0.18215
        return self.zeros_out(latents)
