model:
  base_learning_rate: 1.0e-05
  target: ldm.models.diffusion.ddpm_forget.LatentDiffusion
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "image"
    cond_stage_key: "txt"
    image_size: 64
    channels: 4
    cond_stage_trainable: false   # Note: different from the one we trained before
    conditioning_key: crossattn
    scale_factor: 0.18215
    gamma: 1 # keep this at 1
    lmbda: 50 # change FIM weighting term here
    train_method: 'full' # choices: ['full', 'xattn', 'noxattn']

    scheduler_config: # 10000 warmup steps
      target: ldm.lr_scheduler.LambdaLinearScheduler
      params:
        warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
        f_start: [ 1.e-6 ]
        f_max: [ 1. ]
        f_min: [ 1. ]

    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        image_size: 32 # unused
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder


data:
  target: main_forget.DataModuleFromConfig
  params:
    train_batch_size: 4
    val_batch_size: 6
    num_workers: 4
    num_val_workers: 0 # Avoid a weird val dataloader issue (keep unchanged)
    train:
      target: ldm.data.ForgettingDataset
      params:
        forget_prompt: brad pitt
        forget_dataset_path: ./q_dist/middle_aged_man_dataset
    validation:
      target: ldm.data.VisualizationDataset
      params:
        captions: # these are captions of 6 prompts that we will log throughout training. feel free to change them to log whatever you wish
        - a photo of brad pitt
        - brad pitt holding an apple dancing under the tree
        - brad pitt dancing in a ball room
        - A pack of huskies racing across a frozen lake, their panting breath and synchronized strides creating a rhythmic and awe-inspiring spectacle.
        - A majestic dolphin leaping out of the water in a dramatic display of acrobatics, the sun glinting off its wet skin and creating a dazzling spray of droplets.
        - A highway full of vehicles, from big to small, bustling, busy.
        output_size: 512
        n_gpus: 4 # CHANGE THIS TO NUMBER OF GPUS! small hack to sure we see all our logging samples

lightning:
  find_unused_parameters: False

  modelcheckpoint:
    params:
      every_n_epochs: 0
      save_top_k: 0
      monitor: null

  callbacks:
    image_logger:
      target: main_forget.ImageLogger
      params:
        batch_frequency: 1
        max_images: 999
        increase_log_steps: False
        log_first_step: False
        log_all_val: True
        clamp: True
        log_images_kwargs:
          ddim_eta: 0
          ddim_steps: 50
          use_ema_scope: True
          inpaint: False
          plot_progressive_rows: False
          plot_diffusion_rows: False
          N: 6 # keep this the same as number of validation prompts!
          unconditional_guidance_scale: 7.5
          unconditional_guidance_label: [""]

  trainer:
    benchmark: True
    num_sanity_val_steps: 0
    max_epochs: 200 # modify epochs here!
    check_val_every_n_epoch: 10
    