name: soft_pose_HumanArt
model:
  learning_rate: 1e-5
  target: cldm.soft_pose_cldm_train.SoftPose
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    image_size: 64
    channels: 4
    cond_stage_trainable: false
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False
    gaussian_kernels: [23,13] # gaussian kernels used to generate pose masks
    alpha: 5 # coefficient in the proposed pose-mask guided loss

    control_config:
      target: cldm.soft_pose_cldm_train.SoftPoseAdapter
      params:
        in_channels: 4
        model_channels: 320
        cond_channels: 3
        inject_channels: [192] #[192, 256, 384, 512]
        inject_layers: [1] #[1,4,7,10]
        num_res_blocks: 2
        attention_resolutions: [4, 2, 1]
        channel_mult: [1, 2, 4, 4]
        use_checkpoint: True
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        legacy: False

    unet_config:
      target: cldm.soft_pose_cldm_train.ControlUNetModel
      params:
        image_size: 32
        in_channels: 4
        model_channels: 320
        out_channels: 4
        num_res_blocks: 2
        attention_resolutions: [4, 2, 1]
        channel_mult: [1, 2, 4, 4]
        use_checkpoint: True
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

data:
  target: ldm.data.dataset_pose.PoseDataset
  params:
    batch_size: 1
    num_workers: 1
    wrap: False
    train:
      target: ldm.data.dataset_pose.PoseDataset
      params:
        phase: train
        image_size: 512
        map_file: dataset/HumanArt/mapping_file_training.json
        base_path: dataset/HumanArt
        max_person_num: 10
        keypoint_num: 17
        keypoint_dim: 3
        skeleton_width: 10
        keypoint_thresh: 0.02
        pose_skeleton: [
                [0,0,1],
                [1,0,2],
                [2,1,3],
                [3,2,4],
                [4,3,5],
                [5,4,6],
                [6,5,7],
                [7,6,8],
                [8,7,9],
                [9,8,10],
                [10,5,11],
                [11,6,12],
                [12,11,13],
                [13,12,14],
                [14,13,15],
                [15,14,16],
            ]
    test:
      target: ldm.data.dataset_pose.PoseDataset
      params:
        phase: eval
        image_size: 512
        map_file: dataset/HumanArt/mapping_file_evaluation.json
        base_path: dataset/HumanArt
        max_person_num: 10
        keypoint_num: 17
        keypoint_dim: 3
        skeleton_width: 10
        keypoint_thresh: 0.02
        pose_skeleton: [
                [0,0,1],
                [1,0,2],
                [2,1,3],
                [3,2,4],
                [4,3,5],
                [5,4,6],
                [6,5,7],
                [7,6,8],
                [8,7,9],
                [9,8,10],
                [10,5,11],
                [11,6,12],
                [12,11,13],
                [13,12,14],
                [14,13,15],
                [15,14,16],
            ]


lightning:
  find_unused_parameters: True
  modelcheckpoint:
    params:
      every_n_train_steps: 1000000 # -> not saving checkpoints per n steps

  callbacks:
    metrics_over_trainsteps_checkpoint:
      params:
        every_n_train_steps: 1000000 # -> not saving checkpoints per n steps
        
  trainer:
    benchmark: True
    num_sanity_val_steps: 0
    accumulate_grad_batches: 4
    precision: 16