defaults:
  - _self_
  - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
  - /strategy: ddp
  - /noise: loglinear
  - /lr_scheduler: constant_warmup

mode: train  
diffusion: absorbing_state
backbone: dit  
parameterization: subs
time_conditioning: False
T: 0
subs_masking: False

seed: 1

# IMAGE CONDITIONING
tracking_interval: 500  # Log every N training steps
use_image_conditioning: true

loader:
  global_batch_size: 128 # More reasonable for image-text
  eval_global_batch_size: ${.global_batch_size}
  # Note: batch_size and eval_batch_size are **per machine**
  batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  num_workers: 1 # Enable parallel data loading ${eval:"len(__import__('os').sched_getaffinity(0))"}
  pin_memory: True

sampling:
  predictor: analytic  # analytic, ddpm, ddpm_cache
  steps: 50
  noise_removal: True
  num_sample_batches: 2  # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
  num_sample_log: 2
  semi_ar: False
  stride_length: 1
  num_strides: 1

training:
  ema: 0.9999
  antithetic_sampling: True
  importance_sampling: False
  sampling_eps: 1e-3
  change_of_variables: False

eval:
  checkpoint_path: '' # Used to evaluate a checkpoint after training.
  disable_ema: False
  compute_generative_perplexity: False
  perplexity_batch_size: 8
  compute_perplexity_on_sanity: False
  generate_samples: true
  sample_frequency: 5     # Only generate samples every 5 epochs
  validation_steps: 25    # Fast validation sampling

optim:
  weight_decay: 0.1
  lr: 1e-4
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8
  name: "adam"
  fused: false

trainer:
  _target_: lightning.Trainer
  accelerator: cuda
  num_nodes: 1
  devices: 4
  accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
  gradient_clip_val: 0.5
  precision: 'bf16'
  num_sanity_val_steps: 2
  max_steps: 500000 
  log_every_n_steps: 10
  limit_train_batches: 0.05   # train on full dataset
  limit_val_batches: 0.1     # 1.0 validate on full dataset
  val_check_interval: 3 # Every 3 epochs

hydra:
  run:
    dir: ./outputs/${now:%Y.%m.%d}/${now:%H%M%S}
  job:
    chdir: true
  job_logging:
    root:
      handlers: [console]
      
logging:
  save_generation_examples: true
  num_tracking_images: 4

checkpointing:
  save_dir: ${cwd:}
  resume_from_ckpt: true
  resume_ckpt_path: None
  
  
# IMAGE ENCODING  
image_encoder: 
  type: 'clip'
  weights: None 
  load_custom_weights: true
  custom_weight_path: * # clip weights 
  use_patch_tokens: true

model:
  vocab_size: 49408 # CLIP vocab size
  image_embed_dim: 512 # CLIP output dimension
  hidden_size: 768
  name: small
  type: ddit
  cond_dim: 128
  length: 40
  n_blocks: 12
  n_heads: 12
  scale_by_sigma: true
  dropout: 0.1
  tie_word_embeddings: false
  num_layers: 2
  num_heads: 2
  dropout_rate: 0.8
  init_from_hf_dit: false 
  hf_dit_path: null # It has to be specified
  init_from_pretrained_backbone: true
  pretrained_backbone_path: null # It has to be specified
  tokenizer: clip # Options --> clip, gpt2
  

train_ds:
    img_size: 224
    augmentation: false 
    wrap: false   
    
val_ds:
    img_size: 224
    augmentation: false
    wrap: false    
  
compute:
  ngpus: 4
