from functools import partial
import numpy as np
from tqdm import tqdm
import scipy.stats as stats
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import copy
import torch.nn.functional as F
from models.vision_transformer_v10v8 import Block as ViTBlock
from models.vision_transformer_v10v8_enco import Block as encoViTBlock
from models.diffloss_GtR import DiffLoss
import os
from calflops import calculate_flops
from calflops.utils import flops_to_string, macs_to_string, params_to_string
from models.sampler_util import *

def mask_by_order(mask_len, order, bsz, seq_len, device):
    masking = torch.zeros(bsz, seq_len).to(device)
    masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).to(device)).bool()
    return masking

def get_model_device(model):
    # 获取模型的第一个参数（通常是权重）的设备
    return next(model.parameters()).device

def update_x(self, current, cache_dic, x):
    B, N, C = x.shape
    similarity_ = F.cosine_similarity(x, cache_dic['cache']['de_ou'], dim=-1)
    similarity_ = similarity_.reshape(B, -1)
    # similarity_[current['mask_to_pred_mask']] = 0
    similarity_[current['buffer_tokens_mask']] = 0
    indsss, inds = torch.sort(similarity_, dim=-1, descending=False)
    fresh_num = current['prev_mask_to_pred_len'] + 64
    imp_inds = inds[:, :fresh_num]
    update_mask = torch.zeros((B, N), device=x.device)
    update_mask = update_mask.scatter_(1, imp_inds, 1)
    cache_dic['cache']['de_ou'] = torch.where(update_mask.bool().unsqueeze(-1).expand(-1, -1, C), x,
                                              cache_dic['cache']['de_ou'])
    x = copy.deepcopy(cache_dic['cache']['de_ou'])
    # x = cache_dic['cache']['de_ou']

def convert_order(original_order: torch.Tensor, w: int, h: int) -> torch.Tensor:
    """
    将形状为 (bsz, w*h, 2) 的顺序张量转换为展平后的索引形式 (bsz, w*h)。

    Args:
        original_order (torch.Tensor): 输入的顺序张量，形状为 (bsz, seq_len, 2)。
        w (int): 图像的宽度（列数）。
        h (int): 图像的高度（行数）。

    Returns:
        torch.Tensor: 转换后的索引张量，形状为 (bsz, seq_len)。
    """
    # 提取 x 和 y 坐标
    x = original_order[..., 0]  # (bsz, seq_len)
    y = original_order[..., 1]  # (bsz, seq_len)
    # 计算行优先的线性索引
    linear_indices = y * w + x
    # 确保结果与输入张量的数据类型和设备一致
    return linear_indices.to(dtype=original_order.dtype, device=original_order.device)

