from functools import partial

import numpy as np
from tqdm import tqdm
import scipy.stats as stats
import math
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import copy
from timm.models.vision_transformer import Block
from models.vision_transformer_enco import Block as encoViTBlock
from models.vision_transformer import Block as decoViTBlock

import os, sys

sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..'))
from models.diffloss_GtR import DiffLoss as DiffLossGtR

from models.utils import Text_Embedding, Attention_Block
from calflops import calculate_flops
from calflops.utils import flops_to_string, macs_to_string, params_to_string
from models.sampler_util import *

def cumulate_flops(block, flops_dic, **kwargs):
    block_flops = block
    # Deepcopy kwargs to ensure the original kwargs are not modified
    kwargs_flops = {key: copy.deepcopy(value) for key, value in kwargs.items()}

    # Now pass the modified kwargs directly to calculate_flops
    flops, macs, params = calculate_flops(model=block_flops,
                                          kwargs=kwargs_flops,
                                          print_results=False,
                                          output_as_string=False)
    for key, value in kwargs_flops.items():
        del value
    del kwargs_flops
    del block_flops
    # gc.collect()
    # torch.cuda.ipc_collect()
    # torch.cuda.empty_cache()
    flops_dic['flops'] += flops
    flops_dic['macs'] += macs
    flops_dic['params'] = params


def print_flops(current):
    flops, macs, params = flops_to_string(current['flops'], units=None, precision=2), \
        macs_to_string(current['macs'], units=None, precision=2), \
        params_to_string(current['params'], units=None, precision=2)
    print("Bert(hfl/chinese-roberta-wwm-ext) FLOPs:%s   MACs:%s   Params:%s \n" % (flops, macs, params))
    # 重置计数器
    current['flops'], current['macs'], current['params'] = 0, 0, 0

def mask_by_order(mask_len, order, bsz, seq_len):
    masking = torch.zeros(bsz, seq_len).cuda()
    masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()],
                            src=torch.ones(bsz, seq_len).cuda()).bool()
    return masking

def global_force_fresh(cache_dic, current):
    return (current['step'] < cache_dic['start_step']
            or (current['step'] - cache_dic['start_step']) % cache_dic['fresh_t'] == 0)

