_target_: agents.bet_mlp_agent.BetMLP_Agent
_recursive_: false

model:
  _target_: agents.bet_mlp_agent.BeTPolicy
  _recursive_: false

  obs_dim: ${obs_dim}
  act_dim: ${action_dim}
  vocab_size: 64
  predict_offsets: True
  device: ${device}

  offset_loss_scale: 1.0
  focal_loss_gamma: 2.0

  model:
    _target_: agents.models.common.mlp.ResidualMLPNetwork
    input_dim: ${obs_dim}
    hidden_dim: ${hidden_dim}
    num_hidden_layers: ${num_hidden_layers}
#    output_dim: ${action_dim}
    dropout: 0
    activation: 'Mish'
    use_spectral_norm: false
    use_norm: False
    norm_style: 'BatchNorm'
    device: ${device}

obs_encoding_net:
  _target_: torch.nn.Identity
  output_dim: ${obs_dim}

action_ae:
  _target_: agents.models.bet.action_ae.discretizers.k_means.KMeansDiscretizer
  #num_bins: 14
  action_dim: ${action_dim}
  device: ${device}
  predict_offsets: True

optimization:
  _target_: torch.optim.AdamW
  lr: 1.0e-3
  betas: [0.9, 0.995]
  eps: 1.0e-8
  weight_decay: 0.1

grad_norm_clip: 1.0

trainset: ${trainset}
valset: ${valset}

train_batch_size: ${train_batch_size}
val_batch_size: ${val_batch_size}
num_workers: ${num_workers}
epoch: ${epoch}
device: ${device}
scale_data: ${scale_data}
window_size: ${window_size}
eval_every_n_epochs: ${eval_every_n_epochs}

