model:
  base_learning_rate: 5.0e-06
  target: ldm.models.diffusion.Ss_Diffusion2_5.QTaskDiffusion #----注意在这里更改Ss_Diffusion2
  params:
    #first_stage_key_cond: [ "agn_mask","warped_cloth_mask",'openpose_map',"parse_img","densepose_img", 'openpose_img'] #["cloth_mask",,,"agn", "agn_mask", "image_densepose"] # 如果需要计算额外的loss就需要将mask、densepose等转换到潜空间
    first_stage_key: "inpaint" # 目标要生成的图片
    cond_stage_key: "caption" # 空间文本信息
    cond_stage_trainable: True # Note: different from the one we trained before
    sd_train_text_encoder: False
    #cond_image_key1: "cloth"
    #cond_image_key2: 'human'
    #cond_caption_key1: 'cloth_captions'
    #cond_caption_key2: 'human_captions'

    semantic_spatical_model_train: true
    use_auxiliary_loss: True #---------第二阶段不考虑

    cross_attention_train: True # 1.交叉注意力部分
    ## ------------第一阶段不考虑
    down_self_attention_train: False #2.自注意力部分


    p_weight : 0.03
    m_weight : 0.0004
    dataloader_prob: [1,1]

    crossattn_mask_threshold: 0.4
    use_crossattn_and_feature: False

    u_cond_percent: 0.2


    #first_init_from_ckpt: True # unet第一次从别的预训练模型加载需要一些特殊的初始化


    #drop_proportion: # 三种数据格式训练，每种条件丢弃的概率
    #  - 0.15
    #  - 0.15
    #  - 0.15
    #multi_dataloader_prob: [1,1,1] # 三种训练模式概率
    #lamda_weigt: [0.01,0.01,0.01,0.01,0.01,0.01,0.01]

    #use_spatial_semantic_loss: False  # 是否采用辅助损失，使用时确保QTmodel返回semantic、spatial feature即


    #qformer_pretrained_path: './checkpoint/QFormer.bin'
    #TaskFormer_pretrained_path: './checkpoint/Mask2Former.pkl'
    #unet_path: './checkpoint/unet.bin'
    #---------从50继续训练
    ckpt_path:  logs/20250103_1-5/models/[Train]_[epoch=4]_[train_loss_epoch=0.00000].ckpt # 'checkpoints/model.ckpt' #logs/20241231_dummy/models/[Train]_[epoch=49]_[train_loss_epoch=0.00000].ckpt 
    #first_init_from_ckpt: True
    #use_bf16: True


    
    image_size: 64
    channels: 4
    img_H : 512
    img_W : 512 # 384


    # 以下都是默认的
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False

 

    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        image_size: 64 # unused
        in_channels: 9
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: False  #----
        legacy: False  
        add_conv_in_front_of_unet: False

        #---------------第一阶段，不考虑spatial对扩散的影响，只训练spatial的两个损失
        add_cond2selfattn: True # 增加外部信息到自注意力模块。这两个同时为true决定下采样外部信息注入自注意力模块
        disable_self_attn: True # 让最原始的自注意力失效
        add_text_weigh_to_selfattn: False # 自注意力是否由文本控制注意力矩阵，这时上采样阶段使用，只有三个同时为True，这个才生效
        #content_layer: [ 4,5,6,7,8 ]
        # content_layer: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
      sequential_crossattn: True




    scheduler_config: # 10000 warmup steps
      target: ldm.lr_scheduler.LambdaLinearScheduler
      params:
        warm_up_steps: [ 1000 ] # 10000
        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
        f_start: [ 1.e-6 ]
        f_max: [ 1. ]
        f_min: [ 1. ]


    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
            - 1
            - 2
            - 4
            - 4
          num_res_blocks: 2
          attn_resolutions: [ ]
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

    #QTModel:
      #num_classes: 255
      #aux_loss: False
      #contrastive_align_loss: True # 扩散模型中确定使用辅助损失则需要返回
      #contrastive_hdim: 64  #args.contrastive_loss_hdim
      #text_encoder_type: 'roberta-base' #args.text_encoder_type,
      #freeze_text_encoder: True #args.freeze_text_encoder
      #unet_feature: "up1" #args.unet_feature


      #semantic_matching: True # todo-是否提供用来做额外损失
      #spatial_matching:  True
      #require_QFormer_text_embedding: True

      #TaskFormer_cloth_mask_q_pos: 0 # 记录的每一部分的结束位置
      #TaskFormer_agn_mask_q_pos: 1
      #TaskFormer_openpose_map_q_pos: 19
      #TaskFormer_parse_q_pos: 23
      #TaskFormer_openpose_img_q_pos: 27 
      #TaskFormer_densepose_q_pos : 31 
      #TaskFormer_spatial_q_pos: 61

    SwinTransformer_backbone:
      img_size: [512,512]
      embed_dim: 128
      depths: [2, 2, 18, 2]
      num_heads: [4, 8, 16, 32]
      window_size: 16
      drop_path_rate: 0.2
      last_norm: False
      #mask: ''
      #------
      #pretrained_path: 'checkpoints/swin_base_patch4_window12_384_22kto1k.pth'
    SwinTransformerV2:
      img_size: 256
      embed_dim: 128
      depths: [2, 2, 18, 2]
      num_heads: [4, 8, 16, 32]
      window_size: 16
      drop_path_rate: 0.5
      #pretrained_path: 'checkpoints/swinv2_base_patch4_window16_256.pth'


    TaskFormer:
      in_channels: [768,512,256,128] #[768,2560,1280] # todo-后续得改，具体取决又输入unet的哪三层或者concat成的三层
      #mask_classification: true
      #num_classes: 133
      hidden_dim: 256
      num_queries: 100 #
      mask_q_num: 1
      pose_q_num: 35
      nheads: 8
      dim_feedforward: 2048
      dec_layers: 9
      pre_norm: false
      mask_dim: 256  # 这里表示的是最大图片的channel，好让decoder query提取的mask_emb 维度转换到这个channel，然后做乘法的操作
                      # 原文这里是256，可能原文的视觉提取backbone最大特征图的通道数是256
      enforce_input_project: True # 原false
      QFormer_num_queries: 16
      #use_crossattn_and_feature: true
      predict_attention_mask: False
      #crossattn_mask_threshold: 0.2
      
      QFormer_detach: False #--------第一阶段考虑一下QFormer连贯性
      #-----------
      #pretrained_path: "checkpoints/Mask2Former.pkl"
    QFormer:
      vit_model: "eva_clip_g"
      num_query_token: 16
      cross_attention_freq: 1

      img_size: 32 # 这里没改，应该改成32，原本默认为224

      use_vit_out: false
      vit_encoder_dim: 1024  # 原clip好像就是1024
      #--------------
      #pretrained_path: 'checkpoints/QFormer.bin' 
      
    Decoder:
        #convin_kernel_size: [1, 1, 1, 1, 1, 1]
        #convin_stride: [1, 1, 1, 1, 1, 1]
        #convin_padding: [0, 0, 0, 0, 0, 0]
        #attn_residual_block_idx: [1,2,4,5,7,8] # [320, 64, 64],[ 320, 64, 64],[ 640, 32, 32],[ 640, 32, 32],[1280, 16, 16],[1280, 16, 16]
                                                #即[320, 4096],[320,4096],[640,1024],[640,1024],[1280,256],[1280,156]
        #inner_dims: [64, 64, 64, 64, 64, 64]
        #ctx_dims: [320, 320,640, 640, 1280, 1280] # 最终输出到unet中的维度
        #embed_dims: [64, 64, 128, 128,  256, 256] # 此模型的中间维度
        #heads: [8, 8, 8, 8,  8, 8]
        #depth: 4
        #to_self_attn: True
        #to_queries: True
        #aspect_ratio : 1
        #to_keys : False
        #to_values: False
        #detach_input: False 
        #pose_dim: 35
      
        convin_kernel_size: [1, 1, 1, 1, 1, 1,1,1,1]
        convin_stride: [1, 1, 1, 1, 1, 1,1,1,1]
        convin_padding: [0, 0, 0, 0, 0, 0,0,0,0]
        attn_residual_block_idx: [11,10,9,8,7,6,5,4,3] # [320, 64, 64],[ 320, 64, 64],[ 640, 32, 32],[ 640, 32, 32],[1280, 16, 16],[1280, 16, 16]
                                                #即[320, 4096],[320,4096],[640,1024],[640,1024],[1280,256],[1280,156]
        #--------------
        pose_attn_residual_block_idx: [1,2,4,5,7,8]
        pose_ctx_dims: [320, 320,640, 640, 1280, 1280]
        pose_dim: 35
        
        #----------------
        inner_dims: [64, 64, 64, 64, 64, 64,64,64,64]
        ctx_dims: [320, 320,320,640, 640, 640,1280, 1280,1280] # 最终输出到unet中的维度
        embed_dims: [64, 64, 64,128, 128,128,  256, 256,256] # 此模型的中间维度
        heads: [2, 2, 2,4,4,4, 8,  8, 8]
        depth: 2
        to_self_attn: True
        to_queries: True
        aspect_ratio : 1
        to_keys : False
        to_values: False
        detach_input: False 
    #Cond:
    #  hidden_dim: 768 # 这里是pos编码的维度
    #  position_embedding: 'sine'


    validation_config:
      ddim_steps: 200 #50
      eta: 0.0
      #scale: 5.0

dataset_name: cp_dataset #VITONHDDataset

resume_path: null # 想从断点重新加载，可以让其他模型path为空，只加载这里的总模型参数即可
#log_images_kwargs: #这会对log_images函数进行传递一些参数，因此
#  unconditional_guidance_scale: 5.0

#  ddim_steps: 2 # 新加的，可删
#  eta: 0.0
  
  # 这里不太确定，按照推理的代码来说的话就是这个
#  unconditional_guidance_label: [""]#["over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"]