class MAR(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, img_size=256, vae_stride=16, patch_size=1,
                 encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
                 decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm,
                 vae_embed_dim=16,
                 mask_ratio_min=0.7,
                 label_drop_prob=0.1,
                 class_num=1000,
                 attn_dropout=0.1,
                 proj_dropout=0.1,
                 buffer_size=64,
                 diffloss_d=3,
                 diffloss_w=1024,
                 diffusion_batch_mul=4,
                 grad_checkpointing=False,
                 diff_upper_steps=50,
                 diff_lower_steps=10,
                 diff_annealing_strategy="linear",
                 diff_sampler='default',
                 token_cache = False,
                 cfg_cache = False,
                 pivot_step_threshold=15,
                 pivot_diffusion_steps=50,
                 token_selection_strategy="pivotal",
                 pivot_token_percentage=0.1,
                 **kwargs
                 ):
        super().__init__()

        # --------------------------------------------------------------------------
        # VAE and patchify specifics
        self.vae_embed_dim = vae_embed_dim

        self.img_size = img_size
        self.vae_stride = vae_stride
        self.patch_size = patch_size
        self.seq_h = self.seq_w = img_size // vae_stride // patch_size
        self.seq_len = self.seq_h * self.seq_w
        self.token_embed_dim = vae_embed_dim * patch_size**2
        self.encoder_embed_dim = encoder_embed_dim
        self.encoder_depth = encoder_depth
        self.encoder_num_heads = encoder_num_heads
        self.token_cache = token_cache
        self.cfg_cache = cfg_cache
        self.diffloss_d = diffloss_d
        self.grad_checkpointing = grad_checkpointing
        # Piecewise cosine decay schedule configuration
        # self.piecewise_schedule = {
        #     0: self.seq_len - 1,  # 初始值为完整序列长度
        #     23: 192,          # 第23步衰减到192
        #     30: 128,          # 第29步衰减到128
        #     31: 64,
        #     32: 1             # 第31步衰减到1
        # }
        # --------------------------------------------------------------------------
        # Class Embedding
        self.num_classes = class_num
        self.class_emb = nn.Embedding(class_num, encoder_embed_dim)
        self.label_drop_prob = label_drop_prob
        # Fake class embedding for CFG's unconditional generation
        self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))

        # --------------------------------------------------------------------------
        # MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
        self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)

        # --------------------------------------------------------------------------
        # MAR encoder specifics
        self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
        self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
        self.buffer_size = buffer_size
        self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim))

        self.encoder_blocks = nn.ModuleList([
            encoViTBlock(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
                  proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
        self.encoder_norm = norm_layer(encoder_embed_dim)

        # --------------------------------------------------------------------------
        # MAR decoder specifics
        self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim))

        self.decoder_blocks = nn.ModuleList([
            ViTBlock(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
                  proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))

        self.initialize_weights()

        # --------------------------------------------------------------------------
        self.diffloss = DiffLoss(
            target_channels=self.token_embed_dim,
            z_channels=decoder_embed_dim,
            width=diffloss_w,
            depth=diffloss_d,
            grad_checkpointing=grad_checkpointing,
            diff_upper_steps=diff_upper_steps,
            diff_lower_steps=diff_lower_steps,
            diff_annealing_strategy=diff_annealing_strategy,
            diff_sampler=diff_sampler,
            pivot_step_threshold=pivot_step_threshold,
            pivot_diffusion_steps=pivot_diffusion_steps,
            token_selection_strategy=token_selection_strategy,
            pivot_token_percentage=pivot_token_percentage
        )
        self.diffusion_batch_mul = diffusion_batch_mul
        self.flops = 0
        self.macs = 0

    def initialize_weights(self):
        # parameters
        torch.nn.init.normal_(self.class_emb.weight, std=.02)
        torch.nn.init.normal_(self.fake_latent, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)
        torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
        torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
        torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
            if m.weight is not None:
                nn.init.constant_(m.weight, 1.0)

    def patchify(self, x):
        bsz, c, h, w = x.shape
        p = self.patch_size
        h_, w_ = h // p, w // p

        x = x.reshape(bsz, c, h_, p, w_, p)
        x = torch.einsum('nchpwq->nhwcpq', x)
        x = x.reshape(bsz, h_ * w_, c * p ** 2)
        return x  # [n, l, d]

    def unpatchify(self, x):
        bsz = x.shape[0]
        p = self.patch_size
        c = self.vae_embed_dim
        h_, w_ = self.seq_h, self.seq_w

        x = x.reshape(bsz, h_, w_, c, p, p)
        x = torch.einsum('nhwcpq->nchpwq', x)
        x = x.reshape(bsz, c, h_ * p, w_ * p)
        return x  # [n, c, h, w]

    def sample_orders(self, bsz, device = None):
        # generate a batch of random generation orders
        orders = []
        for _ in range(bsz):
            order = np.array(list(range(self.seq_len)))
            np.random.shuffle(order)
            orders.append(order)
        if device is None:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        orders = torch.Tensor(np.array(orders)).to(device).long()
        return orders


    def sample_orders_checkoard(self, bsz, device):
        height, width = 16, 16
        order_black = []
        order_white = []
        for i in range(height):
            for j in range(width):
                if (i + j) % 2 == 1:
                    order_black.append(i * width + j)
                else:
                    order_white.append(i * width + j)
        orders = []
        for _ in range(bsz):
            order_black_np = np.array(order_black)
            np.random.shuffle(order_black_np)
            order_white_np = np.array(order_white)
            np.random.shuffle(order_white_np)
            order_np = np.concatenate([order_black_np, order_white_np])
            orders.append(order_np)

        orders = torch.Tensor(np.array(orders)).to(device).long()
        return orders

    def sample_orders_diagonal(self, bsz, device=None):
        height, width = 16, 16  # 16x16网格，总共256个tokens
        diagonal_order = []
        
        # 遍历所有可能的反对角线
        # 反对角线由 i+j 的值确定，范围从0到(height-1)+(width-1)
        for diagonal_sum in range(height + width - 1):
            diagonal_tokens = []
            
            # 在当前反对角线上找到所有的位置
            for i in range(height):
                j = diagonal_sum - i
                # 检查j是否在有效范围内
                if 0 <= j < width:
                    # 将2D坐标转换为1D索引
                    token_index = i * width + j
                    diagonal_tokens.append(token_index)
            
            # 将当前反对角线的tokens添加到总序列中
            diagonal_order.extend(diagonal_tokens)
        
        # 生成批次的orders
        orders = []
        for _ in range(bsz):
            order = np.array(diagonal_order)
            # 如果需要随机化，可以在这里添加 np.random.shuffle(order)
            orders.append(order)
        
        if device is None:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        orders = torch.tensor(np.array(orders), device=device, dtype=torch.long)
        return orders

    # def compute_piecewise_cosine_decay(self, step):
    #     """
    #     计算分段余弦衰减的掩码数量
    #
    #     Args:
    #         step: 当前迭代步数
    #
    #     Returns:
    #         mask_length: 当前步数对应的掩码数量
    #     """
    #     # 获取排序的步数端点
    #     steps = sorted(self.piecewise_schedule.keys())
    #
    #     # 找到当前步数所在的区间
    #     start_step = 0
    #     end_step = steps[-1]
    #     start_value = self.piecewise_schedule[0]
    #     end_value = self.piecewise_schedule[steps[-1]]
    #
    #     for i in range(len(steps) - 1):
    #         if step <= steps[i + 1]:
    #             start_step = steps[i]
    #             end_step = steps[i + 1]
    #             start_value = self.piecewise_schedule[start_step]
    #             end_value = self.piecewise_schedule[end_step]
    #             break
    #
    #     # 如果步数超过最后一个端点，直接返回最终值
    #     if step >= steps[-1]:
    #         return self.piecewise_schedule[steps[-1]]
    #
    #     # 在当前区间内计算余弦衰减
    #     if start_step == end_step:
    #         return start_value
    #
    #     # 计算在当前区间内的进度（0到1）
    #     progress = (step - start_step) / (end_step - start_step)
    #
    #     # 应用余弦衰减：cos(0) = 1, cos(π/2) = 0
    #     # cosine_decay = np.cos(progress * math.pi / 2)
    #     cosine_decay = 1 - (progress) ** 2
    #
    #     # 线性插值：从start_value衰减到end_value
    #     mask_length = end_value + (start_value - end_value) * cosine_decay
    #
    #     return mask_length

    def random_masking(self, x, orders):
        # generate token mask
        bsz, seq_len, embed_dim = x.shape
        mask_rate = self.mask_ratio_generator.rvs(1)[0]
        num_masked_tokens = int(np.ceil(seq_len * mask_rate))
        mask = torch.zeros(bsz, seq_len, device=x.device)
        mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
                             src=torch.ones(bsz, seq_len, device=x.device))
        return mask


    def obtain_rela_mask(self, current, cache_dic, bsz, x):
        # bsz ,seq_len, dim  相对seq_len， 上一step预测的token
        if cache_dic['prev_mask_to_pred'] is None:
            prev_mask_to_pred = torch.zeros_like(current['mask']).to(device).bool()
        else:
            prev_mask_to_pred = cache_dic['prev_mask_to_pred']
        # bsz, seq_len, dim  相对于seq_len, 截至本step已经预测的token
        cur_pred_mask = ~current['mask']
        # bsz, seq_len, dim  相对于seq_len, 除了上一step外，已经预测的token
        prev_pred_mask = cur_pred_mask & (~prev_mask_to_pred)

        # bsz, pred_len, dim 相对于所有已经预测的token， 除了上一step已经预测的token
        current["prev_pred_rela"] = torch.cat([torch.zeros(bsz, self.buffer_size, device=x.device),
                                               torch.masked_select(prev_pred_mask, cur_pred_mask).reshape(bsz, -1)],
                                              dim=1)

        current["prev_pred_rela_with_buffer"] = torch.cat([torch.ones(bsz, self.buffer_size, device=x.device),
                                                           torch.masked_select(prev_pred_mask, cur_pred_mask).reshape(
                                                               bsz, -1)],
                                                          dim=1)
        # bsz, pred_len, dim 相对于所有已经预测的token， 上一step已经预测的token
        current["prev_mtp_rela"] = torch.cat([torch.zeros(bsz, self.buffer_size, device=x.device),
                                              torch.masked_select(prev_mask_to_pred, cur_pred_mask).reshape(bsz, -1)],
                                             dim=1)

        current["prev_mtp_rela_with_buffer"] = torch.cat([torch.ones(bsz, self.buffer_size, device=x.device),
                                                          torch.masked_select(prev_mask_to_pred, cur_pred_mask).reshape(
                                                              bsz, -1)],
                                                         dim=1)
        current["enco_update_mask"] = torch.ones_like(current["prev_mtp_rela"], device=x.device)

    def forward_mae_encoder(self, x, mask, class_embedding, current, cache_dic):
        # orig_token_cache = current['token_cache']
        # current['token_cache'] = False

        # bsz, seq_len, dim
        x = self.z_proj(x)
        bsz, seq_len, embed_dim = x.shape

        # concat buffer
        # bsz, 64+seq_len, dim
        x = torch.cat([torch.zeros(bsz, self.buffer_size, embed_dim, device=x.device), x], dim=1)
        mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)

        # random drop class embedding during training
        if self.training:
            drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
            drop_latent_mask = drop_latent_mask.unsqueeze(-1).to(device).to(x.dtype)
            class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding

        x[:, :self.buffer_size] = class_embedding.unsqueeze(1)

        # encoder position embedding
        x = x + self.encoder_pos_embed_learned
        x = self.z_proj_ln(x)

        # dropping
        #bsz, 64+predicted_len, dim
        x = x[(1-mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
        _, cur_pred_len, _ = x.shape

        # apply Transformer blocks
        if self.grad_checkpointing and not torch.jit.is_scripting():
            for block in self.encoder_blocks:
                x = checkpoint(block, x)
        else:
            if current['token_cache'] and not current['is_force_fresh']:
                self.obtain_rela_mask(current, cache_dic, bsz, x)
                x = torch.masked_select(x, current["prev_mtp_rela_with_buffer"].
                                                unsqueeze(-1).expand(-1, -1, embed_dim).bool()).reshape(bsz, -1, embed_dim)
            if current['cfg_cache'] and not current['is_force_fresh']:
                x = x[:int(bsz / 2)].detach().clone()

            for i, block in enumerate(self.encoder_blocks):
                current['enco_layer_idx'] = i
                if current['cal_flops']:
                    self.cumulate_flops(block, x=x, current=current, cache_dic=cache_dic)
                x = block(x, current, cache_dic)

            if current['cfg_cache'] and not current['is_force_fresh']:
                x = torch.cat([x, x], dim = 0)

            if current['token_cache'] and not current['is_force_fresh']:
                new_x_full = torch.zeros(bsz, cur_pred_len, embed_dim).to(x)
                new_x_full[current["prev_mtp_rela_with_buffer"].bool()] = x.view(-1, x.size(-1))
                x = new_x_full
                # x = torch.where(current["prev_mtp_rela_with_buffer"].unsqueeze(-1).expand(-1, -1, 768).bool(), x, new_x_full)

        x = self.encoder_norm(x)
        # current['token_cache'] = orig_token_cache
        return x

    def obtain_update_mask(self, mask, mask_with_buffer, prev_mask_to_pred, mask_to_pred, current):
        '''
        1表示需要更新的 token，0表示使用缓存的 token
        '''
        device = mask.device
        bsz = mask.size(0)
        buffer_zeros = torch.zeros(bsz, self.buffer_size, device=device)
        mask_zeros = torch.zeros_like(mask, device=device)
        #没有diffusion的token进行更新
        unpredicted_tokens_mask = mask_with_buffer.bool()
        #buffer中的token进行更新
        buffer_tokens_mask = torch.cat([torch.ones_like(buffer_zeros), mask_zeros], dim=1).bool()
        # 已经diffusion的 token 不需要更新
        predicted_tokens_mask = torch.cat([buffer_zeros, torch.logical_not(mask)], dim=1).bool()  # 确保为布尔类型
        if prev_mask_to_pred is None:
            prev_mask_to_pred = torch.zeros_like(mask_to_pred, device=device).bool()
        # 上一轮预测的 token 需要更新
        prev_mask_to_pred = torch.cat([buffer_zeros, prev_mask_to_pred], dim=1).bool()
        # 当前轮预测的 token 需要更新
        mask_to_pred_mask = torch.cat([buffer_zeros, mask_to_pred], dim=1).bool()
        # 所有buffer都不更新
        all_cache_mask = torch.zeros_like(mask_with_buffer, device=device).bool()
        none_cache_mask = torch.ones_like(mask_with_buffer, device=device).bool()

        # 验证三个掩码逐元素与的结果是否为全 True
        combined_mask = buffer_tokens_mask | unpredicted_tokens_mask | predicted_tokens_mask
        assert combined_mask.all(), "Combined mask is not fully True!"
        ''''''
        unpredicted_with_buffer_pt = unpredicted_tokens_mask | buffer_tokens_mask | prev_mask_to_pred | mask_to_pred_mask
        unpredicted_with_pt = unpredicted_tokens_mask | prev_mask_to_pred | mask_to_pred_mask
        predicted_with_buffer_pt = predicted_tokens_mask | buffer_tokens_mask | prev_mask_to_pred | mask_to_pred_mask
        predicted_with_pt = predicted_tokens_mask | prev_mask_to_pred | mask_to_pred_mask
        prev_mask_to_pred_with_buffer = prev_mask_to_pred | buffer_tokens_mask
        ''''''
        # current["update_mask"] = unpredicted_with_buffer_pt
        current["update_mask"] = none_cache_mask
        current['mask_to_pred_mask'] = mask_to_pred_mask
        current['prev_mask_to_pred_mask'] = prev_mask_to_pred
        current['prev_mask_to_pred_with_buffer'] = prev_mask_to_pred_with_buffer
        current['predicted_mask'] = predicted_tokens_mask


    def global_force_fresh(self, cache_dic, current):
        return (current['step'] < cache_dic['start_step']
                or (current['step'] - cache_dic['start_step']) % cache_dic['fresh_t'] == 0)

    def cumulate_flops(self, block, **kwargs):
        block_flops = block
        # Deepcopy kwargs to ensure the original kwargs are not modified
        kwargs_flops = {key: copy.deepcopy(value) for key, value in kwargs.items()}

        # Now pass the modified kwargs directly to calculate_flops
        flops, macs, params = calculate_flops(model=block_flops,
                                              kwargs=kwargs_flops,
                                              print_results=False,
                                              output_as_string=False)
        for key, value in kwargs_flops.items():
            del value
        del kwargs_flops
        del block_flops
        # gc.collect()
        # torch.cuda.ipc_collect()
        # torch.cuda.empty_cache()
        self.flops += flops
        self.macs += macs
        self.params = params


    def print_flops(self):
        flops, macs, params =  flops_to_string(self.flops, units=None, precision=2),\
               macs_to_string(self.macs, units=None, precision=2),  \
               params_to_string(self.params, units=None, precision=2)
        print("Bert(hfl/chinese-roberta-wwm-ext) FLOPs:%s   MACs:%s   Params:%s \n" % (flops, macs, params))
        # 重置计数器
        self.flops, self.macs, self.params = 0, 0, 0

    def forward_mae_decoder(self, x, mask, current, cache_dic):

        x = self.decoder_embed(x)
        # 0表示已经预测的 token，1表示还未预测的 token
        mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
        # 填充掩码 token
        mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
        x_after_pad = mask_tokens.clone()
        x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
        # 解码器位置嵌入
        x = x_after_pad + self.decoder_pos_embed_learned

        self.obtain_update_mask(mask, mask_with_buffer, cache_dic['prev_mask_to_pred'], cache_dic['mask_to_pred'], current)

        current['mask_to_pred_len'] = torch.sum(cache_dic['mask_to_pred'][0])
        if cache_dic['prev_mask_to_pred'] is None:
            current['prev_mask_to_pred_len'] = 0
        else:
            current['prev_mask_to_pred_len'] = torch.sum(cache_dic['prev_mask_to_pred'][0])

        if current['token_cache'] and not current['is_force_fresh']:
            B, N, C = x.shape
            cache_dic['cache']['de_ou'] = torch.where(current['prev_mask_to_pred_with_buffer'].unsqueeze(-1).expand(-1, -1, C), x,
                                                      cache_dic['cache']['de_ou'])
            x = copy.deepcopy(cache_dic['cache']['de_ou'])
            current['original_x'] = x
            x = x[current["update_mask"].nonzero(as_tuple=True)].reshape(x.shape[0], -1, x.shape[2])
        else:
            cache_dic['cache']['de_ou'] = x

        B, N, C = x.size()
        x_origi = x
        if current['cfg_cache'] and not current['is_force_fresh']:
            x = x[:int(B / 2)].detach().clone()
            current['orig_update_mask'] = current['update_mask']
            current['orig_mask_to_pred_mask'] = current['mask_to_pred_mask']
            current['orig_prev_mask_to_pred_mask'] = current['prev_mask_to_pred_mask']
            current['orig_predicted_mask'] = current['predicted_mask']

            current['update_mask'] = current['update_mask'][:int(B / 2)].detach().clone()
            current['mask_to_pred_mask'] = current['mask_to_pred_mask'][:int(B / 2)].detach().clone()
            current['prev_mask_to_pred_mask'] = current['prev_mask_to_pred_mask'][:int(B / 2)].detach().clone()
            current['predicted_mask'] = current['predicted_mask'][:int(B / 2)].detach().clone()

        for i, block in enumerate(self.decoder_blocks):
            current['layer_idx'] = i
            if current['cal_flops']:
                self.cumulate_flops(block, x=x, current=current, cache_dic=cache_dic)
            x = block(x, current, cache_dic)


        if current['cfg_cache'] and not current['is_force_fresh']:
            if current['token_cache']:
                diff = cache_dic['cache'][current['layer_idx']]['diff'][current['update_mask'].nonzero(as_tuple=True)].reshape(int(B / 2), -1, C)
            else:
                diff = cache_dic['cache'][current['layer_idx']]['diff']

            x_origi[:int(B / 2)] = x
            x_origi[int(B / 2):] = x + diff
            x = x_origi
            current['update_mask'] = current['orig_update_mask']
            current['mask_to_pred_mask'] = current['orig_mask_to_pred_mask']
            current['prev_mask_to_pred_mask'] = current['orig_prev_mask_to_pred_mask']
            current['predicted_mask'] = current['orig_predicted_mask']
        else:
            cache_dic['cache'][current['layer_idx']]['diff'] = x[int(B / 2):] - x[:int(B / 2)]


        if current['token_cache'] and not current['is_force_fresh']:
            current['original_x'][(current['update_mask']).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
            x = current['original_x']
        x = self.decoder_norm(x)
        x = x[:, self.buffer_size:]
        x = x + self.diffusion_pos_embed_learned
        return x


    def forward_loss(self, z, target, mask):
        bsz, seq_len, _ = target.shape
        target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
        z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
        mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul)
        loss = self.diffloss(z=z, target=target, mask=mask)
        return loss

    def forward(self, imgs, labels):
        # class embed
        class_embedding = self.class_emb(labels)
        # patchify and mask (drop) tokens
        x = self.patchify(imgs)
        gt_latents = x.clone().detach()
        orders = self.sample_orders(bsz=x.size(0))
        mask = self.random_masking(x, orders)
        # mae encoder
        # 4 89 768
        x = self.forward_mae_encoder(x, mask, class_embedding)
        # 4 256 768
        # mae decoder
        z = self.forward_mae_decoder(x, mask)
        # diffloss
        loss = self.forward_loss(z=z, target=gt_latents, mask=mask)
        return loss
    def renew_cache(self, bsz, depth, num_heads, embed_dim, device, cfg_cache, cache_dic):
        assert embed_dim % num_heads == 0, 'dim should be divisible by num_heads'
        head_dim = embed_dim // num_heads
        if device.type != 'cpu':
            dtype = torch.float16
        else:
            dtype = torch.float32
        if not cfg_cache:
            bsz = int(bsz * 2)
        for j in range(depth):
            cache_dic['enco_cache'][j]['k'] = torch.zeros(
                (
                    bsz,
                    num_heads,
                    320,
                    head_dim,
                ), dtype=dtype, device=device)

            cache_dic['enco_cache'][j]['v'] = torch.zeros(
                (
                    bsz,
                    num_heads,
                    320,
                    head_dim,
                ), dtype=dtype, device=device)
            cache_dic['enco_cache'][j]['cur_kv_len'] = 0


    def get_cache_dic(self, bsz, depth, num_heads, embed_dim, device, cfg_cache):
        assert embed_dim % num_heads == 0, 'dim should be divisible by num_heads'
        head_dim = embed_dim // num_heads
        cache_dic = {}
        cache = {}
        enco_cache = {}

        if device.type != 'cpu':
            dtype = torch.float16
        else:
            dtype = torch.float32
        if not cfg_cache:
            bsz = int(bsz*2)
        for j in range(depth):
            cache[j] = {}
            enco_cache[j] = {}
            enco_cache[j]['k'] = torch.zeros(
            (
                bsz,
                num_heads,
                320,
                head_dim,
            ), dtype = dtype, device=device)

            enco_cache[j]['v'] = torch.zeros(
            (
                bsz,
                num_heads,
                320,
                head_dim,
            ), dtype = dtype, device=device)
            enco_cache[j]['cur_kv_len'] = 0

        cache_dic['cache'] = cache
        cache_dic['enco_cache'] = enco_cache
        current = {}
        return cache_dic, current

    def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, num_steps = 0, args = None,
                      progress=False):
        device = torch.device(args.device)
        cache_dic, current = self.get_cache_dic(bsz,  self.encoder_depth, self.encoder_num_heads, self.encoder_embed_dim, device, args.cfg_cache)
        current['depth'] = self.encoder_depth
        current['token_cache'] = args.token_cache
        current['cfg_cache'] = args.cfg_cache
        current['cal_flops'] = args.cal_flops
        current['cal_caching_num'] = args.cal_caching_num
        current['device'] = args.device
        current['num_iter'] = num_iter
        if current['num_iter'] == 32:
            cache_dic['start_step'] = 5
            cache_dic['fresh_t'] = 6

        elif current['num_iter'] == 33:
            cache_dic['start_step'] = 5
            cache_dic['fresh_t'] = 6

        else:
            cache_dic['start_step'] = 5
            cache_dic['fresh_t'] = 6

        cache_dic['prev_mask_to_pred'] = None
        cache_dic['mask_to_pred'] = None
        random_number = num_steps
        current['random_number'] = random_number
        # 初始化
        mask = torch.ones(bsz, self.seq_len).to(device)
        tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).to(device)

        # 定义保存目录
        save_root = 'sample_order_save'
        bsz_dir = os.path.join(save_root, f'bsz_{bsz}num_iter_{num_iter}model_depth_{self.encoder_depth}_ngpu_{args.gpu_num}_inum_{args.num_images}')
        os.makedirs(bsz_dir, exist_ok=True)
        # orders = self.sample_orders(bsz, device)
        # original_order = generate_full_autoregressive_order(bsz)
        # orders = convert_order(original_order, 16, 16).to(device).long()
        
        # 根据args.order_type选择order生成方式
        if args.order_type == 'random':
            orders = self.sample_orders(bsz, device)
        elif args.order_type == 'autoregressive':
            original_order = generate_full_autoregressive_order(bsz)
            orders = convert_order(original_order, 16, 16).to(device).long()
        else:
            raise ValueError(f"Unknown order_type: {args.order_type}")
            
        indices = list(range(num_iter))
        if progress:
            indices = tqdm(indices)
        if current['cal_caching_num']:
            total_remaining_tokens = 0
            total_pruning_step = 0
        # 生成潜变量
        for step in indices:
            current['step'] = step
            cur_tokens = tokens.clone()
            # 类别嵌入和 CFG
            if labels is not None:
                class_embedding = self.class_emb(labels)
            else:
                class_embedding = self.fake_latent.repeat(bsz, 1)
            if not cfg == 1.0:
                tokens = torch.cat([tokens, tokens], dim=0)
                class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
                mask = torch.cat([mask, mask], dim=0)

            if False:
                #256， 192， 128， 0

                # mask_lens = [255, 254, 253, 252, 251, 250, 249, 248, 247, 246,
                #              245, 244, 243, 241, 238, 236, 234, 231, 228, 225,
                #              222, 219, 216, 212, 209, 205, 201, 197, 193, 189,
                #              185, 181, 176, 171, 167, 162, 157, 152, 147, 142,
                #              136, 131, 126, 120, 115, 109, 103, 97, 92, 86,
                #              80, 74, 68, 62, 56, 49, 43, 37, 31, 25,
                #              18, 12, 6, 1]

                mask_lens = [
                 255.0, 254.0, 253.0, 251.0, 248.0, 245.0, 241.0, 237.0, 232.0, 227.0,
                 221.0, 215.0, 208.0, 201.0, 193.0, 185.0, 176.0, 167.0, 158.0, 148.0,
                 138.0, 128.0, 117.0, 106.0, 95.0, 83.0, 72.0, 60.0, 48.0, 36.0,
                 24.0, 12.0, 1.0]

                # mask_lens = [254, 251, 244, 236, 225, 212, 197, 181, 162, 142,
                #              120, 97, 74, 49, 25, 1]
                # mask_lens = [251.0, 236.0, 212.0, 181.0, 142.0, 97.0, 49.0, 1.0]

                #total33
                # 24，6，2
                # mask_lens = [255, 254, 253, 252, 251, 250, 249, 248, 247, 246,
                #              244, 243, 240, 237, 234, 231, 228, 224, 221, 217,
                #              213, 209, 201, 193, 185, 176, 167, 158, 149, 139,
                #              128, 64, 1]

                # 12, 2, 2
                # mask_lens = [255, 254, 253, 251, 249, 247, 244, 240, 234, 224,
                #              221, 193, 168, 128, 64, 1]

                mask_len = mask_lens[step]
                mask_len = torch.tensor([mask_len], dtype=torch.float32).to(device)

                '''
                # mask ratio for the next round, using piecewise cosine decay
                mask_len_value = self.compute_piecewise_cosine_decay(step)
                mask_len = torch.Tensor([np.floor(mask_len_value)]).cuda()
    
                mask_len = torch.maximum(torch.Tensor([1]).cuda(),
                                         torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
                '''
            else:
                # mask ratio for the next round, following MaskGIT and MAGE.
                mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
                mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda()
                # masks out at least one for the next iteration
                mask_len = torch.maximum(torch.Tensor([1]).cpu(),
                                         torch.minimum(torch.sum(mask.cpu(), dim=-1, keepdims=True) - 1, mask_len.cpu()))

            # 获取下一次迭代的掩码和本次需要预测的位置
            mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len, device)

            if step >= num_iter - 1:
                mask_to_pred = mask[:bsz].bool()
            else:
                mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
            if not cfg == 1.0:
                mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)


            cache_dic['mask_to_pred'] = mask_to_pred

            current['is_force_fresh'] = self.global_force_fresh(cache_dic, current)
            current['use_cache'] = not current['is_force_fresh']
            current['mask'] = mask.bool()
            # MAE 编码器
            x = self.forward_mae_encoder(tokens, mask, class_embedding, current, cache_dic)
            z = self.forward_mae_decoder(x, mask, current, cache_dic)
            if current['cal_caching_num'] and current['token_cache'] and not current['is_force_fresh']:
                total_remaining_tokens += current['remainging_token']
                total_pruning_step += 1
            cache_dic['prev_mask_to_pred'] = mask_to_pred
            mask = mask_next

            # 选取需要预测的潜变量
            z_pred = z[mask_to_pred.nonzero(as_tuple=True)]
            # CFG 调度
            if cfg_schedule == "linear":
                cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
            elif cfg_schedule == "constant":
                cfg_iter = cfg
            else:
                raise NotImplementedError
            sampled_token_latent = self.diffloss.sample(z_pred, temperature, cfg_iter, step=step, ar_num_iter=num_iter, bsz=bsz)
            if current['cal_flops']:
                self.cumulate_flops(self.diffloss, z=z_pred, temperature=temperature, cfg=cfg_iter, step=step, ar_num_iter=num_iter, bsz=bsz)

            if not cfg == 1.0:
                sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)  # 移除空类别的样本
                mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)

            cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
            tokens = cur_tokens.clone()
        if current['cal_caching_num']:
            print('total_remaining_tokens', total_remaining_tokens, 'total_pruning_step',
                  total_pruning_step, 'mean_remaining_tokens', int(total_remaining_tokens / total_pruning_step))
        if current['cal_flops']:
            self.print_flops()

        # 还原为图像格式
        tokens = self.unpatchify(tokens)
        return tokens

def mar_base(**kwargs):
    model = MAR(
        encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
        decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mar_large(**kwargs):
    model = MAR(
        encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
        decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mar_huge(**kwargs):
    model = MAR(
        encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16,
        decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model
