# @package _global_
defaults:
  - override /env: pointmass
  - override /evaluator: pointmass
num_steps: 5
num_envs: 256
log_interval: 10
save_interval: 10000000000
eval_interval: 1000000000
num_env_steps: 5e6

policy_updater:
  _target_: imitation_learning.meta_irl.updater.MetaIRL
  _recursive_: False

  batch_size: 256
  plot_interval: ${eval_interval}
  norm_expert_actions: False
  n_inner_iters: 1
  reward_update_freq: 1
  storage_cfg: ${storage}
  device: ${device}
  num_envs: ${num_envs}
  total_num_updates: ${total_num_updates}
  use_lr_decay: True

  get_dataset_fn:
    _target_: imitation_learning.common.utils.get_transition_dataset
    dataset_path: data/traj/pm_100.pth
    env_name: ${env.env_name}


  policy_init_fn:
    _target_: imitation_learning.meta_irl.rewards.reg_init
    _recursive_: False
    policy_cfg: ${policy}

  reward:
    _target_: imitation_learning.meta_irl.rewards.NeuralReward
    obs_shape: ${obs_shape}
    action_dim: ${action_dim}
    reward_hidden_dim: 128
    reward_type: "NEXT_STATE"
    cost_take_dim: -1
    include_tanh: False
    n_hidden_layers: 2

  inner_updater:
    _target_: imitation_learning.meta_irl.differentiable_ppo.DifferentiablePPO
    _recursive_: False

    use_clipped_value_loss: True
    max_grad_norm: -1
    value_loss_coef: 0.5
    clip_param: 0.2
    entropy_coef: 0.001
    num_epochs: 2
    num_mini_batch: 4

    # Returns calculation
    gae_lambda: 0.95
    use_gae: True
    gamma: 0.99

  inner_opt:
    _target_: torch.optim.Adam
    lr: 0.0001

  reward_opt:
    _target_: torch.optim.Adam
    lr: 0.001

  irl_loss:
    _target_: torch.nn.MSELoss
    reduction: 'mean'

eval_args:
  policy_updater:
    reward_update_freq: -1
    n_inner_iters: 1
  load_policy: False
  num_env_steps: 5e6



hydra:
  run:
    dir: ./
