import torch
from torch import nn
import diffusers


class AutoencoderKL(nn.Module):
    def __init__(
        self,
        scale: float = 0.18215,
        shift: float = 0.0,
        repo="stabilityai/stable-diffusion-2-1",
    ):
        super().__init__()
        self.scale = scale
        self.shift = shift
        self.ae = diffusers.AutoencoderKL.from_pretrained(repo, subfolder="vae")
        self.ae.eval()
        self.ae.requires_grad_(False)

    def forward(self, img):
        return self.encode(img)

    @torch.no_grad()
    def encode(self, img):
        latent = self.ae.encode(img, return_dict=False)[0].sample()
        return (latent - self.shift) * self.scale

    @torch.no_grad()
    def decode(self, latent):
        rec = self.ae.decode(latent / self.scale + self.shift, return_dict=False)[0]
        return rec


class TinyAutoencoderKL(nn.Module):
    def __init__(self, repo="madebyollin/taesd"):
        super().__init__()
        self.ae = diffusers.AutoencoderTiny.from_pretrained(repo)
        self.ae.eval()
        self.ae.compile()
        self.ae.requires_grad_(False)
        self.ae.encode = torch.compile(self.ae.encode, fullgraph=True, dynamic=False)

    def forward(self, img):
        return self.encode(img)

    @torch.no_grad()
    def encode(self, img):
        latent = self.ae.encode(img, return_dict=False)[0]
        return latent.clone()

    @torch.no_grad()
    def decode(self, latent):
        rec = self.ae.decode(latent, return_dict=False)[0]
        return rec