class FLUID(nn.Module):
    """
        Finetune Masked Autoencoder with VisionTransformer backbone
    """

    def __init__(self, img_size=256, vae_stride=16, patch_size=1, text_depth=6,
                 encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
                 decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
                 cross_embed_dim=1024, cross_depth=16, cross_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm,
                 caption_channels=4096, max_length=512,
                 interpolate_offset=0.1,
                 vae_embed_dim=16,
                 mask_ratio_min=0.7,
                 text_drop_prob=0.1,
                 attn_dropout=0.1,
                 proj_dropout=0.1,
                 diffloss_d=3,
                 diffloss_w=1024,
                 diffusion_batch_mul=4,
                 grad_checkpointing=False,
                 token_cache=False,
                 cfg_cache=False,
                 # GtR specific parameters
                 diff_upper_steps=50,
                 diff_lower_steps=5,
                 diff_annealing_strategy="linear",
                 diff_sampler="default",
                 pivot_step_threshold=15,
                 pivot_diffusion_steps=50,
                 token_selection_strategy="pivotal",
                 pivot_token_percentage=0.1,
                 order_strategy="random",
                 mask_strategy="cosine",
                 ):
        super().__init__()

        # --------------------------------------------------------------------------
        # VAE and patchify specifics
        self.vae_embed_dim = vae_embed_dim

        self.img_size = img_size
        self.vae_stride = vae_stride
        self.patch_size = patch_size
        self.seq_h = self.seq_w = img_size // vae_stride // patch_size
        self.seq_len = 256  # self.seq_h * self.seq_w # start from 256px ckpt
        self.token_embed_dim = vae_embed_dim * patch_size ** 2
        self.grad_checkpointing = grad_checkpointing
        self.caption_channels = caption_channels
        self.max_length = max_length
        self.cross_embed_dim = cross_embed_dim
        self.interpolate_offset = interpolate_offset
        self.encoder_embed_dim = encoder_embed_dim
        self.encoder_depth = encoder_depth
        self.encoder_num_heads = encoder_num_heads
        self.diffloss_d = diffloss_d

        # --------------------------------------------------------------------------
        # Fluid cross attn specifics (Text Embedding)
        self.last_embed = nn.Linear(decoder_embed_dim, cross_embed_dim, bias=True)
        self.last_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, cross_embed_dim))

        self.text_drop_prob = text_drop_prob
        # Fake text embedding for CFG's unconditional generation
        self.fake_latent = nn.Parameter(torch.zeros(1, 1, cross_embed_dim))

        self.text_emb = Text_Embedding(caption_channels, encoder_embed_dim, text_num_heads=encoder_num_heads,
                                       text_depth=text_depth,
                                       mlp_ratio=mlp_ratio, norm_layer=norm_layer, proj_dropout=proj_dropout,
                                       attn_dropout=attn_dropout,
                                       grad_checkpointing=grad_checkpointing)

        # Cross-attention layer (projecting text embedding into the vision token space)
        self.cross_attn_blocks = nn.ModuleList([
            Attention_Block(dim=cross_embed_dim, num_heads=cross_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True,
                            norm_layer=norm_layer,
                            proj_drop=proj_dropout, attn_drop=attn_dropout, is_cross=True) for _ in range(cross_depth)])
        self.cross_norm = norm_layer(cross_embed_dim)

        # --------------------------------------------------------------------------
        # Fluid variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
        self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)

        # --------------------------------------------------------------------------
        # Fluid encoder specifics
        self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
        self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
        self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, encoder_embed_dim))

        self.encoder_blocks = nn.ModuleList([
            encoViTBlock(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
                  proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
        self.encoder_norm = norm_layer(encoder_embed_dim)

        # --------------------------------------------------------------------------
        # Fluid decoder specifics
        self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))

        self.decoder_blocks = nn.ModuleList([
            decoViTBlock(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
                  proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)

        self.initialize_weights()

        # --------------------------------------------------------------------------
        # Diffusion Loss - using GtR
        self.diffloss = DiffLossGtR(
            target_channels=self.token_embed_dim,
            z_channels=decoder_embed_dim,
            width=diffloss_w,
            depth=diffloss_d,
            grad_checkpointing=grad_checkpointing,
            diff_upper_steps=diff_upper_steps,
            diff_lower_steps=diff_lower_steps,
            diff_annealing_strategy=diff_annealing_strategy,
            diff_sampler=diff_sampler,
            pivot_step_threshold=pivot_step_threshold,
            pivot_diffusion_steps=pivot_diffusion_steps,
            token_selection_strategy=token_selection_strategy,
            pivot_token_percentage=pivot_token_percentage
        )

        self.diffusion_batch_mul = diffusion_batch_mul
        self.order_strategy = order_strategy
        self.mask_strategy = mask_strategy
        # Piecewise cosine decay schedule configuration
        # Will be set dynamically in sample_tokens method based on num_iter
        self.piecewise_schedule = None

    def initialize_weights(self):
        # parameters
        torch.nn.init.normal_(self.fake_latent, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)
        torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
        torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
        torch.nn.init.normal_(self.last_pos_embed_learned, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
            if m.weight is not None:
                nn.init.constant_(m.weight, 1.0)

    def patchify(self, x):
        bsz, c, h, w = x.shape
        p = self.patch_size
        h_, w_ = h // p, w // p

        x = x.reshape(bsz, c, h_, p, w_, p)
        x = torch.einsum('nchpwq->nhwcpq', x)
        x = x.reshape(bsz, h_ * w_, c * p ** 2)
        return x  # [n, l, d]

    def unpatchify(self, x):
        bsz = x.shape[0]
        p = self.patch_size
        c = self.vae_embed_dim
        h_, w_ = self.seq_h, self.seq_w

        x = x.reshape(bsz, h_, w_, c, p, p)
        x = torch.einsum('nhwcpq->nchpwq', x)
        x = x.reshape(bsz, c, h_ * p, w_ * p)
        return x  # [n, c, h, w]

    def sample_orders(self, bsz):
        # generate a batch of random generation orders
        orders = []
        for _ in range(bsz):
            order = np.array(list(range(self.seq_len)))
            np.random.shuffle(order)
            orders.append(order)
        orders = torch.Tensor(np.array(orders)).cuda().long()
        return orders

    def compute_piecewise_cosine_decay(self, step):
        """
        计算分段余弦衰减的掩码数量

        Args:
            step: 当前迭代步数

        Returns:
            mask_length: 当前步数对应的掩码数量
        """
        # 获取排序的步数端点
        steps = sorted(self.piecewise_schedule.keys())

        # 找到当前步数所在的区间
        start_step = 0
        end_step = steps[-1]
        start_value = self.piecewise_schedule[0]
        end_value = self.piecewise_schedule[steps[-1]]

        for i in range(len(steps) - 1):
            if step <= steps[i + 1]:
                start_step = steps[i]
                end_step = steps[i + 1]
                start_value = self.piecewise_schedule[start_step]
                end_value = self.piecewise_schedule[end_step]
                break

        # 如果步数超过最后一个端点，直接返回最终值
        if step >= steps[-1]:
            return self.piecewise_schedule[steps[-1]]

        # 在当前区间内计算余弦衰减
        if start_step == end_step:
            return start_value

        # 计算在当前区间内的进度（0到1）
        progress = (step - start_step) / (end_step - start_step)

        # 应用余弦衰减：cos(0) = 1, cos(π/2) = 0
        # cosine_decay = np.cos(progress * math.pi / 2)
        cosine_decay = 1 - (progress) ** 2

        # 线性插值：从start_value衰减到end_value
        mask_length = end_value + (start_value - end_value) * cosine_decay

        return mask_length

    def random_masking(self, x, orders):
        # generate token mask
        bsz, seq_len, embed_dim = x.shape
        mask_rate = self.mask_ratio_generator.rvs(1)[0]
        num_masked_tokens = int(np.ceil(seq_len * mask_rate))
        mask = torch.zeros(bsz, seq_len, device=x.device)
        mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
                             src=torch.ones(bsz, seq_len, device=x.device))
        return mask

    def interpolate_pos_encoding(self, x, pos_embed):
        previous_dtype = x.dtype
        npatch = x.shape[1]
        N = pos_embed.shape[1]
        if npatch == N and self.seq_w == self.seq_h:
            return pos_embed
        pos_embed = pos_embed.float()
        dim = x.shape[-1]
        w0 = self.seq_w
        h0 = self.seq_h
        M = int(math.sqrt(N))  # Recover the number of patches in each dimension
        assert N == M * M
        kwargs = {}
        if self.interpolate_offset:
            # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
            # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
            sx = float(w0 + self.interpolate_offset) / M
            sy = float(h0 + self.interpolate_offset) / M
            kwargs["scale_factor"] = (sx, sy)
        else:
            # Simply specify an output size instead of a scale factor
            kwargs["size"] = (w0, h0)
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
            mode="bicubic",
            **kwargs,
        )
        assert (w0, h0) == pos_embed.shape[-2:]
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return pos_embed.to(previous_dtype)


    def obtain_rela_mask(self, current, cache_dic, bsz, x):
        # bsz ,seq_len, dim  相对seq_len， 上一step预测的token
        if cache_dic['prev_pred'] is None:
            prev_pred = torch.zeros_like(current['mask']).to(x.device).bool()
        else:
            prev_pred = cache_dic['prev_pred']
        # bsz, seq_len, dim  相对于seq_len, 截至本step已经预测的token
        cur_pred_mask = ~current['mask']
        current["prev_mtp_rela"] = torch.masked_select(prev_pred, cur_pred_mask).reshape(bsz, -1)


    def obtain_update_mask(self, mask, prev_pred_mask, mask_to_pred, current):
        '''
        1表示需要更新的 token，0表示使用缓存的 token
        '''
        device = mask.device
        prev_pred_mask = torch.zeros_like(mask_to_pred, device=device).bool() if prev_pred_mask is None else prev_pred_mask.bool()
        unpredicted_mask = mask.bool()
        predicted_mask = torch.logical_not(mask).bool()  # 确保为布尔类型
        to_pred_mask = mask_to_pred.bool()
        none_cache_mask = torch.ones_like(mask, device=device).bool()

        current["update_mask"] = none_cache_mask
        current['to_pred_mask'] = to_pred_mask
        current['prev_pred_mask'] = prev_pred_mask
        current['predicted_mask'] = predicted_mask


    def forward_fluid_encoder(self, x, mask, current, cache_dic):
        if current['cal_flops']:
            cumulate_flops(self.z_proj, flops_dic=current, input=x)
        x = self.z_proj(x)
        bsz, _, embed_dim = x.shape
        encoder_pos_embed = self.interpolate_pos_encoding(x, self.encoder_pos_embed_learned)
        x = x + encoder_pos_embed
        x = self.z_proj_ln(x)
        x = x[torch.logical_not(mask).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
        _, cur_pred_len, _ = x.shape
        # apply Transformer blocks
        if self.grad_checkpointing and not torch.jit.is_scripting():
            for block in self.encoder_blocks:
                x = checkpoint(block, x)
        else:
            if current['token_cache'] and not current['is_force_fresh']:
                self.obtain_rela_mask(current, cache_dic, bsz, x)
                x = torch.masked_select(x, current["prev_mtp_rela"].unsqueeze(-1).expand(-1, -1, embed_dim).bool()).reshape(bsz, -1, embed_dim)
            if current['cfg_cache'] and not current['is_force_fresh']:
                x = x[:int(bsz / 2)].detach().clone()
            for i, block in enumerate(self.encoder_blocks):
                current['enco_layer_idx'] = i
                if current['cal_flops']:
                    cumulate_flops(block, flops_dic=current, x=x, current=current, cache_dic=cache_dic)
                x = block(x, current, cache_dic)
            if current['cfg_cache'] and not current['is_force_fresh']:
                x = torch.cat([x, x], dim = 0)
            if current['token_cache'] and not current['is_force_fresh']:
                new_x_full = torch.zeros(bsz, cur_pred_len, embed_dim).to(x)
                new_x_full[current["prev_mtp_rela"].bool()] = x.view(-1, x.size(-1))
                x = new_x_full
        x = self.encoder_norm(x)
        return x

    def forward_fluid_decoder(self, x, mask, current, cache_dic):
        if current['cal_flops']:
            cumulate_flops(self.decoder_embed, flops_dic=current, input=x)
        x = self.decoder_embed(x)
        # pad mask tokens
        mask_tokens = self.mask_token.repeat(mask.shape[0], mask.shape[1], 1).to(x.dtype)
        x_after_pad = mask_tokens.clone()
        x_after_pad[torch.logical_not(mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
        # decoder position embedding
        decoder_pos_embed = self.interpolate_pos_encoding(x_after_pad, self.decoder_pos_embed_learned)
        x = x_after_pad + decoder_pos_embed

        self.obtain_update_mask(mask, cache_dic['prev_pred'], cache_dic['mask_to_pred'], current)
        current['to_pred_len'] = torch.sum(cache_dic['mask_to_pred'][0])
        current['prev_pred_len'] = 0 if cache_dic['prev_pred'] is None else torch.sum(cache_dic['prev_pred'][0])
        if current['token_cache'] and not current['is_force_fresh']:
            B, N, C = x.shape
            cache_dic['cache']['de_ou'] = torch.where(current['prev_pred_mask'].unsqueeze(-1).expand(-1, -1, C), x,
                                                      cache_dic['cache']['de_ou'])
            x = copy.deepcopy(cache_dic['cache']['de_ou'])
            current['original_x'] = x
            x = x[current["update_mask"].nonzero(as_tuple=True)].reshape(x.shape[0], -1, x.shape[2])
        else:
            cache_dic['cache']['de_ou'] = x

        B, N, C = x.size()
        x_origi = x
        if current['cfg_cache'] and not current['is_force_fresh']:
            x = x[:int(B / 2)].detach().clone()
            current['orig_update_mask'] = current['update_mask']
            current['orig_to_pred_mask'] = current['to_pred_mask']
            current['orig_prev_pred_mask'] = current['prev_pred_mask']
            current['orig_predicted_mask'] = current['predicted_mask']

            current['update_mask'] = current['update_mask'][:int(B / 2)].detach().clone()
            current['to_pred_mask'] = current['to_pred_mask'][:int(B / 2)].detach().clone()
            current['prev_pred_mask'] = current['prev_pred_mask'][:int(B / 2)].detach().clone()
            current['predicted_mask'] = current['predicted_mask'][:int(B / 2)].detach().clone()

        # apply Transformer blocks
        if self.grad_checkpointing and not torch.jit.is_scripting():
            for block in self.decoder_blocks:
                x = checkpoint(block, x)
        else:
            for i, block in enumerate(self.decoder_blocks):
                current['layer_idx'] = i
                if current['cal_flops']:
                    cumulate_flops(block, flops_dic=current, x=x, current=current, cache_dic=cache_dic)
                x = block(x, current, cache_dic)


        if current['cfg_cache'] and not current['is_force_fresh']:
            if current['token_cache']:
                diff = cache_dic['cache'][current['layer_idx']]['diff'][current['update_mask'].nonzero(as_tuple=True)].reshape(int(B / 2), -1, C)
            else:
                diff = cache_dic['cache'][current['layer_idx']]['diff']
            x_origi[:int(B / 2)] = x
            x_origi[int(B / 2):] = x + diff
            x = x_origi
            current['update_mask'] = current['orig_update_mask']
            current['to_pred_mask'] = current['orig_to_pred_mask']
            current['prev_pred_mask'] = current['orig_prev_pred_mask']
            current['predicted_mask'] = current['orig_predicted_mask']
        else:
            cache_dic['cache'][current['layer_idx']]['diff'] = x[int(B / 2):] - x[:int(B / 2)]


        if current['token_cache'] and not current['is_force_fresh']:
            current['original_x'][(current['update_mask']).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
            x = current['original_x']
        x = self.decoder_norm(x)
        return x

    def forward_fluid_text_decoder(self, x, text_embeddings, mask_to_pred, current):
        if current['cal_flops']:
            cumulate_flops(self.last_embed, flops_dic=current, input=x)
        x = self.last_embed(x)
        # x = x + self.last_pos_embed_learned
        text_pos_embed = self.interpolate_pos_encoding(x, self.last_pos_embed_learned)
        x = x + text_pos_embed
        bsz, _, dim = x.shape

        x = x[mask_to_pred.nonzero(as_tuple=True)].reshape(bsz, -1, dim)

        # apply Transformer blocks
        if self.grad_checkpointing and not torch.jit.is_scripting():
            for block in self.cross_attn_blocks:
                x = checkpoint(block, x, text_embeddings)
        else:
            for block in self.cross_attn_blocks:
                if current['cal_flops']:
                    cumulate_flops(block, flops_dic=current, q=x, k_v=text_embeddings)
                x = block(x, text_embeddings)
        x = self.cross_norm(x)

        return x

    def forward_loss(self, z, target, mask):
        bsz, seq_len, _ = target.shape
        target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
        z = z.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
        mask = mask.reshape(bsz * seq_len).repeat(self.diffusion_batch_mul)
        loss = self.diffloss(z=z, target=target, mask=mask)
        return loss

    def forward(self, imgs, texts, height=512, width=512):
        assert height % (self.vae_stride * self.patch_size) == 0 and width % (self.vae_stride * self.patch_size) == 0
        self.seq_h, self.seq_w = height // self.vae_stride // self.patch_size, width // self.vae_stride // self.patch_size
        self.seq_len = self.seq_h * self.seq_w
        # patchify and mask (drop) tokens
        x = self.patchify(imgs)
        gt_latents = x.clone().detach()
        orders = self.sample_orders(bsz=x.size(0))

        mask = self.random_masking(x, orders)

        # mae encoder
        x = self.forward_fluid_encoder(x, mask)

        # mae decoder
        x = self.forward_fluid_decoder(x, mask)

        # Tokenize and encode the input text
        text_embeddings = self.text_emb(texts)
        z = self.forward_fluid_text_decoder(x, text_embeddings)

        # diffloss
        loss = self.forward_loss(z=z, target=gt_latents, mask=mask)

        return loss

    def get_cache_dic(self, bsz, depth, num_heads, embed_dim, device, cfg_cache):
        assert embed_dim % num_heads == 0, 'dim should be divisible by num_heads'
        head_dim = embed_dim // num_heads
        cache_dic = {}
        cache = {}
        enco_cache = {}

        if device.type != 'cpu':
            dtype = torch.float16
        else:
            dtype = torch.float32
        if not cfg_cache:
            bsz = int(bsz * 2)
        for j in range(depth):
            cache[j] = {}
            enco_cache[j] = {}
            enco_cache[j]['k'] = torch.zeros(
                (
                    bsz,
                    num_heads,
                    1024,
                    head_dim,
                ), dtype=dtype, device=device)

            enco_cache[j]['v'] = torch.zeros(
                (
                    bsz,
                    num_heads,
                    1024,
                    head_dim,
                ), dtype=dtype, device=device)
            enco_cache[j]['cur_kv_len'] = 0

        cache_dic['cache'] = cache
        cache_dic['enco_cache'] = enco_cache
        current = {}
        return cache_dic, current

    def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", texts=None, temperature=1.0, height=512,
                      width=512, progress=False, args=None):

        device = torch.device("cuda")
        cache_dic, current = self.get_cache_dic(bsz,  self.encoder_depth, self.encoder_num_heads, self.encoder_embed_dim, device, args.cfg_cache)
        current['depth'] = self.encoder_depth
        current['token_cache'] = args.token_cache
        current['cfg_cache'] = args.cfg_cache
        current['cal_flops'] = args.cal_flops
        cache_dic['start_step'] = args.start_step
        cache_dic['fresh_t'] = args.fresh_t
        cache_dic['prev_pred'] = None
        cache_dic['mask_to_pred'] = None
        current['flops'], current['macs'], current['params'] = 0, 0, 0

        # Set piecewise schedule dynamically based on num_iter
        self.piecewise_schedule = {
            0: 1024 - 1,  # 初始值为完整序列长度
            num_iter - 3: 512,          # num_ar_steps - 3 步衰减到512
            num_iter - 2: 256,          # num_ar_steps - 2 步衰减到256
            num_iter - 1: 1             # num_ar_steps - 1 步衰减到1
        }

        assert height % (self.vae_stride * self.patch_size) == 0 and width % (self.vae_stride * self.patch_size) == 0
        self.seq_h, self.seq_w = height // self.vae_stride // self.patch_size, width // self.vae_stride // self.patch_size
        self.seq_len = self.seq_h * self.seq_w

        # init and sample generation orders
        mask = torch.ones(bsz, self.seq_len).cuda()
        tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()

        # Generate orders based on strategy
        if self.order_strategy == "autoregressive":
            original_order = generate_full_autoregressive_order(bsz)
            orders = convert_order(original_order, 32, 32).cuda().long()
        else:  # random
            orders = self.sample_orders(bsz)

        indices = list(range(num_iter))
        if progress:
            indices = tqdm(indices)
        if texts is not None:
            text_embeddings_temp = self.text_emb(texts)
        mask_lens=[]
        # generate latents
        for step in indices:
            current['step'] = step
            cur_tokens = tokens.clone()
            # text_embedding and CFG
            if texts is None:
                text_embeddings = self.fake_latent.expand(bsz, self.max_length, self.cross_embed_dim)
            else:
                text_embeddings = text_embeddings_temp
            if not cfg == 1.0:
                tokens = torch.cat([tokens, tokens], dim=0)
                text_embeddings = torch.cat(
                    [text_embeddings, self.fake_latent.expand(bsz, self.max_length, self.cross_embed_dim)], dim=0)
                mask = torch.cat([mask, mask], dim=0)



            if self.mask_strategy == "cosine":
                # mask ratio for the next round, following MaskGIT and MAGE.
                mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
                mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda()
                # masks out at least one for the next iteration
                mask_len = torch.maximum(torch.Tensor([1]).cuda(),
                                         torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))

            elif self.mask_strategy == "fixed":
                # mask_lens = [1022.0, 1019.0, 1012.0, 1004.0, 993.0, 979.0, 964.0, 946.0, 925.0, 903.0,
                #              878.0, 851.0, 822.0, 791.0, 758.0, 724.0, 687.0, 649.0, 609.0, 568.0,
                #              526.0, 482.0, 437.0, 391.0, 344.0, 297.0, 248.0, 199.0, 150.0, 100.0,
                #              50.0, 1.0]
                # mask_lens = [1021.0, 1015.0, 1004.0, 989.0, 969.0, 946.0, 918.0, 886.0, 851.0, 812.0,
                #              769.0, 724.0, 675.0, 623.0, 568.0, 512.0, 452.0, 391.0, 329.0, 265.0,
                #              199.0, 133.0, 66.0, 1.0]
                # mask_lens = [1019.0, 1004.0, 979.0, 946.0, 903.0, 851.0, 791.0, 724.0, 649.0, 568.0,
                #              482.0, 391.0, 297.0, 199.0, 100.0, 1.0]
                # mask_lens = [1004.0, 946.0, 851.0, 724.0, 568.0, 391.0, 199.0, 1.0]
                ##############################################################

                # mask_lens = [1022.0, 1019.0, 1012.0, 1004.0, 993.0, 979.0, 964.0, 946.0, 925.0, 903.0,
                #              878.0, 851.0, 822.0, 791.0, 758.0, 724.0, 687.0, 649.0, 609.0, 568.0,
                #              526.0, 512.0, 256.0, 1.0]

                mask_lens = [1022.0, 1019.0, 1012.0, 1004.0, 993.0, 979.0, 964.0, 946.0, 925.0, 903.0,
                             878.0, 851.0, 791.0, 724.0, 649.0, 568.0, 512.0, 256.0, 1.0]

                # mask_lens = [1019.0, 1004.0, 979.0, 946.0, 880.0, 821.0, 724.0, 629.0, 512.0, 256.0, 1.0]

                mask_len = mask_lens[step]
                mask_len = torch.tensor([mask_len], dtype=torch.float32).cuda()
            else:  # piecewise_cosine
                # mask ratio for the next round, using piecewise cosine decay
                mask_len_value = self.compute_piecewise_cosine_decay(step)
                mask_len = torch.Tensor([np.floor(mask_len_value)]).cuda()

                mask_len = torch.maximum(torch.Tensor([1]).cuda(),
                                         torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
                print(mask_len)
            mask_lens.append(mask_len[0].item())
            # get masking for next iteration and locations to be predicted in this iteration
            mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)

            if step >= num_iter - 1:
                mask_to_pred = mask[:bsz].bool()
            else:
                mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
            if not cfg == 1.0:
                mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)

            cache_dic['mask_to_pred'] = mask_to_pred
            current['is_force_fresh'] = global_force_fresh(cache_dic, current)
            current['use_cache'] = not current['is_force_fresh']
            current['mask'] = mask.bool()

            # mae encoder
            x = self.forward_fluid_encoder(tokens, mask, current, cache_dic)

            # mae decoder
            x = self.forward_fluid_decoder(x, mask, current, cache_dic)

            z_pred = self.forward_fluid_text_decoder(x, text_embeddings, mask_to_pred, current)
            z_pred = z_pred.reshape(-1, z_pred.size(2))

            cache_dic['prev_pred'] = mask_to_pred
            mask = mask_next

            # z_pred = z[mask_to_pred.nonzero(as_tuple=True)]
            # cfg schedule follow Muse
            if cfg_schedule == "linear":
                cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
            elif cfg_schedule == "constant":
                cfg_iter = cfg
            else:
                raise NotImplementedError

            # sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter)
            sampled_token_latent = self.diffloss.sample(z_pred, temperature=temperature, cfg=cfg_iter,
                                                        step=step, ar_num_iter=num_iter, bsz=bsz)
            if current['cal_flops']:
                cumulate_flops(self.diffloss, flops_dic=current, z=z_pred, temperature=temperature, cfg=cfg_iter,
                                    step=step, ar_num_iter=num_iter, bsz=bsz)

            if not cfg == 1.0:
                sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)  # Remove null class samples
                mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)

            cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
            tokens = cur_tokens.clone()
        # print(mask_lens)
        # unpatchify
        if current['cal_flops']:
            print_flops(current)
        tokens = self.unpatchify(tokens)
        return tokens


def fluid_base(**kwargs):
    model = FLUID(
        encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
        decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12,
        cross_embed_dim=768, cross_depth=12, cross_num_heads=12,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def fluid_large(**kwargs):
    model = FLUID(
        encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
        decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
        cross_embed_dim=1024, cross_depth=16, cross_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def fluid_huge(**kwargs):
    model = FLUID(
        encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16,
        decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16,
        cross_embed_dim=1280, cross_depth=20, cross_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


if __name__ == "__main__":
    mt5_cache_dir = '/data/xianfeng/code/model/google/flan-t5-xxl'
    mt5_model_name = mt5_cache_dir
    # img_size = 256
    patch_size = 2
    vae_embed_dim = 16
    vae_stride = 8
    vae_temperal_stride = 4
    diffloss_d = 8
    diffloss_w = 1280

    model = fluid_large(vae_embed_dim=vae_embed_dim, vae_stride=vae_stride, patch_size=patch_size,
                        diffloss_d=diffloss_d, diffloss_w=diffloss_w)

    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Number of trainable parameters: {}M".format(n_params / 1e6))
    model = model.cuda()

    # generate pseudo image data (batch_size, channels, height, width)
    fake_image = torch.randn(4, 16, 32, 32).cuda()  # batch_size=4, 3 channels, 256x256 resolution

    # generate pseudo text data
    fake_texts = ["This is a test sentence.", "Another example sentence.", "Text embedding test.",
                  "Vision and text fusion."]

    # T5 Embedding
    import os, sys

    sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
    sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..'))

    # from models.vae import AutoencoderKL
    # define the vae and mar model
    from diffusers.models import AutoencoderKL

    vae_path = '/data/xianfeng/code/model/stabilityai/stable-diffusion-3.5-large'
    vae = AutoencoderKL.from_pretrained(os.path.join(vae_path, "vae")).cuda().eval()
    for param in vae.parameters():
        param.requires_grad = False

    # posterior = vae.encode(fake_video).latent_dist.sample()

    from models.utils import T5_Embedding

    max_length = 512
    t5_infer = T5_Embedding(mt5_model_name, mt5_cache_dir, max_length).cuda()

    import time

    start_time = time.time()

    text_emb = t5_infer(fake_texts)

    x = fake_image
    # posterior = vae.encode(fake_image)

    # # normalize the std of latent to be 1. Change it if you use a different tokenizer
    # if vae.config.shift_factor is not None:
    #     x = (posterior.sample() - vae.config.shift_factor) * vae.config.scaling_factor
    # else:
    #     x = posterior.sample().mul_(vae.config.scaling_factor)

    # import pdb
    # pdb.set_trace()
    # check
    with torch.cuda.amp.autocast():
        loss = model(x, text_emb, height=256, width=256)

    end_time = time.time()
    print(f'cost time:{end_time - start_time}')
    import pdb

    pdb.set_trace()
    print(f"Loss: {loss.item()}")