from functools import partial

import numpy as np
from tqdm import tqdm
import scipy.stats as stats
import math
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import copy
from timm.models.vision_transformer import Block
from models.vision_transformer_enco import Block as encoViTBlock
from models.vision_transformer import Block as decoViTBlock

import os, sys

sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..'))
from models.diffloss import DiffLoss
from models.difflossv2 import DiffLoss as DiffLossv2
from models.diffloss_taylorSeer import DiffLoss as DiffLossTaylorSeer

from models.utils import Text_Embedding, Attention_Block
from calflops import calculate_flops
from calflops.utils import flops_to_string, macs_to_string, params_to_string


def cumulate_flops(block, flops_dic, **kwargs):
    block_flops = block
    # Deepcopy kwargs to ensure the original kwargs are not modified
    kwargs_flops = {key: copy.deepcopy(value) for key, value in kwargs.items()}

    # Now pass the modified kwargs directly to calculate_flops
    flops, macs, params = calculate_flops(model=block_flops,
                                          kwargs=kwargs_flops,
                                          print_results=False,
                                          output_as_string=False)
    for key, value in kwargs_flops.items():
        del value
    del kwargs_flops
    del block_flops
    # gc.collect()
    # torch.cuda.ipc_collect()
    # torch.cuda.empty_cache()
    flops_dic['flops'] += flops
    flops_dic['macs'] += macs
    flops_dic['params'] = params


def print_flops(current):
    flops, macs, params = flops_to_string(current['flops'], units=None, precision=2), \
        macs_to_string(current['macs'], units=None, precision=2), \
        params_to_string(current['params'], units=None, precision=2)
    print("Bert(hfl/chinese-roberta-wwm-ext) FLOPs:%s   MACs:%s   Params:%s \n" % (flops, macs, params))
    # 重置计数器
    current['flops'], current['macs'], current['params'] = 0, 0, 0

def mask_by_order(mask_len, order, bsz, seq_len):
    masking = torch.zeros(bsz, seq_len).cuda()
    masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()],
                            src=torch.ones(bsz, seq_len).cuda()).bool()
    return masking

def global_force_fresh(cache_dic, current):
    return (current['step'] < cache_dic['start_step']
            or (current['step'] - cache_dic['start_step']) % cache_dic['fresh_t'] == 0)

