# 项目创建时间：2024/9/12 21:13
# This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
# All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
"""
 Copyright (c) 2023, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import pickle
import random
import logging
import os
import numpy as np
import cv2 
import torch
from einops import rearrange,repeat
from torch.optim.lr_scheduler import LambdaLR
from torchvision.transforms.functional import resize
from torchvision.utils import make_grid
from torch.nn import functional as F

from torch.cuda.amp import autocast

from os.path import join as opj
#from ldm.models.model.utils import tensor2img
from torch import nn
from transformers.activations import QuickGELUActivation as QuickGELU
from contextlib import nullcontext
from omegaconf import ListConfig
from torchvision.transforms import Resize,InterpolationMode


from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL

from ldm.util import log_txt_as_img ,instantiate_from_config,default,ismap, isimage,tensor2img

from ldm.models.diffusion.ddpm_3 import LatentDiffusion
#from ldm.models.model import  model
#from ldm.models.model.unet import register_unet_output2
from ldm.models.diffusion.ddim import DDIMSampler
#from ldm.models.util import dice_loss

from ldm.models.QFormer.QFormer import build_QFormer_encoding
from ldm.models.TaskFormer.TaskFormer_revise_3 import build_TaskFormer_encoding
from ldm.models.vision_backbone.swin_transformer import build_backbone
from ldm.models.vision_backbone.swin_transformer_v2 import build_pose_encoder

from ldm.models.decoder.appearance_encoder_revise3 import build_Decoder
import kornia
__conditioning_keys__ = {'concat': 'c_concat',
                         'crossattn': 'c_crossattn',
                         'adm': 'y'}

def dice_loss(
        inputs: torch.Tensor,
        targets: torch.Tensor,
    ):
    """
    Compute the DICE loss, similar to generalized IOU for masks
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
                (batch_size, num_classes, H*W)
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
                (batch_size, num_classes, H*W)
    """
    inputs = inputs.flatten(2)
    targets = targets.flatten(2)
    #print(f"mask pred shape:{inputs.shape}, target shape:{targets.shape}")
    numerator = 2 * (inputs * targets).sum(-1)
    #print(f"dice_loss numerator max:{torch.max(numerator)},min:{torch.min(numerator)}")
    denominator = inputs.sum(-1) + targets.sum(-1)
    #print(f"dice_loss denominator max:{torch.max(denominator)},min:{torch.min(denominator)}")
    dice = (numerator + 1) / (denominator + 1)
    loss = 1 - dice
    #print(f"dice_loss mean:{((numerator + 1) / (denominator + 1)).mean()}")
    return loss.mean()
# human_loss, human_semantic_feature, agn_mask, down_self_residual = 
#                           self.semantic_spatical_model(c["ref_human"], c["caption"],c['pose_img'],c['z_agn_mask'])

class build_model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        #self.pose_query = cfg.MODEL.DECODER_CONFIG.POSE_QUERY
        self.crossattn_mask_threshold = cfg.pop('crossattn_mask_threshold')
        self.use_crossattn_and_feature = cfg.pop('use_crossattn_and_feature')
        self.use_auxiliary_loss = cfg["use_auxiliary_loss"]
        self.backbone = build_backbone(cfg["SwinTransformer_backbone"])
        ##----------------
        """
        self.Decoder = build_Decoder(cfg["Decoder"])

        self.pose_encoder = build_pose_encoder(cfg["SwinTransformerV2"])
        """
        self.qformer_use_clip = cfg['QFormer']['use_vit_out']
        self.QFormer = build_QFormer_encoding(cfg["QFormer"])

        cfg["TaskFormer"].update({'crossattn_mask_threshold':self.crossattn_mask_threshold})
        cfg["TaskFormer"].update({'use_crossattn_and_feature':self.use_crossattn_and_feature})
        self.TaskFormer = build_TaskFormer_encoding(cfg["TaskFormer"])

        #proj_in_channel = 35
        #self.pose_feature_proj = [ nn.Conv2d(proj_in_channel, 128, kernel_size=1,stride=1, padding=0),nn.Conv2d(proj_in_channel, 256, kernel_size=1,stride=1, padding=0),
        #                         nn.Conv2d(proj_in_channel, 512, kernel_size=1,stride=1, padding=0)]
        #for i in self.pose_feature_proj:
        #    i.bias.data =  i.bias.data.to(torch.float16)
        #    i.weight.data = i.weight.data.to(torch.float16)
        #    print(f"pose_feature_proj:{i}, change to torch.float16")

        #-----------------------clip
        #self.register_buffer('clip_mean', torch.Tensor([0.48145466,0.4578275,0.40821073]),persistent=False)
        #self.register_buffer('clip_std', torch.Tensor([0.26862954,0.26130258,0.27577711]),persistent=False)


        print("cond model init successfully!")

    #(c['img'], c["caption"]*2,clip_last_hidden,c['warp_cloth'],c['cloth_mask']) # 两个none分别代表mask和z_img
    #    'main_cloth':main_cloth,'warp_cloth_mask':warp_cloth_mask
    def forward(self, img_cond,text, clip_featrue,residual_img_mask,residual_img=None,z_img=None,use_raw_mask=False,use_dice_loss=True):
        #mask = batched_inputs["mask"] if "mask" in batched_inputs else None
        #x, features = self.backbone(batched_inputs["img_cond"], mask=mask)
        #up_block_additional_residuals = self.appearance_encoder(features)
        # loss-[mask, pose]
        #---------------------------------------------------------------------------------
        ##'z':z, 'caption':c, 'ref_img':ref_img, 'cloth_mask':cloth_mask,"warp_cloth":warp_cloth,'img':img

        # mask--(b,hw), pose_img--(b,c,h,w), 
        # backbone feature size:[torch.Size([24, 16384, 128]), torch.Size([24, 4096, 256]), torch.Size([24, 1024, 512]), torch.Size([24, 256, 1024])], grad:[False, False, False, False]
        x, features = self.backbone(img_cond)
    
        bs = x.shape[0] # 2*bspip install opencv-python-headless
        human_pose_bs = int(bs/2) # bs     
        
        #--------------------------------------------------------
                #residual_img = torch.cat((warp_cloth,main_cloth))
        with torch.no_grad():
            cloth_mask = residual_img_mask[human_pose_bs:]
            cloth_mask = Resize([64,64],interpolation=InterpolationMode.NEAREST)(cloth_mask) # z shape:[64,64]
            ## 第二阶段注释
            ## _, pose_features = self.pose_encoder(residual_img) #if pose_img  is not None else (None,None) # feature: [[1, 4096, 128]，[1, 1024, 256]，[1, 256, 512]，[1, 64, 1024]

        #--------------------------------------------------------
    

        
        
        x = {"image":None, "text_input":text}
        QFormer_feature = self.QFormer.extract_features(x,image_embeds_frozen=clip_featrue)


        ## TaskFormer out-----------------
        Task_list = [QFormer_feature.multimodal_embeds, features[2], features[1]]
        # [b,hw,800],[b,hw,64],[b,hw,35],[b,hw,1]
        attn_masks, Task_features,pred_pose_features, pred_masks = self.TaskFormer(Task_list, features[0])
        
        
        if self.use_auxiliary_loss:
            
            #--------------------
            # 下面操作的bs就是2*bs
            pred_mask_befor_sigmoid = pred_masks[-1][human_pose_bs:]# 选择最后一层的输出
            cloth_mask = cloth_mask.flatten(1).unsqueeze(1) # [bs,1,hw] -> [2bs,1,hw]
            
            
            # pred_mask shape:torch.Size([14, 4096]), mask shape:torch.Size([14, 1, 4096])
            h=w= int(pred_mask_befor_sigmoid.shape[-1]**0.5)
            pred_mask_befor_sigmoid = pred_mask_befor_sigmoid.unsqueeze(1).reshape(human_pose_bs,1, h, w)
            pred_mask = pred_mask_befor_sigmoid.sigmoid() 
            cloth_mask = cloth_mask.reshape(human_pose_bs,1,h,w)
            bce_loss = nn.BCEWithLogitsLoss()
            weight = 0.8
            if use_dice_loss:
                loss_1 = dice_loss(pred_mask, cloth_mask)
                loss_2 = bce_loss(pred_mask_befor_sigmoid, cloth_mask)
                mask_loss = weight * loss_1  + (1 - weight) * loss_2
                print(f"dice loss:{loss_1}, bce loss:{loss_2}")
            
            """
            else:
                #human_img = img_cond[human_pose_bs:]
                #print(f"cond model :z_img shape:{z_img.shape}")
                mask = mask.reshape(bs,1,h,w)
                mask_loss = torch.sum(torch.abs(((1 - pred_mask) * z_img * mask)),dim=[1,2,3]).mean()
            """
            
            
            #print(f"mask loss :{mask_loss},grid:{mask_loss.requires_grad}")

            #-----------暂时记录
            #count_pred_mask = torch.sum(pred_mask > 0.5, dim=(1,2,3)) #[b,1,h,w] -> [b]
            #count_mask = torch.sum(mask == 1, dim=(1,2,3)) #[b,1,h,w] -> [b]
            return_mask = torch.cat((torch.zeros_like(pred_mask),pred_mask))
            #print(f"pred_mask_1_num: {count_pred_mask}, mask_1_num:{count_mask}, return mask grad:{return_mask.requires_grad}") # index:{index}
        else:
            #---------------------------------------------------
            mask_loss = torch.tensor(0,requires_grad=False)
            return_mask = torch.ones_like(pred_masks[-1]).reshape(bs,1,64,64)

        pose_loss = torch.tensor(0,requires_grad=False)

        if use_raw_mask:
            h=w= int(pred_mask.shape[-1]**0.5)
            return_mask = mask
            #return_mask = torch.ones_like(mask)#.reshape(bs,1,h,w).to(dtype=pred_mask.dtype)
            #return_mask[:,:,0:5,:] = 0
            print(f"return mask 1 number:{torch.sum(return_mask==1, dim=(1,2,3))}")
    
        ##print(f"return mask feature grad: {return_mask.requires_grad}，pred_mask shape:{return_mask.shape},mask shape:{mask.shape}")

        
        #---------------------------------------------
        # pred_mask的梯度就断了，说明扩散损失不会影响到pred_mask的更新
        """
        if self.use_crossattn_and_feature:
            pred_mask = (pred_mask.reshape(bs,h,w).unsqueeze(1) >= self.crossattn_mask_threshold).to(dtype=pred_mask.dtype) ### 这里阈值需要改变或设置合理的值
        else:
            # [b,1,hw]
            pred_mask = (pred_mask.reshape(bs,1,h,w) > 0.5).to(dtype=pred_mask.dtype)#.detach()
        """
        #pred_mask = torch.where(pred_mask.reshape(bs, h, w).unsqueeze(1) > 0.2, 1.0, 0.0)
        ## print(f"the grad after pred_mask transforms into 0.0/1.0: {pred_mask.requires_grad}")

        #-------------------------------
        # 如果mask预测的1少于target的一半，选择target mask作为unet mask的输入
        
        
        count_pred_mask = torch.sum(return_mask >= 0.5, dim=(1,2,3)) #[b,1,h,w] -> [b]
        count_mask = torch.sum(cloth_mask == 1, dim=(1,2,3)) #[b,1,hw] -> [b]
        print(f"pred_mask_1_num: {count_pred_mask}, mask_1_num:{count_mask}") # index:{index}
        """
        index = count_pred_mask < (count_mask / 2)
        pred_mask[index] = mask.reshape(bs,1,h,w).to(dtype=pred_mask.dtype)[index]
        index2 = count_pred_mask >= count_mask * 1.5
        pred_mask[index2] = mask.reshape(bs,1,h,w).to(dtype=pred_mask.dtype)[index2]
        """


        
       
        #down_sample_residual = None, 
        #pose_residual = None
        #-------------
        """ 第二阶段注释
        # Decoder out
        feature_list = [Task_features[-1], Task_features[-2], Task_features[-3]]
        #print(f"taskformer out feature shape:{[i.shape for i in feature_list]}")
        pose_list =  [pose_features[0].permute(0,2,1).reshape(bs,128,64,64),pose_features[1].permute(0,2,1).reshape(bs,256,32,32),pose_features[2].permute(0,2,1).reshape(bs,512,16,16)]
        
        down_sample_residual,pose_residual = self.Decoder(feature_list, pose_list,residual_img_mask)
        #print(f"decode down_sample_residual shape:{[down_sample_residual[key].shape for key in down_sample_residual.keys()]}")

        #down_sample_residual = {} # 衣服和人物的空间特征互换一下
        #for k,v in down_sample.items(): 
        #    down_sample_residual[k] = torch.cat([v[human_pose_bs:], v[:human_pose_bs]]) 
        #---------------------测试没有pose_residual的引导效果
        #pose_residual = None
        """
        #---------------
        down_sample_residual = None
        pose_residual = None

        #---------------
        
        loss = {'mask_loss':mask_loss,'pose_loss': pose_loss}

        return loss, QFormer_feature.multimodal_embeds, down_sample_residual, pose_residual,return_mask

    # loss:[a,b], QFormer_feature: [b,16,768], pred_mask:[b,1,h,w],
    # down_sample_residual:{"":}






if __name__ == "__main__":
    import yaml
    with open('./configs/configs.yaml', 'r', encoding='utf-8') as file:
        config = yaml.safe_load(file)
    img = torch.rand([2,3,512,512],dtype=torch.float32)
    pose_imag = torch.rand([2,3,256,256],dtype=torch.float32)
    mask = (torch.rand([2,64,64])<0.5).bool()
    pose_imag = torch.rand([2,3,64,64],dtype=torch.float32)

    text = '1'
    model =  build_model(config)
    loss, Qformer_featrue, pred_mask, down_sample_residual,pose_residual = model(img, text, mask,pose_imag)

    #print(f"loss:{loss},grad:{loss[0].requires_grad,loss[1].requires_grad}")
    #print(f"Qformer feature shape: {Qformer_featrue.shape}, grad:{Qformer_featrue.requires_grad}")
    #print(f"pred_mask:{ pred_mask},shape: {pred_mask.shape}, grad :{pred_mask.requires_grad}")
    #print("down_sample_residual:", {k: [v.shape, v.requires_grad] for k,v in down_sample_residual.items()} )
    grad_count = 0
    count = 0
    for i in model.parameters():
        if i.requires_grad: grad_count += 1
        count += 1
    print(f"grad_count:{grad_count}, count:{count}")
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total number of parameters: {total_params}")



class QTaskDiffusion(LatentDiffusion):
    def __init__(
            self,
            validation_config = None,
            #vit_model="clip_L", # qformer中用到的
            #qformer_num_query_token=16,
            #qformer_cross_attention_freq=1,
            #vision_backbone_timestemp = 0,
            #vision_backbone_encoder_hidden_state=None,
            img_H = 512,
            img_W = 384,
            #unet_path= None,
            #qformer_pretrained_path=None, # 只要给出这两个地方，就说明是第一次训练，需要加载别人给的权重，因此有一些额外操作
            #TaskFormer_pretrained_path=None,
            #use_auxiliary_loss = True, ### -----
            p_weight = 0.1,
            m_weight = 0.1,
            semantic_spatical_model_train = True,
            dataloader_prob = None,

            #first_init_from_ckpt = True,

            #use_spatial_semantic_loss = True, # 是否用辅助损失

            #proj_train=False,
            #drop_proportion=[0.15,0.15,0.15], # 丢失概率
            #lamda_weigt = [0.1,0.1,0.1,0.1,0.1,0.1,0.1], # 6 spatial auxiliary loss weight, 1 semantic
            #subject_text_key='text',

            #cond_image_key1='cloth', # 这里标志作为条件的图片在batch中的关键字，我用的cloth
            #cond_image_key2 = 'image',
            #cond_caption_key1='cloth_caption',
            #cond_caption_key2='image_caption',

            #trainable_parameters=None,  # todo-增加部分，训练参数
            diffusers_pretrained_path=None, # 加载已经训练好的断点参数，我可以这里加载整个QFormer_TaskFormer
            #multi_dataloader_prob=None,  # todo-增加？？

            #cross_attention_train=True,  # todo-增加了这个交叉注意力是否需要训练标志
            #self_attention_train=True,

            sd_train_text_encoder=False,
            cond_stage_trainable = False,
            cross_attention_train = True,
            down_self_attention_train = True,
            #first_init_from_ckpt = False,
            #semantic_feature_to_text_dim_train=False,
            #TaskFormerOut_proj_unet_train=False,
            #QFormerOut_proj_backboneDim_train=False,
            #TaskFormerOut_proj_vit_train=False,
            #backbone2Qformer_proj_train = False,
            #QFormer_TaskFormer_train=False,
            #TaskFormer_train=False,
            #qformer_train=False,
            #dataloader_name =2,
            #vision_backbone="unet",
            use_bf16 = False,
            *args, **kwargs
    ):
        ckpt_path = kwargs.pop("ckpt_path", None)  # 初始化unet预训练权重的路径
        ignore_keys = kwargs.pop("ignore_keys", [])  # todo？目前看着没怎么用到
        ## --------------------text第一阶段不考虑
        ## self.first_stage_key_cond = kwargs.get("first_stage_key_cond", None) # 将做额外损失时需要转换到潜空间的mask、潜空间pose等
        super().__init__(*args, **kwargs)
        print("LDM 初始化完成！")
        #self.use_bf16 = use_bf16
        self.img_H=img_H
        self.img_W=img_W
        #self.dataloader_name  = dataloader_name 
        self.validation_config= validation_config

        self.cross_attention_train = cross_attention_train
        self.down_self_attention_train = down_self_attention_train

        self.semantic_spatical_model = build_model(kwargs)
        self.learnable_vector = nn.Parameter(torch.zeros((1,1,768)), requires_grad=True)
        #----------------第二阶段考虑的东西
        """
        self.learnable_vector_dict = nn.ParameterDict(
            {
                'input_block_1_pose_residual':nn.Parameter(torch.zeros(1,320,1,1), requires_grad=True), 
                'input_block_2_pose_residual': nn.Parameter(torch.zeros(1,320,1,1), requires_grad=True), 
                'input_block_4_pose_residual': nn.Parameter(torch.zeros(1,640,1,1), requires_grad=True), 
                'input_block_5_pose_residual': nn.Parameter(torch.zeros(1,640,1,1), requires_grad=True), 
                'input_block_7_pose_residual': nn.Parameter(torch.zeros(1,1280,1,1), requires_grad=True), 
                'input_block_8_pose_residual': nn.Parameter(torch.zeros(1,1280,1,1), requires_grad=True)
            }
        )
        """
        """
        self.learnable_vector_dict = nn.ParameterDict(
            {
                'block_3_pose_residual':nn.Parameter(torch.zeros(1,1280,1,1), requires_grad=True), 
                'block_4_pose_residual': nn.Parameter(torch.zeros(1,1280,1,1), requires_grad=True), 
                'block_5_pose_residual': nn.Parameter(torch.zeros(1,1280,1,1), requires_grad=True), 
                'block_6_pose_residual': nn.Parameter(torch.zeros(1,640,1,1), requires_grad=True), 
                'block_7_pose_residual': nn.Parameter(torch.zeros(1,640,1,1), requires_grad=True), 
                'block_8_pose_residual': nn.Parameter(torch.zeros(1,640,1,1), requires_grad=True),
                'block_9_pose_residual': nn.Parameter(torch.zeros(1,320,1,1), requires_grad=True),
                'block_10_pose_residual': nn.Parameter(torch.zeros(1,320,1,1), requires_grad=True),
                'block_11_pose_residual': nn.Parameter(torch.zeros(1,320,1,1), requires_grad=True)
            }
        )
        """
        self.use_auxiliary_loss = kwargs['use_auxiliary_loss']
        self.p_weight = p_weight
        self.m_weight = m_weight
        self.semantic_spatical_model_train = semantic_spatical_model_train
        self.dataloader_prob = dataloader_prob

        #self.first_init_from_ckpt= first_init_from_ckpt


        #self.lamda_weigt = lamda_weigt # 用辅助损失时的权重
        #self.use_spatial_semantic_loss = use_spatial_semantic_loss



        #self.multi_dataloader_prob = multi_dataloader_prob  # 每一种训练类型数据占的比例，条件给的是[1,1,1]

        #self.first_init_from_ckpt=first_init_from_ckpt

        self.cond_stage_trainable = cond_stage_trainable  # 强制在get_input取原始文本，不过CLIP text encoder todo-？
        #self.vision_backbone_name = vision_backbone
        


        #self.qformer_train = qformer_train
        #self.TaskFormer_train = TaskFormer_train
        #self.QFormer_Taskformer_train = QFormer_TaskFormer_train
        #self.backbone2Qformer_proj_train = backbone2Qformer_proj_train

        #self.proj_train = proj_train
        #self.self_attention_train = self_attention_train
        self.sd_train_text_encoder = sd_train_text_encoder
        #self.cross_attention_train = cross_attention_train # todo-这是本论文增加的交叉注意力需要训练的地方

        #self.TaskFormerOut_proj_unet_train = TaskFormerOut_proj_unet_train
       
        #self.semantic_feature_to_text_dim_train = semantic_feature_to_text_dim_train

       
        self.freeze_modules()

        

        #self.restarted_from_ckpt = False  # todo-新增是否从断点训练
        #if diffusers_pretrained_path is not None:
        #    self.load_QFormer_TaskFormer_checkpoint_from_dir(diffusers_pretrained_path)
        #if ckpt_path is not None:  # 这里blip_dif没有，LDM中有，加载unet预训练参数
        #    self.init_from_ckpt(ckpt_path, ignore_keys)
        #    self.restarted_from_ckpt = True

       # self.drop_proportion = drop_proportion  # 在训练过程中随机丢弃主题文本的概率
        #self.subject_text_key = subject_text_key  # 指定在数据处理时文本数据的键名
        #self.cond_image_key1 = cond_image_key1  # 记录输入的条件在batch中的key是什么，我这里是cloth作为条件
        #self.cond_image_key2 = cond_image_key2
        #self.cond_caption_key1 = cond_caption_key1
        #self.cond_caption_key2 = cond_caption_key2

        #self.trainable_parameters = trainable_parameters if trainable_parameters is not None else []  # 控制模型训练时哪些参数会更新
        #if len(self.trainable_parameters) == 1 and self.trainable_parameters[0] == 'none':  # 没有要训练的参数时，设置eval模式
        #    self.model.eval()  # todo-这里的self.model是ddpm里定义的
        #    self.model.train = self.disabled_train
        #    for param in self.model.parameters():
        #        param.requires_grad = False
        self.restarted_from_ckpt = False
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys)
            self.restarted_from_ckpt = True


    @torch.no_grad() # todo-本函数就是初次加载unet参数，将增加的cross、self模块东西进行初始化，已完成
    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):  # todo-对DDPM中init_from_ckpt进行了修改       # if self.vision_backbone_name == "unet":
        #"""
        state_dict = torch.load(path, map_location="cpu")['state_dict']
        if "state_dict" in list(state_dict.keys()):
            state_dict = state_dict["state_dict"]
        if not ignore_keys:
            keys = list(state_dict.keys())
            for k in keys:
                for ik in ignore_keys:
                    if k.startswith(ik):
                        print("Deleting key {} from state_dict.".format(k))
                        del state_dict[k]
        
        #----------------------第一次加载unet、vae模型时此处的初始化
        
        print("利用原有inpaint model加载权重")
        """
        for name in list(state_dict.keys()): # 初始化自注意力的i_to_k 或 v
            
            if "attn2.i_to_k" in name:
                state_dict[name.replace('attn1.i_to_k', 'attn1.to_k')] = state_dict[name]
                #print(name.replace('attn2.i_to_k', 'attn1.i_to_k'))
            elif "attn2.i_to_v" in name:
                state_dict[name.replace('attn1.i_to_v', 'attn1.to_v')] = state_dict[name]
            elif "input_blocks.0.0.weight" in name:
                new_weight = torch.randn(320, 13, 3, 3)
                new_weight[:, :9, :, :] = state_dict[name]
                new_weight[:,9:,:,:] = state_dict[name][:,4:8,:,:]
                state_dict[name] = new_weight
        """
            #if "input_blocks.0.0.weight" in name:
            #    state_dict[name] = state_dict[name][:,0:9,:,:]
                #print(name.replace('attn2.i_to_k', 'attn1.i_to_k'))
            #elif "learnable_vector" in name:
            #    state_dict["learnable_vector"] = state_dict["learnable_vector"]#.expand(1,16,768)
            #print(f"{name}:{state_dict['state_dict'][name].shape}")
        
        # 注意原来unet的初始化定义就是self.model
        missing, unexpected = self.load_state_dict(state_dict, strict=False) if not only_model else self.model.load_state_dict(state_dict, strict=False) # 不只有模型：vae、text_encoder、unet，只有模型那只有unet

        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
        #if  self.first_init_from_ckpt:
            #if len(missing) > 0:
            #    print(f"Missing Keys:\n {missing}")
            #if len(unexpected) > 0:
            #    print(f"\nUnexpected Keys:\n {unexpected}")
        print(f"diffusion model missing key:{missing}")
        #for i in missing:
        #    if "model" in i:
        #        print(i)
        print(f"diffusion model unexpected key:{unexpected}")
        #for i in unexpected:
        #    if "model" in i:
        #        print(i)
        #"""
        print("successfully achieve load unet、text_encoder、vae weight!")
    
    @torch.no_grad() # 这里应该不需要计算梯度吧？
    def freeze_modules(self): # todo-本函数就是判断对个别模块冻住时的情况，cross、self的变化还没完成，clip_vit没有考虑在这个模块目前
        #self = self.half()

        to_freeze = [self.first_stage_model]  # self.first_stage_model就是VAE
    
        ##-to_freeze.append(self.semantic_spatical_model.pose_encoder)
        #-------------------------
