# @package _global_

global_cfg: # would be passed to actor, critic1, critic2, policy, env
  n_train_steps: 200000
  n_steps_per_epoch: 10000 # so n_epochs = n_train_steps / n_steps_per_epoch
  horizon: 384

runner:
  _target_: src.runner.TrainValuesRunner
  _partial_: true

### source
    # 'values': {
    #     'model': 'models.ValueFunction',
    #     'diffusion': 'models.ValueDiffusion',
    #     'horizon': 256,
    #     'n_diffusion_steps': 256,
    #     'dim_mults': (1, 4, 8),
    #     'renderer': 'utils.Maze2dRenderer',

    #     ## value-specific kwargs
    #     'discount': 0.99,
    #     'termination_penalty': -100,
    #     'normed': False,

    #     ## dataset
    #     'loader': 'datasets.ValueDataset',
    #     'termination_penalty': None,
    #     'normalizer': 'LimitsNormalizer',
    #     'preprocess_fns': ['maze2d_set_terminals'],
    #     'clip_denoised': True,
    #     'use_padding': False,
    #     'max_path_length': 40000,

    #     ## serialization
    #     'logbase': 'logs',
    #     'prefix': 'values/defaults',
    #     'exp_name': watch(args_to_watch), # TODO should be values_to_watch

    #     ## training
    #     'n_steps_per_epoch': 10000,
    #     'loss_type': 'value_l2',
    #     'n_train_steps': 200e3,
    #     'batch_size': 32,
    #     'learning_rate': 2e-4,
    #     'gradient_accumulate_every': 2,
    #     'ema_decay': 0.995,
    #     'save_freq': 1000,
    #     'sample_freq': 0,
    #     'n_saves': 50,
    #     'save_parallel': False,
    #     'n_reference': 50,
    #     'bucket': None,
    #     'device': 'cuda',
    #     'seed': None,
    #     'n_render_samples': 10,
    # },

### config for trainer_diffuser (as reference)
# dataset:
#   _target_: diffuser.datasets.GoalDataset
#   _partial_: true
#   env: "maze2d-large-v1"
#   horizon: ${global_cfg.horizon}
#   normalizer: "LimitsNormalizer"
#   preprocess_fns: ["maze2d_set_terminals"]
#   use_padding: false
#   max_path_length: 40000

### in py
# dataset_config = utils.Config(
#     args.loader,
#     savepath=(args.savepath, 'dataset_config.pkl'),
#     env=args.dataset,
#     horizon=args.horizon,
#     normalizer=args.normalizer,
#     preprocess_fns=args.preprocess_fns,
#     use_padding=args.use_padding,
#     max_path_length=args.max_path_length,
#     ## value-specific kwargs
#     discount=args.discount,
#     termination_penalty=args.termination_penalty,
#     normed=args.normed,
# )

dataset: 
  _target_: diffuser.datasets.ValueDataset
  custom_ds_path: null
  _partial_: true
  env: "maze2d-large-v1"
  horizon: ${global_cfg.horizon}
  normalizer: "LimitsNormalizer"
  preprocess_fns: ["maze2d_set_terminals"]
  use_padding: false
  max_path_length: 40000
  discount: 0.99
  termination_penalty: -100
  normed: false
  



render:
  _target_: diffuser.utils.Maze2dRenderer
  _partial_: true
  env: ${dataset.env}


net:
  _target_: diffuser.models.ValueFunction
  _partial_: true
  horizon: ${global_cfg.horizon}
  dim_mults: [1, 4, 8]

model:
  _target_: diffuser.models.ValueDiffusion
  _partial_: true
  horizon: ${global_cfg.horizon}
  n_timesteps: 256 # n_diffusion_steps in source code
  loss_type: "value_l2"
  # TODO add config for large0maze 2 in py file



trainer:
  _target_: diffuser.utils.Trainer
  task: train_values
  _partial_: true
  train_batch_size: 32
  train_lr: 2e-4
  gradient_accumulate_every: 2
  ema_decay: 0.995
  sample_freq: 0 # for x step, render samples
  save_freq: 1000 # for x step, save model
  label_freq: 10000 # not important, just use for name. e.g. 12234 -> 12000
  save_parallel: false
  results_folder: ${output_dir} # TODO
  bucket: null # TODO ? what
  n_reference: 25 # TODO ? what
  n_render_samples: 5


# common - for all tasks (task_name, tags, output_dir, device)
algorithm_name: "DefaultAlgName"
task_name: "RL_Diffuser_TrainValues"
tags: ["debug"]