resume_training: False

experiment:
  experiments_base_dir: experiments
  project_name: cpr
  session_name: default_gpt2m
  experiment_name: test_cpr_1

trainer:
  num_nodes: 1
  check_val_every_n_epoch: null
  enable_checkpointing: true
  enable_model_summary: true
  enable_progress_bar: true
  accelerator: 'gpu'
  devices: 8
  gradient_clip_val: 1
  limit_train_batches: null
  limit_val_batches: null
  log_every_n_steps: 100
  max_epochs: null
  max_steps: 200000
  num_sanity_val_steps: 2
  accumulate_grad_batches: 1
  precision: 'bf16-mixed'
  reload_dataloaders_every_n_epochs: 1
  val_check_interval: 5000
  deterministic: True

train:
  seed: 1234
  optimizer_name: "adamcpr" # adamcpr or adamw
  optimizer:
    lr: 0.002
    weight_decay: 0.1
    betas:
      - 0.9
      - 0.98
    eps: 1.0e-09
    adam_w_mode: true
    seed: 1234
    scheduler_mult_factor: null


  adamcpr:
     mode: "l2_constrain" #l2_constrain or std_constrain
     kappa: 10
     kappa_factor: False
     lagmul_rate: 1.0
     bias_reg: False
     normalization_reg: False
     kappa_adapt: False
     kappa_init_dependent: False
     kappa_init_warm_start: 1000


  optimizer_param_grouping:
    bias_weight_decay: False
    normalization_weight_decay: False
  scheduler:
    num_warmup_steps: 5000
    num_training_steps: ${trainer.max_steps}
    decay_factor: 0.1
    schedule: "cosine"
  loss_fn:
    inplace_backward: True


  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    dirpath: ${experiment.experiments_base_dir}
    auto_insert_metric_name: False
    every_n_train_steps: 100
    every_n_epochs: null
    save_top_k: 1
    monitor: "step"
    mode: "max"
    filename: "checkpoint-{epoch:02d}-{global_step}"
    save_last: True


transformer:
  precision: ${trainer.precision}
  seq_vocab_size: 0
  trg_vocab_size: 0
  max_len: 0

  model_dim: 1024
  n_layers: 24
  num_head: 16

  ff_factor: 4

  attn_dropout: 0.1
  resi_dropout: 0.1
  embed_dropout: 0.1

  scale_attn_weights: True
  scale_attn_by_inverse_layer_idx: True

  pos_embedding: False
  rel_pos_enc: False
  rotary_emb_fraction: 0.5
  ln_eps: 1e-5
  checkpoint_lvl: 0
  unpadded: true
  use_bias: true
  flash_attn: true
  initializer_range: False



lm_data:
  dataset_name: "openwebtext"
  num_cpu_worker: 1
  num_gpu_worker: ${trainer.devices}
  max_sample_len: 1024
  seed: ${train.seed}
  batch_size: 32
  val_ratio: 0.0005
  val_split_seed: 2357
  data_dir: "data"
  cache_dir: "cache"




