meta_data:
  exp_name: "mad_nba"
  script_path: "run_scripts/train.py"
  num_workers: 1
  job_name: "{dataset}/h_{horizon}-{model}-guidew_{condition_guidance_w}-dl_{loader}-retcond_{returns_condition}"

variables:
  seed: [100, 200, 300]

  horizon: [20]
  dataset: ["train"]
  condition_guidance_w: [1.2]

constants:
  # misc
  seed: 200
  env_type: "nba"
  n_agents: 10
  # nba_hz: 2.5 # corresponds to 0.4s per step
  nba_hz: 5 # corresponds to 0.2s per step
  use_action: False
  discrete_action: True

  # model
  model: "models.PlayerSharedConvAttentionDeconv"
  diffusion: "models.GaussianDiffusion"
  ego_only_inv: False
  share_inv: True
  # horizon: 4
  history_horizon: 0
  n_diffusion_steps: 200
  action_weight: 10
  loss_weights: null
  loss_discount: 1
  dim_mults: [1, 4, 8]
  returns_condition: False
  predict_epsilon: True
  calc_energy: False
  dim: 128
  hidden_dim: 256
  condition_dropout: 0.25
  condition_guidance_w: 1.2
  ar_inv: False
  train_only_inv: False
  clip_denoised: True
  test_ret: 0.9
  agent_share_noise: False
  renderer: "utils.SMACRenderer"

  # dataset
  loader: "datasets.NBASequenceDataset"
  normalizer: "CDFNormalizer"
  # normalizer: "LimitsNormalizer"
  max_n_episodes: 600
  preprocess_fns: []
  use_padding: True
  include_returns: True
  discount: 0.99
  max_path_length: 20000
  termination_penalty: 0.0

  # training
  n_steps_per_epoch: 10000
  loss_type: "l2"
  n_train_steps: 1000000
  batch_size: 32
  learning_rate: 0.0002
  gradient_accumulate_every: 2
  ema_decay: 0.995
  log_freq: 1000
  save_freq: 100000
  sample_freq: 10000
  n_saves: 5
  save_parallel: False
  n_reference: 3
  save_checkpoints: True

  # eval
  evaluator: "utils.NBAEvaluator"
  num_eval: 10
  eval_freq: 0
  eval_batch_size: 512 # max eval batch size
  eval_sample_times: 16
  nba_eval_valid_samples: 1000

  # load checkpoint
  continue_training: True
