CONFIG
├── mode
│   └── sample_eval                                                                                                                                                                
├── seed
│   └── 1                                                                                                                                                                          
├── loader
│   └── global_batch_size: 512                                                                                                                                                     
│       eval_global_batch_size: 512                                                                                                                                                
│       batch_size: 1                                                                                                                                                              
│       eval_batch_size: 1                                                                                                                                                         
│       num_workers: 255                                                                                                                                                           
│       pin_memory: true                                                                                                                                                           
│                                                                                                                                                                                  
├── sampling
│   └── predictor: ddpm_cache                                                                                                                                                      
│       steps: 128                                                                                                                                                                 
│       use_float64: true                                                                                                                                                          
│       noise_removal: true                                                                                                                                                        
│       num_sample_batches: 5000                                                                                                                                                   
│       num_sample_log: 2                                                                                                                                                          
│       semi_ar: false                                                                                                                                                             
│       stride_length: 1                                                                                                                                                           
│       num_strides: 1                                                                                                                                                             
│       generated_seqs_path: ./outputs/sm_mdlm_pretraining_cont_1/generations/remdm-loop_5-100000_T-128_eta-0.02_ton-0.55_toff-0.05_alphaon-0.9_topp-0.9.json                      
│       p_nucleus: 0.9                                                                                                                                                             
│       eta: 0.02                                                                                                                                                                  
│       sampler: remdm-loop                                                                                                                                                        
│       t_on: 0.55                                                                                                                                                                 
│       t_off: 0.05                                                                                                                                                                
│       alpha_on: 0.9                                                                                                                                                              
│       dfm: false                                                                                                                                                                 
│                                                                                                                                                                                  
├── training
│   └── ema: 0.9999                                                                                                                                                                
│       antithetic_sampling: true                                                                                                                                                  
│       importance_sampling: false                                                                                                                                                 
│       sampling_eps: 0.001                                                                                                                                                        
│       change_of_variables: false                                                                                                                                                 
│       loss_precision: bf16                                                                                                                                                       
│       finetune_path: ''                                                                                                                                                          
│                                                                                                                                                                                  
├── eval
│   └── checkpoint_path: ./outputs/sm_mdlm_pretraining_cont_1/checkpoints/5-100000.ckpt                                                                                            
│       disable_ema: false                                                                                                                                                         
│       compute_generative_perplexity: false                                                                                                                                       
│       perplexity_batch_size: 1                                                                                                                                                   
│       compute_perplexity_on_sanity: false                                                                                                                                        
│       gen_ppl_eval_model_name_or_path: gpt2-large                                                                                                                                
│       generate_samples: true                                                                                                                                                     
│       generated_samples_path: /u/her/001_code/duo/samples.json                                                                                                                   
│                                                                                                                                                                                  
├── optim
│   └── weight_decay: 0                                                                                                                                                            
│       lr: 0.0003                                                                                                                                                                 
│       tran_head_lr: 0.001                                                                                                                                                        
│       sm_prob: 0.5                                                                                                                                                               
│       beta1: 0.9                                                                                                                                                                 
│       beta2: 0.999                                                                                                                                                               
│       eps: 1.0e-08                                                                                                                                                               
│                                                                                                                                                                                  
├── trainer
│   └── _target_: lightning.Trainer                                                                                                                                                
│       accelerator: cuda                                                                                                                                                          
│       num_nodes: 1                                                                                                                                                               
│       devices: 1                                                                                                                                                                 
│       accumulate_grad_batches: 512                                                                                                                                               
│       gradient_clip_val: 1.0                                                                                                                                                     
│       precision: bf16                                                                                                                                                            
│       num_sanity_val_steps: 2                                                                                                                                                    
│       max_steps: 1000000                                                                                                                                                         
│       log_every_n_steps: 100                                                                                                                                                     
│       limit_train_batches: 1.0                                                                                                                                                   
│       limit_val_batches: 1.0                                                                                                                                                     
│       val_check_interval: 5000                                                                                                                                                   
│                                                                                                                                                                                  
├── wandb
│   └── project: sm-mdlm                                                                                                                                                           
│       notes: Soft Masking MDLM                                                                                                                                                   
│       group: null                                                                                                                                                                
│       job_type: null                                                                                                                                                             
│       name: null                                                                                                                                                                 
│       id: None_1                                                                                                                                                                 
│       tags:                                                                                                                                                                      
│       - log-linear                                                                                                                                                               
│       - openwebtext-train                                                                                                                                                        
│       - openwebtext-valid                                                                                                                                                        
│       - mdlm_sm                                                                                                                                                                  
│       offline: true                                                                                                                                                              
│                                                                                                                                                                                  
├── checkpointing
│   └── save_dir: /u/her/001_code/duo                                                                                                                                              
│       resume_from_ckpt: true                                                                                                                                                     
│       resume_ckpt_path: /u/her/001_code/duo/checkpoints/best.ckpt                                                                                                                
│                                                                                                                                                                                  
├── callbacks
│   └── checkpoint_every_n_steps:                                                                                                                                                  
│         _target_: lightning.pytorch.callbacks.ModelCheckpoint                                                                                                                    
│         save_top_k: -1                                                                                                                                                           
│         save_last: true                                                                                                                                                          
│         dirpath: /u/her/001_code/duo/checkpoints                                                                                                                                 
│         verbose: true                                                                                                                                                            
│         auto_insert_metric_name: false                                                                                                                                           
│         every_n_train_steps: 5000                                                                                                                                                
│       checkpoint_monitor:                                                                                                                                                        
│         _target_: lightning.pytorch.callbacks.ModelCheckpoint                                                                                                                    
│         monitor: val/nll                                                                                                                                                         
│         mode: min                                                                                                                                                                
│         save_top_k: 1                                                                                                                                                            
│         save_last: false                                                                                                                                                         
│         dirpath: /u/her/001_code/duo/checkpoints                                                                                                                                 
│         filename: best                                                                                                                                                           
│         auto_insert_metric_name: false                                                                                                                                           
│         verbose: true                                                                                                                                                            
│       learning_rate_monitor:                                                                                                                                                     
│         _target_: lightning.pytorch.callbacks.LearningRateMonitor                                                                                                                
│         logging_interval: step                                                                                                                                                   
│                                                                                                                                                                                  
├── data
│   └── train: openwebtext-train                                                                                                                                                   
│       valid: openwebtext-valid                                                                                                                                                   
│       tokenizer_name_or_path: gpt2                                                                                                                                               
│       cache_dir: /dccstor/saentis/data/openwebtext                                                                                                                               
│       wrap: true                                                                                                                                                                 
│       streaming: false                                                                                                                                                           
│       insert_train_eos: true                                                                                                                                                     
│       insert_valid_eos: true                                                                                                                                                     
│                                                                                                                                                                                  
├── model
│   └── name: small                                                                                                                                                                
│       type: ddit                                                                                                                                                                 
│       hidden_size: 768                                                                                                                                                           
│       cond_dim: 128                                                                                                                                                              
│       length: 1024                                                                                                                                                               
│       n_blocks: 12                                                                                                                                                               
│       n_heads: 12                                                                                                                                                                
│       scale_by_sigma: true                                                                                                                                                       
│       dropout: 0.1                                                                                                                                                               
│       tie_word_embeddings: false                                                                                                                                                 
│       vocab_lookup: true                                                                                                                                                         
│                                                                                                                                                                                  
├── strategy
│   └── _target_: lightning.pytorch.strategies.DDPStrategy                                                                                                                         
│       find_unused_parameters: false                                                                                                                                              
│                                                                                                                                                                                  
├── noise
│   └── type: log-linear                                                                                                                                                           
│       parameterization: log-linear                                                                                                                                               
│       eps: 0                                                                                                                                                                     
│       denoiser_latent_conditioning: -1                                                                                                                                           
│       freeze_encoder: false                                                                                                                                                      
│       freeze_decoder: false                                                                                                                                                      
│                                                                                                                                                                                  
├── lr_scheduler
│   └── _target_: transformers.get_constant_schedule_with_warmup                                                                                                                   
│       num_warmup_steps: 2500                                                                                                                                                     
│                                                                                                                                                                                  
├── prior
│   └── type: none                                                                                                                                                                 
│       latent_width: 0                                                                                                                                                            
│       latent_height: 0                                                                                                                                                           
│                                                                                                                                                                                  
└── algo
    └── name: mdlm_sm                                                                                                                                                              
        backbone: dit                                                                                                                                                              
        parameterization: subs                                                                                                                                                     
        time_conditioning: false                                                                                                                                                   
        T: 0                                                                                                                                                                       
        subs_masking: false                                                                                                                                                        
        causal_attention: false                                                                                                                                                    
        ignore_bos: false                                                                                                                                                          
        loss_type: elbo                                                                                                                                                            
        tran_head:                                                                                                                                                                 
          init_scale: 0.0                                                                                                                                                          
          init_centre: -0.75                                                                                                                                                       
          init_steep: 6.66                                                                                                                                                         
          init_temperature: 1.0                                                                                                                                                    
          mixinputs_k: 3                                                                                                                                                           
          transparency_alg: mixinputs_with_topk                                                                                                                                    
                                                                                                                                                                                   
