from functools import partial

import torch
import torch.nn as nn
import numpy as np

from timm.models.vision_transformer import PatchEmbed, Block

from util.pos_embed import get_2d_sincos_pos_embed
from MLP import MLP
from einops import repeat
import matplotlib.pyplot as plt


class VariationalAutoEncoder2ViT(nn.Module):

    def __init__(self, img_size=224, patch_size=16, in_chans=3, z_dim=128,
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm,
                 ELastNorm='LR', DLastNorm='LR'):
        super().__init__()
        self.in_chans = in_chans
        self.L = (img_size // patch_size) ** 2 + 1
        self.patch_size = patch_size
        self.img_size = img_size
        self.z_dim = z_dim

        assert embed_dim == decoder_embed_dim

        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.num_patches = num_patches

        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),
                                      requires_grad=False)  

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for i in range(depth)])
        if ELastNorm == 'LN':
            self.encoder_norm = torch.nn.LayerNorm(embed_dim, eps=1e-6)
        elif ELastNorm == 'NO':
            self.encoder_norm = nn.Identity(embed_dim)
        elif ELastNorm == 'IN':
            self.encoder_norm = nn.InstanceNorm1d(self.L, affine=True)
        elif ELastNorm == 'LR':
            self.encoder_norm = nn.LeakyReLU(0.2)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.head_qz = nn.Sequential(  
            nn.Linear(embed_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, 2 * z_dim),
        )

        self.decoder_embed_dim = decoder_embed_dim
        self.z_embed = nn.Sequential( 
            nn.Linear(z_dim, 2 * decoder_embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * decoder_embed_dim, decoder_embed_dim),
        )  

        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim),
                                              requires_grad=False)  

        self.decoder_depth = decoder_depth
        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for i in range(decoder_depth)])

        if DLastNorm == 'LN':
            self.decoder_norm = torch.nn.LayerNorm(embed_dim, eps=1e-6)
        elif DLastNorm == 'NO':
            self.decoder_norm = nn.Identity(decoder_embed_dim)
        elif DLastNorm == 'IN':
            self.decoder_norm = nn.InstanceNorm1d(self.L, affine=True)
        elif DLastNorm == 'LR':
            self.decoder_norm = nn.LeakyReLU(0.2)

        self.decoder_pred = nn.Sequential(
            nn.Linear(decoder_embed_dim, 2 * decoder_embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * decoder_embed_dim, 2 * decoder_embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * decoder_embed_dim, patch_size ** 2 * in_chans + 1),
        )  

        self.initialize_weights()

    def initialize_weights(self):
        pos_embed_all = get_2d_sincos_pos_embed(
            self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
            nos_token=True, cls_token=True)  
        pos_embed = pos_embed_all[:-1, ...]  
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        decoder_pos_embed_all = get_2d_sincos_pos_embed(
            self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
            nos_token=True, cls_token=True)  
        decoder_pos_embed = decoder_pos_embed_all[:-1, ...]  
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
            if m.weight is not None:
                nn.init.constant_(m.weight, 1.0)

    def patchify(self, imgs):
 
        p = self.patch_embed.patch_size[0]
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], self.in_chans, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * self.in_chans))
        return x

    def unpatchify(self, x):

        p = self.patch_embed.patch_size[0]
        h = w = int(x.shape[1] ** .5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_chans))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], self.in_chans, h * p, h * p))
        return imgs

    def random_masking(self, x, r_ind=None, Sratio=0., Tratio=1., Lst=None):
 
        N, L, D = x.shape 

        if r_ind is None:
            r_ind = [torch.randperm(L).to(x.device) for i in range(N)]
            r_ind = torch.cat(r_ind, dim=0).view(N, L)  
        ids_restore = torch.argsort(r_ind, dim=1)

        mask_s = torch.zeros([N, L], device=x.device)
        mask_t = torch.zeros([N, L], device=x.device)
        if (Tratio == -10.) and (Lst is not None):
            len_s = int(Lst * Sratio)
            len_t = Lst - len_s
        else:
            len_s = int(L * Sratio)
            
            if isinstance(Tratio, torch.Tensor):
                len_t = int(np.ceil((L - len_s) * Tratio.cpu()))
            else:
                len_t = int(np.ceil((L - len_s) * Tratio))
        mask_s[:, :len_s] = 1
        mask_t[:, len_s:len_s + len_t] = 1
        mask_s = torch.gather(mask_s, dim=1, index=ids_restore)
        mask_t = torch.gather(mask_t, dim=1, index=ids_restore)

        return mask_s, mask_t, ids_restore, r_ind

    def forward_encoder(self, imgs, mask_st):

        x = self.patch_embed(imgs)
        xst = mask_st.unsqueeze(-1) * x + (1. - mask_st).unsqueeze(-1) * self.mask_token

        cls_token = self.cls_token.expand(xst.shape[0], -1, -1)  
        cxst = torch.cat([cls_token, xst], dim=1)
        cxst = cxst + self.pos_embed

        for blk in self.blocks:
            cxst = blk(cxst)
        cxst_embed = self.encoder_norm(cxst)

        mu_zqst, lstd_zqst = torch.chunk(self.head_qz(cxst_embed[:, 0, :]), chunks=2, dim=-1)

        return mu_zqst, lstd_zqst

    def forward_decoder(self, zqst, imgs, mask_s, toimg=False):

        latent_zst = self.z_embed(zqst)  

        x = self.patch_embed(imgs) 
        xs = mask_s.unsqueeze(-1) * x + (1. - mask_s).unsqueeze(-1) * self.mask_token

        x = torch.cat([latent_zst.unsqueeze(1), xs], dim=1)

        x = x + self.decoder_pos_embed

        for blk in self.decoder_blocks:
            x = blk(x)
        dec_embed = self.decoder_norm(x)

        latent_pred_patch = self.decoder_pred(dec_embed[:, 1:, :])  
        mu_xp, lstd_xp = torch.tanh(latent_pred_patch[:, :, :-1]), latent_pred_patch[:, -1:, -1:]

        lstd_xp = torch.zeros_like(lstd_xp)

        if toimg:
            pred_img = self.unpatchify(mu_xp)  
            recon_ospt = mask_s.unsqueeze(-1) * self.patchify(imgs) + (1. - mask_s).unsqueeze(-1) * mu_xp
            recon_ospt_img = self.unpatchify(recon_ospt)
            return recon_ospt_img, pred_img

        else:
            return mu_xp, lstd_xp

    def forward(self, imgs, r_ind=None, Sratio=0., Tratio=1., Lst=None, betaKL=1.):

        assert Sratio >= 0.  
        assert (Tratio > 0) or ((Tratio == -10.) and (Lst is not None) and (r_ind is not None))

        x = self.patchify(imgs)  
        mask_s, mask_t, ids_restore, r_ind = self.random_masking(x, r_ind=r_ind, Sratio=Sratio,
                                                                 Tratio=Tratio, Lst=Lst)

        mu_zqst, lstd_zqst = self.forward_encoder(imgs, mask_st=mask_s + mask_t)
        mu_zqs, lstd_zqs = self.forward_encoder(imgs, mask_st=mask_s)

        zqst = mu_zqst + lstd_zqst.exp() * torch.randn_like(mu_zqst)  

        mu_xp, lstd_xp = self.forward_decoder(zqst, imgs, mask_s, toimg=False)

        NLL1 = (lstd_xp + 0.5 * ((x - mu_xp) / lstd_xp.exp()).pow(2)).mean(-1)  
        NLL = (NLL1 * mask_t).sum() / mask_t.sum()
        lsmls = lstd_zqst - lstd_zqs
        KLloss = 0.5 * (
                ((mu_zqst - mu_zqs) / lstd_zqs.exp()).pow(2)
                + lsmls.exp().pow(2) - 1. - 2. * lsmls
        ).mean()  
        loss = NLL + betaKL * KLloss

        return loss, NLL, KLloss


