_base_ = ["../schedules/mxlr=1e-3_mnlr=1e-4_it=20000_clip.py"]



hyper_v2_args =  dict(
      lstm_hidden_size=64,
      ff_hidden_size=4,
      ff_hidden_layers=2,
      initial_momentum_decays=(0.9, 0.99, 0.999),
      initial_rms_decays=(0.999,),
      initial_adafactor_decays=(0.9, 0.99, 0.999),
      param_inits=32,
      mix_layers=True,
      exp_mult=0.001,
      step_mult=0.01,
      validation_mode=False,
      with_validation_feature_dim=False,

      # ablation flags.
      with_g=True,
      with_m=True,
      with_m_feat=True,
      with_rms=True,
      with_rms_feat=True,
      with_rms_norm_g=True,
      with_rsqrt_rms=True,
      with_p=True,
      with_fac_norm_g=True,
      with_fac_rms=True,
      with_fac_rsqrt=True,
      with_grad_clip_feat=True,
      with_fac_mom_mult=True,
      with_rms_only_norm_g=True,
      adafactor_accumulator=True,
      param_scale_mult=True,
      use_bugged_next_lstm_state=False,
      use_bugged_loss_features=False,
      precondition_output=False,
      reparam_decay=10.,
      rnn_state_decay=0.0,

      # more summaries
      summarize_each_layer=False,
      summarize_all_control=False,

      # Modify the lopt to probe behavior
      constant_loss=False,
      clip_param_scale_amount=None,)