defaults:
  - default
  - backbone: dit
  - algorithm: gauge_flow
  - env: walker2d
  - _self_

dataset:
  _target_: src.datasets.constrained_dataset.ConstrainedMinariDataset
  # 约束相关
  full_constrained_idx: ${full_constrained_idxs}
  single_A: ${single_A}
  single_b: ${single_b}

  env: ${dataset_minari_name}
  horizon: ${horizon}
  normalizer: 'GaussianNormalizer'
  preprocess_fns: []
  max_path_length: ${max_seq}
  max_n_episodes: 10000
  termination_penalty: 0
  use_padding: false

env:
  use_cpx: ${use_cpx}
  vel_scale: ${vel_scale}
  height_limit: ${height_limit}
  height_min: ${height_min}
  v_max: ${v_max}
  v_min: ${v_min}

  
# 不加 guide policy config
policy:
  _target_: src.sampling.policies.GuidedPolicy
  preprocess_fns: []


val_dataloader:
  _target_: torch.utils.data.DataLoader
  batch_size: ${eval.samples}
  shuffle: true

eval:
  load_model_path: "outputs/walker2dcpx2/gaugeflow_train_step100/42_2026-01-09_15-58-19/state_final.pt"
  samples: 200   # 计算分布差异时，采样多少个horizon
  # rollout
  n_episodes: 10  # rollout 多少条轨迹
  seed: ${seed}
  is_video: true
  video_episodes: 2
  skip_rollout: true

  check_index_list: [0, 1, 2, 3, 4, 5, 6, 7] # 在计算平滑度时考虑哪些维度


env_name: walker2dcpx2
algo_name: gaugeflow
run_name: time_step100


device: "cuda:0"
seed: 42
dataset_minari_name: "mujoco/walker2d/medium-v0"
horizon: 100  # 如果要使用 unet 要保证 horizon 能够被8整除
max_seq: 1000
obs_dim: 17
act_dim: 6
transition_dim: 23
cond_dim: 17
steps: 100  # 采样步数


height_limit: 1.35
height_min: 0.9
v_max: 1.4
v_min: -1.4
vel_scale: 0.01
use_cpx: 2
full_constrained_idxs: [6, 15] # 对应hopper 的z height, z vel
single_A: [[1, 0.01], [1, 0], [-1, 0], [0, 1], [0, -1]]
single_b: [1.35, 1.35, -0.9, 1.4, 1.4]
num_cons: 5