class NaiveVariationalAutoEncoder(VariationalAutoEncoder2ViT):

    def forward(self, imgs, r_ind=None, Sratio=0., Tratio=1., Lst=None, betaKL=1.):

        assert Sratio >= 0.  
        assert (Tratio > 0) or ((Tratio == -10.) and (Lst is not None) and (r_ind is not None))

        x = self.patchify(imgs) 
        mask_s, mask_t, ids_restore, r_ind = self.random_masking(x, r_ind=r_ind, Sratio=0.,
                                                                 Tratio=1., Lst=Lst)

        mu_zqst, lstd_zqst = self.forward_encoder(imgs, mask_st=mask_s + mask_t)

        mu_zqs, lstd_zqs = torch.zeros_like(mu_zqst), torch.zeros_like(lstd_zqst)

        zqst = mu_zqst + lstd_zqst.exp() * torch.randn_like(mu_zqst)  # BE

        mu_xp, lstd_xp = self.forward_decoder(zqst, imgs, mask_s, toimg=False)

        NLL1 = (lstd_xp + 0.5 * ((x - mu_xp) / lstd_xp.exp()).pow(2)).mean(-1)  # BL
        NLL = (NLL1 * mask_t).sum() / mask_t.sum()
        lsmls = lstd_zqst - lstd_zqs
        KLloss = 0.5 * (
                ((mu_zqst - mu_zqs) / lstd_zqs.exp()).pow(2)
                + lsmls.exp().pow(2) - 1. - 2. * lsmls
        ).mean()  #
        loss = NLL + betaKL * KLloss

        return loss, NLL, KLloss


