# @package _global_

defaults:
  - /callbacks: default.yaml
  - /logger: wandb.yaml

### shared by modelmodule, datamodule, runner, trainer ...
global_cfg:
  null

### main

runner:
  _target_: src.runner.TrainControllerRunner
  _partial_: true

modelmodule:
  _target_: src.modelmodule.FillActModelModule
  _partial_: true
  net: 
    _target_: src.modelmodule.FillActWrapper
    _partial_: true
    tahn: true # only for kuka where action must be in -1,1
    net:
      - _target_: torch.nn.Linear
        _partial_: true
        out_features: 1024
      - _target_: torch.nn.BatchNorm1d
        num_features: 1024
      - _target_: torch.nn.ReLU
      - _target_: torch.nn.Dropout
        p: 0.2
      - _target_: torch.nn.Linear
        in_features: 1024
        out_features: 512
      - _target_: torch.nn.BatchNorm1d
        num_features: 512
      - _target_: torch.nn.ReLU
      - _target_: torch.nn.Dropout
        p: 0.1
      - _target_: torch.nn.Linear
        _partial_: true
        in_features: 512
  metric_func:
    _target_: src.modelmodule.L1DistanceMetric
    _partial_: true
  loss_func:
    _target_: torch.nn.L1Loss
    _partial_: true
  optimizations:
    - param_target: all
      optimizer: 
        _target_: torch.optim.Adam
        _partial_: true
        lr: 0.001
        weight_decay: 0.0
        betas: [0.9, 0.999] 
      lr_scheduler_config:
        scheduler:
          _target_: torch.optim.lr_scheduler.CosineAnnealingLR
          _partial_: true
          T_max: ${trainer.max_steps}
          eta_min: 1e-5
        interval: step 
        frequency: 1
  optimization_first: ${modelmodule.optimizations.0} # for wandb log

datamodule: 
  _target_: src.datamodule.EnvDatamodule
  _partial_: true
  batch_size: 32
  pin_memory: false
  num_workers: ${oc.decode:${oc.env:NUM_WORKERS}}
  train_val_test_split: [0.95,0.025,0.025]
  dataset:
    _target_: src.datamodule.EnvTransitionDataset
    _partial_: true
    env: "halfcheetah-mixed"
    multi_step: 5 # ! only for transition dataset
    custom_ds_path: null
    preprocess_fns: by_env
    normalizer: by_env
    gpu: true
    seed: ${seed}
    lazyload: true
    forcesave: false

trainer:
    _target_: "pytorch_lightning.Trainer"
    _partial_: true
    default_root_dir: ${output_dir}
    accelerator: "gpu"
    deterministic: false
    log_every_n_steps: 100
    num_sanity_val_steps: 2
    max_steps: 40000
    val_check_interval: 1000
    check_val_every_n_epoch: null

### common - for all tasks (task_name, tags, output_dir, device)
algorithm_name: "DefaultAlgName"
task_name: "RL_Diffuser"
tags: ["debug"]