defaults:
  - _self_
  - /model: small
  - /noise: loglinear
  - /lr_scheduler: constant_warmup_cosine_decay
  - /experiments: []
  # - override hydra/launcher: submitit_slurm

slurm: False
debug: False
mode: train  # train / eval
diffusion: absorbing_state
backbone: dit  # dit / dimamba / ar
parameterization: subs  # subs / d3pm / sedd
time_conditioning: True
T: 0  # 0 (continuous time) / 1000
subs_masking: False
seed: 42
profile: False
# These belong in trainer.* and hydra.launcher.* but are put here for CLI convinience
devices: ${device_count:}
nodes: 1
partition: ${find_partition:}
constraint: ${find_constraint:}
ckpt: null

loader:
  desired_global_batch_size: 512
  global_batch_size: null
  eval_global_batch_size: ${.global_batch_size}
  batch_size: ${div_up:${.desired_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  eval_batch_size: 1
  num_workers: ${eval:"max(len(__import__('os').sched_getaffinity(0)) // 16, 4)"}
  pin_memory: True
  persistent_workers: True

sampling:
  predictor: ddpm_cache # analytic, ddpm, ddpm_cache
  steps: 1000
  max_sampling_steps: 500 # The highest level we use for sampling
  noise_removal: True
  num_sample_log: 2
  semi_ar: False
  stride_length: 1
  num_strides: 1

eval:
  checkpoint_path: ''  # Used to evaluate a checkpoint after training.
  disable_ema: False
  compute_generative_perplexity: False
  perplexity_batch_size: 8
  gen_ppl_eval_model_name_or_path: gpt2-large  # gpt2-large, meta-llama/Llama-2-7b-hf
  generate_samples: True
  cfg: null
  num_masking_viz_batches: 1
  num_sample_batches: 2  # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
  test_eval_speed: False
  standalone_fid: False
  visualize_data_only: false
  val_with_train_data: false
  max_num_fid_batches_per_device: null
  class_conditional_fid: false
  compute_entropy: false
  compute_standalone_mauve: false
  compute_standalone_entropy: false
  compute_img_to_txt_mauve_clip: false
  compute_img_to_txt_mauve_during_unconditional_fid: false
  mauve_num_samples: 5000
  mauve_divergence_curve_discretization_size: 25 # default in mauve repo
  mauve_average_over_seeds: 3
  mauve_scaling_factor: 5 # default in mauve repo
  txt_conditional_fid: false
  unconditional_fid: false
  fid_mode: inline
  calculate_clip_score: false
  clean_fid_use_precomputed_stats: false
  clean_fid_precomputed_name: null
  clean_fid_precomputed_split: null
  clean_fid_precomputed_res: null
  attention_caching: false
  set_random_gen_seed: false
  compute_val_metrics_standalone: false
  num_val_metrics_standalone_batches_per_device: ${eval:'max(${eval.num_val_metrics_standalone_samples} // (${trainer.devices} * ${loader.eval_batch_size}), 1)'}
  num_val_metrics_standalone_samples: -1
  return_unweighed_sim: false
  compute_chameleon_perplexity: false
  global_disable_mauve: false
  bypass_normal_validation: false
  auto_enhance: false
  num_auto_enhance_iter: 2
  ar_inpainting_min_val: 0.5
  ar_inpainting_max_val: 1.0
  ar_inpainting_force_val: null

optim:
  weight_decay: 0
  lr: 1e-6
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8
  fused: true

model:
  use_custom_vae_config: false
  use_custom_vae_ckpt: null
  downscale_ratio: null
  image_vocab_size: null
  vae_type: null
  use_attention_mask: false
  vqgan_config: null
  vqgan_ckpt: null
  llama_ckpt: null
  liquid_ckpt: null

  cond_use_custom_vae_config: false
  cond_use_custom_vae_ckpt: null
  cond_downscale_ratio: null
  cond_image_vocab_size: null
  cond_vae_type: null
  text_model: true

  attn_type: flash
  force_varlen_attn: false
  force_cast_bf16: false
  norm_type: layernorm
  mup: false
  qk_norm: false
  distillation: false
  force_argmax_valid_indices: false
  use_flash_attn_3: false
  use_spda_attn: false # Spelled wrong...
  rope_2d: false
  modality_embed: false
  zero_linear_init: true
  full_attention: true
  use_lora: false
  use_kv_cache: false
  force_optimized_native_attn: false
  use_pretrained_img_emb: true
  use_flex_attention: false
  add_labels: null
  flex_attention_txt_masking_prob: null
  flex_attention_img_masking_prob: null

trainer:
  _target_: lightning.Trainer
  accelerator: cuda
  num_nodes: ${nodes}
  devices: ${devices}

  # Given a desired global batch size (e.g., how many batches we see before a optim.step, summed over all nodes/gpus/accum_steps), we find the number of gradient accumulations that gets us closest given our current configuration. We assume that loader.batch_size is the largest that can fit in a single fwd/bwd.
  accumulate_grad_batches: ${find_grad_accum:${loader.desired_global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size}}}
  gradient_clip_val: 1.0
  precision: 'bf16'
  max_steps: 214_430_000

  num_epochs: 1_000_0
  optimizer_cls: adamw
  set_grads_to_none: true
  eval_on_start: true
  eval_decay_steps: false
  eval_epochs: null
  ckpt_steps: 10000
  fsdp: false
  force_enable_checkpointing: false
  limit_val_batches: null
  ckpt_every_n_minutes: 60
  ckpt_recent_timeout_minutes: 10
  checkpoint_all_ranks: true
  force_null_sigma: false

  log_every_n_steps: 10
  limit_train_batches: 1.0   # train on full dataset, can be used to toggle quick run
  val_check_interval: 3000

  ema: 0.9999
  antithetic_sampling: True
  importance_sampling: False
  sampling_eps: 1e-3
  change_of_variables: False
  benchmark: true
  backward_pass: true
  forward_pass: true
  profile_memory: false
  pytorch_profile: false
  nvtx_profile: false
  custom_ddp_bf16: true
  log_seperate_modal_losses: true
  use_gradient_checkpointing: false
  text_loss_weight: null
  img_loss_weight: null
  disable_strict_load: false
  attach_oom_observer_eval: false
  find_unused_parameters: false
  restart_on_failure: false
  skip_early_checkpointing: true
  log_flops: true
  sync_timing: false
  use_custom_ema: false
  scale_lr_by_batch_size: false
  tpu_eager: false
  allow_dynamic_nodes: false
  force_disable_signal_handler: false
  tpu_profile: false
  tpu_cache: false
  enable_jax_smi: false
  tpu_compile_debug: false
  xla_spmd: false
  log_grad_norm: true
  tpu_profile_markers: true
  compile: false
  disable_all_checkpointing: false
  tpu_force_mark_step: false
  ar_shift: false
  ar_llm_loss: false
  ar_print_loss: false
  chameleon_z_loss: null
  image_mode: discrete # continuous / discrete
  chameleon_use_ce_loss: false
  low_precision_loss: false
  low_precision_params: false
  scratch: false
  use_spmd_distributed_checkpointing: null
  use_simple_spmd_distributed_checkpointing: false
  load_from_state_dict: null
  load_from_optimizer_state_dict: null
  multimodal_batches: false
  sync_dataloader_timing: false
  compile_flag_pos_emb: false
  compile_fullgraph: false
  compile_mode: max-autotune-no-cudagraphs
  joint_ar_nar_prob: null
  joint_ar_nar_prob_warmup_steps: null
  joint_ar_nar_timestep_warmup_steps: null
  spmd_mesh: null
  detect_anomaly: false
  freeze_chameleon_embeddings: false
  ckpt_model_only: false
  use_orig_params: null
  disable_adjust_num_warmup_steps: false
  mask_entire_modality: null
  iterate_dataloader_only: false
  force_bf16_eval: false
  disable_all_eval_generation: false
  debug_xla_sept: false
  ignore_text_in_unified: false
  allow_null_sigma: false
  disable_forward_autocast_during_eval: false
  viz_images_only: false
  add_label: false
  first_token_dropout: null
  disable_ddp_optimizer: false
  rand_flip_ar_prob: null
  rand_ar_modality_dropout: null
  use_linear_warmup_cosine_annealing: false
  no_ce_weighting: false
  interleaved: false
  interleaved_training_flex_attention: false
  awr: false
  ar_inpainting: false

wandb:
  entity: grads
  project: ${eval:'"medim-debug" if ${debug} else "medim"'}
  resume: ${eval:'"allow" if ${slurm} else None'}
  id: null
  group: null
  job_type: null
  name: null
  tags:
    - ${data.train}

checkpointing_root_dir: ${oc.env:UNIDISC_CHECKPOINTING_ROOT_DIR,null}
root_output_dir: ${oc.env:UNIDISC_ROOT_OUTPUT_DIR,outputs}
python_orig: |
              accelerate launch \
              --num_machines $SLURM_NNODES \
              --num_processes $NUM_PROCESSES \
              --rdzv_backend c10d \
              --main_process_ip $MASTER_ADDR \
              --main_process_port $MASTER_PORT \
              --machine_rank $SLURM_PROCID \
              --mixed_precision bf16 \
              --dynamo_backend no \
              --enable_cpu_affinity \
              --max_restarts 0 \

mem_per_gpu: 40
cpus_per_gpu: 8
slurm_name: null
timeout_min: ${partition_limit:${partition}}
hydra:
  run:
    dir: ${oc.env:HYDRA_RUN_DIR,${root_output_dir}/outputs/${get_dir_name:}/${oc.env:HYDRA_RUN_DIR_NAME,${now:%Y_%m_%d}/${now:%H_%M_%S}}}
  sweep:
    dir: ${oc.env:HYDRA_RUN_DIR,${root_output_dir}/outputs/${get_dir_name:}/${oc.env:HYDRA_RUN_DIR_NAME,${now:%Y_%m_%d}/${now:%H_%M_%S}}}
    subdir: ${hydra.job.id}
  job:
    chdir: true

checkpointing:
  # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
  save_dir: ${cwd:}/checkpoints
  # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
  resume_from_ckpt: true
  resume_ckpt_path: ${cwd:}/checkpoints
  initial_resume_ckpt_path: null
  resume_wandb: true
  checkpoints_total_limit: 2
  use_automatic_naming: false


data:
  cache_dir: ${oc.env:HF_DATASETS_CACHE,/grogu/user/mprabhud/aswerdlo/huggingface/datasets}
  num_proc: ${eval:"max(len(__import__('os').sched_getaffinity(0)) // 4, 16)"}
  cond_resolution: null
  iterable: false
  force_disable_shuffle: false
  pin_dataset_to_gpu: false
  webdataset_iterable: false
  webdataset_train_data: null
  webdataset_val_data: null
  webdataset_train_num_samples: null
  webdataset_val_num_samples: null
  webdataset_indexed: false
  dataset_type: null
  keep_tensordict_on_disk: false
  use_token_dataset: false
  use_custom_tensordict_collate: false
  use_weighted_tensordict_sampler: false
  enable_cuda_in_tensordict_collate: true
  data_dir_train: null
  data_dir_val: null
  token_output_dir: null
  wrap_dataloaders: true
  force_shuffle_train: false
  move_tensordict_to_shm: false
  keep_hf_dataset_in_memory: false
  use_chameleon: false
  tokenize_vqvae_in_dataloader: false
  force_mp_spawn: false
  force_raw_images_in_multiple_tensordict: false
  disable_text_modality: false
  txt_only: false
  disable_mask_after_eos: false
  allow_label: false
  split_dataset: false
  img_token_shift: ${model.text_vocab_size}
  zero_shot_eval_dataset: null
  require_sample_ids: false
  use_packing_collate: false
  dynamic_packing_lengths: false
  remove_txt_img_padding: false
  add_image_gen_tokens: false
  use_slow_tokenizer: false
  add_image_token: false
  train: "unset_dataset"
  val: "unset_dataset"
  tokenizer_name_or_path: "NousResearch/Llama-2-7b-hf"

dummyarg: null