class VariationalAutoEncoder1ViT(nn.Module):

    def __init__(self, img_size=224, patch_size=16, in_chans=3, z_dim=128,
                 embed_dim=1024, depth=24, num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm,
                 ELastNorm='LR', DLastNorm='LR'):
        super().__init__()
        self.in_chans = in_chans
        self.L = (img_size // patch_size) ** 2 + 1
        self.patch_size = patch_size
        self.img_size = img_size
        self.z_dim = z_dim

        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.num_patches = num_patches

        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),
                                      requires_grad=False)  

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for i in range(depth)])
        if ELastNorm == 'LN':
            self.encoder_norm = torch.nn.LayerNorm(embed_dim, eps=1e-6)
        elif ELastNorm == 'NO':
            self.encoder_norm = nn.Identity(embed_dim)
        elif ELastNorm == 'IN':
            self.encoder_norm = nn.InstanceNorm1d(self.L, affine=True)
        elif ELastNorm == 'LR':
            self.encoder_norm = nn.LeakyReLU(0.2)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.head_qz = nn.Sequential(  
            nn.Linear(embed_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, 2 * z_dim),
        )
       
        self.z_embed = nn.Sequential( 
            nn.Linear(z_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, embed_dim),
        )  

        if DLastNorm == 'LN':
            self.decoder_norm = torch.nn.LayerNorm(embed_dim, eps=1e-6)
        elif DLastNorm == 'NO':
            self.decoder_norm = nn.Identity(embed_dim)
        elif DLastNorm == 'IN':
            self.decoder_norm = nn.InstanceNorm1d(self.L, affine=True)
        elif DLastNorm == 'LR':
            self.decoder_norm = nn.LeakyReLU(0.2)

        self.decoder_pred = nn.Sequential(
            nn.Linear(embed_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, patch_size ** 2 * in_chans + 1),
        )  

        self.initialize_weights()

    def initialize_weights(self):

        pos_embed_all = get_2d_sincos_pos_embed(
            self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
            nos_token=True, cls_token=True)  
        pos_embed = pos_embed_all[:-1, ...]  
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
            if m.weight is not None:
                nn.init.constant_(m.weight, 1.0)

    def patchify(self, imgs):
     
        p = self.patch_embed.patch_size[0]
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], self.in_chans, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * self.in_chans))
        return x

    def unpatchify(self, x):
       
        p = self.patch_embed.patch_size[0]
        h = w = int(x.shape[1] ** .5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_chans))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], self.in_chans, h * p, h * p))
        return imgs


    def random_masking(self, x, r_ind=None, Sratio=0., Tratio=1., Lst=None):
       
        N, L, D = x.shape  

        if r_ind is None:
            r_ind = [torch.randperm(L).to(x.device) for i in range(N)]
            r_ind = torch.cat(r_ind, dim=0).view(N, L)  
        ids_restore = torch.argsort(r_ind, dim=1)

        mask_s = torch.zeros([N, L], device=x.device)
        mask_t = torch.zeros([N, L], device=x.device)
        if (Tratio == -10.) and (Lst is not None):
            len_s = int(Lst * Sratio)
            len_t = Lst - len_s
        else:
            len_s = int(L * Sratio)
            if isinstance(Tratio, torch.Tensor):
                len_t = int(np.ceil((L - len_s) * Tratio.cpu()))
            else:
                len_t = int(np.ceil((L - len_s) * Tratio))
        mask_s[:, :len_s] = 1
        mask_t[:, len_s:len_s + len_t] = 1
        mask_s = torch.gather(mask_s, dim=1, index=ids_restore)
        mask_t = torch.gather(mask_t, dim=1, index=ids_restore)

        return mask_s, mask_t, ids_restore, r_ind

    def forward_encoder(self, imgs, mask_st):

        x = self.patch_embed(imgs)

        xst = mask_st.unsqueeze(-1) * x + (1. - mask_st).unsqueeze(-1) * self.mask_token

        cls_token = self.cls_token.expand(xst.shape[0], -1, -1)  
        cxst = torch.cat([cls_token, xst], dim=1)
        cxst = cxst + self.pos_embed

        for blk in self.blocks:
            cxst = blk(cxst)
        cxst_embed = self.encoder_norm(cxst)

        mu_zqst, lstd_zqst = torch.chunk(self.head_qz(cxst_embed[:, 0, :]), chunks=2, dim=-1)

        return mu_zqst, lstd_zqst

    def forward_decoder(self, zqst, imgs, mask_s, toimg=False):

        latent_zst = self.z_embed(zqst) 

        x = self.patch_embed(imgs) 
        xs = mask_s.unsqueeze(-1) * x + (1. - mask_s).unsqueeze(-1) * self.mask_token

        x = torch.cat([latent_zst.unsqueeze(1), xs], dim=1)
        x = x + self.pos_embed

        for blk in self.blocks:
            x = blk(x)
        dec_embed = self.decoder_norm(x)

        latent_pred_patch = self.decoder_pred(dec_embed[:, 1:, :])  
        mu_xp, lstd_xp = torch.tanh(latent_pred_patch[:, :, :-1]), latent_pred_patch[:, -1:, -1:]

        lstd_xp = torch.zeros_like(lstd_xp)

        if toimg:
            pred_img = self.unpatchify(mu_xp) 
            recon_ospt = mask_s.unsqueeze(-1) * self.patchify(imgs) + (1. - mask_s).unsqueeze(-1) * mu_xp
            recon_ospt_img = self.unpatchify(recon_ospt)
            return recon_ospt_img, pred_img

        else:
            return mu_xp, lstd_xp

    def forward(self, imgs, r_ind=None, Sratio=0., Tratio=1., Lst=None, betaKL=1.):

        assert Sratio >= 0.  
        assert (Tratio > 0) or ((Tratio == -10.) and (Lst is not None) and (r_ind is not None))

        x = self.patchify(imgs)  
        mask_s, mask_t, ids_restore, r_ind = self.random_masking(x, r_ind=r_ind, Sratio=Sratio,
                                                                 Tratio=Tratio, Lst=Lst)

        mu_zqst, lstd_zqst = self.forward_encoder(imgs, mask_st=mask_s + mask_t)
        mu_zqs, lstd_zqs = self.forward_encoder(imgs, mask_st=mask_s)

        zqst = mu_zqst + lstd_zqst.exp() * torch.randn_like(mu_zqst)  

        mu_xp, lstd_xp = self.forward_decoder(zqst, imgs, mask_s, toimg=False)

        NLL1 = (lstd_xp + 0.5 * ((x - mu_xp) / lstd_xp.exp()).pow(2)).mean(-1)  
        NLL = (NLL1 * mask_t).sum() / mask_t.sum()
        lsmls = lstd_zqst - lstd_zqs
        KLloss = 0.5 * (
                ((mu_zqst - mu_zqs) / lstd_zqs.exp()).pow(2)
                + lsmls.exp().pow(2) - 1. - 2. * lsmls
        ).mean()  
        loss = NLL + betaKL * KLloss

        return loss, NLL, KLloss


