# Base
wt103: &wt103
   dataset: wt103
   data: ../data/wikitext-103/

train: &train
   <<: *wt103
   cuda: true
   n_layer: 18
   d_model: 1024
   n_head: 16
   d_head: 64
   d_inner: 4096
   dropout: 0.2
   dropatt: 0.2
   optim: jitlamb
   lr: 0.01
   eta_min: 0.0001
   roll: true
   warmup_step: 16000
   max_step: 100000
   tgt_len: 384
   mem_len: 384
   init_std: 0.005
   eval_tgt_len: 128
   batch_size: 128
   multi_gpu: ddp
   log_interval: 100
   eval_interval: 5000
   vocab: word
   adaptive: true
   div_val: 4

train_multinode: &train_multinode
   <<: *wt103
   <<: *train
   lr: 0.02
   max_step: 25000
   batch_size: 512
   eval_batch_size: 128
   eval_interval: 1000

eval: &eval
   <<: *wt103
   cuda: true
   tgt_len: 128
   mem_len: 1600
   clamp_len: 1000
   same_length: true
   split: test

default:
   train:
      <<: *train
   eval:
      <<: *eval

manual_eval:
   train:
      <<: *train
   eval:
      <<: *eval
      manual_config: '{"n_token": 267735, "n_layer": 18, "n_head": 16, "d_model": 1024, "d_head": 64, "d_inner": 4096, "dropout": 0.2, "dropatt": 0.2, "dtype": null, "tie_weight": true, "d_embed": 1024, "div_val": 4, "tie_projs": [false, true, true, true], "pre_lnorm": false, "tgt_len": 384, "ext_len": 0, "mem_len": 384, "cutoffs": [19997, 39997, 199997], "same_length": false, "attn_type": 0, "clamp_len": -1, "sample_softmax": -1}'

# Full training configs for NVIDIA DGX-1 (8x NVIDIA V100 16GB GPU)
dgx1_8gpu_fp16: &dgx1_8gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 4
   eval:
      <<: *eval
      fp16: true

dgx1_8gpu_fp32: &dgx1_8gpu_fp32
   train:
      <<: *train
      batch_chunk: 8
   eval:
      <<: *eval

dgx1_4gpu_fp16: &dgx1_4gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 8
   eval:
      <<: *eval
      fp16: true

dgx1_4gpu_fp32: &dgx1_4gpu_fp32
   train:
      <<: *train
      batch_chunk: 16
   eval:
      <<: *eval

dgx1_2gpu_fp16: &dgx1_2gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 16
   eval:
      <<: *eval
      fp16: true

dgx1_2gpu_fp32: &dgx1_2gpu_fp32
   train:
      <<: *train
      batch_chunk: 32
   eval:
      <<: *eval

dgx1_1gpu_fp16: &dgx1_1gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 32
   eval:
      <<: *eval
      fp16: true

dgx1_1gpu_fp32: &dgx1_1gpu_fp32
   train:
      <<: *train
      batch_chunk: 64
      swap_mem: true
   eval:
      <<: *eval


# Full training configs for NVIDIA DGX-2H (16x NVIDIA V100 32GB GPU)
dgx2_16gpu_fp16: &dgx2_16gpu_fp16
   train:
      <<: *train
      fp16: true
   eval:
      <<: *eval
      fp16: true

dgx2_16gpu_fp32: &dgx2_16gpu_fp32
   train:
      <<: *train
   eval:
      <<: *eval

dgx2_8gpu_fp16: &dgx2_8gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 2
   eval:
      <<: *eval
      fp16: true

dgx2_8gpu_fp32: &dgx2_8gpu_fp32
   train:
      <<: *train
      batch_chunk: 2
   eval:
      <<: *eval

dgx2_4gpu_fp16: &dgx2_4gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 4
   eval:
      <<: *eval
      fp16: true

dgx2_4gpu_fp32: &dgx2_4gpu_fp32
   train:
      <<: *train
      batch_chunk: 4
   eval:
      <<: *eval

