import torch
import torch.nn as nn
import sys

from mar.models.vae import AutoencoderKL
from einops import rearrange

class VAEEncoder(nn.Module):
    def __init__(self, name='vae'):
        super().__init__()
        self.name = name
        self.base_model = AutoencoderKL(
            embed_dim=16, 
            ch_mult=(1, 1, 2, 2, 4), 
            ckpt_path='path_to_kl16_ckpt'
        ).eval()
        self.emb_dim = 12
        self.patch_size = 14
        self.latent_ndim = 2

    def forward(self, x):
        encoded = self.base_model.encode(x)
        encoded = encoded.mode()
        encoded = rearrange(encoded, 'b h w c -> b (h w) c')
        return encoded