name: megatron_gpt
restore_from_path: null  # used when starting from a .nemo file

trainer:
  devices: 8
  num_nodes: 4
  accelerator: gpu
  precision: 16
  logger: False # logger provided by exp_manager
  enable_checkpointing: False
  replace_sampler_ddp: False
  max_epochs: -1 # PTL default. In practice, max_steps will be reached first.
  max_steps: 200 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
  log_every_n_steps: 1
  val_check_interval: 20
  # check_val_every_n_epoch: null
  limit_val_batches: 2
  limit_test_batches: 0
  accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models
  gradient_clip_val: 1.0
  benchmark: False

exp_manager:
  # set this to save checkpoints
  explicit_log_dir: ilql_sentiments_logs
  exp_dir: null
  name: megatron_gpt_20b_ilql_sentiments
  create_tensorboard_logger: False
  create_wandb_logger: True
  wandb_logger_kwargs:
    project: trlxnemo
    name: megatron_gpt_20b_ilql_sentiments
  resume_if_exists: False
  resume_ignore_no_checkpoint: True
  # set this to save checkpoints
  create_checkpoint_callback: True
  checkpoint_callback_params:
    monitor: reduced_train_loss
    save_top_k: 1
    mode: min
    always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
    save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits
    filename: 'megatron_gpt-{reduced_train_loss:.2f}-{step}-{consumed_samples}'
    model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}
  log_step_timing: True
  step_timing_kwargs:
    sync_cuda: True
    buffer_size: 5

model:
  micro_batch_size: 4
  global_batch_size: 512
  tensor_model_parallel_size: 4
  pipeline_model_parallel_size: 1
  resume_from_checkpoint: null # manually set the checkpoint file to load from
  # model architecture
  encoder_seq_length: 1024
  max_position_embeddings: 2048
  num_layers: 44
  hidden_size: 6144
  ffn_hidden_size: ${multiply:4, ${.hidden_size}}  # Transformer FFN hidden size. 4 * hidden_size.
  num_attention_heads: 48
  init_method_std: 0.007  # Standard deviation of the zero mean normal distribution used for weight initialization.')
  hidden_dropout: 0.1  # Dropout probability for hidden state transformer.
  kv_channels: null  # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
  apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
  layernorm_epsilon: 1e-5
  make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency.
  pre_process: True # add embedding
  post_process: True # add pooler
  persist_layer_norm: True # Use of persistent fused layer norm kernel.
  grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce
  gradient_accumulation_fusion: True # Fuse weight gradient accumulation to GEMMs


  ## Activation Checkpointing
  activations_checkpoint_granularity:  'selective' #'selective'  # 'selective' or 'full'
  activations_checkpoint_method: 'uniform' # 'uniform', 'block', not used with 'selective'
  activations_checkpoint_num_layers: null # not used with 'selective'

  ## Sequence Parallelism
  sequence_parallel: True

  tokenizer:
    library: 'megatron'
    type: 'GPT2BPETokenizer'
    model: null
    vocab_file: null
    merge_file: null
    delimiter: null # only used for tabular tokenizer
    sentencepiece_legacy: false # Legacy=True allows you to add special tokens to sentencepiece tokenizers.

  # precision
  native_amp_init_scale: 4294967296 # 2 ** 32
  native_amp_growth_interval: 1000
  hysteresis: 2 # Gradient scale hysteresis
  fp32_residual_connection: False # Move residual connections to fp32
  fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16

  # Megatron O2-style half-precision
  # TODO: this causes hangs for some reason
  megatron_amp_O2: True # Enable O2-level automatic mixed precision using main parameters
  grad_allreduce_chunk_size_mb: 125
  sync_batch_comm: False
  # miscellaneous
  seed: 1234
  use_cpu_initialization: False # Init weights on the CPU (slow for large models)
  onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
  apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this
  gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)

  data:
    data_prefix:
        - dataset: hh
    index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix
    data_impl: mmap
    splits_string: 900,50,50
    seq_length: ${model.encoder_seq_length}
    skip_warmup: True
    num_workers: 2
    dataloader_type: cyclic
    reset_position_ids: False # Reset position ids after end-of-document token
    reset_attention_mask: False # Reset attention mask after end-of-document token
    eod_mask_loss: False # Mask loss for the end of document tokens

  # Nsys profiling options
  nsys_profile:
    enabled: False
    start_step: 10  # Global batch to start profiling
    end_step: 10 # Global batch to end profiling
    ranks: [0, 4, 8, 12] # Global rank IDs to profile
    gen_shape: False # Generate model and kernel details including input shapes

  optim:
    name: distributed_fused_adam
    lr: 5.0e-5
    weight_decay: 1.0e-6
    betas:
    - 0.9
    - 0.95
    sched:
      name: CosineAnnealing
      max_steps: 200
      min_lr: 5.0e-5
