# @package _global_
defaults:
  - /callbacks: default.yaml
  - /logger: wandb.yaml

### shared by modelmodule, datamodule, runner, trainer ...
global_cfg:
  horizon: 64 # ! DEBUG to 64 for efficiency

### main

modelmodule:
  _target_: src.modelmodule.DiffuserModule
  _partial_: true
  data_noise: 0.00
  net:
    _target_: src.modelmodule.DiffusionWrapper
    _partial_: true
    diffusion:
      _target_: diffuser.models.GaussianDiffusion
      _partial_: true
      horizon: ${global_cfg.horizon}
      n_timesteps: 256 # n_diffusion_steps in source code
      loss_type: "l2"
      clip_denoised: true # ! ?
      predict_epsilon: false
      action_weight: 1 # ! only affect first step action, please use ingore_action
      ignore_action: true # if true, only apply loss on observations
      loss_weights: null # manually set for each dim
      loss_discount: 1 # may need less weight for future steps
      loss_beta_weight: true
      # observation_dim: {dynamic}
      # action_dim: {dynamic}
    net: 
      _target_: diffuser.models.temporal.TemporalUnet
      _partial_: true
      horizon: ${global_cfg.horizon}
      # transition_dim: {dynamic}
      # cond_dim: {dynamic}
      dim: 32
      dim_mults: [1, 4, 8]
      attention: false
  metric_func:
    _target_: src.modelmodule.L1DistanceMetric
    _partial_: true
  loss_func:
    _target_: torch.nn.MSELoss
    _partial_: true
  optimizations:
    - param_target: all
      optimizer: 
        _target_: torch.optim.Adam
        _partial_: true
        lr: 2e-4
        weight_decay: 0.0
        betas: [0.9, 0.999] 
      lr_scheduler_config:
        scheduler:
          _target_: torch.optim.lr_scheduler.CosineAnnealingLR
          _partial_: true
          T_max: ${trainer.max_steps}
          eta_min: 1e-5
        interval: step
        frequency: 1
  optimization_first: ${modelmodule.optimizations.0} # for wandb log
  controller:
    turn_on: false
    
datamodule: 
  _target_: src.datamodule.EnvDatamodule
  _partial_: true
  batch_size: 32
  pin_memory: false
  num_workers: ${oc.decode:${oc.env:NUM_WORKERS}}
  train_val_test_split: [0.95,1e-9,1e-9]
  dataset:
    _target_: src.datamodule.EnvEpisodeDataset
    _partial_: true
    env: maze2d-large-v1
    horizon: ${global_cfg.horizon}
    custom_ds_path: null
    preprocess_fns: by_env
    normalizer: by_env
    gpu: true
    seed: ${seed}
    clip_denoised: false # ! removed but in paper # maze:true mujoco:true
    use_padding: true # ! removed but in paper # maze:false mujoco:true
    mode: valid_multi_step%1 # ! multi_step%{step_num}, default, valid_multi_step%{step_num},interpolation%1
    lazyload: true # would lazy make indices
    forcesave: false # ! would not replace if already exists a file even when remake to avoid WRITE conflicts

runner:
  _target_: src.runner.TrainDiffuserRunner
  _partial_: true

trainer:
    _target_: pytorch_lightning.Trainer
    _partial_: true
    default_root_dir: ${output_dir}
    # min_epochs: null
    # max_epochs: 100
    # check_val_every_n_epoch: 1
    accelerator: "gpu"
    # devices: 1
    # move_metrics_to_cpu: true
    deterministic: false
    max_steps: 40000
    val_check_interval: 1000
    check_val_every_n_epoch: null
    log_every_n_steps: 100
    num_sanity_val_steps: 2


### common - for all tasks (task_name, tags, output_dir, device)
algorithm_name: "DefaultAlgName"
task_name: "RL_Diffuser"
tags: ["debug"]