# 项目创建时间：2024/9/12 21:13
# This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
# All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
"""
 Copyright (c) 2023, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import pickle
import random
import logging
import os
import numpy as np
import cv2
import torch
from einops import rearrange,repeat
from torch.optim.lr_scheduler import LambdaLR
from torchvision.transforms.functional import resize
from torchvision.utils import make_grid
from torch.nn import functional as F

from torch.cuda.amp import autocast

from os.path import join as opj
#from ldm.models.model.utils import tensor2img
from torch import nn
from transformers.activations import QuickGELUActivation as QuickGELU
from contextlib import nullcontext
from omegaconf import ListConfig
from torchvision.transforms import Resize

from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL

from ldm.util import log_txt_as_img ,instantiate_from_config,default,ismap, isimage,tensor2img

from ldm.models.diffusion.ddpm import LatentDiffusion
#from ldm.models.model import  model
#from ldm.models.model.unet import register_unet_output2
from ldm.models.diffusion.ddim import DDIMSampler
#from ldm.models.util import dice_loss

from ldm.models.QFormer.QFormer import build_QFormer_encoding
from ldm.models.TaskFormer.TaskFormer import build_TaskFormer_encoding
from ldm.models.vision_backbone.swin_transformer import build_backbone
from ldm.models.vision_backbone.swin_transformer_v2 import build_pose_encoder

from ldm.models.decoder.appearance_encoder import build_Decoder

__conditioning_keys__ = {'concat': 'c_concat',
                         'crossattn': 'c_crossattn',
                         'adm': 'y'}

def dice_loss(
        inputs: torch.Tensor,
        targets: torch.Tensor,
    ):
    """
    Compute the DICE loss, similar to generalized IOU for masks
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
                (batch_size, num_classes, H*W)
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
                (batch_size, num_classes, H*W)
    """
    inputs = inputs.flatten(2)
    numerator = 2 * (inputs * targets).sum(-1)
    #print(f"dice_loss numerator max:{torch.max(numerator)},min:{torch.min(numerator)}")
    denominator = inputs.sum(-1) + targets.sum(-1)
    #print(f"dice_loss denominator max:{torch.max(denominator)},min:{torch.min(denominator)}")
    loss = 1 - (numerator + 1) / (denominator + 1)
    #print(f"dice_loss mean:{((numerator + 1) / (denominator + 1)).mean()}")
    return loss.mean()
# human_loss, human_semantic_feature, agn_mask, down_self_residual = 
#                           self.semantic_spatical_model(c["ref_human"], c["caption"],c['pose_img'],c['z_agn_mask'])
class build_model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        #self.pose_query = cfg.MODEL.DECODER_CONFIG.POSE_QUERY
        self.crossattn_mask_threshold = cfg.pop('crossattn_mask_threshold')
        self.use_crossattn_and_feature = cfg.pop('use_crossattn_and_feature')
        self.use_auxiliary_loss = cfg["use_auxiliary_loss"]
        self.backbone = build_backbone(cfg["SwinTransformer_backbone"])
        ##----------------
        ## self.Decoder = build_Decoder(cfg["Decoder"])

        self.pose_encoder = build_pose_encoder(cfg["SwinTransformerV2"])
            
        self.QFormer = build_QFormer_encoding(cfg["QFormer"])

        cfg["TaskFormer"].update({'crossattn_mask_threshold':self.crossattn_mask_threshold})
        cfg["TaskFormer"].update({'use_crossattn_and_feature':self.use_crossattn_and_feature})
        self.TaskFormer = build_TaskFormer_encoding(cfg["TaskFormer"])

        proj_in_channel = 35
        self.pose_feature_proj = [ nn.Conv2d(proj_in_channel, 128, kernel_size=1,stride=1, padding=0),nn.Conv2d(proj_in_channel, 256, kernel_size=1,stride=1, padding=0),
                                 nn.Conv2d(proj_in_channel, 512, kernel_size=1,stride=1, padding=0)]
        #for i in self.pose_feature_proj:
        #    i.bias.data =  i.bias.data.to(torch.float16)
        #    i.weight.data = i.weight.data.to(torch.float16)
        #    print(f"pose_feature_proj:{i}, change to torch.float16")

        print("cond model init successfully!")


        """
        self.decoder = Decoder(
            n_ctx=cfg.MODEL.DECODER_CONFIG.N_CTX,
            ctx_dim=cfg.MODEL.DECODER_CONFIG.CTX_DIM,
            heads=cfg.MODEL.DECODER_CONFIG.HEADS,
            depth=cfg.MODEL.DECODER_CONFIG.DEPTH,
            last_norm=cfg.MODEL.COND_STAGE_CONFIG.LAST_NORM,
            img_size=cfg.INPUT.COND.IMG_SIZE,
            embed_dim=cfg.MODEL.COND_STAGE_CONFIG.EMBED_DIM,
            depths=cfg.MODEL.COND_STAGE_CONFIG.DEPTHS,
            pose_query=cfg.MODEL.DECODER_CONFIG.POSE_QUERY,
            pose_channel=cfg.MODEL.POSE_GUIDANCE_CONFIG.CHANNELS[-1]
        )
        """

        #self.learnable_vector = nn.Parameter(torch.randn((1, cfg.MODEL.DECODER_CONFIG.N_CTX, cfg.MODEL.DECODER_CONFIG.CTX_DIM)))
        #self.u_cond_percent = cfg.MODEL.U_COND_PERCENT
        #self.u_cond_down_block_guidance = cfg.MODEL.U_COND_DOWN_BLOCK_GUIDANCE
        #self.u_cond_up_block_guidance = cfg.MODEL.U_COND_UP_BLOCK_GUIDANCE

    def forward(self, img_cond, text, mask,pose_img):
        #mask = batched_inputs["mask"] if "mask" in batched_inputs else None
        #x, features = self.backbone(batched_inputs["img_cond"], mask=mask)
        #up_block_additional_residuals = self.appearance_encoder(features)
        # loss-[mask, pose]

        # mask--(b,hw), pose_img--(b,c,h,w), 
        ## print(f"backbone feature size:{[i.shape for i in features]}, grad:{[i.requires_grad for i in features]}")
        x, features = self.backbone(img_cond)
        with torch.no_grad():
            if self.use_auxiliary_loss:
                _, pose_features = self.pose_encoder(pose_img) #if pose_img  is not None else (None,None) # feature: [[1, 4096, 128]，[1, 1024, 256]，[1, 256, 512]，[1, 64, 1024]
        ## print(f"pose feature size:{[i.shape for i in pose_features]}, grad:{[i.requires_grad for i in pose_features]}")


        bs = x.shape[0] # 2*bspip install opencv-python-headless
        human_pose_bs = int(bs/2) # bs

        # QFormer out
        x = {"image":None, "text_input":text}
        QFormer_feature = self.QFormer.extract_features(x,image_embeds_frozen=features.pop())
        #print(f"QFormer_Feature: {QFormer_feature.multimodal_embeds.shape}, grad:{QFormer_feature.multimodal_embeds.requires_grad}")
        
        # TaskFormer out
        Task_list = [QFormer_feature.multimodal_embeds, features[2], features[1]]
        attn_masks, Task_features,pred_pose_features, pred_masks, pred_mask_feature = self.TaskFormer(Task_list, features[0])
        #print(f"QFormer_Feature after TaskFormer: {QFormer_feature.multimodal_embeds.shape}, grad:{QFormer_feature.multimodal_embeds.requires_grad}")

        ## print(f"attn masks size: {[i.shape for i in attn_masks]}")
        ## print(f"attn masks  grad: {[i.requires_grad for i in attn_masks]}")
        ## print(f"Task_feature size: {[i.shape if i is not None else i for i in Task_features]}")
        ## print(f"Task_feature grad: {[i.requires_grad if i is not None else i for i in Task_features]}")

        if self.use_auxiliary_loss:
            # compute auxiliary loss
            pose_loss = 0
            ##-print(f"pose_feature:{[i.shape for i in pose_features]}")
            pose_features.pop()

            for i in range(3):
                pose_features[i] = pose_features[i][human_pose_bs:] # 2*bs -> bs
                pred_pose_features[-i-1] = pred_pose_features[-i-1][human_pose_bs:]

                size = int(pose_features[i].shape[-2]**0.5)
                ##-print(f"pred_pose grad:{pred_pose_features[-i-1].requires_grad}")
                #-----------------------------
                self.pose_feature_proj[i] = self.pose_feature_proj[i].to(features[0].device,dtype=features[0].dtype)
                #print(f"pose_feature_proj:{i}, change to torch.float16")

                pred_pose = self.pose_feature_proj[i](pred_pose_features[-i-1].permute(0,2,1).reshape(human_pose_bs,-1,size,size)) # 因为用的是方形的image



                #print("pred_pose:",pred_pose.shape)
                pose = pose_features[i].permute(0,2,1).reshape(human_pose_bs,-1,size,size)
                #print(f"pose shape:{pose.shape}")
                dot_product = torch.sum(pose * pred_pose, dim=(1, 2, 3))
                #norm_tensor1 = torch.norm(pose, dim=(1, 2, 3))
                #norm_tensor2 = torch.norm(pred_pose, dim=(1, 2, 3))
                #norm_tensor1 = torch.linalg.norm(pose, ord='fro', dim=(1, 2, 3))
                #norm_tensor2 = torch.linalg.norm(pred_pose, ord='fro', dim=(1, 2, 3))
                ## print(f"pose max:{torch.max(pose)},min:{torch.min(pose)},pred_pose max:{torch.max(pred_pose)},min:{torch.min(pred_pose)}")
                norm_tensor1 = torch.sqrt(torch.sum(pose ** 2, dim=(1, 2, 3)))
                norm_tensor2 = torch.sqrt(torch.sum(pred_pose ** 2, dim=(1, 2, 3)))
                cosine_sim = (dot_product / norm_tensor1) /norm_tensor2
                
                pose_loss += 1 - cosine_sim.mean()
                ## print(f"dot :{dot_product}, tensor1:{norm_tensor1}, tensor2:{norm_tensor2}.cos_sim:{cosine_sim},pose loss:{pose_loss}")

            pose_loss /= 3
            ## print(f"pose_loss :{pose_loss}")
            
            #--------------------
            # 下面操作的bs就是2*bs
            ##- print(f"pred_mask grad:{pred_masks[-1].requires_grad},mask grad:{mask.requires_grad}, pred_mask_feature grad:{pred_mask_feature.requires_grad}")
            pred_mask = pred_masks[-1]# 选择最后一层的输出
            mask = mask.flatten(1).unsqueeze(1) # [b,1,hw]

            if self.use_crossattn_and_feature:#  [b,8,hw]
                # 从[b,8,hw] 选择一个[b,hw]
                #mask = mask.flatten(1).unsqueeze(1) # [b,1,hw]

                #print(pred_mask.shape, mask.shape)
                #pred_mask = F.interpolate(pred_mask, size=[mask[-2], mask[-1]], mode="bilinear", align_corners=False) #[b,8,hw]

                dot_products = torch.mean((pred_mask_feature * (1-mask))**2, dim=2)  #[b, 8]
                min_indices = torch.argmin(dot_products, dim=1)  #[b]
                pred_mask = pred_mask[torch.arange(bs), min_indices, :] # [b,hw]
            else:
                pred_mask = pred_mask.sigmoid() 
            
            #print(f"pred_mask shape:{pred_mask.shape}, mask shape:{mask.shape}")
            # pred_mask shape:torch.Size([14, 4096]), mask shape:torch.Size([14, 1, 4096])
            h=w= int(pred_mask.shape[-1]**0.5)
            pred_mask = pred_mask.unsqueeze(1).reshape(bs,1, h, w)

            mask_loss = dice_loss(pred_mask, mask)
            ## print(f"pred_mask max:{torch.max(pred_mask)},min:{torch.min(pred_mask)},mask max:{torch.max(mask)},mask loss:{mask_loss}")

            
            #-----------暂时记录
            count_pred_mask = torch.sum(pred_mask > 0.5, dim=(1,2,3)) #[b,1,h,w] -> [b]
            count_mask = torch.sum(mask == 1, dim=(1,2)) #[b,1,hw] -> [b]
            print(f"pred_mask_1_num: {count_pred_mask}, mask_1_num:{count_mask}") # index:{index}
            return_mask = mask.reshape(bs,1,h,w).to(dtype=pred_mask.dtype)
        else:
            h=w= int(pred_mask.shape[-1]**0.5)
            return_mask = mask.reshape(bs,1,h,w).to(dtype=pred_mask.dtype)

        """
        #---------------------------------------------
        # pred_mask的梯度就断了，说明扩散损失不会影响到pred_mask的更新
        ## print(f"the grad before pred_mask transforms into 0.0/1.0: {pred_mask.requires_grad}，pred_mask shape:{pred_mask.shape},mask shape:{mask.shape}")
        if self.use_crossattn_and_feature:
            pred_mask = (pred_mask.reshape(bs,h,w).unsqueeze(1) >= self.crossattn_mask_threshold).to(dtype=pred_mask.dtype) ### 这里阈值需要改变或设置合理的值
        else:
            # [b,1,hw]
            pred_mask = (pred_mask.reshape(bs,1,h,w) > 0.5).to(dtype=pred_mask.dtype)#.detach()
        #pred_mask = torch.where(pred_mask.reshape(bs, h, w).unsqueeze(1) > 0.2, 1.0, 0.0)
        ## print(f"the grad after pred_mask transforms into 0.0/1.0: {pred_mask.requires_grad}")

        #-------------------------------
        # 如果mask预测的1少于target的一半，选择target mask作为unet mask的输入
        count_pred_mask = torch.sum(pred_mask > 0.5, dim=(1,2,3)) #[b,1,h,w] -> [b]
        count_mask = torch.sum(mask == 1, dim=(1,2)) #[b,1,hw] -> [b]
        index = count_pred_mask < (count_mask / 2)
        print(f"pred_mask_1_num: {count_pred_mask}, mask_1_num:{count_mask}") # index:{index}
        pred_mask[index] = mask.reshape(bs,1,h,w).to(dtype=pred_mask.dtype)[index]
        index2 = count_pred_mask >= count_mask * 1.5
        pred_mask[index2] = mask.reshape(bs,1,h,w).to(dtype=pred_mask.dtype)[index2]




        # Decoder out
        out_list = [Task_features[-1], Task_features[-2], Task_features[-3]]
        down_sample_residual = self.Decoder(out_list)
        ## print({k: v.shape for k,v in down_sample_residual.items()})
        """
        loss = {'mask_loss':mask_loss,'pose_loss': pose_loss}

        ###---return loss, QFormer_feature.multimodal_embeds,pred_mask, down_sample_residual
        return loss, QFormer_feature.multimodal_embeds,return_mask, None

    # loss:[a,b], QFormer_feature: [b,16,768], pred_mask:[b,1,h,w],
    # down_sample_residual:{"":}














        """
        if self.training:
            bsz = bsz * 2
            down_block_additional_residuals = self.pose_encoder(torch.cat([batched_inputs["pose_img_src"], batched_inputs["pose_img_tgt"]]))
            up_block_additional_residuals = {k: torch.cat([v, v]) for k, v in up_block_additional_residuals.items()}
            c = self.decoder(x, features, down_block_additional_residuals)
            if not self.pose_query:
                c = torch.cat([c, c])

            u_cond_prop = torch.rand(bsz, 1, 1)
            u_cond_prop = (u_cond_prop < self.u_cond_percent).to(dtype=x.dtype, device=x.device)
            c = self.learnable_vector.expand(bsz, -1, -1).to(dtype=x.dtype) * u_cond_prop + c * (1 - u_cond_prop)
            if self.u_cond_down_block_guidance:
                down_block_additional_residuals = [torch.zeros_like(sample) * u_cond_prop.unsqueeze(1) + \
                                                   sample * (1 - u_cond_prop.unsqueeze(1)) \
                                                   for sample in down_block_additional_residuals]
            if self.u_cond_up_block_guidance:
                up_block_additional_residuals = {k: torch.zeros_like(v) * u_cond_prop + v * (1 - u_cond_prop) \
                                                 for k, v in up_block_additional_residuals.items()}
        else:
            down_block_additional_residuals = self.pose_encoder(batched_inputs["pose_img"])
            c = self.decoder(x, features, down_block_additional_residuals)
            c = torch.cat([self.learnable_vector.expand(bsz, -1, -1).to(dtype=x.dtype), c], dim=0)

        return c, down_block_additional_residuals, up_block_additional_residuals
        """

if __name__ == "__main__":
    import yaml
    with open('./configs/configs.yaml', 'r', encoding='utf-8') as file:
        config = yaml.safe_load(file)
    img = torch.rand([1,3,512,512],dtype=torch.float32)
    pose_imag = torch.rand([1,3,256,256],dtype=torch.float32)
    mask = (torch.rand([1,64,64])<0.5).bool()
    text = '1'
    model =  build_model(config)
    #loss, Qformer_featrue, pred_mask, down_sample_residual = model(img, text, mask,pose_imag)
    #print(f"loss:{loss},grad:{loss[0].requires_grad,loss[1].requires_grad}")
    #print(f"Qformer feature shape: {Qformer_featrue.shape}, grad:{Qformer_featrue.requires_grad}")
    #print(f"pred_mask:{ pred_mask},shape: {pred_mask.shape}, grad :{pred_mask.requires_grad}")
    #print("down_sample_residual:", {k: [v.shape, v.requires_grad] for k,v in down_sample_residual.items()} )
    grad_count = 0
    count = 0
    for i in model.parameters():
        if i.requires_grad: grad_count += 1
        count += 1
    print(f"grad_count:{grad_count}, count:{count}")
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total number of parameters: {total_params}")



class QTaskDiffusion(LatentDiffusion):
    def __init__(
            self,
            validation_config = None,
            #vit_model="clip_L", # qformer中用到的
            #qformer_num_query_token=16,
            #qformer_cross_attention_freq=1,
            #vision_backbone_timestemp = 0,
            #vision_backbone_encoder_hidden_state=None,
            img_H = 512,
            img_W = 384,
            #unet_path= None,
            #qformer_pretrained_path=None, # 只要给出这两个地方，就说明是第一次训练，需要加载别人给的权重，因此有一些额外操作
            #TaskFormer_pretrained_path=None,
            #use_auxiliary_loss = True, ### -----
            p_weight = 0.1,
            m_weight = 0.1,
            semantic_spatical_model_train = True,
            dataloader_prob = None,

            #first_init_from_ckpt = True,

            #use_spatial_semantic_loss = True, # 是否用辅助损失

            #proj_train=False,
            #drop_proportion=[0.15,0.15,0.15], # 丢失概率
            #lamda_weigt = [0.1,0.1,0.1,0.1,0.1,0.1,0.1], # 6 spatial auxiliary loss weight, 1 semantic
            #subject_text_key='text',

            #cond_image_key1='cloth', # 这里标志作为条件的图片在batch中的关键字，我用的cloth
            #cond_image_key2 = 'image',
            #cond_caption_key1='cloth_caption',
            #cond_caption_key2='image_caption',

            #trainable_parameters=None,  # todo-增加部分，训练参数
            diffusers_pretrained_path=None, # 加载已经训练好的断点参数，我可以这里加载整个QFormer_TaskFormer
            #multi_dataloader_prob=None,  # todo-增加？？

            #cross_attention_train=True,  # todo-增加了这个交叉注意力是否需要训练标志
            #self_attention_train=True,

            sd_train_text_encoder=False,
            cond_stage_trainable = False,
            cross_attention_train = True,
            down_self_attention_train = True,
            #first_init_from_ckpt = False,
            #semantic_feature_to_text_dim_train=False,
            #TaskFormerOut_proj_unet_train=False,
            #QFormerOut_proj_backboneDim_train=False,
            #TaskFormerOut_proj_vit_train=False,
            #backbone2Qformer_proj_train = False,
            #QFormer_TaskFormer_train=False,
            #TaskFormer_train=False,
            #qformer_train=False,
            #dataloader_name =2,
            #vision_backbone="unet",
            use_bf16 = False,
            *args, **kwargs
    ):
        ckpt_path = kwargs.pop("ckpt_path", None)  # 初始化unet预训练权重的路径
        ignore_keys = kwargs.pop("ignore_keys", [])  # todo？目前看着没怎么用到
        ## --------------------text第一阶段不考虑
        ## self.first_stage_key_cond = kwargs.get("first_stage_key_cond", None) # 将做额外损失时需要转换到潜空间的mask、潜空间pose等
        super().__init__(*args, **kwargs)
        print("LDM 初始化完成！")
        #self.use_bf16 = use_bf16
        self.img_H=img_H
        self.img_W=img_W
        #self.dataloader_name  = dataloader_name 
        self.validation_config= validation_config

        self.cross_attention_train = cross_attention_train
        self.down_self_attention_train = down_self_attention_train

        self.semantic_spatical_model = build_model(kwargs)
        self.learnable_vector = nn.Parameter(torch.randn((1,16,768)), requires_grad=True)
        """
        self.learnable_vector_dict = nn.ParameterDict(
            {
                'block_1_self_attn':nn.Parameter(torch.randn(1, 1, 320), requires_grad=True), 
                'block_2_self_attn': nn.Parameter(torch.randn(1, 1, 320), requires_grad=True), 
                'block_4_self_attn': nn.Parameter(torch.randn(1, 1, 640), requires_grad=True), 
                'block_5_self_attn': nn.Parameter(torch.randn(1, 1, 640), requires_grad=True), 
                'block_7_self_attn': nn.Parameter(torch.randn(1, 1, 1280), requires_grad=True), 
                'block_8_self_attn': nn.Parameter(torch.randn(1, 1, 1280), requires_grad=True)
            }
        )
        """
        self.use_auxiliary_loss = kwargs['use_auxiliary_loss']
        self.p_weight = p_weight
        self.m_weight = m_weight
        self.semantic_spatical_model_train = semantic_spatical_model_train
        self.dataloader_prob = dataloader_prob

        #self.first_init_from_ckpt= first_init_from_ckpt


        #self.lamda_weigt = lamda_weigt # 用辅助损失时的权重
        #self.use_spatial_semantic_loss = use_spatial_semantic_loss



        #self.multi_dataloader_prob = multi_dataloader_prob  # 每一种训练类型数据占的比例，条件给的是[1,1,1]

        #self.first_init_from_ckpt=first_init_from_ckpt

        self.cond_stage_trainable = cond_stage_trainable  # 强制在get_input取原始文本，不过CLIP text encoder todo-？
        #self.vision_backbone_name = vision_backbone
        


        #self.qformer_train = qformer_train
        #self.TaskFormer_train = TaskFormer_train
        #self.QFormer_Taskformer_train = QFormer_TaskFormer_train
        #self.backbone2Qformer_proj_train = backbone2Qformer_proj_train

        #self.proj_train = proj_train
        #self.self_attention_train = self_attention_train
        self.sd_train_text_encoder = sd_train_text_encoder
        #self.cross_attention_train = cross_attention_train # todo-这是本论文增加的交叉注意力需要训练的地方

        #self.TaskFormerOut_proj_unet_train = TaskFormerOut_proj_unet_train
       
        #self.semantic_feature_to_text_dim_train = semantic_feature_to_text_dim_train

       
        self.freeze_modules()

        

        #self.restarted_from_ckpt = False  # todo-新增是否从断点训练
        #if diffusers_pretrained_path is not None:
        #    self.load_QFormer_TaskFormer_checkpoint_from_dir(diffusers_pretrained_path)
        #if ckpt_path is not None:  # 这里blip_dif没有，LDM中有，加载unet预训练参数
        #    self.init_from_ckpt(ckpt_path, ignore_keys)
        #    self.restarted_from_ckpt = True

       # self.drop_proportion = drop_proportion  # 在训练过程中随机丢弃主题文本的概率
        #self.subject_text_key = subject_text_key  # 指定在数据处理时文本数据的键名
        #self.cond_image_key1 = cond_image_key1  # 记录输入的条件在batch中的key是什么，我这里是cloth作为条件
        #self.cond_image_key2 = cond_image_key2
        #self.cond_caption_key1 = cond_caption_key1
        #self.cond_caption_key2 = cond_caption_key2

        #self.trainable_parameters = trainable_parameters if trainable_parameters is not None else []  # 控制模型训练时哪些参数会更新
        #if len(self.trainable_parameters) == 1 and self.trainable_parameters[0] == 'none':  # 没有要训练的参数时，设置eval模式
        #    self.model.eval()  # todo-这里的self.model是ddpm里定义的
        #    self.model.train = self.disabled_train
        #    for param in self.model.parameters():
        #        param.requires_grad = False
        self.restarted_from_ckpt = False
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys)
            self.restarted_from_ckpt = True


    @torch.no_grad() # todo-本函数就是初次加载unet参数，将增加的cross、self模块东西进行初始化，已完成
    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):  # todo-对DDPM中init_from_ckpt进行了修改       # if self.vision_backbone_name == "unet":
        #"""
        state_dict = torch.load(path, map_location="cpu")['state_dict']
        if "state_dict" in list(state_dict.keys()):
            state_dict = state_dict["state_dict"]
        if not ignore_keys:
            keys = list(state_dict.keys())
            for k in keys:
                for ik in ignore_keys:
                    if k.startswith(ik):
                        print("Deleting key {} from state_dict.".format(k))
                        del state_dict[k]
        
        #----------------------第一次加载unet、vae模型时此处的初始化
        
        print("利用原有inpaint model加载权重")
        for name in list(state_dict.keys()): # 初始化自注意力的i_to_k 或 v
            if "attn2.i_to_k" in name:
                state_dict[name.replace('attn1.i_to_k', 'attn1.to_k')] = state_dict[name]
                #print(name.replace('attn2.i_to_k', 'attn1.i_to_k'))
            elif "attn2.i_to_v" in name:
                state_dict[name.replace('attn1.i_to_v', 'attn1.to_v')] = state_dict[name]
                #print(name.replace('attn2.i_to_k', 'attn1.i_to_k'))
            elif "learnable_vector" in name:
                state_dict["learnable_vector"] = state_dict["learnable_vector"].expand(1,16,768)
            #print(f"{name}:{state_dict['state_dict'][name].shape}")
        
        # 注意原来unet的初始化定义就是self.model
        missing, unexpected = self.load_state_dict(state_dict, strict=False) if not only_model else self.model.load_state_dict(state_dict, strict=False) # 不只有模型：vae、text_encoder、unet，只有模型那只有unet

        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
        #if  self.first_init_from_ckpt:
            #if len(missing) > 0:
            #    print(f"Missing Keys:\n {missing}")
            #if len(unexpected) > 0:
            #    print(f"\nUnexpected Keys:\n {unexpected}")
        print(f"diffusion model missing key:{missing}")
        #for i in missing:
        #    if "model" in i:
        #        print(i)
        print(f"diffusion model unexpected key:{unexpected}")
        #for i in unexpected:
        #    if "model" in i:
        #        print(i)
        #"""
        print("successfully achieve load unet、text_encoder、vae weight!")
    
    @torch.no_grad() # 这里应该不需要计算梯度吧？
    def freeze_modules(self): # todo-本函数就是判断对个别模块冻住时的情况，cross、self的变化还没完成，clip_vit没有考虑在这个模块目前
        #self = self.half()

        to_freeze = [self.first_stage_model]  # self.first_stage_model就是VAE
        to_freeze.append(self.semantic_spatical_model.pose_encoder)
        ## to_freeze.append(self.semantic_spatical_model.backbone)

        #if True:
        print("eval: first_stage_model, FrozenCLIPEmbedder")
            #to_freeze.append(self.model) # 这里的model应该是改写的unet
        """
        if not self.sd_train_text_encoder:
            to_freeze.append(self.cond_stage_model)
        """
        if not self.semantic_spatical_model_train:
            to_freeze.append(self.semantic_spatical_model)


        
        for module in to_freeze:
            module.eval()
            module.train = self.disabled_train
            module.requires_grad_(False)
    
    def configure_optimizers(self):  # todo-本文这里重写了DDPM的configure_optimizers方法，没有加入params的参数都不会更新
        # 设置优化器需要优化的参数，这里Blip-diffusion没有找到
        lr = self.learning_rate
        params = []
        params_name =[]

        # 只训练image2(k,v)这里的参数，不对原来的text2(k,v)修改. 冻结无关diffusion层

        params.append(self.learnable_vector) # 44.8 M    Trainable params
        ####------------先不考虑
        ## params += list(self.learnable_vector_dict.parameters())


        count_cross_and_norm = 0
        count_down_self = 0
        params_cross = []
        params_down_self = []
        
        transformer_param = []
        
        for name, param in self.model.named_parameters():
            """
            if "input_blocks" in name and "attn1" in name:
                print(name)
                param.requires_grad = True
                params_down_self.append(param) 
                count_down_self += 1
            ## ----------------------
            elif "attn2" in name or "transformer_blocks.0.norm" in name:
                print(name)
                param.requires_grad = True
                params_cross.append(param)
                count_cross_and_norm += 1
            else:
                param.requires_grad = False # 冻结无关的stable diffusion1 参数
            """
            if "transformer_blocks" in name:
                param.requires_grad = True
                transformer_param.append(param)
                params_name.append(name)
            else:
                param.requires_grad = False
        params += transformer_param

        if self.cross_attention_train:
            params += params_cross
            print(f"cross_attention_and_norm_train_nums:{count_cross_and_norm}")
        if self.down_self_attention_train: 
            params += params_down_self
            print(f"down_self_attention_train_nums:{count_down_self}")


        
        if self.semantic_spatical_model_train:
            #params += list(self.semantic_spatical_model.parameters())
            count = 0
            cond_model = []
            """
            # backbone 只有position_table 为true
            for name, param in self.semantic_spatical_model.backbone.named_parameters():
                if "relative_position_bias_table"  in name:
                    param.requires_grad = True
                    print(f"traing param in backbone:{name, param.shape}")
                else:
                    param.requires_grad = False
            """
            for name, param in self.semantic_spatical_model.backbone.named_parameters():
                param.requires_grad = True

            for name, param in self.semantic_spatical_model.named_parameters():
                if param.requires_grad:
                    count += 1
                    cond_model.append(param)
                    params_name.append(name)
                else:
                    print(f"{name}, {param.shape}")
            print(f"semantic_spatical_model learning param number: {count}")
            params += cond_model

        if self.sd_train_text_encoder:
            params += list(self.cond_stage_model.parameters())
        

        if self.learn_logvar: # todo-好像是学习方差
            print('Diffusion model optimizing logvar')
            params.append(self.logvar)
        print(f"all training params:{params_name}、learnable_vector, num: {len(params_name) + 1}")

        opt = torch.optim.AdamW(params, lr=lr)

        if self.use_scheduler: # 这一部分原LDM就有
            assert 'target' in self.scheduler_config
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                }]
            return [opt], scheduler
        return opt

    def disabled_train(self, mode=True):
        """Overwrite model.train with this function to make sure train/eval mode
        does not change anymore."""
        return self

    def get_learned_conditioning(self, c):  # todo-原DDPM里LDM只有text_encoder注入，这里有QFormer的图片信息注入
        # c:{'image':[cond_image_key1, conda_image_key2], 'text_input':[[cond_caption_key1], [cond_caption_key2]],"text":[text]}
        # 当前key1 就是 cloth， key2 就是image，text是要输入到diffusion里的
        bs = c['ref_cloth'].shape[0] 

        #cloth_loss, cloth_semantic_feature, mask, _ = self.semantic_spatical_model(c["ref_cloth"], c["caption"], c['z_cloth_mask'])
        #human_loss, human_semantic_feature, agn_mask, down_self_residual = self.semantic_spatical_model(c["ref_human"], c["caption"],c['z_agn_mask'],c['pose_img'])
        
        loss, bs_2_semantic_feature, bs_2_mask, bs_2_down_self_residual = self.semantic_spatical_model(torch.cat([c["ref_cloth"],c["ref_human"]]), c["caption"]*2,torch.cat([c["z_cloth_mask"],c['z_agn_mask']]),torch.cat([torch.zeros_like(c['pose_img']),c["pose_img"]]))
        #bs_2_mask = bs_2_mask.detach() ## 
        #bs_2_semantic_feature = bs_2_semantic_feature.detach()## 
        
        
        
        cloth_semantic_feature = bs_2_semantic_feature[:bs]
        human_semantic_feature = bs_2_semantic_feature[bs:]

        mask = bs_2_mask[:bs]
        agn_mask = bs_2_mask[bs:]
        ###------------------ 不考虑down_sample
        down_self_residual = None
        """
        down_self_residual = {}
        for k,v in bs_2_down_self_residual.items():
            down_self_residual[k] = v[bs:] ## 
        """
        ###————————————————-第一阶段先不考虑text信息
        text_encoder = None
        ### text_encoder = self.cond_stage_model.encode(c["caption"]).detach() ##--------

        if self.training: # [2*b, ]
            ##——————————————————————————第一阶段先不考虑text信息
            ## text_encoder = text_encoder.repeat(2, 1, 1)
            ## rand_num = random.choices([0, 1], weights=self.dataloader_prob)[0]
            ## if rand_num == 0:
            ##    semantic_feature = torch.cat([cloth_semantic_feature, human_semantic_feature])
            ## else:
            semantic_feature = torch.cat([cloth_semantic_feature,cloth_semantic_feature])
            
            pred_mask = 1 - torch.cat([mask, agn_mask])
            ## print(f"pred_mask shape:{pred_mask.shape},grad:{pred_mask.requires_grad}, z shape:{c['z'].shape}")
            inpaint = pred_mask * c['z']
            z = torch.cat([c['z'], inpaint, pred_mask], dim=1)

            u_cond_prop = torch.rand(bs*2, 1, 1).to(self.device)
            u_cond_prop = (u_cond_prop < self.u_cond_percent).to(dtype=semantic_feature.dtype, device=semantic_feature.device)
            semantic_feature = self.learnable_vector.expand(bs*2, -1, -1) * u_cond_prop + semantic_feature * (1 - u_cond_prop) # [bs*2, 16, 768]
            # 第一阶段训练先不考虑空间特征
            """
            u_cond_prop = u_cond_prop[bs:]
            for key in self.learnable_vector_dict.keys():
                ## print(f"down_self_residual shape:{down_self_residual[key].shape}, grad:{down_self_residual[key].requires_grad},learnable_vector_dict shape:{self.learnable_vector_dict[key].shape},\
                ##      u_cond_prop shape:{u_cond_prop.shape},grad:{u_cond_prop.requires_grad}")
                human_pose_cond = self.learnable_vector_dict[key]* u_cond_prop + down_self_residual[key]*(1-u_cond_prop)
                _,len,dim = human_pose_cond.shape
                down_self_residual[key] = torch.cat([self.learnable_vector_dict[key].expand(bs,len,dim),human_pose_cond]) # [bs*2, hw, dim]
            """
        else: # [b,]
            semantic_feature = cloth_semantic_feature
            inpaint = (1 - agn_mask) * c['z'][bs:]
            z = torch.cat([c['z'][bs:],inpaint,1-agn_mask], dim=1)

        cond =[semantic_feature, None,text_encoder]

        return cond, down_self_residual, loss, z

        













        if 'image' in c: # 有条件时的cond计算,
            
            out = {}
            if len(c["image"]) == 2: # 训练 taget_image+cloth_image 情况，编辑try_on任务
                assert len(c["text_input"]) == 2,"the size of image and text is not compatible!"
                #print("两类图片提取条件信息")
                
                cloth_image =  c["image"][0]
                cloth_text = c["text_input"][0]

                cloth_cond = self.forward_ctx_embeddings(input_image=cloth_image,text_input=cloth_text,timestemps=self.vision_backbone_timestemp,encoder_hidden_states=self.vision_backbone_encoder_hidden_state)
                cloth_spatial_feature = cloth_cond.pop('spatial_cond', None)
                encoder_cloth_semantic_feature  = cloth_cond.pop('semantic_cond',None)
                
                
                target_image = c["image"][1]
                target_text = c["text_input"][1]
                target_cond = self.forward_ctx_embeddings(input_image=target_image,text_input=target_text,timestemps=self.vision_backbone_timestemp,encoder_hidden_states=self.vision_backbone_encoder_hidden_state)
                encoder_target_spatial_feature2 = encoder_target_spatial_feature = target_cond.pop('spatial_cond', None)
                #target_semantic_feature = target_cond.pop('semantic_cond', None)
                
                out["semantic_cond"] = encoder_cloth_semantic_feature
                out["QFormer_text_embedding"] = cloth_cond.pop("QFormer_text_embedding", None)

                out["cloth_mask_feature_matching"] = target_cond.pop("cloth_mask_feature_matching",None)
                out["agn_mask_feature_matching"] = target_cond.pop("agn_mask_feature_matching",None)
                out["openpose_map_feature_matching"] = target_cond.pop("openpose_map_feature_matching",None)

                out["densepose_feature_matching"] = target_cond.pop("densepose_feature_matching", None)
                out["openpose_img_feature_matching"] = target_cond.pop("openpose_img_feature_matching",None)
                out["parse_feature_matching"] = target_cond.pop("parse_feature_matching", None)

                if self.first_stage_key == self.cond_image_key1: # 当想重建的照片和原参考任务照片cond_image_key1相同时，即动作编辑任务(key1由cloth转换成第二个human图片)，通过在上采样阶段注入新动作特征
                    encoder_target_spatial_feature2 = cloth_spatial_feature
                # 当不在训练时，两种文本就是正常的 衣服 + 空间状态文本
                # todo- target_text、spatial、semantic三个条件以什么样的方式drop，可以后面完善一下
                cloth_text_encoder_hidden_states = self.cond_stage_model.encode(cloth_text)
                spatial_text_encoder_hidden_states = self.cond_stage_model.encode(c[self.cond_stage_key])


                if self.training:  # 训练状态下以一定概率丢弃style_embedding、content_embedding
                    rand_num = random.randint(0, 2)
                    if rand_num == 0: # 属性信息
                        encoder_cloth_semantic_feature = encoder_cloth_semantic_feature * c.get('drop_mask', 1)
                        cloth_text_encoder_hidden_states = self.cond_stage_model.encode(['' for i in target_text])
                    elif rand_num == 1: # 空间特征
                        encoder_target_spatial_feature = encoder_target_spatial_feature * c.get('drop_mask', 1)
                        encoder_target_spatial_feature2 = encoder_target_spatial_feature2 * c.get('drop_mask', 1)
                        spatial_text_encoder_hidden_states = self.cond_stage_model.encode(['' for i in c[self.cond_stage_key]])
                    else:
                        encoder_cloth_semantic_feature = encoder_cloth_semantic_feature * c.get('drop_mask', 1)
                        encoder_target_spatial_feature =  encoder_target_spatial_feature* c.get('drop_mask', 1)
                        encoder_target_spatial_feature2 = encoder_target_spatial_feature2 * c.get('drop_mask', 1)
                        spatial_text_encoder_hidden_states = self.cond_stage_model.encode(['' for i in c[self.cond_stage_key]])
                        cloth_text_encoder_hidden_states = self.cond_stage_model.encode(['' for i in target_text])

                    
            else: # 训练只有taget_image或cloth_image即输入，重建任务
                assert len(c["text_input"]) == 1,"一类参考图片对应一类文本"
                cond = self.forward_ctx_embeddings(input_image=c["image"][0],text_input=c["text_input"][0],timestemps=self.vision_backbone_timestemp,encoder_hidden_states=self.vision_backbone_encoder_hidden_state)
                encoder_target_spatial_feature2 = encoder_target_spatial_feature = cond.pop('spatial_cond', None)
                encoder_cloth_semantic_feature = cond.pop('semantic_cond',None)

                if self.dataloader_name == 0: # 只有cloth_image
                    out["semantic_cond"] = encoder_cloth_semantic_feature
                    out["QFormer_text_embedding"] = cond.pop("QFormer_text_embedding", None)
                elif self.dataloader_name == 1: # 只有 target_image
                    out["cloth_mask_feature_matching"] = cond.pop("cloth_mask_feature_matching",None)
                    out["agn_mask_feature_matching"] = cond.pop("agn_mask_feature_matching",None)
                    out["openpose_map_feature_matching"] = cond.pop("openpose_map_feature_matching",None)

                    out["densepose_feature_matching"] = cond.pop("densepose_feature_matching", None)
                    out["openpose_img_feature_matching"] = cond.pop("openpose_img_feature_matching",None)
                    out["parse_feature_matching"] = cond.pop("parse_feature_matching", None)

                if self.training:  # 训练状态下以一定概率丢弃style_embedding、content_embedding
                    rand_num = random.randint(0, 2)
                    if rand_num == 0:
                        encoder_cloth_semantic_feature = encoder_cloth_semantic_feature * c.get('drop_mask', 1)
                    elif rand_num == 1:
                        encoder_target_spatial_feature = encoder_target_spatial_feature * c.get('drop_mask', 1)
                        encoder_target_spatial_feature2 = encoder_target_spatial_feature2 * c.get('drop_mask', 1)
                    else:
                        encoder_cloth_semantic_feature = encoder_cloth_semantic_feature * c.get('drop_mask', 1)
                        encoder_target_spatial_feature =  encoder_target_spatial_feature* c.get('drop_mask', 1)
                        encoder_target_spatial_feature2 = encoder_target_spatial_feature2 * c.get('drop_mask', 1)
                    spatial_text_encoder_hidden_states = cloth_text_encoder_hidden_states = self.cond_stage_model.encode(['' for i in c["text_input"][0]]) # 利用文本对单图做重构，此处的文本是空。主要强化QTaskformer的功能
                else:
                    cloth_text_encoder_hidden_states = self.cond_stage_model.encode(c["text_input"][0]) # 测试单图重构：1.如果是衣服图片，则两个文本输入都是衣服的描述即可 2.如果是穿衣模特照片，衣服文本 + 空间状态文本
                    if self.dataloader_name == 0:                                                       
                        spatial_text_encoder_hidden_states=cloth_text_encoder_hidden_states
                    elif self.dataloader_name == 1:
                        spatial_text_encoder_hidden_states = self.cond_stage_model.encode(c[self.cond_stage_key])


            encoder_hidden_states = [encoder_cloth_semantic_feature,cloth_text_encoder_hidden_states, spatial_text_encoder_hidden_states,encoder_target_spatial_feature,encoder_target_spatial_feature2]
        else:  # 没有参考图片输入时，这时只有target_text信息，好像这里当成无条件引导时的无条件的处理
            cloth_text_encoder_hidden_states = self.cond_stage_model.encode(c['text_input'][0])
            spatial_text_encoder_hidden_states = self.cond_stage_model.encode(c['text'])
            b, _, dim = cloth_text_encoder_hidden_states.shape
            #spatial_len = self.QFormer_TaskFormer.TaskFormer.semantic_feature_start
            semantic_len = self.QFormer_TaskFormer.TaskFormer.num_queries - self.QFormer_TaskFormer.TaskFormer.semantic_feature_start

            encoder_hidden_states = [torch.zeros(b,semantic_len, dim).to(self.device),
                                     cloth_text_encoder_hidden_states,spatial_text_encoder_hidden_states,torch.zeros(b,(self.img_H //8)*(self.img_W//8) , dim).to(self.device),torch.zeros(b, (self.img_H //8)*(self.img_W//8), dim).to(self.device)]
            # todo-没提图片时，增加上全0的query，这时候 encoder_hidden_states组成和上面还不同
        if self.train and self.use_spatial_semantic_loss:
            return encoder_hidden_states, [c,out]
        
        return encoder_hidden_states # out{"mask_feature_matching","densepose_feature_matching","pose_feature_matching","semantic_feature_matching","QFormer_text_embedding"}


    @torch.no_grad()
    def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False, force_c_encode=False,
                 *args, **kwargs):  # todo-对DDPM里LDM的get_input重写，训练中采用的到，推理不需要这个函数
        # 对输入批次中图片的处理，主要将需要vae编码的图片进行编码、条件信息c的整合
        # c["image":, "text_input":, "drop_mask":, ], 如果做额外损失，可能增加["latent_mask"：，"latent_pose":, "latent_densepose": ]
        # batch["first_stage_key=image":,"cond_image_key1=cloth":,"cond_stage_key=text": ],可能再增加上面额外的三种key
        # text 应该就是[[cloth_text],[target_text]]这种格式
        # 0-drop spatial、semantic，1-drop text，2 -no drop
        if k == "inpaint":
            #x = batch['GT']
            ref_cloth = batch['ref_cloth']
            ref_human = batch['ref_human']
            cloth_mask = batch['cloth_mask']
            agn_mask = batch['agn_mask']
            pose_img = batch['pose_img']
        else:
            x = batch[k]
            if len(x.shape) == 3:
                x = x[..., None]
            if bs is not None:
                x = x[:bs]
            x = x.to(self.device)
            encoder_posterior = self.encode_first_stage(x)
            x = self.get_first_stage_encoding(encoder_posterior).detach()
            return x

           

        # x = rearrange(x, 'b h w c -> b c h w')
        #x = x.to(memory_format=torch.contiguous_format).float()
        ref_cloth = ref_cloth.to(memory_format=torch.contiguous_format).float()
        ref_human = ref_human.to(memory_format=torch.contiguous_format).float()
        pose_img = pose_img.to(memory_format=torch.contiguous_format).float()
        cloth_mask = cloth_mask.to(memory_format=torch.contiguous_format).float()
        agn_mask = agn_mask.to(memory_format=torch.contiguous_format).float()

        if bs is not None:
            bs = int(bs)
            #print(f"bs :{bs}, bs type:{type(bs)},ref_cloth shape:{ref_cloth.shape}")
            ref_cloth = ref_cloth[:bs]
            ref_human = ref_human[:bs]
            pose_img = pose_img[:bs]
            agn_mask = agn_mask[:bs]
            cloth_mask = cloth_mask[:bs]

        z_cloth = ref_cloth.to(self.device)
        z_human = ref_human.to(self.device)
        encoder_posterior = self.encode_first_stage(z_cloth)
        z_cloth = self.get_first_stage_encoding(encoder_posterior).detach()
        encoder_posterior = self.encode_first_stage(z_human)
        z_human = self.get_first_stage_encoding(encoder_posterior).detach()
        z = torch.cat((z_cloth, z_human)) # 通道维度拼接，同时训练生成衣服和目标图片
        z_cloth_mask = Resize([z.shape[-2],z.shape[-1]])(cloth_mask)
        z_agn_mask = Resize([z.shape[-2],z.shape[-1]])(agn_mask)
        #mask = torch.cat([z_cloth_mask, z_agn_mask])

        # 本质这段没啥用，就是对于本实验就是得到文本 batch[caption]
        if self.model.conditioning_key is not None: 
            if cond_key is None:
                cond_key = self.cond_stage_key
            if cond_key != self.first_stage_key:
                if cond_key in ['txt', 'caption', 'coordinates_bbox']:
                    xc = batch[cond_key]
                elif cond_key == 'image':
                    xc = ref_cloth 
                elif cond_key == 'class_label':
                    xc = batch
                else:
                    xc = super().get_input(batch, cond_key).to(self.device)
            else:
                xc = x
            if not self.cond_stage_trainable or force_c_encode:
                if isinstance(xc, dict) or isinstance(xc, list):
                    # import pudb; pudb.set_trace()
                    c = super().get_learned_conditioning(xc)
                else:
                    c = super().get_learned_conditioning(xc.to(self.device))
                    c = super().proj_out(c)
                    c = c.float()
            else:
                c = xc
            if bs is not None:
                c = c[:bs]

            if self.use_positional_encodings:
                pos_x, pos_y = self.compute_latent_shifts(batch)
                ckey = __conditioning_keys__[self.model.conditioning_key]
                c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}

        else:
            c = None
            xc = None
            if self.use_positional_encodings:
                pos_x, pos_y = self.compute_latent_shifts(batch)
                c = {'pos_x': pos_x, 'pos_y': pos_y}

        out = {
            'z':z, 'caption':c, 'ref_cloth':ref_cloth, 'ref_human':ref_human, 'z_cloth_mask':z_cloth_mask,'z_agn_mask':z_agn_mask,'pose_img':pose_img
        }

        
        return out

        

    


        """
        keys_list = list(batch.keys())
        print(f"batch keys: {keys_list}")

        self.dataloader_name = random.choices([0, 1, 2], weights=self.multi_dataloader_prob)[0]  # 根据权重 self.multi_dataloader_prob 从 batch_keys 中随机选择一个键，self.multi_dataloader_prob配置是[1,1,1]
        drop_proportion = self.drop_proportion[self.dataloader_name]

        if bs is  None:
            bs = batch[k].size()[0]
        
        out = self.batch2latent(batch[k][:bs])

        
        # todo-考虑重建损失，即image只有一个target_image,和无分类引导丢弃条件概率的考虑
        if self.dataloader_name == 2:
            cloth_target_image = torch.cat([batch[self.cond_image_key1][:bs],batch[self.cond_image_key2][:bs]])
            latent_cloth_image, latent_image = self.batch2latent(cloth_target_image).chunk(2)

            cloth_target_text = [batch[self.cond_caption_key1][:bs], batch[self.cond_caption_key2][:bs]]
            #outputs = super().get_input(batch, k, bs=bs, **kwargs) #调用LDM的输入函数，将target_image、文本进行输入处理，todo-不再多输出。输出z、c就行，后面改一下
            #c = outputs[1] # str text, todo- 这里只想得到原文本如str类型，而不是编码后的文本，所以训练要设置cond_stage_trainable=True，这样会跳过编码
            #latent_cloth, _ = super().get_input(batch, self.cond_image_key1, bs=bs, **kwargs)

            image = [latent_cloth_image,latent_image]

            # 这里的c应该就是[[cloth_text],[target_text]],image也是，所以get_learning_condition那理得改一下
            c = {'text_input':  cloth_target_text, 'image': image}
        elif self.dataloader_name == 0:
            cloth_image = self.batch2latent(batch[self.cond_image_key1][:bs])
            c = {'text_input':[batch[self.cond_caption_key1][:bs]],'image':[cloth_image] }
            if self.cond_image_key1 != k:
                out = cloth_image
        else:
            target_image = self.batch2latent(batch[self.cond_image_key2][:bs])
            c = {'text_input':[batch[self.cond_caption_key2][:bs]],'image':[target_image] }
            if self.cond_image_key2 != k:
                out = target_image

        if self.training:  # todo- drop_subject_prob这个是求drop_mask时使用的
            # 按照drop_subject_prob的概率随机丢弃一部分inp_image即将这些图片的mask地方标为1
            c['drop_mask'] = (torch.rand(len(out)) >= drop_proportion).reshape(len(out), 1, 1).to(out.device)[:bs]
        else:
            c['drop_mask'] = (torch.rand(len(out)) >= 0.0).reshape(len(out), 1, 1).to(out.device)[:bs]

        if self.first_stage_key_cond is not None: # 将第一阶段VAE编码，需要用到的数据进行编码即"agn", "agn_mask", "image_densepose"
            for key in self.first_stage_key_cond:
                if 'img' in key:
                    cond = self.batch2latent(batch[key][:bs]) # 这里调用LDM的get_input即得到编码的cond
                else:
                    cond = self.batch2latent(batch[key][:bs],  no_latent=True)
                c[f"latent_{key}"] = cond
        
        c["text"] = batch[self.cond_stage_key][:bs] # 空间信息的文本

        return out,c  # [z, c] -->  c[text_input, image, drop_mask, text, latent_agn_mask, latent_wrap_cloth_mask,latent_openpose_map,latent_openpose_img,latent_densepose_img,latent_parse_img]
    """


    def apply_model(self, x_noisy, t, cond, down_sample_residual=None,selfattn_img_weight=1.0,
                    return_ids=False):  # todo- 改了ddpm LDM里的apply_model函数，不再考虑图片的分割以及对分割图应用扩散再恢复的策略，增加img_weight这一个参数
        # img_weight: 图像权重，用于平衡不同输入作为条件的影响，如text_embedding和imgae_embedding
        if isinstance(cond, dict):
            # hybrid case, cond is expected to be a dict
            pass
        else:  # 将cond的list类型转换成dict类型
            if not isinstance(cond, list):
                cond = [cond]
            key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
            cond = {key: cond} 

        x_recon = self.model(x_noisy, t,down_residual=down_sample_residual,selfattn_img_weight=selfattn_img_weight,**cond)  # **cond 的作用是将前面创建的 cond 字典中的所有键值对作为关键字参数传递给 self.model 函数

        if isinstance(x_recon, tuple) and not return_ids:
            return x_recon[0]
        else:
            return x_recon

    @torch.no_grad()
    def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=100, ddim_eta=1., return_keys=None,
                   quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
                   plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
                   use_ema_scope=True,
                   **kwargs):  # todo- 对DDPM里log_images的重写

        # 通常用于模型的评估或测试阶段，其中可能希望使用更稳定的 EMA 参数来提高性能，而在训练阶段则使用模型的实际参数。
        #ema_scope = self.ema_scope if use_ema_scope else nullcontext  # todo-新增

        use_ddim = ddim_steps is not None

        log = dict()
        #if self.use_bf16:  # 与标准的 float32（32位浮点数）相比，它可以提供相似的动态范围和精度，但占用更少的内存和计算资源
        #    amp_dtype = torch.bfloat16
        #else:
        #    amp_dtype = torch.float16
        #with torch.cuda.amp.autocast(dtype=amp_dtype,
        #                             enabled=self.use_bf16):  # 将自动将一些操作转换为使用 amp_dtype 数据类型，从而减少内存使用和加速计算
        out = self.get_input(batch, self.first_stage_key, bs=N/2)
        c, down_sample_residual, auxiliary_loss, z = self.get_learned_conditioning(out)

        N = min(z.shape[0], N)
        n_row = min(z.shape[0], n_row)

        #x = batch[self.first_stage_key]
        #if len(x.shape) == 3:
        #    x = x[..., None]
        #x = rearrange(x, 'b h w c -> b c h w')
        #x = x.to(memory_format=torch.contiguous_format).float()
        pred_agn_feature = self.decode_first_stage(z[:,4:-1,:,:]) # 

        log["cloth_human_predAgnFeature"] = torch.cat((out['ref_cloth'], out['ref_human'],pred_agn_feature))
        _,_,h,w = out['ref_cloth'].shape
        log["reconstruction"] = self.decode_first_stage(out['z'])

        log["pose_img2input_size"] = Resize((h,w), interpolation=2)(out['pose_img'])
        log['z_resize_to_512_cloth_agn_predAgnMask'] = torch.cat((Resize((h,w), interpolation=2)(out['z_cloth_mask']),Resize((h,w), interpolation=2)(out['z_agn_mask']),(Resize((h,w), interpolation=2)(z[:,-1,:,:])).unsqueeze(1)))
        log['z_64_cloth_agn_predAgnMask'] = torch.cat((out['z_cloth_mask'],out['z_agn_mask'],z[:,-1,:,:].unsqueeze(1)))

        
            
            
        """
            if plot_diffusion_rows:
                # get diffusion row
                diffusion_row = list()
                z_start = z[:n_row]
                for t in range(self.num_timesteps):
                    if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
                        t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
                        t = t.to(self.device).long()
                        noise = torch.randn_like(z_start)
                        z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
                        diffusion_row.append(self.decode_first_stage(z_noisy))

                diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
                diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
                diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
                diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
                log["diffusion_row"] = diffusion_grid

            if self.dataloader_name == 2:
                log[self.cond_image_key1] = c['image'][0]
                log[self.cond_image_key2] = c['image'][1]
            elif self.dataloader_name == 0:
                 log[self.cond_image_key1] = c['image']
            else:
                log[self.cond_image_key2] = c['image']

            #if self.cond_image_key2 in c:  # cloth图片记录
            #    log[f"{self.cond_image_key2}"] = c


        """
            
        if self.model.conditioning_key is not None:
            if hasattr(self.cond_stage_model, "decode"):
                xc = self.cond_stage_model.decode(c)
                log["conditioning"] = xc
            elif self.cond_stage_key in ["caption", "txt"]:
                xc = log_txt_as_img((h, w), batch[self.cond_stage_key])
                log[f"{self.cond_stage_key}"] = xc
            elif self.cond_stage_key == 'class_label':
                xc = log_txt_as_img((z.shape[2], z.shape[3]), batch["human_label"])
                log['conditioning'] = xc
            elif isimage(xc):
                log["conditioning"] = xc
            if ismap(xc):
                log["original_conditioning"] = self.to_rgb(xc)
            #elif isimage(xc):
            #    log["conditioning"] = xc
            #if ismap(xc):
            #    log["original_conditioning"] = self.to_rgb(xc)
        if plot_diffusion_rows:
            # get diffusion row
            diffusion_row = list()

            z_start = z[:n_row] #-------------这里是[b,9,h,w]吧如何作为采样噪声，不应该是[b,4,h,w]计算吗

            for t in range(self.num_timesteps):
                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
                    t = t.to(self.device).long()
                    noise = torch.randn_like(z_start)
                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
                    diffusion_row.append(self.decode_first_stage(z_noisy))

            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
            log["diffusion_row"] = diffusion_grid
            
        if sample:
            # get denoise row
            with self.ema_scope("Plotting"): # 本实验self.use_ema=False即没用到
                if self.first_stage_key=='inpaint':
                    #print(f"DDIM cond :{c}")
            
                    samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,down_sample_residual=down_sample_residual, # ----- 
                                                            ddim_steps=ddim_steps,eta=ddim_eta,rest=z[:,4:,:,:])
                else:
                    samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,down_sample_residual=down_sample_residual,# ------
                                                            ddim_steps=ddim_steps,eta=ddim_eta)                
                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
            x_samples = self.decode_first_stage(samples)
            log["samples"] = x_samples

            if plot_denoise_rows:
                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
                log["denoise_row"] = denoise_grid

            if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
                    self.first_stage_model, IdentityFirstStage):
                # also display when quantizing x0 while sampling
                with self.ema_scope("Plotting Quantized Denoised"):
                    samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
                                                             ddim_steps=ddim_steps,eta=ddim_eta,
                                                             quantize_denoised=True)
                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
                    #                                      quantize_denoised=True)
                x_samples = self.decode_first_stage(samples.to(self.device))
                log["samples_x0_quantized"] = x_samples
            
            if inpaint:
                # make a simple center square
                b, h, w = z.shape[0], z.shape[2], z.shape[3]
                mask = torch.ones(N, h, w).to(self.device)
                # zeros will be filled in
                mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
                mask = mask[:, None, ...]
                with self.ema_scope("Plotting Inpaint"):

                    samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
                                                ddim_steps=ddim_steps, x0=z[:N,:4], mask=mask)
                x_samples = self.decode_first_stage(samples.to(self.device))
                log["samples_inpainting"] = x_samples
                log["mask"] = mask

                # outpaint
                with self.ema_scope("Plotting Outpaint"):
                    samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
                                                ddim_steps=ddim_steps, x0=z[:N], mask=mask)
                x_samples = self.decode_first_stage(samples.to(self.device))
                log["samples_outpainting"] = x_samples

        if plot_progressive_rows:
            with self.ema_scope("Plotting Progressives"):
                img, progressives = self.progressive_denoising(c,
                                                               shape=(self.channels, self.image_size, self.image_size),
                                                               batch_size=N)
            prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
            log["progressive_row"] = prog_row

        if return_keys:
            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
                return log
            else:
                return {key: log[key] for key in return_keys}
        
        return log

        """
            if unconditional_guidance_scale > 1.0:  # 如果使用非条件引导，生成非条件引导图片并记录
                uc_N = N
                if isinstance(unconditional_guidance_label, ListConfig):
                    unconditional_guidance_label =  list(unconditional_guidance_label)
                uc = self.get_learned_conditioning({
                                                       'text_input': [uc_N * unconditional_guidance_label], # 注意text_input 和 image这两个值都是list，因为可能有两种图片或者两种文本情况。而正常的text即空间文本信息就一个
                                                       'text': uc_N * unconditional_guidance_label
                                                       })  # 无条件embedding生成，这里对应健target_text的值就是N个unconditional_guidance_label
                #uc.to(self.device)
                #print(f"uc device:{uc.device}")
                if self.train and self.use_spatial_semantic_loss:
                    c,_ = self.get_learned_conditioning(c)
                else:
                    c = self.get_learned_conditioning(c)

                with ema_scope("Sampling with classifier-free guidance"):
                    samples_cfg, _ = self.sample_log(cond=c, batch_size=uc_N, ddim=use_ddim,
                                                     ddim_steps=ddim_steps, eta=ddim_eta,
                                                     unconditional_guidance_scale=unconditional_guidance_scale,
                                                     unconditional_conditioning=uc, # 这里原来是[uc,uc]
                                                     # DDPM采样用不到这个，DDIM采样能用到，todo-不知道改没改DDIM采样代码
                                                     )
                    x_samples_cfg = self.decode_first_stage(samples_cfg)
                    log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
            """
    """
    @torch.no_grad()
    def on_validation_epoch_start(self): # 验证集每一epoch开始动作
        self.ddim_sampler = DDIMSampler(self)
        self.validation_gene_dirs = []
        for data_type in ["pair", "unpair"]:
            to_dir = opj(self.valid_config.img_save_dir, f"{data_type}_{self.current_epoch}")
            self.validation_gene_dirs.append(to_dir)
    """
    @torch.no_grad()
    def validation_step(self, batch, batch_idx, dataloader_idx): # stableVITON
        if batch_idx > 2: return
        is_train = self.training
        if is_train:
            self.eval()
        #self.batch = batch
        data_type = "pair" if dataloader_idx==0 else "unpair"

        #z, c = self.get_input(batch, self.first_stage_key)
        out = self.get_input(batch, self.first_stage_key)
        c, down_self_residual, auxiliary_loss, z = self.get_learned_conditioning(out)
        bs = z.shape[0]

        x_recon = self.decode_first_stage(out['z'][bs:])
        shape = (4, self.img_H//8, self.img_W//8)
        ## print(f"validation bs: {bs}")
        """
        unconditional_guidance_label=["over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"]
        uc = self.get_learned_conditioning({
                                            'text_input': [bs * unconditional_guidance_label], # 注意text_input 和 image这两个值都是list，因为可能有两种图片或者两种文本情况。而正常的text即空间文本信息就一个
                                            'text': bs * unconditional_guidance_label
                                                       })  # 无条件embedding生成，这里对应健target_text的值就是N个unconditional_guidance_label
        if self.train and self.use_spatial_semantic_loss:
            c, _ = self.get_learned_conditioning(c)
        else:
            c = self.get_learned_conditioning(c)
        """
        ddim_sampler = DDIMSampler(self)
        samples, intermediates = ddim_sampler.sample(
            self.validation_config.ddim_steps,
            bs,
            shape,
            c,
            down_sample_residual=down_self_residual,
            x_T=None,
            verbose=False,
            eta=self.validation_config.eta,
            mask=None,
            x0=None,
            rest=z[:,4:,:,:],
            #unconditional_guidance_scale=self.validation_config.scale,
            #unconditional_conditioning= uc
        )
        x_samples = self.decode_first_stage(samples)
        pred_agn_feature = self.decode_first_stage(z[:,4:-1,:,:])
        # z (b,9,64,64)
        pred_agn_mask = Resize((x_samples.shape[2],x_samples.shape[3]), interpolation=2)(z[:,-1,:,:]).unsqueeze(1)
        ##print(f"pred_agn_mask shape before zip : { pred_agn_mask.shape}, count 1 number: {torch.sum(pred_agn_mask == 1).item()}")
        ##print(f"pred_agn_mask : {pred_agn_mask}")
        to_dir = opj(self.validation_config.img_save_dir, f"{data_type}_{self.current_epoch}")
        os.makedirs(to_dir, exist_ok=True)
        for x_sample, gt_cloth_img, fn, recon,pred_agn_feature,pred_agn_mask in zip(x_samples, batch['ref_cloth'],batch["im_name"], x_recon,pred_agn_feature,pred_agn_mask):
            x_sample_img = tensor2img(x_sample)
            x_recon_img = tensor2img(recon)
            gt_cloth_img = tensor2img(gt_cloth_img)
            pred_agn_feature = tensor2img(pred_agn_feature)
            #print(f"pred_agn_mask shape: {pred_agn_mask.shape}")
            pred_agn_mask = tensor2img(pred_agn_mask)
            #print(f"[x_sample_img, gt_img, cloth_img, x_recon_img] shape: {[i.shape for i in [x_sample_img, gt_img, cloth_img, x_recon_img]]}")
            cloth_save = np.concatenate([x_sample_img, gt_cloth_img, x_recon_img,pred_agn_feature,pred_agn_mask], axis=1)
            to_path = opj(to_dir, fn)
            cv2.imwrite(to_path, cloth_save[:,:,::-1]) # 原来是RGB的形式，cv2处理的格式是BGR
        if is_train:
            self.train()

    def get_auxiliary_loss(self, auxiliary):
        loss = []
        # spatial loss
        if auxiliary[1].get("mask_feature_matching",None) is not None:
            loss.append(dice_loss(auxiliary[1]["cloth_mask_feature_matching"],auxiliary[0]["latent_warped_cloth_mask"].flatten(1), 1))
            loss.append(dice_loss(auxiliary[1]["agn_mask_feature_matching"],auxiliary[0]["latent_agn_mask"].flatten(1),1))
            loss.append(dice_loss(auxiliary[1]["openpose_map_feature_matching"],auxiliary[0]["latent_openpose_map"].flatten(1),18))
            #loss.append(self.get_loss(auxiliary[1]["cloth_mask_feature_matching"],auxiliary[0]["latent_warped_cloth_mask"]))
            #loss.append(self.get_loss(auxiliary[1]["agn_mask_feature_matching"],auxiliary[0]["latent_agn_mask"]))
            #loss.append(self.get_loss(auxiliary[1]["openpose_map_feature_matching"],auxiliary[0]["latent_openpose_map"]))
            #mask_pre, mask_target = self.vision_backbone(torch.cat[auxiliary[1]["mask_feature_matching"],auxiliary[0]["latent_mask"]],self.vision_backbone_timestemp,self.vision_backbone_encoder_hidden_state)[0].chunk()
            densepose_pre = self.vision_backbone(auxiliary[1]["densepose_feature_matching"],self.vision_backbone_timestemp,self.vision_backbone_encoder_hidden_state)[0]
            densepose_target = self.vision_backbone(auxiliary[0]["latent_densepose_img"],self.vision_backbone_timestemp,self.vision_backbone_encoder_hidden_state)[0]
            openpose_pre = self.vision_backbone(auxiliary[1]["openpose_img_feature_matching"],self.vision_backbone_timestemp,self.vision_backbone_encoder_hidden_state)[0]
            openpose_target = self.vision_backbone(auxiliary[0]["latent_openpose_img"],self.vision_backbone_timestemp,self.vision_backbone_encoder_hidden_state)[0]
            parse_pre= self.vision_backbone(auxiliary[1]["parse_feature_matching"],self.vision_backbone_timestemp,self.vision_backbone_encoder_hidden_state)[0]
            parse_target = self.vision_backbone(auxiliary[0]["latent_parse_img"],self.vision_backbone_timestemp,self.vision_backbone_encoder_hidden_state)[0]
            #loss.append(self.get_loss(mask_pre, mask_target))                                 # 因为目前不确定有没有openpose 先不用
            loss.append(self.get_loss(densepose_pre, densepose_target))
            loss.append(self.get_loss(openpose_pre, openpose_target))
            loss.append(self.get_loss(parse_pre, parse_target))
        else:
            for i in range(6):
                loss.append(torch.zeros([1,]))
        # semantic loss
        if auxiliary[1].get("semantic_feature_matching", None) is not None:
            sim_q2t = torch.matmul(  # [b,semantic_query, dim], [b, 1, dim]
                auxiliary[1]["semantic_feature_matching"], auxiliary[1]["QFormer_text_out"].permute(0,2,1)
            ).squeeze()
            sim_i2t, _ = sim_q2t.max(-1)
            loss.append(sim_i2t.mean())
        else:
            loss.append(torch.zeros([1,]))

        sum = 0
        for i in range(7):
            sum += loss[i] * self.lamda_weigt[i]
            print(f"辅助损失{i}:{loss[i]}")
        return sum

    def p_losses(self, x_start, cond, t, down_self_residual,auxiliary_loss,noise=None):
        if self.first_stage_key == 'inpaint':
            # x_start=x_start[:,:4,:,:]
            noise = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:]))
            x_noisy = self.q_sample(x_start=x_start[:,:4,:,:], t=t, noise=noise)
            x_noisy = torch.cat((x_noisy,x_start[:,4:,:,:]),dim=1)
        else:
            noise = default(noise, lambda: torch.randn_like(x_start))
            x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

        
        model_output = self.apply_model(x_noisy, t, cond,down_self_residual)

        loss_dict = {}
        prefix = 'train' if self.training else 'val'

        if self.parameterization == "x0":
            target = x_start
        elif self.parameterization == "eps":
            target = noise
        else:
            raise NotImplementedError()

        ## print(f"model_out shape:{model_output.shape}, target shape:{target.shape}")
        loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})

        #print(f"t :{t.device}，logvar ：{self.logvar.device}")
        #-----------------设备不同，换了个顺序
        logvar_t = self.logvar.to(self.device)[t]

        loss = loss_simple / torch.exp(logvar_t) + logvar_t
        # loss = loss_simple / torch.exp(self.logvar) + self.logvarS
        if self.learn_logvar:
            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
            loss_dict.update({'logvar': self.logvar.data.mean()})

        loss = self.l_simple_weight * loss.mean()

        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
        #print(f"loss_vlb:{type(loss_vlb)}")
        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
        loss += (self.original_elbo_weight * loss_vlb)
        print(f"diffusion loss：{loss,type(loss)}")
        if self.use_auxiliary_loss : # 辅助损失
            mask_loss = self.m_weight*auxiliary_loss['mask_loss']
            pose_loss = self.p_weight*auxiliary_loss['pose_loss']
            loss_dict.update({f'{prefix}/pose_loss': pose_loss})
            loss_dict.update({f'{prefix}/mask_loss': mask_loss})
            print(f"pose loss：{pose_loss},grad:{pose_loss.requires_grad},mask loss：{mask_loss},grad:{mask_loss.requires_grad}")
            loss += mask_loss + pose_loss
        loss_dict.update({f'{prefix}/loss': loss})
        
        return loss, loss_dict
    
    def forward(self, c, *args, **kwargs):
        if self.model.conditioning_key is not None:
            assert c is not None
            if self.cond_stage_trainable:
                c, down_self_residual, auxiliary_loss, z = self.get_learned_conditioning(c)
            if self.shorten_cond_schedule:  # TODO: drop this option
                tc = self.cond_ids[t].to(self.device)
                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
            t = torch.randint(0, self.num_timesteps, (z.shape[0],), device=self.device).long() # z.shape[0]=bs,训练时则为2*bs
        return self.p_losses(z, c,t, down_self_residual,auxiliary_loss,*args, **kwargs)
    
    def shared_step(self, batch, **kwargs):
        out = self.get_input(batch, self.first_stage_key)
        loss, loss_dict = self(out)
        return loss, loss_dict
    
    @torch.no_grad()
    def sample_log(self,cond,batch_size,ddim, ddim_steps,down_sample_residual,**kwargs):
        # 根据不同的条件和配置来生成样本，并记录生成过程中的中间状态
        if ddim:
            ddim_sampler = DDIMSampler(self)
            shape = (self.channels, self.img_H//8, self.img_W//8) 
            samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,shape,conditioning=cond,
                                                        down_sample_residual=down_sample_residual,verbose=False,**kwargs)
              

        else:
            samples, intermediates = self.sample(cond=cond, batch_size=batch_size,down_sample_residual=down_sample_residual,
                                                 return_intermediates=True,**kwargs)

        return samples, intermediates
    
    