# ========== TASK CONFIG ==========
env_name: mimicgen_stack_three
device: cuda
num_envs: 25
max_episode_steps: 800
dataset_dir: /path/to/dataset # Modify this to your dataset directory
ori_dataset_path: ${dataset_dir}/stack_three_d0.hdf5
base_dir: ${dataset_dir}/${env_name}_${action_space}_${dist_metric}_${dist_horizon} # Directory to save the precomputed distances, policies, etc.
save_model: true
dynamic_mode: true # Set to true if you want dynamic camera placements/movements

# ========== CLASS-SPECIFIC CONFIG ==========
dist_metric: dtw
dist_horizon: 16
dtw_file_path: ${base_dir}/dists.pth
dist_quantile: 0.01
use_sparse: true # Set to true if your dataset is large
temperature: 0.05

# ========== OBS / ACTION STRUCTURE ==========
obs_keys: [robot0_eef_pos, robot0_eef_quat, robot0_joint_pos, robot0_gripper_qpos, agentview_image]
proprio_dim: 16 # 3 (robot0_eef_pos) + 4 (robot0_eef_quat) + 7 (robot0_joint_pos) + 2 (robot0_gripper_qpos)
latent_dim: 0
obs_horizon: 1
pred_horizon: 16
action_horizon: 8
action_space: abs_ee_pose
action_dim: ${if:"'joint_pos' in '${action_space}'", 8, 10}

# ========== DATASET ==========
dataset:
  _target_: CLASS.dataset.robomimic.BCRobomimicDataset
  dataset_path: ${dataset_dir}/dataset.zarr
  num_demo: 1000
  obs_keys: ${obs_keys}
  action_space: ${action_space}
  obs_horizon: ${obs_horizon}
  pred_horizon: ${pred_horizon}
  device: cpu

# ========== POLICY ==========
policy:
  _target_: CLASS.agents.diffusion_policy.DiffusionPolicy
  obs_keys: ${obs_keys}
  proprio_dim: ${proprio_dim}
  latent_dim: ${latent_dim}
  action_dim: ${action_dim}
  obs_horizon: ${obs_horizon}
  pred_horizon: ${pred_horizon}
  vision_model: imn
  frozen_encoder: false
  spatial_softmax: true
  num_kp: 0
  device: ${device}
  num_inference_steps: 16

  model:
    image_encoder:
      _target_: CLASS.model.vision.VisionEncoders
      vision_model: ${policy.vision_model}
      views: ${eval:"[key for key in ${policy.obs_keys} if 'image' in key]"}
      replace_norm: true
      spatial_softmax: ${policy.spatial_softmax}
      num_kp: ${policy.num_kp}
      noise_std: 0.001
      frozen: ${policy.frozen_encoder}

    proprio_encoder:
      _target_: CLASS.model.MLP.MLP # torch.nn.Linear
      in_features: ${policy.proprio_dim}
      hidden_features: 128
      num_layers: 0
      out_features: ${policy.latent_dim}

    fusion_fn:
      _target_: CLASS.util.fusion.ConcatFusion

    policy_head:
      _target_: CLASS.model.unet1d.ConditionalUnet1D
      input_dim: ${action_dim}
      global_cond_dim: null
      diffusion_step_embed_dim: 256
      down_dims: [256, 512, 1024]
      kernel_size: 5
      n_groups: 8

  noise_scheduler:
    _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
    num_train_timesteps: 16
    beta_start: 0.0001
    beta_end: 0.02
    beta_schedule: squaredcos_cap_v2
    prediction_type: epsilon

# ========== TRAINING ==========
pretrain:
  num_epoch: 10
  train_bs: 160
  train_steps: null
  eval_steps: null

  optim:
    _target_: CLASS.util.optim.LARS
    lr: 0
    vision_lr: 0.4

  scheduler:
    name: cosine
    num_warmup_steps: null

finetune:
  enabled: true # Set to true if you want to fine-tune with pretrained weights
  pretrain_policy_path: null # null if you want to use the latest trained model
  num_epoch: ${if:"${finetune.enabled}", 20, 100} # CLASS pre-training enables training with fewer epochs
  train_bs: 64
  train_steps: null
  eval_steps: null

  optim:
    _target_: torch.optim.AdamW     
    lr: 1e-4
    vision_lr: 1e-5
    fused: true

  scheduler:
    name: cosine
    num_warmup_steps: 100

# ========== EVALUATION ==========
eval:
  num_ep: 50 # Must be divisible by num_envs
  seed: 100000
  temperature: ${if:"${dynamic_mode}", 0.02, 0.01}  
  nnn: 64
  use_cossim: true
  nonparam_horizon: 16
  param_horizon: 12
  render: true
  option: 0

# ========== LOGGING ==========
wandb:
  name: ${env_name}_${now:%Y%m%d_%H%M%S}
  project: CLASS