# @package _global_
defaults:
  - override /trainer: default # choose trainer from 'configs/trainer/'
  - override /model: t1_2d
  - override /datamodule: scalarflow
  - override /optimizer: adamw
  - override /scheduler: step
  - override /callbacks: [default, logimg_recons]
  - override /metrics: [relativel1, smape, psnr]
  - override /logger: wandb

seed: 69420

model:
  modes: 224
  in_channels: 2
  out_channels: 1 
  keep_high: False
  perform_inverse: False
  residual: True
  weight_init: 1
  act: "gelu"
  use_operator_layer: True
  signal_resolution: [1062, 600]
  transform: "dct"

trainer:
  accelerator: gpu
  devices: 1
  num_nodes: 1
  max_epochs: 20 

datamodule:
  data_dir: /datasets/scalarflow/scalarflow_full_cam3 
  batch_size: 1 
  context_steps: 2 
  target_steps: 3 
  target_steps_val_test: 10 
  stack_on_channels: True # stack all context steps and camera views on the channel dimension
  is_preprocessed: True # use to skip preprocessing steps online
  save_cache: False # set to False to decrease RAM usage drastically
  max_cache_size: 10 # cache GB
  num_workers: 10 

train:
  integration_order: 2 # 0: state 1: velocity 2: acceleration
  optimizer:
    lr:  1e-2 
    weight_decay: 0 
  scheduler:  # new scheduler based on sweeps
    _target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
    T_0: 32 # cosine annealing period in number of scheduler steps
  scheduler_interval: step # epoch for longer
  loss_fn: 
    _target_: hollow.losses.relative_l2.RelativeL2

callbacks:
  # model_checkpoint: null # leave default checkpoints
  early_stopping: null
  logimg_recons:
    _target_: hollow.callbacks.wandb_callbacks.LogReconstructionsScalarFlow
    num_samples: 1 # how many reconstructions to log

logger:
  wandb:
    project: sflowhighres
    name: T1-dct

task:
  _target_: hollow.tasks.sequence_default.SequenceDefaultModel

eval: 
  metrics:
  # MAE
    one_step/mae:
      _target_: hollow.metrics.timeseries_metrics.MAEatIndex
      index: 0
    3_steps/mae:
      _target_: hollow.metrics.timeseries_metrics.MAEatIndex
      index: 2
    5_steps/mae:
      _target_: hollow.metrics.timeseries_metrics.MAEatIndex
      index: 4
    10_steps/mae:
      _target_: hollow.metrics.timeseries_metrics.MAEatIndex
      index: 9
  # SMAPE
    one_step/smape:
      _target_: hollow.metrics.timeseries_metrics.SMAPEatIndex
      index: 0
    3_steps/smape:
      _target_: hollow.metrics.timeseries_metrics.SMAPEatIndex
      index: 2
    5_steps/smape:
      _target_: hollow.metrics.timeseries_metrics.SMAPEatIndex
      index: 4
    10_steps/smape:
      _target_: hollow.metrics.timeseries_metrics.SMAPEatIndex
      index: 9
  # PSNR
    one_step/psnr:
      _target_: hollow.metrics.timeseries_metrics.PSNRatIndex
      index: 0
    3_steps/psnr:
      _target_: hollow.metrics.timeseries_metrics.PSNRatIndex
      index: 2
    5_steps/psnr:
      _target_: hollow.metrics.timeseries_metrics.PSNRatIndex
      index: 4
    10_steps/psnr:
      _target_: hollow.metrics.timeseries_metrics.PSNRatIndex
      index: 9