class FLUID(nn.Module):
    """
        Finetune Masked Autoencoder with VisionTransformer backbone
    """

    def __init__(self, img_size=256, vae_stride=16, patch_size=1, text_depth=6,
                 encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
                 decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
                 cross_embed_dim=1024, cross_depth=16, cross_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm,
                 caption_channels=4096, max_length=512,
                 interpolate_offset=0.1,
                 vae_embed_dim=16,
                 mask_ratio_min=0.7,
                 text_drop_prob=0.1,
                 attn_dropout=0.1,
                 proj_dropout=0.1,
                 diffloss_d=3,
                 diffloss_w=1024,
                 num_sampling_steps='100',
                 diffusion_batch_mul=4,
                 grad_checkpointing=False,
                 token_cache=False,
                 cfg_cache=False,
                 diff_type='origin',
                 ):
        super().__init__()

        # --------------------------------------------------------------------------
        # VAE and patchify specifics
        self.vae_embed_dim = vae_embed_dim

        self.img_size = img_size
        self.vae_stride = vae_stride
        self.patch_size = patch_size
        self.seq_h = self.seq_w = img_size // vae_stride // patch_size
        self.seq_len = 256  # self.seq_h * self.seq_w # start from 256px ckpt
        self.token_embed_dim = vae_embed_dim * patch_size ** 2
        self.grad_checkpointing = grad_checkpointing
        self.caption_channels = caption_channels
        self.max_length = max_length
        self.cross_embed_dim = cross_embed_dim
        self.interpolate_offset = interpolate_offset
        self.encoder_embed_dim = encoder_embed_dim
        self.encoder_depth = encoder_depth
        self.encoder_num_heads = encoder_num_heads
        self.diffloss_d = diffloss_d

        # --------------------------------------------------------------------------
        # Fluid cross attn specifics (Text Embedding)
        self.last_embed = nn.Linear(decoder_embed_dim, cross_embed_dim, bias=True)
        self.last_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, cross_embed_dim))

        self.text_drop_prob = text_drop_prob
        # Fake text embedding for CFG's unconditional generation
        self.fake_latent = nn.Parameter(torch.zeros(1, 1, cross_embed_dim))

        self.text_emb = Text_Embedding(caption_channels, encoder_embed_dim, text_num_heads=encoder_num_heads,
                                       text_depth=text_depth,
                                       mlp_ratio=mlp_ratio, norm_layer=norm_layer, proj_dropout=proj_dropout,
                                       attn_dropout=attn_dropout,
                                       grad_checkpointing=grad_checkpointing)

        # Cross-attention layer (projecting text embedding into the vision token space)
        self.cross_attn_blocks = nn.ModuleList([
            Attention_Block(dim=cross_embed_dim, num_heads=cross_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True,
                            norm_layer=norm_layer,
                            proj_drop=proj_dropout, attn_drop=attn_dropout, is_cross=True) for _ in range(cross_depth)])
        self.cross_norm = norm_layer(cross_embed_dim)

        # --------------------------------------------------------------------------
        # Fluid variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
        self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)

        # --------------------------------------------------------------------------
        # Fluid encoder specifics
        self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
        self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
        self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, encoder_embed_dim))

        self.encoder_blocks = nn.ModuleList([
            encoViTBlock(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
                  proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
        self.encoder_norm = norm_layer(encoder_embed_dim)

        # --------------------------------------------------------------------------
        # Fluid decoder specifics
        self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))

        self.decoder_blocks = nn.ModuleList([
            decoViTBlock(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
                  proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)

        self.initialize_weights()

        # --------------------------------------------------------------------------
        # Diffusion Loss
        if diff_type=='cache':
            self.diffloss = DiffLossv2(
                target_channels=self.token_embed_dim,
                z_channels=decoder_embed_dim,
                width=diffloss_w,
                depth=diffloss_d,
                num_sampling_steps=num_sampling_steps,
                grad_checkpointing=grad_checkpointing
            )
        elif diff_type=='taylor':
            self.diffloss = DiffLossTaylorSeer(
                target_channels=self.token_embed_dim,
                z_channels=decoder_embed_dim,
                width=diffloss_w,
                depth=diffloss_d,
                num_sampling_steps=num_sampling_steps,
                grad_checkpointing=grad_checkpointing
            )
        elif diff_type=='origin':
            self.diffloss = DiffLoss(
                target_channels=self.token_embed_dim,
                z_channels=decoder_embed_dim,
                width=diffloss_w,
                depth=diffloss_d,
                num_sampling_steps=num_sampling_steps,
                grad_checkpointing=grad_checkpointing
            )
        else:
            raise NotImplementedError

        self.diffusion_batch_mul = diffusion_batch_mul

    def initialize_weights(self):
        # parameters
        torch.nn.init.normal_(self.fake_latent, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)
        torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
        torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
        torch.nn.init.normal_(self.last_pos_embed_learned, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
            if m.weight is not None:
                nn.init.constant_(m.weight, 1.0)

    def patchify(self, x):
        bsz, c, h, w = x.shape
        p = self.patch_size
        h_, w_ = h // p, w // p

        x = x.reshape(bsz, c, h_, p, w_, p)
        x = torch.einsum('nchpwq->nhwcpq', x)
        x = x.reshape(bsz, h_ * w_, c * p ** 2)
        return x  # [n, l, d]

    def unpatchify(self, x):
        bsz = x.shape[0]
        p = self.patch_size
        c = self.vae_embed_dim
        h_, w_ = self.seq_h, self.seq_w

        x = x.reshape(bsz, h_, w_, c, p, p)
        x = torch.einsum('nhwcpq->nchpwq', x)
        x = x.reshape(bsz, c, h_ * p, w_ * p)
        return x  # [n, c, h, w]

    def sample_orders(self, bsz):
        # generate a batch of random generation orders
        orders = []
        for _ in range(bsz):
            order = np.array(list(range(self.seq_len)))
            np.random.shuffle(order)
            orders.append(order)
        orders = torch.Tensor(np.array(orders)).cuda().long()
        return orders

    def random_masking(self, x, orders):
        # generate token mask
        bsz, seq_len, embed_dim = x.shape
        mask_rate = self.mask_ratio_generator.rvs(1)[0]
        num_masked_tokens = int(np.ceil(seq_len * mask_rate))
        mask = torch.zeros(bsz, seq_len, device=x.device)
        mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
                             src=torch.ones(bsz, seq_len, device=x.device))
        return mask

    def interpolate_pos_encoding(self, x, pos_embed):
        previous_dtype = x.dtype
        npatch = x.shape[1]
        N = pos_embed.shape[1]
        if npatch == N and self.seq_w == self.seq_h:
            return pos_embed
        pos_embed = pos_embed.float()
        dim = x.shape[-1]
        w0 = self.seq_w
        h0 = self.seq_h
        M = int(math.sqrt(N))  # Recover the number of patches in each dimension
        assert N == M * M
        kwargs = {}
        if self.interpolate_offset:
            # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
            # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
            sx = float(w0 + self.interpolate_offset) / M
            sy = float(h0 + self.interpolate_offset) / M
            kwargs["scale_factor"] = (sx, sy)
        else:
            # Simply specify an output size instead of a scale factor
            kwargs["size"] = (w0, h0)
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
            mode="bicubic",
            **kwargs,
        )
        assert (w0, h0) == pos_embed.shape[-2:]
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return pos_embed.to(previous_dtype)


    def obtain_rela_mask(self, current, cache_dic, bsz, x):
        # bsz ,seq_len, dim  相对seq_len， 上一step预测的token
        if cache_dic['prev_pred'] is None:
            prev_pred = torch.zeros_like(current['mask']).to(x.device).bool()
        else:
            prev_pred = cache_dic['prev_pred']
        # bsz, seq_len, dim  相对于seq_len, 截至本step已经预测的token
        cur_pred_mask = ~current['mask']
        current["prev_mtp_rela"] = torch.masked_select(prev_pred, cur_pred_mask).reshape(bsz, -1)


    def obtain_update_mask(self, mask, prev_pred_mask, mask_to_pred, current):
        '''
        1表示需要更新的 token，0表示使用缓存的 token
        '''
        device = mask.device
        prev_pred_mask = torch.zeros_like(mask_to_pred, device=device).bool() if prev_pred_mask is None else prev_pred_mask.bool()
        unpredicted_mask = mask.bool()
        predicted_mask = torch.logical_not(mask).bool()  # 确保为布尔类型
        to_pred_mask = mask_to_pred.bool()
        none_cache_mask = torch.ones_like(mask, device=device).bool()

        current["update_mask"] = none_cache_mask
        current['to_pred_mask'] = to_pred_mask
        current['prev_pred_mask'] = prev_pred_mask
        current['predicted_mask'] = predicted_mask


    def forward_fluid_encoder(self, x, mask, current, cache_dic):
        if current['cal_flops']:
            cumulate_flops(self.z_proj, flops_dic=current, input=x)
        x = self.z_proj(x)
        bsz, _, embed_dim = x.shape
        encoder_pos_embed = self.interpolate_pos_encoding(x, self.encoder_pos_embed_learned)
        x = x + encoder_pos_embed
        x = self.z_proj_ln(x)
        x = x[torch.logical_not(mask).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
        _, cur_pred_len, _ = x.shape
        # apply Transformer blocks
        if self.grad_checkpointing and not torch.jit.is_scripting():
            for block in self.encoder_blocks:
                x = checkpoint(block, x)
        else:
            if current['token_cache'] and not current['is_force_fresh']:
                self.obtain_rela_mask(current, cache_dic, bsz, x)
                x = torch.masked_select(x, current["prev_mtp_rela"].unsqueeze(-1).expand(-1, -1, embed_dim).bool()).reshape(bsz, -1, embed_dim)
            if current['cfg_cache'] and not current['is_force_fresh']:
                x = x[:int(bsz / 2)].detach().clone()
            for i, block in enumerate(self.encoder_blocks):
                current['enco_layer_idx'] = i
                if current['cal_flops']:
                    cumulate_flops(block, flops_dic=current, x=x, current=current, cache_dic=cache_dic)
                x = block(x, current, cache_dic)
            if current['cfg_cache'] and not current['is_force_fresh']:
                x = torch.cat([x, x], dim = 0)
            if current['token_cache'] and not current['is_force_fresh']:
                new_x_full = torch.zeros(bsz, cur_pred_len, embed_dim).to(x)
                new_x_full[current["prev_mtp_rela"].bool()] = x.view(-1, x.size(-1))
                x = new_x_full
        x = self.encoder_norm(x)
        return x

    def forward_fluid_decoder(self, x, mask, current, cache_dic):
        if current['cal_flops']:
            cumulate_flops(self.decoder_embed, flops_dic=current, input=x)
        x = self.decoder_embed(x)
        # pad mask tokens
        mask_tokens = self.mask_token.repeat(mask.shape[0], mask.shape[1], 1).to(x.dtype)
        x_after_pad = mask_tokens.clone()
        x_after_pad[torch.logical_not(mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
        # decoder position embedding
        decoder_pos_embed = self.interpolate_pos_encoding(x_after_pad, self.decoder_pos_embed_learned)
        x = x_after_pad + decoder_pos_embed

        self.obtain_update_mask(mask, cache_dic['prev_pred'], cache_dic['mask_to_pred'], current)
        current['to_pred_len'] = torch.sum(cache_dic['mask_to_pred'][0])
        current['prev_pred_len'] = 0 if cache_dic['prev_pred'] is None else torch.sum(cache_dic['prev_pred'][0])
        if current['token_cache'] and not current['is_force_fresh']:
            B, N, C = x.shape
            cache_dic['cache']['de_ou'] = torch.where(current['prev_pred_mask'].unsqueeze(-1).expand(-1, -1, C), x,
                                                      cache_dic['cache']['de_ou'])
            x = copy.deepcopy(cache_dic['cache']['de_ou'])
            current['original_x'] = x
            x = x[current["update_mask"].nonzero(as_tuple=True)].reshape(x.shape[0], -1, x.shape[2])
        else:
            cache_dic['cache']['de_ou'] = x

        B, N, C = x.size()
        x_origi = x
        if current['cfg_cache'] and not current['is_force_fresh']:
            x = x[:int(B / 2)].detach().clone()
            current['orig_update_mask'] = current['update_mask']
            current['orig_to_pred_mask'] = current['to_pred_mask']
            current['orig_prev_pred_mask'] = current['prev_pred_mask']
            current['orig_predicted_mask'] = current['predicted_mask']

            current['update_mask'] = current['update_mask'][:int(B / 2)].detach().clone()
            current['to_pred_mask'] = current['to_pred_mask'][:int(B / 2)].detach().clone()
            current['prev_pred_mask'] = current['prev_pred_mask'][:int(B / 2)].detach().clone()
            current['predicted_mask'] = current['predicted_mask'][:int(B / 2)].detach().clone()

        # apply Transformer blocks
        if self.grad_checkpointing and not torch.jit.is_scripting():
            for block in self.decoder_blocks:
                x = checkpoint(block, x)
        else:
            for i, block in enumerate(self.decoder_blocks):
                current['layer_idx'] = i
                if current['cal_flops']:
                    cumulate_flops(block, flops_dic=current, x=x, current=current, cache_dic=cache_dic)
                x = block(x, current, cache_dic)


        if current['cfg_cache'] and not current['is_force_fresh']:
            if current['token_cache']:
                diff = cache_dic['cache'][current['layer_idx']]['diff'][current['update_mask'].nonzero(as_tuple=True)].reshape(int(B / 2), -1, C)
            else:
                diff = cache_dic['cache'][current['layer_idx']]['diff']
            x_origi[:int(B / 2)] = x
            x_origi[int(B / 2):] = x + diff
            x = x_origi
            current['update_mask'] = current['orig_update_mask']
            current['to_pred_mask'] = current['orig_to_pred_mask']
            current['prev_pred_mask'] = current['orig_prev_pred_mask']
            current['predicted_mask'] = current['orig_predicted_mask']
        else:
            cache_dic['cache'][current['layer_idx']]['diff'] = x[int(B / 2):] - x[:int(B / 2)]


        if current['token_cache'] and not current['is_force_fresh']:
            current['original_x'][(current['update_mask']).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
            x = current['original_x']
        x = self.decoder_norm(x)
        return x

    def forward_fluid_text_decoder(self, x, text_embeddings, mask_to_pred, current):
        if current['cal_flops']:
            cumulate_flops(self.last_embed, flops_dic=current, input=x)
        x = self.last_embed(x)
        # x = x + self.last_pos_embed_learned
        text_pos_embed = self.interpolate_pos_encoding(x, self.last_pos_embed_learned)
        x = x + text_pos_embed
        bsz, _, dim = x.shape

        x = x[mask_to_pred.nonzero(as_tuple=True)].reshape(bsz, -1, dim)

        # apply Transformer blocks
        if self.grad_checkpointing and not torch.jit.is_scripting():
            for block in self.cross_attn_blocks:
                x = checkpoint(block, x, text_embeddings)
        else:
            for block in self.cross_attn_blocks:
                if current['cal_flops']:
                    cumulate_flops(block, flops_dic=current, q=x, k_v=text_embeddings)
                x = block(x, text_embeddings)
        x = self.cross_norm(x)

        return x

    def forward_loss(self, z, target, mask):
        bsz, seq_len, _ = target.shape
        target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
        z = z.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
        mask = mask.reshape(bsz * seq_len).repeat(self.diffusion_batch_mul)
        loss = self.diffloss(z=z, target=target, mask=mask)
        return loss

    def forward(self, imgs, texts, height=512, width=512):
        assert height % (self.vae_stride * self.patch_size) == 0 and width % (self.vae_stride * self.patch_size) == 0
        self.seq_h, self.seq_w = height // self.vae_stride // self.patch_size, width // self.vae_stride // self.patch_size
        self.seq_len = self.seq_h * self.seq_w
        # patchify and mask (drop) tokens
        x = self.patchify(imgs)
        gt_latents = x.clone().detach()
        orders = self.sample_orders(bsz=x.size(0))

        mask = self.random_masking(x, orders)

        # mae encoder
        x = self.forward_fluid_encoder(x, mask)

        # mae decoder
        x = self.forward_fluid_decoder(x, mask)

        # Tokenize and encode the input text
        text_embeddings = self.text_emb(texts)
        z = self.forward_fluid_text_decoder(x, text_embeddings)

        # diffloss
        loss = self.forward_loss(z=z, target=gt_latents, mask=mask)

        return loss

    def get_cache_dic(self, bsz, depth, num_heads, embed_dim, device, cfg_cache):
        assert embed_dim % num_heads == 0, 'dim should be divisible by num_heads'
        head_dim = embed_dim // num_heads
        cache_dic = {}
        cache = {}
        enco_cache = {}

        if device.type != 'cpu':
            dtype = torch.float16
        else:
            dtype = torch.float32
        if not cfg_cache:
            bsz = int(bsz * 2)
        for j in range(depth):
            cache[j] = {}
            enco_cache[j] = {}
            enco_cache[j]['k'] = torch.zeros(
                (
                    bsz,
                    num_heads,
                    1024,
                    head_dim,
                ), dtype=dtype, device=device)

            enco_cache[j]['v'] = torch.zeros(
                (
                    bsz,
                    num_heads,
                    1024,
                    head_dim,
                ), dtype=dtype, device=device)
            enco_cache[j]['cur_kv_len'] = 0

        cache_dic['cache'] = cache
        cache_dic['enco_cache'] = enco_cache
        current = {}
        return cache_dic, current

    def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", texts=None, temperature=1.0, height=512,
                      width=512, progress=False, args=None):

        device = torch.device("cuda")
        cache_dic, current = self.get_cache_dic(bsz,  self.encoder_depth, self.encoder_num_heads, self.encoder_embed_dim, device, args.cfg_cache)
        current['depth'] = self.encoder_depth
        current['token_cache'] = args.token_cache
        current['cfg_cache'] = args.cfg_cache
        current['cal_flops'] = args.cal_flops
        cache_dic['start_step'] = args.start_step
        cache_dic['fresh_t'] = args.fresh_t
        cache_dic['prev_pred'] = None
        cache_dic['mask_to_pred'] = None
        current['flops'], current['macs'], current['params'] = 0, 0, 0

        assert height % (self.vae_stride * self.patch_size) == 0 and width % (self.vae_stride * self.patch_size) == 0
        self.seq_h, self.seq_w = height // self.vae_stride // self.patch_size, width // self.vae_stride // self.patch_size
        self.seq_len = self.seq_h * self.seq_w

        # init and sample generation orders
        mask = torch.ones(bsz, self.seq_len).cuda()
        tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
        orders = self.sample_orders(bsz)

        indices = list(range(num_iter))
        if progress:
            indices = tqdm(indices)
        if texts is not None:
            text_embeddings_temp = self.text_emb(texts)
        # generate latents
        for step in indices:
            current['step'] = step
            cur_tokens = tokens.clone()
            # text_embedding and CFG
            if texts is None:
                text_embeddings = self.fake_latent.expand(bsz, self.max_length, self.cross_embed_dim)
            else:
                text_embeddings = text_embeddings_temp
            if not cfg == 1.0:
                tokens = torch.cat([tokens, tokens], dim=0)
                text_embeddings = torch.cat(
                    [text_embeddings, self.fake_latent.expand(bsz, self.max_length, self.cross_embed_dim)], dim=0)
                mask = torch.cat([mask, mask], dim=0)

            # mask ratio for the next round, following MaskGIT and MAGE.
            mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
            mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda()

            # masks out at least one for the next iteration
            mask_len = torch.maximum(torch.Tensor([1]).cuda(),
                                     torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))

            # get masking for next iteration and locations to be predicted in this iteration
            mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)

            if step >= num_iter - 1:
                mask_to_pred = mask[:bsz].bool()
            else:
                mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
            if not cfg == 1.0:
                mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)

            cache_dic['mask_to_pred'] = mask_to_pred
            current['is_force_fresh'] = global_force_fresh(cache_dic, current)
            current['use_cache'] = not current['is_force_fresh']
            current['mask'] = mask.bool()

            # mae encoder
            x = self.forward_fluid_encoder(tokens, mask, current, cache_dic)

            # mae decoder
            x = self.forward_fluid_decoder(x, mask, current, cache_dic)

            z_pred = self.forward_fluid_text_decoder(x, text_embeddings, mask_to_pred, current)
            z_pred = z_pred.reshape(-1, z_pred.size(2))

            cache_dic['prev_pred'] = mask_to_pred
            mask = mask_next

            # z_pred = z[mask_to_pred.nonzero(as_tuple=True)]
            # cfg schedule follow Muse
            if cfg_schedule == "linear":
                cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
            elif cfg_schedule == "constant":
                cfg_iter = cfg
            else:
                raise NotImplementedError

            # sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter)
            cache_type = 'default'
            if args.diff_type == 'cache':
                sampled_token_latent = self.diffloss.sample(z_pred, temperature=temperature, cfg=cfg_iter,
                                                            diffloss_d=self.diffloss_d, cache_type=cache_type, step=step, cache_threshold=7, device="cuda")  # 1:[256,16] 2:[512,16]...
                if current['cal_flops']:
                    cumulate_flops(self.diffloss, flops_dic=current, z=z_pred, temperature=temperature, cfg=cfg_iter,
                                        diffloss_d=self.diffloss_d, cache_type=cache_type, step=step, cache_threshold=7,
                                        device="cuda")  # 1:[256,16] 2:[512,16]...

            elif args.diff_type == 'taylor':
                # pass
                sampled_token_latent = self.diffloss.sample(z_pred, temperature, cfg_iter)
                if current['cal_flops']:
                    cumulate_flops(self.diffloss, flops_dic=current, z=z_pred, temperature=temperature, cfg=cfg_iter)
            else:
                # pass
                sampled_token_latent = self.diffloss.sample(z_pred, temperature, cfg_iter)
                if current['cal_flops']:
                    cumulate_flops(self.diffloss, flops_dic=current, z=z_pred, temperature=temperature, cfg=cfg_iter)

            if not cfg == 1.0:
                sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)  # Remove null class samples
                mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)

            cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
            tokens = cur_tokens.clone()

            # torch.cuda.synchronize()
            # if current['is_force_fresh']:
            #     print(f"force fresh----total attn: {format(timetick0.elapsed_time(timetick4), '.2f')} ms."
            #           f" encoder: {format(timetick0.elapsed_time(timetick1), '.2f')} ms."
            #           f" decoder: {format(timetick1.elapsed_time(timetick2), '.2f')} ms."
            #           f" text decoder: {format(timetick2.elapsed_time(timetick3), '.2f')} ms."
            #           f" diffusion: {format(timetick3.elapsed_time(timetick4), '.2f')} ms.")
            # else:
            #     print(f"not force fresh----total attn: {format(timetick0.elapsed_time(timetick4), '.2f')} ms."
            #           f" encoder: {format(timetick0.elapsed_time(timetick1), '.2f')} ms."
            #           f" decoder: {format(timetick1.elapsed_time(timetick2), '.2f')} ms."
            #           f" text decoder: {format(timetick2.elapsed_time(timetick3), '.2f')} ms."
            #           f" diffusion: {format(timetick3.elapsed_time(timetick4), '.2f')} ms.")
        # unpatchify
        if current['cal_flops']:
            print_flops(current)
        tokens = self.unpatchify(tokens)
        return tokens


