# Base
wt103: &wt103
   dataset: wt103
   data: ../data/wikitext-103/

train: &train
   <<: *wt103
   cuda: true
   n_layer: 16
   d_model: 512
   n_head: 8
   d_head: 64
   d_inner: 2048
   dropout: 0.1
   dropatt: 0.0
   optim: ProxAdamW
   lr: 0.00025
   lambda_: 1.75e-7
   eta_min: 0.0
   roll: true
   warmup_step: 0
   max_step: 400000
   tgt_len: 192
   mem_len: 192
   eval_tgt_len: 192
   batch_size: 64
   multi_gpu: ddp
   log_interval: 10
   eval_interval: 5000
   vocab: word
   adaptive: true

eval: &eval
   <<: *wt103
   cuda: true
   tgt_len: 64
   mem_len: 640
   clamp_len: 400
   same_length: true
   split: test

default:
   train:
      <<: *train
   eval:
      <<: *eval

manual_eval:
   train:
      <<: *train
   eval:
      <<: *eval
      manual_config: '{"n_token": 267735, "n_layer": 16, "n_head": 8, "d_model": 512, "d_head": 64, "d_inner": 2048, "dropout": 0.1, "dropatt": 0.0, "dtype": null, "tie_weight": true, "d_embed": 512, "div_val": 1, "tie_projs": [false, true, true, true], "pre_lnorm": false, "tgt_len": 192, "ext_len": 0, "mem_len": 192, "cutoffs": [19997, 39997, 199997], "same_length": false, "attn_type": 0, "clamp_len": -1, "sample_softmax": -1}'

# Full training configs for NVIDIA DGX-1 (8x NVIDIA V100 16GB GPU)
dgx1_8gpu_fp16: &dgx1_8gpu_fp16
   train:
      <<: *train
      fp16: true
   eval:
      <<: *eval
      fp16: true

dgx1_8gpu_fp32: &dgx1_8gpu_fp32
   train:
      <<: *train
      batch_chunk: 2
   eval:
      <<: *eval

dgx1_4gpu_fp16: &dgx1_4gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 2
   eval:
      <<: *eval
      fp16: true

dgx1_4gpu_fp32: &dgx1_4gpu_fp32
   train:
      <<: *train
      batch_chunk: 4
   eval:
      <<: *eval

dgx1_2gpu_fp16: &dgx1_2gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 4
   eval:
      <<: *eval
      fp16: true

dgx1_2gpu_fp32: &dgx1_2gpu_fp32
   train:
      <<: *train
      batch_chunk: 8
   eval:
      <<: *eval

dgx1_1gpu_fp16: &dgx1_1gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 8
   eval:
      <<: *eval
      fp16: true

dgx1_1gpu_fp32: &dgx1_1gpu_fp32
   train:
      <<: *train
      batch_chunk: 16
   eval:
      <<: *eval

# Full training configs for NVIDIA DGX-2H (16x NVIDIA V100 32GB GPU)
dgx2_16gpu_fp16: &dgx2_16gpu_fp16
   train:
      <<: *train
      fp16: true
   eval:
      <<: *eval
      fp16: true

dgx2_16gpu_fp32: &dgx2_16gpu_fp32
   train:
      <<: *train
   eval:
      <<: *eval

dgx2_8gpu_fp16: &dgx2_8gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 2
   eval:
      <<: *eval
      fp16: true

dgx2_8gpu_fp32: &dgx2_8gpu_fp32
   train:
      <<: *train
      batch_chunk: 2
   eval:
      <<: *eval

dgx2_4gpu_fp16: &dgx2_4gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 4
   eval:
      <<: *eval
      fp16: true

dgx2_4gpu_fp32: &dgx2_4gpu_fp32
   train:
      <<: *train
      batch_chunk: 4
   eval:
      <<: *eval

dgx2_2gpu_fp16: &dgx2_2gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 8
   eval:
      <<: *eval
      fp16: true

dgx2_2gpu_fp32: &dgx2_2gpu_fp32
   train:
      <<: *train
      batch_chunk: 8
   eval:
      <<: *eval

dgx2_1gpu_fp16: &dgx2_1gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 16
   eval:
      <<: *eval
      fp16: true

dgx2_1gpu_fp32: &dgx2_1gpu_fp32
   train:
      <<: *train
      batch_chunk: 16
   eval:
      <<: *eval

# Full training configs for NVIDIA DGX A100 (8x NVIDIA A100 40GB GPU)
dgxa100_8gpu_fp16: &dgxa100_8gpu_fp16
   train:
      <<: *train
      fp16: true
   eval:
      <<: *eval
      fp16: true

dgxa100_8gpu_tf32: &dgxa100_8gpu_tf32
   train:
      <<: *train
   eval:
      <<: *eval

dgxa100_4gpu_fp16: &dgxa100_4gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 2
   eval:
      <<: *eval
      fp16: true

dgxa100_4gpu_tf32: &dgxa100_4gpu_tf32
   train:
      <<: *train
      batch_chunk: 2
   eval:
      <<: *eval

dgxa100_2gpu_fp16: &dgxa100_2gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 4
   eval:
      <<: *eval
      fp16: true

dgxa100_2gpu_tf32: &dgxa100_2gpu_tf32
   train:
      <<: *train
      batch_chunk: 4
   eval:
      <<: *eval

dgxa100_1gpu_fp16: &dgxa100_1gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 8
   eval:
      <<: *eval
      fp16: true

dgxa100_1gpu_tf32: &dgxa100_1gpu_tf32
   train:
      <<: *train
      batch_chunk: 8
   eval:
      <<: *eval

# Training benchmarks
trainbench: &trainbench
   train:
      <<: *train
      log_interval: 1
      max_step: 500
      max_step_scheduler: 40000
