defaults:
  - _self_
  - task: halfcheetah_expert

name: train_bet_lowdim
_target_: diffusion_policy.workspace.pbrl_reward_bet_lowdim_online_workspace.PbrlBETLowdimWorkspace
#checkpoint_dir: 'data/experiments/low_dim/d4rl/halfcheetah/bet/base/epoch=0200-test_cumulative_rewards=3560.400.ckpt'

obs_dim: ${task.obs_dim}
action_dim: ${task.action_dim}
keypoint_dim: ${task.keypoint_dim}
task_name: ${task.name}
exp_name: "default"

gamma: 1
horizon: 1
n_obs_steps: 1
n_action_steps: 1
n_latency_steps: 0
past_action_visible: False
keypoint_visible_rate: 1.0
obs_as_local_cond: False
obs_as_global_cond: False
pred_action_steps_only: False

policy:
  _target_: diffusion_policy.policy.pbrl_bet_lowdim_policy.CQLBETLowdimPolicy

  action_ae:
    _target_: diffusion_policy.model.bet.action_ae.discretizers.k_means.KMeansDiscretizer
    num_bins: 24
    action_dim: ${action_dim}
    predict_offsets: True
  
  obs_encoding_net:
    _target_: torch.nn.Identity
    output_dim: ${obs_dim}
  
  state_prior:
    _target_: diffusion_policy.model.bet.latent_generators.mingpt.MinGPT

    discrete_input: false
    input_dim: ${obs_dim}

    vocab_size: ${policy.action_ae.num_bins}

    # Architecture details
    n_layer: 4
    n_head: 4
    n_embd: 72

    block_size: ${horizon}  # Length of history/context
    predict_offsets: True
    offset_loss_scale: 1000.0  # actions are very small
    focal_loss_gamma: 2.0
    action_dim: ${action_dim}

  gamma: ${gamma}
  horizon: ${horizon}
  n_obs_steps: ${n_obs_steps}
  n_action_steps: ${n_action_steps}
  action_dim: ${action_dim}
  obs_dim: ${obs_dim}
  num_samples: 25
  tau: 0.001
  bc_alpha: 0.01

reward_model:  
  _target_: diffusion_policy.common.reward_model.TransformerRewardModel 
  observation_dim: ${obs_dim}
  action_dim: ${action_dim}
  structure_type: "transformer1"
  ensemble_size: 3
  d_model: 256
  num_layers: 2
  max_seq_len: 100
  activation: "tanh"
  device: ${training.device_gpu}
  logger: None
  task: 'halfcheetah'

reward_training:
  lr: 2.0e-4
  batch_size: 30
  n_epochs: 200
  warm_up_epochs: 10
  data_size: ${task.pref_dataset.N}

dataloader:
  batch_size: 50
  num_workers: 1
  shuffle: True
  pin_memory: True
  persistent_workers: False

val_dataloader:
  batch_size: 50
  num_workers: 1
  shuffle: False
  pin_memory: True
  persistent_workers: False

optimizer:
  actor:
    weight_decay: 0.1
    learning_rate: 1.0e-4
    betas: [0.9, 0.999]
  qf1:
    weight_decay: 1.0e-5
    lr: 1.0e-4
    betas: [0.9, 0.999]
  qf2:
    weight_decay: 1.0e-5
    lr: 1.0e-4
    betas: [0.9, 0.999]



training:
  device_cpu: "cpu"
  device_gpu: "cuda:0"
  seed: 42
  debug: False
  resume: False #True
  # optimization
  lr_scheduler: cosine
  lr_warmup_steps: 0
  num_epochs: 200
  gradient_accumulate_every: 1
  grad_norm_clip: 1.0
  enable_normalizer: False
  # training loop control
  # in epochs
  rollout_every: 20
  checkpoint_every: 20
  val_every: 1
  sample_every: 5
  # steps per epoch
  max_train_steps: null
  max_val_steps: null
  # misc
  tqdm_interval_sec: 1.0
  online:
    num_groups: 4
    all_votes: 100
    reverse_ratio: 0.2
    reverse_rate: 0.5
    reverse_freq: 1
    update_history: False
  map:
    use_map: False # only true when use our method
    map_ratio: 0


logging:
  project: diffusion_policy_pbrl
  resume: True
  mode: online
  name: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
  tags: ["${name}", "${task_name}", "${exp_name}"]
  id: null
  group: null

checkpoint:
  topk:
    monitor_key: test_cumulative_rewards
    mode: max
    k: 5
    format_str: 'epoch={epoch:04d}-test_cumulative_rewards={test_cumulative_rewards:.3f}.ckpt'
  save_last_ckpt: True
  save_last_snapshot: False

multi_run:
  run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
  wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}

hydra:
  job:
    override_dirname: ${name}
  run:
    dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
  sweep:
    dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
    subdir: ${hydra.job.num}
