# defaults:
#   - _self_
  # - configs_gosai/callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
  # - configs_gosai/model: dnaconv
  # - configs_gosai/strategy: ddp
  # - configs_gosai/noise: loglinear
  # - configs_gosai/lr_scheduler: constant_warmup

callbacks:
  checkpoint_every_n_steps:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    save_top_k: -1 # Do not save any "best" models; this callback is being used to save every n train steps
    save_last: True # save model as ${save_dir}/checkpoints/last.ckpt
    dirpath: ${checkpointing.save_dir}/checkpoints
    verbose: True
    auto_insert_metric_name: False
    every_n_train_steps: 500

  checkpoint_monitor:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    monitor: val/nll # name of the logged metric which determines when model is improving
    mode: min # can be "max" or "min"
    save_top_k: 1 # save k best models (determined by above metric)
    save_last: False # True = additionally always save model from last epoch
    dirpath: ${checkpointing.save_dir}/checkpoints
    filename: best
    auto_insert_metric_name: False
    verbose: True

  learning_rate_monitor:
    _target_: lightning.pytorch.callbacks.LearningRateMonitor
    logging_interval: step


model:
  name: dnaconv
  type: cnn 
  length: 200 # for gosai
  hidden_dim: 128
  num_cnn_stacks: 4
  dropout: 0.0
  clean_data: False

  cls_free_guidance: False
  cls_free_threshold: 2.52
  cls_free_prob: 0.3
  cls_free_weight: 0.3 # weight in sampling

strategy:
  _target_: lightning.pytorch.strategies.DDPStrategy
  find_unused_parameters: false  # TODO(yair): this seems hacky, I think if things are correct we shouldn't need this


noise:
  type: loglinear
  sigma_min: 1e-4
  sigma_max: 20

lr_scheduler:
  _target_: transformers.get_constant_schedule_with_warmup
  num_warmup_steps: 2500
  


mode: train  
diffusion: absorbing_state
backbone: cnn  
parameterization: subs 
time_conditioning: False
T: 0  # 0 (continuous time) / 1000 
subs_masking: False
debug_mode: False

seed: 1

data:
  streaming: False

loader:
  global_batch_size: 512
  eval_global_batch_size: ${.global_batch_size}
  # Note: batch_size and eval_batch_size are **per machine**
  batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
  pin_memory: True

sampling:
  predictor: ddpm 
  steps: 128
  noise_removal: True
  num_sample_batches: 2  # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
  num_sample_log: 2
  semi_ar: False
  stride_length: 1
  num_strides: 1

training:
  ema: 0.9999
  antithetic_sampling: True
  importance_sampling: False
  sampling_eps: 1e-3
  change_of_variables: False

eval:
  checkpoint_path: ''  # Used to evaluate a checkpoint after training.
  disable_ema: False
  compute_generative_perplexity: True # False
  perplexity_batch_size: 8
  compute_perplexity_on_sanity: False
  gen_ppl_eval_model_name_or_path: gpt2-large  # gpt2-large, meta-llama/Llama-2-7b-hf
  generate_samples: True
  subset_size: 5000

optim:
  weight_decay: 0
  lr: 3e-4
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8

trainer:
  _target_: lightning.Trainer
  accelerator: cuda
  num_nodes: 1
  devices: ${device_count:}
  accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
  gradient_clip_val: 1.0
  precision: 'bf16'
  num_sanity_val_steps: 2
  max_steps: 131500 # 100 epochs 
  log_every_n_steps: 10
  limit_train_batches: 1.0   # train on full dataset, can be used to toggle quick run
  limit_val_batches: 1.0     # validate on full dataset, can be used to toggle quick run
  val_check_interval: 1000

wandb:
  project: gosai-dna
  notes: null
  group: null
  job_type: null
  name: null
  id: ${uuid:}
  tags:
    - ${noise.type}

hydra:
  run:
    dir: data_and_model/mdlm/outputs_gosai/${now:%Y.%m.%d}/${now:%H%M%S}
  job:
    chdir: true

checkpointing:
  # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
  save_dir: ${cwd:}
  # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
  resume_from_ckpt: true
  resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt

finetuning:
  gumbel_softmax_temp: 1.0
  truncate_steps: 3
