# @package _global_

defaults:
  - vq16_magvit
  - override /model: small
  - override /lr_scheduler: constant_warmup_cosine_decay

model:
  img_length: ${eval:'(${data.resolution} // ${model.downscale_ratio})**2'}
  txt_length: ${eval:'${data.block_size} if ${.unified_model} else 0'}
  length: ${eval:'${.txt_length} + ${.img_length}'}
  image_model: true
  text_model: true
  unified_model: true
  image_model_fid_eval: false
  force_argmax_valid_indices: true
  use_pretrained_img_emb: false
  codebook_embed_dim: 256
  qk_norm: true
  norm_type: rms
  sandwich_normalization: true
  zero_linear_init: false
  modality_embed: true
  rope_2d: false
  use_spda_attn: true
  force_optimized_native_attn: true
  freeze_txt_emb: false
  add_labels: null
  txt_dropout: null
  text_vocab_size: 32001

data:
  train: combined_tokens
  valid: ${.train}
  n_duplicate_train: null
  wrap: true
  streaming: false
  precache: false
  tokenizer_name_or_path: NousResearch/Llama-2-7b-hf
  resolution: 256
  block_size: 128
  n_val_samples: null
  unpaired: false
  n_duplicate_val: null
  save_train_dataloader: true
  save_validation_dataloader: true
  iterable: false
  webdataset_iterable: false
  webdataset_indexed: false
  dataset_type: null
  tokens_flip_collate: false
  n_train_samples: null
  raw_data_dir: null
  tokenizers_parallelism: false
  token_data_dir: null
  force_disable_shuffle: false
  keep_tensordict_on_disk: true
  use_custom_tensordict_collate: true
  force_mp_spawn: false
  enable_cuda_in_tensordict_collate: false
  use_weighted_tensordict_sampler: true
  fraction_txt_data: 0.0
  tokenize_vqvae_in_dataloader: false
  use_token_dataset: true
  image_dataset: tglcourse/lsun_church_train
  image_data_train: null
  image_data_val: null
  keep_hf_dataset_in_memory: true
  allow_label: false
  disable_text_modality: true
  force_raw_train_images: false
  aggressive_aug: true
  allow_aug_vqvae_dataloader: true
  move_tensordict_to_shm: false
  data_dir_train:
    - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit
      weight: -1
      name: datacomp1b_8_magvit_train
    - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_train_256
      weight: -1
      name: cc12m_tokens_train_256
    - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/HPDv2_image_reward_v1_v2_v3_magvit
      weight: -1
      name: HPDv2_image_reward_v1_v2_v3_magvit
    - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/pick_score_sac_prompts_v1_v2_v3_magvit
      weight: -1
      name: pick_score_sac_prompts_v1_v2_v3_magvit
    - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/datacomp1b_0_1_6_magvit
      weight: -1
      name: datacomp1b_0_1_6_magvit
    - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_0
      weight: -1
      name: laion400m_magvit_part_0
    - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/grogu/laion400m_magvit_part_1
      weight: -1
      name: laion400m_magvit_part_1
  data_dir_val:
    - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/datacomp1b_8_magvit_val
      weight: 1
      name: datacomp1b_8_magvit_val
    - dir: ${oc.env:DIFFUSION_DATA_DIR}/tokens/scratch_ssd_tokens/cc12m_tokens_val_256
      weight: 1
      name: cc12m_tokens_val_256

eval:
  generate_samples: true
  compute_generative_perplexity: true
  log_every_n_evals: 10
  log_every_n_fid: 20
  limit_val_batches_manual: 16
  perplexity_batch_size: ${loader.eval_batch_size}
  num_masking_viz_batches: -1
  cfg: null
  class_conditional_fid: false
  force_cfg_value: true
  split_cfg_batches: true
  max_num_fid_batches_per_device: ${eval:'8192 // (${trainer.devices} * ${loader.eval_batch_size})'}
  fid_mode: clean
  clean_fid_precomputed_name: lsun_church
  clean_fid_precomputed_split: trainfull
  clean_fid_precomputed_res: 256
  
trainer:
  log_every_n_steps: 10
  val_check_interval: 1000
  custom_ddp_bf16: true
  scale_lr_by_batch_size: false
  limit_val_batches: 16
  use_gradient_checkpointing: false
  log_seperate_modal_losses: true
  softmin_snr: 5
  text_loss_weight: 1.0
  img_loss_weight: null
  low_precision_loss: false
  compile: true
  multimodal_batches: true
  compile_fullgraph: false
  log_grad_norm_every_n_steps: 10
  mask_entire_modality: 0.1
  force_shift_image_batches: false
  ckpt_steps: 10000
  ckpt_every_n_minutes: -1
  ignore_text_in_unified: false
  disable_all_eval_generation: true
  eval_on_start: false
  ckpt_model_only: false
  ema: 0.0
  use_custom_ema: false
  log_flops: false
  disable_distributed_torchmetrics: true
  restart_on_failure: true
  force_null_sigma: true
  allow_null_sigma: true
  compile_flag_pos_emb: true
  add_label: false
  first_token_dropout: null
  force_shift_raw_image_batches: true
  txt_dropout: 0.1
  force_full_attention_mask_loss_only: true

optim:
  lr: 0.0003
  weight_decay: 0.05

loader:
  batch_size: 64
  eval_batch_size: ${loader.batch_size}
  num_workers: 4
  desired_global_batch_size: 512
  persistent_workers: true
  pin_memory: true
  num_eval_workers: 1

sampling:
  steps: ${model.length}
  num_sample_batches: 2
  max_sampling_steps: ${model.length}

wandb:
  mode: online

lr_scheduler:
  num_warmup_steps: 5000
  num_training_steps: ${trainer.max_steps}

checkpointing:
  checkpoints_total_limit: 10