defaults:
  - default
  - backbone: poly
  - algorithm: discrete_flow
  - env: halfcheetah
  - _self_


dataset:
  _target_: src.datasets.constrained_dataset.ConstrainedMinariDataset
  # 约束相关
  full_constrained_idx: ${full_constrained_idxs}
  single_A: ${single_A}
  single_b: ${single_b}

  env: ${dataset_minari_name}
  horizon: ${horizon}
  normalizer: 'GaussianNormalizer'
  preprocess_fns: []
  max_path_length: ${max_seq}
  max_n_episodes: 10000
  termination_penalty: 0
  use_padding: false

algorithm:
  action_weight: 10.0
  loss_discount: 0.99
  use_ot_batch: true


backbone:
  cons_begin_seq_idx: 0 # 因为halfcheetah对动作轨迹施加约束，因此第0 horizon也需要考虑
  share_traj_encoder: true
  time_invariance_cons: true # 使用简化的constrainec encoder

trainer:
  _target_: src.utils.training.Trainer
  # diffusion_model: ??
  # dataset: ??
  # renderer: null
  device: ${device}
  ema_decay: 0.995
  train_batch_size: 64
  train_lr: 1e-4
  gradient_accumulate_every: 1
  step_start_ema: 2000
  update_ema_every: 100
  log_freq: 200
  save_freq: 200000
  label_freq: ${trainer.save_freq}
  save_parallel: false
  
# 不加 guide policy config
policy:
  _target_: src.sampling.policies.PolyFlowPolicy
  # diffusion_model: null
  # normalizer: null
  preprocess_fns: []


val_dataloader:
  _target_: torch.utils.data.DataLoader
  batch_size: ${eval.samples}
  shuffle: true

eval:
  samples: 200   # 计算分布差异时，采样多少个horizon
  # rollout
  n_episodes: 10  # rollout 多少条轨迹
  seed: ${seed}
  is_video: true
  video_episodes: 2

  # check_index_list: [0, 1, 2, 3, 4, 5, 6, 7] # 在计算平滑度时考虑哪些维度



env_name: halfcheetah
algo_name: polyflow
run_name: train_ot


device: "cuda:0"
seed: 42
dataset_minari_name: "mujoco/halfcheetah/medium-v0"
horizon: 100  # 如果要使用 unet 要保证 horizon 能够被8整除
max_seq: 1000
obs_dim: 17
act_dim: 6
transition_dim: 23
cond_dim: 17
steps: 10  # 采样步数

iteration: 1000000  # 训练step

leg_limit: 1.2
torsion_limit: 0.8
full_constrained_idxs: [0, 1, 3, 4] # 对应halfcheetah的 后大腿 后小腿 前大腿 前小腿
single_A: [[1, 1, 0, 0], [0, 0, 1, 1], [1, 0, -1, 0], [-1, 0, 1, 0], [1, 0, 0, 0], [-1, 0, 0, 0], [0, 1, 0, 0], [0, -1, 0, 0], [0, 0, 1, 0], [0, 0, -1, 0], [0, 0, 0, 1], [0, 0, 0, -1]]
single_b: [1.2, 1.2, 0.8, 0.8, 1, 1, 1, 1, 1, 1, 1, 1]
num_cons: 12