class PatchVariationalAutoEncoder1ViT(nn.Module):

    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 x_dim=768, z_dim=128,
                 embed_dim=1024, depth=24, num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm,
                 ELastNorm='LR', DLastNorm='LR', barweight='Uniform'):
        super().__init__()
        self.in_chans = in_chans
        self.num_patches = (img_size // patch_size) ** 2
        self.L = self.num_patches + 1
        self.patch_size = patch_size
        self.img_size = img_size
        self.z_dim = z_dim
        self.x_dim = x_dim
        self.barweight = barweight

        self.input_embed = nn.Sequential(  
            nn.Linear(x_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, embed_dim),
        )
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim),
                                      requires_grad=False)  

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for i in range(depth)])
        if ELastNorm == 'LN':
            self.encoder_norm = norm_layer(embed_dim)
        elif ELastNorm == 'NO':
            self.encoder_norm = nn.Identity(embed_dim)
        elif ELastNorm == 'IN':
            self.encoder_norm = nn.InstanceNorm1d(self.L, affine=True)
        elif ELastNorm == 'LR':
            self.encoder_norm = nn.LeakyReLU(0.2)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.head_qz = nn.Sequential(  
            nn.Linear(embed_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, 2 * z_dim),
        )

        self.z_embed = nn.Sequential(  
            nn.Linear(z_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, embed_dim),
        )  

        if DLastNorm == 'LN':
            self.decoder_norm = norm_layer(embed_dim)
        elif DLastNorm == 'NO':
            self.decoder_norm = nn.Identity(embed_dim)
        elif DLastNorm == 'IN':
            self.decoder_norm = nn.InstanceNorm1d(self.L, affine=True)
        elif DLastNorm == 'LR':
            self.decoder_norm = nn.LeakyReLU(0.2)

        self.decoder_pred = nn.Sequential(
            nn.Linear(embed_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, 2 * embed_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(2 * embed_dim, x_dim + 1),
        ) 

        self.initialize_weights()

    def initialize_weights(self):
        pos_embed_all = get_2d_sincos_pos_embed(
            self.pos_embed.shape[-1], int(self.num_patches ** .5),
            nos_token=True, cls_token=True)  
        pos_embed = pos_embed_all[:-1, ...]  
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
            if m.weight is not None:
                nn.init.constant_(m.weight, 1.0)

    def patchify(self, imgs):
    
        p = self.patch_size
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], self.in_chans, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * self.in_chans))
        return x

    def unpatchify(self, x):
      
        p = self.patch_size
        h = w = int(x.shape[1] ** .5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_chans))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], self.in_chans, h * p, h * p))
        return imgs

    def random_masking(self, x, r_ind=None, Sratio=0., Tratio=1., Lst=None):
        N, L, D = x.shape  
        if r_ind is None:
            r_ind = [torch.randperm(L).to(x.device) for i in range(N)]
            r_ind = torch.cat(r_ind, dim=0).view(N, L)  
        ids_restore = torch.argsort(r_ind, dim=1)

        mask_s = torch.zeros([N, L], device=x.device)
        mask_t = torch.zeros([N, L], device=x.device)
        if (Tratio == -10.) and (Lst is not None):
            len_s = int(Lst * Sratio)
            len_t = Lst - len_s
        else:
            len_s = int(L * Sratio)
            if isinstance(Tratio, torch.Tensor):
                len_t = int(np.ceil((L - len_s) * Tratio.cpu()))
            else:
                len_t = int(np.ceil((L - len_s) * Tratio))
        mask_s[:, :len_s] = 1
        mask_t[:, len_s:len_s + len_t] = 1
        mask_s = torch.gather(mask_s, dim=1, index=ids_restore)
        mask_t = torch.gather(mask_t, dim=1, index=ids_restore)

        return mask_s, mask_t, ids_restore, r_ind

    def forward_encoder(self, bar_imgs, mask_st):

        x = self.input_embed(bar_imgs) 
        xst = mask_st.unsqueeze(-1) * x + (1. - mask_st).unsqueeze(-1) * self.mask_token

        cls_token = self.cls_token.expand(xst.shape[0], -1, -1)  
        cxst = torch.cat([cls_token, xst], dim=1)

        cxst = cxst + self.pos_embed

        for blk in self.blocks:
            cxst = blk(cxst)
        cxst_embed = self.encoder_norm(cxst)

        mu_zqst, lstd_zqst = torch.chunk(self.head_qz(cxst_embed[:, 0, :]), chunks=2, dim=-1)

        return mu_zqst, lstd_zqst

    def forward_decoder(self, zqst, bar_imgs, mask_s):

        latent_zst = self.z_embed(zqst)  
        x = self.input_embed(bar_imgs)  
        xs = mask_s.unsqueeze(-1) * x + (1. - mask_s).unsqueeze(-1) * self.mask_token

        x = torch.cat([latent_zst.unsqueeze(1), xs], dim=1)
        x = x + self.pos_embed

        for blk in self.blocks:
            x = blk(x)
        dec_embed = self.decoder_norm(x)

        latent_pred_patch = self.decoder_pred(dec_embed[:, 1:, :])  
        mu_xp, lstd_xp = latent_pred_patch[:, :, :-1], latent_pred_patch[:, -1:, -1:]

        lstd_xp = torch.zeros_like(lstd_xp)

        return mu_xp, lstd_xp

    def forward(self, bar_imgs, r_ind=None, Sratio=0., Tratio=1., Lst=None, betaKL=1.):

        assert Sratio >= 0.  
        assert (Tratio > 0) or ((Tratio == -10.) and (Lst is not None) and (r_ind is not None))

        x = bar_imgs  
        mask_s, mask_t, ids_restore, r_ind = self.random_masking(x, r_ind=r_ind, Sratio=Sratio,
                                                                 Tratio=Tratio, Lst=Lst)

        mu_zqst, lstd_zqst = self.forward_encoder(bar_imgs, mask_st=mask_s + mask_t)
        mu_zqs, lstd_zqs = self.forward_encoder(bar_imgs, mask_st=mask_s)

        zqst = mu_zqst + lstd_zqst.exp() * torch.randn_like(mu_zqst)  

        mu_xp, lstd_xp = self.forward_decoder(zqst, bar_imgs, mask_s)


        NLL1 = (lstd_xp + 0.5 * ((x - mu_xp) / lstd_xp.exp()).pow(2)).mean(-1)  # BL
        NLL = (NLL1 * mask_t).sum() / mask_t.sum()
        lsmls = lstd_zqst - lstd_zqs
        KLloss = 0.5 * (
                ((mu_zqst - mu_zqs) / lstd_zqs.exp()).pow(2)
                + lsmls.exp().pow(2) - 1. - 2. * lsmls
        ).mean()  
        loss = NLL + betaKL * KLloss

        return loss, NLL, KLloss, mu_xp
