n_layers: 12
length: 5
hidden_mlp_dims:
  X: 256
  E: 128
  C: 256
hidden_dims:
  dx: 256
  de: 64
  dc: 256
  n_head: 8
  dim_ffX: 256
  dim_ffE: 64
cond_dim: 64  # must be same as input_dims['y']
act_fn_in: "ReLU"
act_fn_out: "ReLU"