dgx2_2gpu_fp16: &dgx2_2gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 8
   eval:
      <<: *eval
      fp16: true

dgx2_2gpu_fp32: &dgx2_2gpu_fp32
   train:
      <<: *train
      batch_chunk: 8
   eval:
      <<: *eval

dgx2_1gpu_fp16: &dgx2_1gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 16
   eval:
      <<: *eval
      fp16: true

dgx2_1gpu_fp32: &dgx2_1gpu_fp32
   train:
      <<: *train
      batch_chunk: 16
   eval:
      <<: *eval

# Full training configs for NVIDIA DGX A100 (8x NVIDIA A100 40GB GPU)
dgxa100_8gpu_fp16: &dgxa100_8gpu_fp16
   train:
      <<: *train
      fp16: true
   eval:
      <<: *eval
      fp16: true

dgxa100_8gpu_tf32: &dgxa100_8gpu_tf32
   train:
      <<: *train
      batch_chunk: 2
   eval:
      <<: *eval

dgxa100_4gpu_fp16: &dgxa100_4gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 2
   eval:
      <<: *eval
      fp16: true

dgxa100_4gpu_tf32: &dgxa100_4gpu_tf32
   train:
      <<: *train
      batch_chunk: 4
   eval:
      <<: *eval

dgxa100_2gpu_fp16: &dgxa100_2gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 4
   eval:
      <<: *eval
      fp16: true

dgxa100_2gpu_tf32: &dgxa100_2gpu_tf32
   train:
      <<: *train
      batch_chunk: 8
   eval:
      <<: *eval

dgxa100_1gpu_fp16: &dgxa100_1gpu_fp16
   train:
      <<: *train
      fp16: true
      batch_chunk: 8
   eval:
      <<: *eval
      fp16: true

dgxa100_1gpu_tf32: &dgxa100_1gpu_tf32
   train:
      <<: *train
      batch_chunk: 16
   eval:
      <<: *eval

# Full training configs for multi-node NVIDIA DGX-2 (16x NVIDIA V100 32GB GPU)
8dgx2_16gpu_fp16: &8dgx2_16gpu_fp16
   train:
      <<: *train_multinode
      fp16: true
   eval:
      <<: *eval
      fp16: true
      batch_size: 128

8dgx2_16gpu_fp32: &8dgx2_16gpu_fp32
   train:
      <<: *train_multinode
   eval:
      <<: *eval
      batch_size: 128

4dgx2_16gpu_fp16: &4dgx2_16gpu_fp16
   train:
      <<: *train_multinode
      fp16: true
   eval:
      <<: *eval
      fp16: true
      batch_size: 64

4dgx2_16gpu_fp32: &4dgx2_16gpu_fp32
   train:
      <<: *train_multinode
   eval:
      <<: *eval
      batch_size: 64

2dgx2_16gpu_fp16: &2dgx2_16gpu_fp16
   train:
      <<: *train_multinode
      fp16: true
   eval:
      <<: *eval
      fp16: true
      batch_size: 32

2dgx2_16gpu_fp32: &2dgx2_16gpu_fp32
   train:
      <<: *train_multinode
      batch_chunk: 2
   eval:
      <<: *eval
      batch_size: 32

1dgx2_16gpu_fp16: &1dgx2_16gpu_fp16
   train:
      <<: *train_multinode
      fp16: true
      batch_chunk: 2
   eval:
      <<: *eval
      fp16: true
      batch_size: 16

1dgx2_16gpu_fp32: &1dgx2_16gpu_fp32
   train:
      <<: *train_multinode
      batch_chunk: 4
   eval:
      <<: *eval
      batch_size: 16

# Training benchmarks
trainbench: &trainbench
   train:
      <<: *train
      log_interval: 1
      max_step: 500
      max_step_scheduler: 100000

trainbench_multinode: &trainbench_multinode
   train:
      <<: *train_multinode
      log_interval: 1
      max_step: 500
      max_step_scheduler: 25000
