import torch
import torch.nn as nn
from diffusers import VQModel


class VQEncoder(nn.Module):
    def __init__(self,
                #  vqvae: VQModel,
                 ):
        super().__init__()
        # self.encoder = vqvae
        self.vqvae = VQModel.from_pretrained("microsoft/vq-diffusion-ithq", subfolder='vqvae')
        for param in self.vqvae.parameters():
            param.requires_grad = False
            
    def forward(self, input):
        """Encoding"""
        latents = self.vqvae.encoder(input)
        return latents
    
    def forward_full(self, input):
        return self.vqvae.forward(input).sample
    
    def decode(self, latents):
        return self.vqvae.decoder(latents).sample