n_layers: 5
length: 5
hidden_mlp_dims:
  X: 256
  E: 64
hidden_dims:
  dx: 256
  de: 64
  n_head: 8
  dim_ffX: 256
  dim_ffE: 128
cond_dim: 64  # must be same as input_dims['y']
act_fn_in: "ReLU"
act_fn_out: "ReLU"
