# @package __global__

defaults:
  - ../default
  - override /dset: fMRI/default
  - _self_

solver: compression_fMRI

# loss balancing
losses:
  adv: 4.
  feat: 4.
  l1: 1.
  l2: 0.
balancer:
  balance_grads: true
  ema_decay: 0.999
  per_batch_item: true
  total_norm: 1.

adversarial:
  every: 1
  adversaries: [msdfMRI]
  adv_loss: hinge
  feat_loss: l1

# losses hyperparameters
l1: {}
l2: {}

# metrics
metrics: {}

# adversaries hyperparameters
msdfMRI:
  in_channels: 1024
  out_channels: 1
  scale_norms: [spectral_norm, weight_norm, weight_norm]
  kernel_sizes: [5, 3]
  filters: 2048
  min_filters: 256
  downsample_scales: [2, 2, 2, 2]
  inner_kernel_sizes: null
  groups: [2, 2, 2, 2]
  strides: null
  paddings: null
  activation: LeakyReLU
  activation_params: {negative_slope: 0.3}

# data hyperparameters
dataset:
  batch_size: 64
  num_workers: 10
  segment_duration: 64
  min_fMRI_duration: 16
  max_fMRI_duration: null
  tr_precision: 0.2
  max_read_retry: 1
  max_tr: 256.0
  return_info: true
  normalize: false
  train:
    num_samples: 569650  # raw num: 11393 * 50
  valid:
    num_samples: 4740  # raw num: 474 * 10
  evaluate:
    batch_size: 32
    num_samples: 1130  # raw num: 113 * 10
  generate:
    batch_size: 16
    num_samples: 16

# solver hyperparameters
evaluate:
  every: 25
  num_workers: 5
  metrics:
    l1: true
generate:
  every: 25
  num_workers: 5
  block_num: 2

# checkpointing schedule
checkpoint:
  save_last: true
  save_every: 25
  keep_last: 10
  keep_every_states: null

# optimization hyperparameters
optim:
  epochs: 200
  updates_per_epoch: 2000
  lr: 3e-4
  max_norm: 0.
  optimizer: adam
  adam:
    betas: [0.5, 0.9]
    weight_decay: 0.
  ema:
    use: true         # whether to use EMA or not
    updates: 1        # update at every step
    device: ${device} # device for EMA, can be put on GPU if more frequent updates
    decay: 0.99       # EMA decay value, if null, no EMA is used
