_target_: consistency_policy.rollout_d4rl.rollout_student_d4rl_workspace.RolloutD4RLWorkspace
exp_name: default
horizon: 1
keypoint_visible_rate: 1.0
logging:
  group: null
  id: null
  mode: online
  name: 2022.12.29-22.31.30_train_diffusion_unet_hybrid_square_image
  project: diffusion_policy_debug
  resume: false
  tags:
  - rollout_diffusion_unet
  - d4rl
  - default
multi_run:
  run_dir: data/outputs/2022.12.29/22.31.30_train_diffusion_unet_hybrid_square_image
  wandb_name_base: 2022.12.29-22.31.30_train_diffusion_unet_hybrid_square_image
# n_action_steps: 128
# n_action_steps: 256
n_action_steps: 384

n_latency_steps: 0
n_obs_steps: 2
name: rollout_student_d4rl_maze
obs_as_global_cond: true
past_action_visible: false
policy:
  _target_: consistency_policy.reward_guided_student_maze.guided_ctm_policy_maze.GuidedCTMPUnetMazePolicy
  inference_mode: none # this will be populated automatically from training.inference_mode, do not set it here
  cond_predict_scale: true
  # diffusion_step_embed_dim: 256
  # down_dims:
  # - 256
  # - 512
  # - 1024

  # diffusion_step_embed_dim: 256
  # down_dims:
  # - 512
  # - 1024
  # - 2048

  diffusion_step_embed_dim: 256
  down_dims:
  - 256
  - 512
  - 1024
  - 2048

  reward_step_embed_dim: 32
  reward_down_dims:
  - 64
  - 128
  - 256

  # dropout_rate: 0.2
  # horizon: 128
  # horizon: 256
  horizon: 384
  kernel_size: 5
  # n_action_steps: 128
  # n_action_steps: 256
  n_action_steps: 384

  n_groups: 8
  n_obs_steps: 2
  #extra args
  initial_ema_decay: 0.0
  delta: -1 #0.0068305197 # sqrt(160) * .00054
  special_skip: true
  chaining_times: ['D', 27, 54]
  #teacher
  # teacher_path: ./Diffusion/outputs/edm/d4rl_maze_sm/checkpoints/epoch=0340-train_mse_error=0.012.ckpt
  # teacher_path: ./Diffusion/outputs/edm/d4rl_maze_med/checkpoints/epoch=0060-train_mse_error=0.595.ckpt
  teacher_path: ./Diffusion/outputs/edm/d4rl_maze_lg/checkpoints/epoch=0020-train_mse_error=1.762.ckpt


  # teacher_path: ./Diffusion/outputs/edm/d4rl_maze_large/checkpoints/epoch=0140-train_mse_error=0.121.ckpt
  # teacher_path: ./Diffusion/outputs/edm/d4rl_maze_med/checkpoints/epoch=0060-train_mse_error=0.052.ckpt
  #KDE
  use_kde: False
  kde_samples: 0
  #warm start
  # edm: ./Diffusion/outputs/edm/d4rl_maze_sm/checkpoints/epoch=0340-train_mse_error=0.012.ckpt
  # edm: ./Diffusion/outputs/edm/d4rl_maze_med/checkpoints/epoch=0060-train_mse_error=0.595.ckpt
  edm: ./Diffusion/outputs/edm/d4rl_maze_lg/checkpoints/epoch=0020-train_mse_error=1.762.ckpt

  # edm: ./Diffusion/outputs/edm/d4rl_maze_large/checkpoints/epoch=0140-train_mse_error=0.121.ckpt
  # edm: ./Diffusion/outputs/edm/d4rl_maze_med/checkpoints/epoch=0060-train_mse_error=0.052.ckpt
  # edm: ./Diffusion/outputs/edm/d4rl_maze_small/checkpoints/epoch=0080-train_mse_error=0.055.ckpt
  
  # reward_path: ./Diffusion/outputs/reward/d4rl_maze_med/checkpoints/epoch=0075-val_loss=0.000.ckpt
  # reward_path: ./Diffusion/outputs/reward/d4rl_maze_sm/checkpoints/epoch=0085-val_loss=0.000.ckpt
  reward_path: ./Diffusion/outputs/reward/d4rl_maze_lg/checkpoints/epoch=0060-val_loss=0.000.ckpt



  losses: [["ctm", "dsm"], [1, 1]]
  ctm_sampler: ctm
  dsm_weights: "karras"
  noise_scheduler:
    _target_: consistency_policy.diffusion.CTM_Scheduler
    time_min: 0.02
    time_max: 80.0
    rho: 7.0
    bins: 80
    solver: heun
    scaling: boundary
    use_c_in: true
    data_std: .5
    time_sampler: ctm
    clamp: true
    ode_steps_max: 1
  obs_as_global_cond: true
  shape_meta:
    action:
      shape:
      - 2
    observation:
      shape: 
      - 4
shape_meta:
  action:
    shape:
    - 2
  observation:
    shape: 
    - 4
task:
  abs_action: true
  env_runner:
    _target_: diffusion_policy.env_runner.maze2d_lowdim_state_runner.Maze2dLowdimStateRunner
    # env_name: 'maze2d-umaze-v1'
    # env_name: 'maze2d-medium-v1'
    env_name: 'maze2d-large-v1'
    crf: 22
    fps: 80
    max_steps: 1000
    # n_action_steps: 128
    # n_action_steps: 256
    n_action_steps: 384
    n_envs: 1
    n_obs_steps: 2
    n_test: 100
    n_test_vis: 20
    n_train: 2
    n_train_vis: 2
    past_action: false
    test_start_seed: 100000
    tqdm_interval_sec: 1.0
  name: d4rl_maze_lg
  shape_meta:
    action:
      shape:
      - 2
    observation:
      shape: 
      - 4
  task_name: d4rl_maze_lg
task_name: d4rl_maze_lg
training:
  inference_mode: true
  online_rollouts: true
  device: cuda:0
  # load_path: ./Diffusion/outputs/ctmp/d4rl_maze_sm/checkpoints/epoch=0000-test_mean_scores=1.264.ckpt
  # load_path: ./Diffusion/outputs/ctmp/d4rl_maze_med/checkpoints/epoch=0040-test_mean_scores=1.124.ckpt
  # load_path: ./Diffusion/outputs/ctmp/d4rl_maze_lg/checkpoints/epoch=0020-test_mean_scores=1.579.ckpt
  # load_path: ./Diffusion/outputs/ctmp/d4rl_maze_lg/checkpoints/epoch=0005-test_mean_scores=1.331.ckpt



  # load_path: ./Diffusion/outputs/guided_ctmp/d4rl_d4rl_maze_med/checkpoints/epoch=0000-test_mean_scores=1.315.ckpt
  # load_path: ./Diffusion/outputs/guided_ctmp/d4rl_d4rl_maze_sm/checkpoints/epoch=0030-test_mean_scores=1.256.ckpt
  load_path: ./Diffusion/outputs/guided_ctmp/d4rl_d4rl_maze_lg/checkpoints/epoch=0085-test_mean_scores=1.518.ckpt

  seed: 42
  use_ema: true
  val_chaining_steps: 1
  # output_dir: ./Diffusion/outputs/ctmp/d4rl_maze_sm
  # output_dir: ./Diffusion/outputs/ctmp/d4rl_maze_med
  output_dir: ./Diffusion/outputs/ctmp/d4rl_maze_lg