name: ig2mv
tag: "jigsaw-3d-ckpts"

exp_root_dir: "outputs"
# root_dir: "outputs"

seed: 42
data_cls: jigsaw3D.data.multiview.MultiviewDataModule
data:
  rgb_root_dir: /path/to/texture_rand_easylight_objaverse
  pbr_root_dir: /path/to/texture_ortho10view_pbr_gt
  scene_list: /path/to/objaverse_list_6w.json
  background_color: gray
  image_names: ["0006", "0000", "0002", "0004", "0008", "0009"] # [6,0,2,4,8,9]
  image_modality: color
  num_views: 6

  prompt_db_path: /path/to/objaverse_full_captions.json
  return_prompt: true

  projection_type: ORTHO

  source_image_modality: ["position", "normal"]
  position_offset: 0.5
  position_scale: 1.0

  reference_root_dir: ["/path/to/texture_rand_easylight_objaverse"]
  reference_scene_list: ["/path/to/objaverse_list_6w.json"]
  reference_image_modality: color
  reference_image_names: ["0000", "0001", "0002", "0003", "0004"]

  train_indices: [0, -16]
  val_indices: [-8, null]
  test_indices: [-16, null]

  height: 512
  width: 512

  batch_size: 1
  num_workers: 1

system_cls: jigsaw3D.systems.jigsaw3D_image_sdxl.jigsaw3DImageSDXLSystem
system:
  check_train_every_n_steps: 10000
  cleanup_after_validation_step: true
  cleanup_after_test_step: true

  # Model / Adapter
  pretrained_model_name_or_path: "/path/to/stabilityai/stable-diffusion-xl-base-1.0"
  pretrained_vae_name_or_path: "/path/to/madebyollin/sdxl-vae-fp16-fix"
  pretrained_adapter_name_or_path: null
  init_adapter_kwargs:
    self_attn_processor: "jigsaw3D.models.attention_processor.jigsaw3DAttnProcessor"
    # Condition encoder 
    cond_in_channels: 6
    # For training
    copy_attn_weights: true
    zero_init_module_keys: ["to_out_mv", "to_out_ref"]

  # Training
  # train_cond_encoder: False
  train_cond_encoder: True
  trainable_modules: ["_mv", "_ref"]
  prompt_drop_prob: 0.1
  image_drop_prob: 0.1
  cond_drop_prob: 0.1

  # Noise sampler
  shift_noise: true
  shift_noise_mode: interpolated
  shift_noise_scale: 8

  # Evaluation
  eval_seed: 42
  eval_num_inference_steps: 30
  eval_guidance_scale: 3.0
  eval_height: ${data.height}
  eval_width: ${data.width}

  # optimizer definition
  optimizer:
    name: AdamW
    args:
      lr: 5e-5 # 5e-5
      betas: [0.9, 0.999]
      weight_decay: 0.01
    params:
      unet:
        lr: 5e-

  scheduler:
    name: SequentialLR
    interval: step
    schedulers:
      - name: LinearLR
        interval: step
        args:
          start_factor: 1e-6
          end_factor: 1.0
          total_iters: 2000
      - name: ConstantLR
        interval: step
        args:
          factor: 1.0
          total_iters: 9999999
    milestones: [2000]

# use deepseed
trainer:
  max_epochs: 10
  log_every_n_steps: 10
  num_sanity_val_steps: 1
  val_check_interval: 5000
  enable_progress_bar: true
  precision: 16
  gradient_clip_val: 1.0
  strategy: deepspeed

checkpoint:
  save_last: true # whether to save at each validation time
  save_top_k: -1
  every_n_epochs: 9999 # do not save at all for debug purpose