#        to_freeze.append(self.semantic_spatical_model.QFormer.visual_encoder)
        #to_freeze.append(self.semantic_spatical_model)
        #to_freeze.append(self.model)


        ## to_freeze.append(self.semantic_spatical_model.backbone)

        #if True:
        #print("eval: first_stage_model, FrozenCLIPEmbedder")
            #to_freeze.append(self.model) # 这里的model应该是改写的unet
        """
        if not self.sd_train_text_encoder:
            to_freeze.append(self.cond_stage_model)
        """
        if not self.semantic_spatical_model_train:
            to_freeze.append(self.semantic_spatical_model)


        
        for module in to_freeze:
            module.eval()
            module.train = self.disabled_train
            module.requires_grad_(False)
    
    def configure_optimizers(self):  # todo-本文这里重写了DDPM的configure_optimizers方法，没有加入params的参数都不会更新
        # 设置优化器需要优化的参数，这里Blip-diffusion没有找到
        lr = self.learning_rate
        params = []
        params_name =[]
        
        # 只训练image2(k,v)这里的参数，不对原来的text2(k,v)修改. 冻结无关diffusion层

         # 44.8 M    Trainable params
        ####------------先不考虑
        #params += list(self.learnable_vector_dict.parameters())


        count_cross_and_norm = 0
        count_down_self = 0
        params_cross = []
        params_down_self = []
        
        transformer_param = []
        
        for name, param in self.model.named_parameters():
            """
            if "input_blocks" in name and "attn1" in name:
                print(name)
                param.requires_grad = True
                params_down_self.append(param) 
                count_down_self += 1
            ## ----------------------
            elif "attn2" in name or "transformer_blocks.0.norm" in name:
                print(name)
                param.requires_grad = True
                params_cross.append(param)
                count_cross_and_norm += 1
            else:
                param.requires_grad = False # 冻结无关的stable diffusion1 参数
            """
            if "transformer_blocks" in name :#or "input_blocks.0.0" in name : #("transformer_blocks" in name and "output_blocks" in name) or "attn2" in name:
                param.requires_grad = True
                transformer_param.append(param)
                params_name.append(name)
            else:
                param.requires_grad = False
        params += transformer_param

        if self.cross_attention_train:
            params += params_cross
            print(f"cross_attention_and_norm_train_nums:{count_cross_and_norm}")
        if self.down_self_attention_train: 
            params += params_down_self
            print(f"down_self_attention_train_nums:{count_down_self}")


        
        if self.semantic_spatical_model_train:
            #params += list(self.semantic_spatical_model.parameters())
            count = 0
            cond_model = []
            
            # backbone 只有position_table 为true
            #for name, param in self.semantic_spatical_model.backbone.named_parameters():
                #if "relative_position_bias_table"  in name:
            #    param.requires_grad = True
                #    print(f"traing param in backbone:{name, param.shape}")
                #else:
                #    param.requires_grad = False
            
            #for name, param in self.semantic_spatical_model.backbone.named_parameters():
                #param.requires_grad = True

            for name, param in self.semantic_spatical_model.named_parameters():
                if param.requires_grad:
                    count += 1
                    cond_model.append(param)
                    params_name.append(name)
                else:
                    print(f"{name}, {param.shape}")
            print(f"semantic_spatical_model learning param number: {count}")
            params += cond_model


        #if self.sd_train_text_encoder:
        #    params += list(self.cond_stage_model.parameters())
        clip_model = []
        if self.cond_stage_trainable:
            print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
            clip_model += list(self.cond_stage_model.final_ln.parameters())+list(self.cond_stage_model.mapper.parameters())+list(self.proj_out.parameters())
        print(f"clip_model learned params num:{len(clip_model)}")
        
        params += clip_model
        
        if self.learn_logvar: # todo-好像是学习方差
            print('Diffusion model optimizing logvar')
            params.append(self.logvar)
        print(f"all training params:{params_name}、learnable_vector, num: {len(params) + 1}")

        self.params = params
        self.params_with_white = params + list(self.learnable_vector)
        opt = torch.optim.AdamW(params, lr=lr)
        self.opt = opt
        if self.use_scheduler: # 这一部分原LDM就有
            assert 'target' in self.scheduler_config
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                }]
            return [opt], scheduler
        return opt

    def disabled_train(self, mode=True):
        """Overwrite model.train with this function to make sure train/eval mode
        does not change anymore."""
        return self

    def get_learned_conditioning(self, c):  # todo-原DDPM里LDM只有text_encoder注入，这里有QFormer的图片信息注入
        # c:{'image':[cond_image_key1, conda_image_key2], 'text_input':[[cond_caption_key1], [cond_caption_key2]],"text":[text]}
        # 当前key1 就是 cloth， key2 就是image，text是要输入到diffusion里的
        bs = int(c['ref_img'].shape[0] / 2)
        #with torch.cuda.amp.autocast(dtype=torch.float16):
        clip_z,clip_last_hidden = self.cond_stage_model(c['ref_img'])
        clip_proj = self.proj_out(clip_z)
        # clip_proj shape:torch.Size([2bs, 1, 768]), clip_outputs shape:torch.Size([2bs, 257, 1024])
                
        #return loss, QFormer_feature.multimodal_embeds, down_sample_residual, pose_residual,return_mask
        #def forward(self, img_cond,text, clip_featrue,warp_cloth,mask,z_img,use_raw_mask=False,use_dice_loss=True):


        loss, semantic_feature, down_self_residual,pose_residual,pred_cloth_mask = self.semantic_spatical_model(c['img'], c["caption"]*2,clip_last_hidden,c['residual_img_mask'],residual_img=None) # 原来residual_img=c['residual_img'] # 两个none分别代表mask和z_img
        
        
        cloth_semantic_feature = torch.cat((clip_proj[:bs],semantic_feature[:bs]),dim=1)
        human_semantic_feature = torch.cat((clip_proj[bs:],semantic_feature[bs:]),dim=1)
        """
        for key in down_self_residual.keys():
            down_self_residual[key] = torch.cat((down_self_residual[key][bs:],down_self_residual[key][:bs]))

        #--------------------------------
        for key in pose_residual.keys():
            pose_residual[key] = torch.cat((pose_residual[key][bs:],pose_residual[key][:bs]))
        #---------------------------------
        """
        ###————————————————-第一阶段先不考虑text信息
        text_encoder = None
        ### text_encoder = self.cond_stage_model.encode(c["caption"]).detach() ##--------
        semantic_feature = torch.cat([human_semantic_feature,cloth_semantic_feature])
        """
        if self.training: # [2*b, ]
            
            u_cond_prop = torch.rand(bs*2, 1, 1).to(self.device)
            u_cond_prop = (u_cond_prop < self.u_cond_percent).to(dtype=semantic_feature.dtype, device=semantic_feature.device)
            semantic_feature = self.learnable_vector * u_cond_prop + semantic_feature * (1 - u_cond_prop) # [bs*2, 16, 768]
        """
        z = c['z']
        
        cond =[semantic_feature, None,text_encoder,None]

        return cond, down_self_residual,pose_residual, loss, z,pred_cloth_mask

        









    @torch.no_grad()
    def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False, force_c_encode=False,
                 *args, **kwargs):  # todo-对DDPM里LDM的get_input重写，训练中采用的到，推理不需要这个函数
        # 对输入批次中图片的处理，主要将需要vae编码的图片进行编码、条件信息c的整合
        # c["image":, "text_input":, "drop_mask":, ], 如果做额外损失，可能增加["latent_mask"：，"latent_pose":, "latent_densepose": ]
        # batch["first_stage_key=image":,"cond_image_key1=cloth":,"cond_stage_key=text": ],可能再增加上面额外的三种key
        # text 应该就是[[cloth_text],[target_text]]这种格式
        # 0-drop spatial、semantic，1-drop text，2 -no drop
        if k == "inpaint":
            #x = batch['GT']
            ref_cloth = batch['ref_cloth']
            ref_human = batch['ref_human']
            human = batch['human']
            cloth = batch['cloth']

            inpaint_human = batch['inpaint_human']
            inpaint_cloth = batch['inpaint_cloth']
            inpaint_mask = batch['inpaint_mask']
            inpaint_cloth_mask = batch['inpaint_cloth_mask']

            gt_cloth = batch['gt_cloth']
            gt_human = batch['gt_human']
            
            #cloth_mask = batch['cloth_mask']
            #-----------
            #inpaint_skin = batch['inpaint_skin']
            #-----------
            """pose
            warp_cloth = batch['warp_cloth']
            main_cloth = batch['main_cloth']
            warp_cloth_mask = batch['warp_cloth_mask']
            """
            #agn_mask = batch['agn_mask']

            #------l1 loss
            #human_feat = batch['human_feat']
            #cloth_feat = batch['cloth_feat']
            #---------
            

        else:
            x = batch[k]
            if len(x.shape) == 3:
                x = x[..., None]
            if bs is not None:
                x = x[:bs]
            x = x.to(self.device)
            encoder_posterior = self.encode_first_stage(x)
            x = self.get_first_stage_encoding(encoder_posterior).detach()
            return x

           

        # x = rearrange(x, 'b h w c -> b c h w')
        #x = x.to(memory_format=torch.contiguous_format).float()
        ref_cloth = ref_cloth.to(memory_format=torch.contiguous_format).float()
        ref_human = ref_human.to(memory_format=torch.contiguous_format).float()
        human = human.to(memory_format=torch.contiguous_format).float()
        cloth = cloth.to(memory_format=torch.contiguous_format).float()
        
        gt_human = gt_human.to(memory_format=torch.contiguous_format).float()
        gt_cloth = gt_cloth.to(memory_format=torch.contiguous_format).float()


        inpaint_human = inpaint_human.to(memory_format=torch.contiguous_format).float()
        inpaint_cloth = inpaint_cloth.to(memory_format=torch.contiguous_format).float()
        inpaint_mask = inpaint_mask.to(memory_format=torch.contiguous_format).float()
        inpaint_cloth_mask = inpaint_cloth_mask.to(memory_format=torch.contiguous_format).float()


        #cloth_mask = cloth_mask.to(memory_format=torch.contiguous_format).float()
        ##-inpaint_skin = inpaint_skin.to(memory_format=torch.contiguous_format).float()
        """pose
        warp_cloth = warp_cloth.to(memory_format=torch.contiguous_format).float()
        #----------
        main_cloth = main_cloth.to(memory_format=torch.contiguous_format).float()
        warp_cloth_mask = warp_cloth_mask.to(memory_format=torch.contiguous_format).float()
        """
        #agn_mask = agn_mask.to(memory_format=torch.contiguous_format).float()


        #human_feat = human_feat.to(memory_format=torch.contiguous_format).float()
        #cloth_feat = cloth_feat.to(memory_format=torch.contiguous_format).float()


        if bs is not None:
            bs = int(bs)
            #print(f"bs :{bs}, bs type:{type(bs)},ref_cloth shape:{ref_cloth.shape}")
            ref_cloth = ref_cloth[:bs]
            ref_human = ref_human[:bs]
            cloth = cloth[:bs]
            human = human[:bs]

            inpaint_human = inpaint_human[:bs]
            inpaint_cloth = inpaint_cloth[:bs]
            inpaint_mask = inpaint_mask[:bs]
            inpaint_cloth_mask = inpaint_cloth_mask[:bs]

            gt_cloth = gt_cloth[:bs]
            gt_human = gt_human[:bs]

            #cloth_mask = cloth_mask[:bs]
            ##-inpaint_skin = inpaint_skin[:bs]
            """pose
            warp_cloth = warp_cloth[:bs]
            #----
            main_cloth = main_cloth[:bs]
            warp_cloth_mask = warp_cloth_mask[:bs]
            """
            #agn_mask = agn_mask[:bs]

            #human_feat = human_feat[:bs]
            #cloth_feat = cloth_feat[:bs]

           


        z_cloth = gt_cloth.to(self.device)
        z_human = gt_human.to(self.device)
        encoder_posterior = self.encode_first_stage(z_cloth)
        z_cloth = self.get_first_stage_encoding(encoder_posterior).detach()
        encoder_posterior = self.encode_first_stage(z_human)
        z_human = self.get_first_stage_encoding(encoder_posterior).detach()
        ##z_warp_cloth = warp_cloth.to(self.device)
        ##encoder_posterior = self.encode_first_stage(z_warp_cloth)
        ##z_warp_cloth = self.get_first_stage_encoding(encoder_posterior).detach()
        z_inpaint_human = inpaint_human.to(self.device)
        z_inpaint_cloth = inpaint_cloth.to(self.device)
        encoder_posterior = self.encode_first_stage(z_inpaint_human)
        z_inpaint_human = self.get_first_stage_encoding(encoder_posterior).detach()
        encoder_posterior = self.encode_first_stage(z_inpaint_cloth)
        z_inpaint_cloth = self.get_first_stage_encoding(encoder_posterior).detach()
        """
        z_inpaint_skin = inpaint_skin.to(self.device)
        encoder_posterior = self.encode_first_stage(z_inpaint_skin)
        z_inpaint_skin = self.get_first_stage_encoding(encoder_posterior).detach()
        #---------

        z_human_feat = human_feat.to(self.device)
        z_cloth_feat = cloth_feat.to(self.device)
        encoder_posterior = self.encode_first_stage(z_cloth_feat)
        z_cloth_feat = self.get_first_stage_encoding(encoder_posterior).detach()
        encoder_posterior = self.encode_first_stage(z_human_feat)
        z_human_feat = self.get_first_stage_encoding(encoder_posterior).detach()
        """


        z_human_mask = Resize([z_inpaint_cloth.shape[-2],z_inpaint_cloth.shape[-1]],interpolation=InterpolationMode.NEAREST)(inpaint_mask)
        #z_cloth_mask = Resize([z_inpaint_cloth.shape[-2],z_inpaint_cloth.shape[-1]],interpolation=InterpolationMode.NEAREST)(cloth_mask)
        #z_cloth_mask = torch.ones_like(z_human_mask)
        z_cloth_mask = Resize([z_inpaint_cloth.shape[-2],z_inpaint_cloth.shape[-1]],interpolation=InterpolationMode.NEAREST)(inpaint_cloth_mask)
        #z_cloth_mask = Resize([z_inpaint_cloth.shape[-2],z_inpaint_cloth.shape[-1]],interpolation=InterpolationMode.NEAREST)(cloth_mask)



        z = torch.cat((z_cloth, z_human)) # 通道维度拼接，同时训练生成衣服和目标图片
        z_inpaint = torch.cat((z_inpaint_cloth, z_inpaint_human))
        z_mask = torch.cat((z_cloth_mask, z_human_mask))
        ##-z_inpaint_s = torch.cat((torch.zeros_like(z_inpaint_skin),z_inpaint_skin))

        z = torch.cat((z,z_inpaint, 1 - z_mask), dim=1) # channel = 9 + 4 = 13
        #-----------------------------------------------
        #z_feat = torch.cat((z_cloth_feat,z_human_feat))
        #z_src = torch.cat((z_feat,z_inpaint, 1 - z_mask), dim=1)
        #------------------------------------------------

        ##z_agn_mask = Resize([z.shape[-2],z.shape[-1]])(agn_mask)
        #mask = torch.cat([z_cloth_mask, z_agn_mask])

        # 本质这段没啥用，就是对于本实验就是得到文本 batch[caption]
        if self.model.conditioning_key is not None: 
            if cond_key is None:
                cond_key = self.cond_stage_key
            if cond_key != self.first_stage_key:
                if cond_key in ['txt', 'caption', 'coordinates_bbox']:
                    xc = batch[cond_key]
                elif cond_key == 'image':
                    xc = ref_cloth 
                elif cond_key == 'class_label':
                    xc = batch
                else:
                    xc = super().get_input(batch, cond_key).to(self.device)
            else:
                xc = x
            if not self.cond_stage_trainable or force_c_encode:
                if isinstance(xc, dict) or isinstance(xc, list):
                    # import pudb; pudb.set_trace()
                    c = super().get_learned_conditioning(xc)
                else:
                    c = super().get_learned_conditioning(xc.to(self.device))
                    c = super().proj_out(c)
                    c = c.float()
            else:
                c = xc
            if bs is not None:
                c = c[:bs]

            if self.use_positional_encodings:
                pos_x, pos_y = self.compute_latent_shifts(batch)
                ckey = __conditioning_keys__[self.model.conditioning_key]
                c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}

        else:
            c = None
            xc = None
            if self.use_positional_encodings:
                pos_x, pos_y = self.compute_latent_shifts(batch)
                c = {'pos_x': pos_x, 'pos_y': pos_y}
        
        ref_img = torch.cat((ref_cloth, ref_human))
        
        img = torch.cat((cloth,human))
        """
        gt = torch.cat((cloth,human))
        residual_img = torch.cat((warp_cloth,main_cloth))
        residual_img_mask = torch.cat((warp_cloth_mask,cloth_mask)) # 两个shape：512*512
        """
        #-------------------------
        residual_img_mask = None, #torch.cat((cloth_mask,cloth_mask)) # 两个shape：512*512
        residual_img = None
        gt = None
        #--------------------------
        out = {
            'z':z, 'caption':c, 'ref_img':ref_img, 'img':img,# 'cloth_mask':z_cloth_mask,"warp_cloth":warp_cloth,
            'residual_img':residual_img,'residual_img_mask':residual_img_mask,
            #--------l1 loss need
            'gt':gt,#'z_src':z_src
            #'main_cloth':main_cloth,'warp_cloth_mask':warp_cloth_mask
            
        }

        
        return out

        

    


        """
        keys_list = list(batch.keys())
        print(f"batch keys: {keys_list}")

        self.dataloader_name = random.choices([0, 1, 2], weights=self.multi_dataloader_prob)[0]  # 根据权重 self.multi_dataloader_prob 从 batch_keys 中随机选择一个键，self.multi_dataloader_prob配置是[1,1,1]
        drop_proportion = self.drop_proportion[self.dataloader_name]

        if bs is  None:
            bs = batch[k].size()[0]
        
        out = self.batch2latent(batch[k][:bs])

        
        # todo-考虑重建损失，即image只有一个target_image,和无分类引导丢弃条件概率的考虑
        if self.dataloader_name == 2:
            cloth_target_image = torch.cat([batch[self.cond_image_key1][:bs],batch[self.cond_image_key2][:bs]])
            latent_cloth_image, latent_image = self.batch2latent(cloth_target_image).chunk(2)

            cloth_target_text = [batch[self.cond_caption_key1][:bs], batch[self.cond_caption_key2][:bs]]
            #outputs = super().get_input(batch, k, bs=bs, **kwargs) #调用LDM的输入函数，将target_image、文本进行输入处理，todo-不再多输出。输出z、c就行，后面改一下
            #c = outputs[1] # str text, todo- 这里只想得到原文本如str类型，而不是编码后的文本，所以训练要设置cond_stage_trainable=True，这样会跳过编码
            #latent_cloth, _ = super().get_input(batch, self.cond_image_key1, bs=bs, **kwargs)

            image = [latent_cloth_image,latent_image]

            # 这里的c应该就是[[cloth_text],[target_text]],image也是，所以get_learning_condition那理得改一下
            c = {'text_input':  cloth_target_text, 'image': image}
        elif self.dataloader_name == 0:
            cloth_image = self.batch2latent(batch[self.cond_image_key1][:bs])
            c = {'text_input':[batch[self.cond_caption_key1][:bs]],'image':[cloth_image] }
            if self.cond_image_key1 != k:
                out = cloth_image
        else:
            target_image = self.batch2latent(batch[self.cond_image_key2][:bs])
            c = {'text_input':[batch[self.cond_caption_key2][:bs]],'image':[target_image] }
            if self.cond_image_key2 != k:
                out = target_image

        if self.training:  # todo- drop_subject_prob这个是求drop_mask时使用的
            # 按照drop_subject_prob的概率随机丢弃一部分inp_image即将这些图片的mask地方标为1
            c['drop_mask'] = (torch.rand(len(out)) >= drop_proportion).reshape(len(out), 1, 1).to(out.device)[:bs]
        else:
            c['drop_mask'] = (torch.rand(len(out)) >= 0.0).reshape(len(out), 1, 1).to(out.device)[:bs]

        if self.first_stage_key_cond is not None: # 将第一阶段VAE编码，需要用到的数据进行编码即"agn", "agn_mask", "image_densepose"
            for key in self.first_stage_key_cond:
                if 'img' in key:
                    cond = self.batch2latent(batch[key][:bs]) # 这里调用LDM的get_input即得到编码的cond
                else:
                    cond = self.batch2latent(batch[key][:bs],  no_latent=True)
                c[f"latent_{key}"] = cond
        
        c["text"] = batch[self.cond_stage_key][:bs] # 空间信息的文本

        return out,c  # [z, c] -->  c[text_input, image, drop_mask, text, latent_agn_mask, latent_wrap_cloth_mask,latent_openpose_map,latent_openpose_img,latent_densepose_img,latent_parse_img]
    """


    def apply_model(self, x_noisy, t, cond, down_sample_residual=None,pose_residual=None,selfattn_img_weight=1.0,
                    return_ids=False):  # todo- 改了ddpm LDM里的apply_model函数，不再考虑图片的分割以及对分割图应用扩散再恢复的策略，增加img_weight这一个参数
        # img_weight: 图像权重，用于平衡不同输入作为条件的影响，如text_embedding和imgae_embedding
        if isinstance(cond, dict):
            # hybrid case, cond is expected to be a dict
            pass
        else:  # 将cond的list类型转换成dict类型
            if not isinstance(cond, list):
                cond = [cond]
            key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
            cond = {key: cond} 

        x_recon = self.model(x_noisy, t,down_residual=down_sample_residual,pose_residual=pose_residual,selfattn_img_weight=selfattn_img_weight,**cond)  # **cond 的作用是将前面创建的 cond 字典中的所有键值对作为关键字参数传递给 self.model 函数

        if isinstance(x_recon, tuple) and not return_ids:
            return x_recon[0]
        else:
            return x_recon

    @torch.no_grad()
    def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=50, ddim_eta=1., return_keys=None,
                   quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
                   plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
                   use_ema_scope=True,
                   **kwargs):  # todo- 对DDPM里log_images的重写

        # 通常用于模型的评估或测试阶段，其中可能希望使用更稳定的 EMA 参数来提高性能，而在训练阶段则使用模型的实际参数。
        #ema_scope = self.ema_scope if use_ema_scope else nullcontext  # todo-新增

        use_ddim = ddim_steps is not None

        log = dict()
        #if self.use_bf16:  # 与标准的 float32（32位浮点数）相比，它可以提供相似的动态范围和精度，但占用更少的内存和计算资源
        #    amp_dtype = torch.bfloat16
        #else:
        #    amp_dtype = torch.float16
        #with torch.cuda.amp.autocast(dtype=amp_dtype,
        #                             enabled=self.use_bf16):  # 将自动将一些操作转换为使用 amp_dtype 数据类型，从而减少内存使用和加速计算
        out = self.get_input(batch, self.first_stage_key, bs=N/2)
        c, down_sample_residual,pose_residual, auxiliary_loss, z,pred_mask = self.get_learned_conditioning(out)

        N = min(z.shape[0], N)
        n_row = min(z.shape[0], n_row)

        #x = batch[self.first_stage_key]
        #if len(x.shape) == 3:
        #    x = x[..., None]
        #x = rearrange(x, 'b h w c -> b c h w')
        #x = x.to(memory_format=torch.contiguous_format).float()
        """
        pred_agn_feature = self.decode_first_stage(z[:,4:-1,:,:]) # 

        log["cloth_human_predAgnFeature"] = torch.cat((out['ref_img'],pred_agn_feature))
        """
        _,_,h,w = out['ref_img'].shape

        log["reconstruction"] = self.decode_first_stage(z[:,0:4,:,:])
        log["pred_mask"] = Resize((h,w), interpolation=InterpolationMode.NEAREST)(pred_mask)
        #log["pose_img2input_size"] = Resize((h,w), interpolation=2)(out['ref_pose'])
        #log['z_resize_to_512_cloth_agn_predAgnMask'] = torch.cat((Resize((h,w), interpolation=2)(out['z_cloth_mask']),Resize((h,w), interpolation=2)(out['z_agn_mask']),(Resize((h,w), interpolation=2)(z[:,-1,:,:])).unsqueeze(1)))
        #log['z_64_cloth_agn_predAgnMask'] = torch.cat((out['z_cloth_mask'],out['z_agn_mask'],z[:,-1,:,:].unsqueeze(1)))

        
            
            
        """
            if plot_diffusion_rows:
                # get diffusion row
                diffusion_row = list()
                z_start = z[:n_row]
                for t in range(self.num_timesteps):
                    if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
                        t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
                        t = t.to(self.device).long()
                        noise = torch.randn_like(z_start)
                        z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
                        diffusion_row.append(self.decode_first_stage(z_noisy))

                diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
                diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
                diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
                diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
                log["diffusion_row"] = diffusion_grid

            if self.dataloader_name == 2:
                log[self.cond_image_key1] = c['image'][0]
                log[self.cond_image_key2] = c['image'][1]
            elif self.dataloader_name == 0:
                 log[self.cond_image_key1] = c['image']
            else:
                log[self.cond_image_key2] = c['image']

            #if self.cond_image_key2 in c:  # cloth图片记录
            #    log[f"{self.cond_image_key2}"] = c


        """
            
        if self.model.conditioning_key is not None:
            if hasattr(self.cond_stage_model, "decode"):
                xc = self.cond_stage_model.decode(c)
                log["conditioning"] = xc
            elif self.cond_stage_key in ["caption", "txt"]:
                xc = log_txt_as_img((h, w), batch[self.cond_stage_key])
                log[f"{self.cond_stage_key}"] = xc
            elif self.cond_stage_key == 'class_label':
                xc = log_txt_as_img((z.shape[2], z.shape[3]), batch["human_label"])
                log['conditioning'] = xc
            elif isimage(xc):
                log["conditioning"] = xc
            if ismap(xc):
                log["original_conditioning"] = self.to_rgb(xc)
            #elif isimage(xc):
            #    log["conditioning"] = xc
            #if ismap(xc):
            #    log["original_conditioning"] = self.to_rgb(xc)
        if plot_diffusion_rows:
            # get diffusion row
            diffusion_row = list()

            z_start = z[:n_row] #-------------这里是[b,9,h,w]吧如何作为采样噪声，不应该是[b,4,h,w]计算吗

            for t in range(self.num_timesteps):
                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
                    t = t.to(self.device).long()
                    noise = torch.randn_like(z_start)
                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
                    diffusion_row.append(self.decode_first_stage(z_noisy))

            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
            log["diffusion_row"] = diffusion_grid
            
        if sample:
            # get denoise row
            with self.ema_scope("Plotting"): # 本实验self.use_ema=False即没用到
                if self.first_stage_key=='inpaint':
                    #print(f"DDIM cond :{c}")
            
                    samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,down_sample_residual=down_sample_residual,pose_residual=pose_residual, # ----- 
                                                            ddim_steps=ddim_steps,eta=ddim_eta,rest=z[:,4:,:,:])
                else:
                    samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,down_sample_residual=down_sample_residual,pose_residual=pose_residual,# ------
                                                            ddim_steps=ddim_steps,eta=ddim_eta)                
                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
            x_samples = self.decode_first_stage(samples)
            log["samples"] = x_samples

            if plot_denoise_rows:
                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
                log["denoise_row"] = denoise_grid

            if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
                    self.first_stage_model, IdentityFirstStage):
                # also display when quantizing x0 while sampling
                with self.ema_scope("Plotting Quantized Denoised"):
                    samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
                                                             ddim_steps=ddim_steps,eta=ddim_eta,
                                                             quantize_denoised=True)
                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
                    #                                      quantize_denoised=True)
                x_samples = self.decode_first_stage(samples.to(self.device))
                log["samples_x0_quantized"] = x_samples
            
            if inpaint:
                # make a simple center square
                b, h, w = z.shape[0], z.shape[2], z.shape[3]
                mask = torch.ones(N, h, w).to(self.device)
                # zeros will be filled in
                mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
                mask = mask[:, None, ...]
                with self.ema_scope("Plotting Inpaint"):

                    samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
                                                ddim_steps=ddim_steps, x0=z[:N,:4], mask=mask)
                x_samples = self.decode_first_stage(samples.to(self.device))
                log["samples_inpainting"] = x_samples
                log["mask"] = mask

                # outpaint
                with self.ema_scope("Plotting Outpaint"):
                    samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
                                                ddim_steps=ddim_steps, x0=z[:N], mask=mask)
                x_samples = self.decode_first_stage(samples.to(self.device))
                log["samples_outpainting"] = x_samples

        if plot_progressive_rows:
            with self.ema_scope("Plotting Progressives"):
                img, progressives = self.progressive_denoising(c,
                                                               shape=(self.channels, self.image_size, self.image_size),
                                                               batch_size=N)
            prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
            log["progressive_row"] = prog_row

        if return_keys:
            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
                return log
            else:
                return {key: log[key] for key in return_keys}
        
        return log

       
    """
    @torch.no_grad()
    def on_validation_epoch_start(self): # 验证集每一epoch开始动作
        self.ddim_sampler = DDIMSampler(self)
        self.validation_gene_dirs = []
        for data_type in ["pair", "unpair"]:
            to_dir = opj(self.valid_config.img_save_dir, f"{data_type}_{self.current_epoch}")
            self.validation_gene_dirs.append(to_dir)
    """
    @torch.no_grad()
    def validation_step(self, batch, batch_idx, dataloader_idx): # stableVITON
        if batch_idx > 2: return
        is_train = self.training
        if is_train:
            self.eval()
        #self.batch = batch
        data_type = "pair" if dataloader_idx==0 else "unpair"

        #z, c = self.get_input(batch, self.first_stage_key)
        out = self.get_input(batch, self.first_stage_key,bs=4)
        c, down_self_residual,pose_residual, auxiliary_loss, z,pred_mask = self.get_learned_conditioning(out)
        bs = z.shape[0]
        print(f"validation bs:{bs}")

        x_recon = self.decode_first_stage(z[:,0:4,:,:])
        shape = (4, self.img_H//8, self.img_W//8)
        ## print(f"validation bs: {bs}")
        """
        unconditional_guidance_label=["over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"]
        uc = self.get_learned_conditioning({
                                            'text_input': [bs * unconditional_guidance_label], # 注意text_input 和 image这两个值都是list，因为可能有两种图片或者两种文本情况。而正常的text即空间文本信息就一个
                                            'text': bs * unconditional_guidance_label
                                                       })  # 无条件embedding生成，这里对应健target_text的值就是N个unconditional_guidance_label
        if self.train and self.use_spatial_semantic_loss:
            c, _ = self.get_learned_conditioning(c)
        else:
            c = self.get_learned_conditioning(c)
        """
        ddim_sampler = DDIMSampler(self)
        samples, intermediates = ddim_sampler.sample(
            self.validation_config.ddim_steps,
            bs,
            shape,
            c,
            down_sample_residual=down_self_residual,
            pose_residual=pose_residual,
            x_T=None,
            verbose=False,
            eta=self.validation_config.eta,
            mask=None,
            x0=None,
            rest=z[:,4:,:,:],
            #unconditional_guidance_scale=self.validation_config.scale,
            #unconditional_conditioning= uc
        )
        x_samples = self.decode_first_stage(samples)
        inpaint_feature = self.decode_first_stage(z[:,4:8,:,:])
        ##-inpaint_skin = self.decode_first_stage(z[:,9:,:,:])
        # z (b,9,64,64)
        inpaint_mask = Resize((x_samples.shape[2],x_samples.shape[3]), interpolation=InterpolationMode.NEAREST)(z[:,8,:,:]).unsqueeze(1)
        pred_mask = Resize((x_samples.shape[2],x_samples.shape[3]), interpolation=InterpolationMode.NEAREST)(pred_mask)
        ##print(f"pred_agn_mask shape before zip : { pred_agn_mask.shape}, count 1 number: {torch.sum(pred_agn_mask == 1).item()}")
        ##print(f"pred_agn_mask : {pred_agn_mask}")
        to_dir = opj(self.validation_config.img_save_dir, f"{data_type}_{self.current_epoch}")
        os.makedirs(to_dir, exist_ok=True)
        #print(x_samples.shape,torch.cat((batch['ref_cloth'],batch['ref_cloth'])).shape,len(2*batch["im_name"]),x_recon.shape,inpaint_feature.shape,inpaint_mask.shape)
        count = 0
        for x_sample, gt_cloth_img, fn, recon,inpaint_feature,inpaint_mask,pred_mask in zip(x_samples, torch.cat((batch['human'][:4],batch['cloth'][:4])),2*(batch["im_name"][:bs]), x_recon,inpaint_feature,inpaint_mask,pred_mask):
            count += 1
            x_sample_img = tensor2img(x_sample)
            x_recon_img = tensor2img(recon) 
            gt_cloth_img = tensor2img(gt_cloth_img)
            inpaint_feature = tensor2img(inpaint_feature)
            #print(f"pred_agn_mask shape: {pred_agn_mask.shape}")
            inpaint_mask = tensor2img(inpaint_mask)
            ##-inpaint_skin = tensor2img(inpaint_skin)
            pred_mask = tensor2img(pred_mask)

            #print(f"[x_sample_img, gt_img, cloth_img, x_recon_img] shape: {[i.shape for i in [x_sample_img, gt_img, cloth_img, x_recon_img]]}")
            cloth_save = np.concatenate([x_sample_img, gt_cloth_img, x_recon_img,inpaint_feature,inpaint_mask, pred_mask], axis=1)
    
            to_path = opj(to_dir, "cloth_"+fn if count <= bs / 2 else "human_"+fn)
            cv2.imwrite(to_path, cloth_save[:,:,::-1]) # 原来是RGB的形式，cv2处理的格式是BGR
        if is_train:
            self.train()

    def p_losses(self, x_start, cond, t, down_self_residual,pose_residual,auxiliary_loss,gt=None,noise=None):
        if self.first_stage_key == 'inpaint':
            # x_start=x_start[:,:4,:,:]
            noise = default(noise, lambda: torch.randn_like(x_start[:,:4,:,:]))
            x_noisy = self.q_sample(x_start=x_start[:,:4,:,:], t=t, noise=noise)
            x_noisy = torch.cat((x_noisy,x_start[:,4:,:,:]),dim=1)
        else:
            noise = default(noise, lambda: torch.randn_like(x_start))
            x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

        
        model_output = self.apply_model(x_noisy, t, cond,down_self_residual,pose_residual)

        loss_dict = {}
        prefix = 'train' if self.training else 'val'

        if self.parameterization == "x0":
            target = x_start
        elif self.parameterization == "eps":
            target = noise
        else:
            raise NotImplementedError()

        ## print(f"model_out shape:{model_output.shape}, target shape:{target.shape}")
        loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})

        #print(f"t :{t.device}，logvar ：{self.logvar.device}")
        #-----------------设备不同，换了个顺序
        logvar_t = self.logvar.to(self.device)[t]

        loss = loss_simple / torch.exp(logvar_t) + logvar_t
        # loss = loss_simple / torch.exp(self.logvar) + self.logvarS
        if self.learn_logvar:
            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
            loss_dict.update({'logvar': self.logvar.data.mean()})

        loss = self.l_simple_weight * loss.mean()

        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
        #print(f"loss_vlb:{type(loss_vlb)}")
        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
        loss += (self.original_elbo_weight * loss_vlb)
        print(f"diffusion loss：{loss,type(loss)}")
        if self.use_auxiliary_loss : # 辅助损失
            mask_loss = self.m_weight*auxiliary_loss['mask_loss']
            #pose_loss = self.p_weight*auxiliary_loss['pose_loss']
            #loss_dict.update({f'{prefix}/pose_loss': pose_loss})
            loss_dict.update({f'{prefix}/mask_loss': mask_loss})
            #print(f"pose loss：{pose_loss},grad:{pose_loss.requires_grad}")#,mask loss：{mask_loss},grad:{mask_loss.requires_grad}")
            loss +=  mask_loss
            #---------------------------------------------------------
            """
            loss_l1_weight = 1#5e-1
            x_denoisy = self.predict_start_from_noise(x_noisy[:,:4,:,:],t=t,noise=model_output)
            x_samples = self.differentiable_decode_first_stage(x_denoisy)
            loss_l1 = self.get_loss(x_samples,gt,mean = True)
            loss += loss_l1 * loss_l1_weight
            print(f"l1 loss:{loss_l1 * loss_l1_weight}")
            loss_dict.update({'train/loss_l1':loss_l1 * loss_l1_weight})
            """
            #--------------------------------------------------------


        loss_dict.update({f'{prefix}/loss': loss})
        
        return loss, loss_dict
    
    def forward(self, input_out, *args, **kwargs): # 增加z_src,gt
        self.opt.params = self.params
        if self.model.conditioning_key is not None:
            if self.cond_stage_trainable:
                c, down_self_residual,pose_residual, auxiliary_loss, z,pred_mask = self.get_learned_conditioning(input_out)
            if self.shorten_cond_schedule:  # TODO: drop this option
                tc = self.cond_ids[t].to(self.device)
                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
            t = torch.randint(0, self.num_timesteps, (z.shape[0],), device=self.device).long() # z.shape[0]=bs,训练时则为2*bs
        
        self.u_cond_prop = random.uniform(0,1)
        if self.u_cond_prop < self.u_cond_percent:
            self.opt.params = self.params_with_white
            c[0] = self.learnable_vector.repeat(c[0].shape[0],1,1)
        
        loss, loss_dict = self.p_losses(z, c,t, down_self_residual,pose_residual,auxiliary_loss,gt=input_out['gt'],*args, **kwargs)
        """
        #-----------------------------
        # L1 loss
        loss_l1_weight = 5e-1
        x_pred = self.sample_hijack(input_out['z_src'],c,t,down_self_residual,pose_residual)
        x_samples = self.differentiable_decode_first_stage(x_pred)
        loss_l1 = self.get_loss(x_samples,input_out['gt'],mean = True)
        loss += loss_l1 * loss_l1_weight
        print(f"l1 loss:{loss_l1 * loss_l1_weight}")
        loss_dict.update({'train/loss_l1':loss_l1 * loss_l1_weight})
        #------------------------------
        """
        return loss, loss_dict
    

    def sample_hijack(self,x_start,cond,t,down_self_residual,pose_residual,noise=None):
        noise = default(noise,lambda:torch.randn_like(x_start[:,:4,:,:]))
        x_noisy = self.q_sample(x_start=x_start[:,:4,:,:],t=t,noise=noise)
        x_noisy = torch.cat((x_noisy,x_start[:,4:,:,:]),dim=1)

        model_output = self.apply_model(x_noisy, t, cond,down_self_residual,pose_residual)
        x_denoisy = self.predict_start_from_noise(x_noisy[:,:4,:,:],t=t,noise=model_output)

        return x_denoisy


    def shared_step(self, batch, **kwargs):
        out = self.get_input(batch, self.first_stage_key)
        loss, loss_dict = self(out)
        return loss, loss_dict
    
    @torch.no_grad()
    def sample_log(self,cond,batch_size,ddim, ddim_steps,down_sample_residual,pose_residual,**kwargs):
        # 根据不同的条件和配置来生成样本，并记录生成过程中的中间状态
        if ddim:
            ddim_sampler = DDIMSampler(self)
            shape = (self.channels, self.img_H//8, self.img_W//8) 
            samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,shape,conditioning=cond,
                                                        down_sample_residual=down_sample_residual,pose_residual=pose_residual,
                                                        verbose=False,**kwargs)
              

        else:
            samples, intermediates = self.sample(cond=cond, batch_size=batch_size,down_sample_residual=down_sample_residual,
                                                 pose_residual=pose_residual,
                                                 return_intermediates=True,**kwargs)

        return samples, intermediates
    
    