def fluid_base(**kwargs):
    model = FLUID(
        encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
        decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12,
        cross_embed_dim=768, cross_depth=12, cross_num_heads=12,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def fluid_large(**kwargs):
    model = FLUID(
        encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
        decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
        cross_embed_dim=1024, cross_depth=16, cross_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def fluid_huge(**kwargs):
    model = FLUID(
        encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16,
        decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16,
        cross_embed_dim=1280, cross_depth=20, cross_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


if __name__ == "__main__":
    mt5_cache_dir = '/data/xianfeng/code/model/google/flan-t5-xxl'
    mt5_model_name = mt5_cache_dir
    # img_size = 256
    patch_size = 2
    vae_embed_dim = 16
    vae_stride = 8
    vae_temperal_stride = 4
    diffloss_d = 8
    diffloss_w = 1280

    model = fluid_large(vae_embed_dim=vae_embed_dim, vae_stride=vae_stride, patch_size=patch_size,
                        diffloss_d=diffloss_d, diffloss_w=diffloss_w)

    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Number of trainable parameters: {}M".format(n_params / 1e6))
    model = model.cuda()

    # generate pseudo image data (batch_size, channels, height, width)
    fake_image = torch.randn(4, 16, 32, 32).cuda()  # batch_size=4, 3 channels, 256x256 resolution

    # generate pseudo text data
    fake_texts = ["This is a test sentence.", "Another example sentence.", "Text embedding test.",
                  "Vision and text fusion."]

    # T5 Embedding
    import os, sys

    sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
    sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..'))

    # from models.vae import AutoencoderKL
    # define the vae and mar model
    from diffusers.models import AutoencoderKL

    vae_path = '/data/xianfeng/code/model/stabilityai/stable-diffusion-3.5-large'
    vae = AutoencoderKL.from_pretrained(os.path.join(vae_path, "vae")).cuda().eval()
    for param in vae.parameters():
        param.requires_grad = False

    # posterior = vae.encode(fake_video).latent_dist.sample()

    from models.utils import T5_Embedding

    max_length = 512
    t5_infer = T5_Embedding(mt5_model_name, mt5_cache_dir, max_length).cuda()

    import time

    start_time = time.time()

    text_emb = t5_infer(fake_texts)

    x = fake_image
    # posterior = vae.encode(fake_image)

    # # normalize the std of latent to be 1. Change it if you use a different tokenizer
    # if vae.config.shift_factor is not None:
    #     x = (posterior.sample() - vae.config.shift_factor) * vae.config.scaling_factor
    # else:
    #     x = posterior.sample().mul_(vae.config.scaling_factor)

    # import pdb
    # pdb.set_trace()
    # check
    with torch.cuda.amp.autocast():
        loss = model(x, text_emb, height=256, width=256)

    end_time = time.time()
    print(f'cost time:{end_time - start_time}')
    import pdb

    pdb.set_trace()
    print(f"Loss: {loss.item()}")