# ========== TASK CONFIG ==========
env_name: robomimic_square
device: cuda
num_envs: 25
max_episode_steps: 600
dataset_dir: /path/to/dataset # Modify this to your dataset directory
ori_dataset_path: ${dataset_dir}/low_dim_v141.hdf5
base_dir: ${dataset_dir}/${env_name}_${action_space}_${dist_metric}_${dist_horizon} # Directory to save the precomputed distances, policies, etc.
save_model: true
dynamic_mode: true # Set to true if you want dynamic camera placements/movements

# ========== CLASS-SPECIFIC CONFIG ==========
dist_metric: dtw
dist_horizon: 16
dtw_file_path: ${base_dir}/dists.pth
dist_quantile: 0.025
use_sparse: false
temperature: 0.05

# ========== OBS / ACTION STRUCTURE ==========
obs_keys: [robot0_eef_pos, robot0_eef_quat, robot0_joint_pos, robot0_gripper_qpos, agentview_image]
proprio_dim: 16 # 3 (robot0_eef_pos) + 4 (robot0_eef_quat) + 7 (robot0_joint_pos) + 2 (robot0_gripper_qpos)
latent_dim: 0
obs_horizon: 1
pred_horizon: 16
action_horizon: 8
action_space: abs_ee_pose
action_dim: ${if:"'joint_pos' in '${action_space}'", 8, 10}

# ========== DATASET ==========
dataset:
  _target_: CLASS.dataset.robomimic.BCRobomimicDataset
  dataset_path: ${dataset_dir}/dataset.zarr
  num_demo: 200
  obs_keys: ${obs_keys}
  action_space: ${action_space}
  obs_horizon: ${obs_horizon}
  pred_horizon: ${pred_horizon}
  device: cpu

# ========== POLICY ==========
policy:
  _target_: CLASS.agents.base_policy.Policy
  obs_keys: ${obs_keys}
  proprio_dim: ${proprio_dim}
  latent_dim: ${latent_dim}
  action_dim: ${action_dim}
  obs_horizon: ${obs_horizon}
  pred_horizon: ${pred_horizon}
  vision_model: imn
  frozen_encoder: false
  spatial_softmax: true
  num_kp: 0
  device: ${device}

  model:
    image_encoder:
      _target_: CLASS.model.vision.VisionEncoders
      vision_model: ${policy.vision_model}
      views: ${eval:"[key for key in ${policy.obs_keys} if 'image' in key]"}
      replace_norm: true
      spatial_softmax: ${policy.spatial_softmax}
      num_kp: ${policy.num_kp}
      noise_std: 0.001 
      frozen: ${policy.frozen_encoder}

    proprio_encoder:
      _target_: CLASS.model.MLP.MLP # torch.nn.Linear
      in_features: ${policy.proprio_dim}
      hidden_features: 128
      num_layers: -1
      out_features: ${policy.latent_dim}

    fusion_fn:
      _target_: CLASS.util.fusion.ConcatFusion

    policy_head:
      _target_: CLASS.model.unet1d.Unet1D
      input_dim: ${action_dim} 
      pred_horizon: ${pred_horizon} 
      global_cond_dim: null
      down_dims: [256, 512, 1024]
      kernel_size: 5
      n_groups: 8

# ========== TRAINING ==========
pretrain:
  num_epoch: 50
  train_bs: 160
  train_steps: null
  eval_steps: null

  optim:
    _target_: CLASS.util.optim.LARS
    lr: 0
    vision_lr: 0.4

  scheduler:
    name: cosine
    num_warmup_steps: null

finetune:
  enabled: true # Set to true if you want to fine-tune with pretrained weights
  pretrain_policy_path: null # null if you want to use the latest trained model
  num_epoch: ${if:"${finetune.enabled}", 20, 100}  # CLASS pre-training enables training with fewer epochs
  train_bs: 64
  train_steps: null
  eval_steps: null

  optim:
    _target_: torch.optim.AdamW
    lr: 1e-4
    vision_lr: 1e-5
    weight_decay: 1e-6
    fused: true

  scheduler:
    name: cosine
    num_warmup_steps: 500

# ========== EVALUATION ==========
eval:
  num_ep: 50
  seed: 100000
  temperature: ${if:"${dynamic_mode}", 0.02, 0.01}  
  nnn: 64
  use_cossim: true
  nonparam_horizon: 16
  param_horizon: 12
  render: true
  option: 0

# ========== LOGGING ==========
wandb:
  name: ${env_name}_${now:%Y%m%d_%H%M%S}
  project: CLASS