# @package _global_
defaults:
  - /pipeline: hmdb51_convnext
  - /model: timm/convnext_tiny_3d
  - override /scheduler: timm_cosine  # timm_cosine or plateau

dataset:
  split_dir: testTrainMulti_7030_splits
  video_dir: videos
  clip_duration: 2
  num_frames: 24
  frame_size: 224
  # __l_max:
  #   - ${.num_frames}
  #   - ${.frame_size}
  #   - ${.frame_size}
  use_ddp: False # handled automatically in PTL
  augment: default
  randaug:
    num_layers: 2
  augmix:
    width: 3
  num_gpus: ${trainer.devices}  # will control if using distributed sampler

task:
  loss:
    _name_: cross_entropy
  # loss_val:
  #   _name_: cross_entropy

loader:
  batch_size: 2
  num_workers: 12
  persistent_workers: ${eval:"${loader.num_workers} != 0"}  # set False when using num_workers = 0

trainer:
  max_epochs: 100
  precision: 16
  devices: 1
  accumulate_grad_batches: ${eval:${train.global_batch_size} // ${.devices} // ${loader.batch_size}}

train:
  pretrained_model_path: /home/eric/hippo/outputs/2022-04-17/10-27-08/checkpoints/val/accuracy.ckpt
  seed: 1112
  ema: 0.   # if using, 0.99996
  optimizer_param_grouping:
    bias_weight_decay: False
    normalization_weight_decay: False
  remove_test_loader_in_eval: true  # null means we do use test loader
  global_batch_size: ${loader.batch_size}  # effective batch size (handled with multiple gpus, and accumulate_grad_batches)
  pretrained_model_strict_load: False
  replace_sampler_ddp: False # ${eval:"${trainer.devices} > 1"}
  pretrained_model_state_hook:
    _name_: convnext_timm_tiny_s4nd_2d_to_3d

optimizer:
  lr: 2e-4
  weight_decay: 0  #  maybe 1e-8

scheduler:
  # params for cosine decay
  warmup_t: 0
  t_initial: 100
  lr_min: 0

encoder: id
decoder: id

model:
  num_classes: 51
  # video_size: ${dataset.__l_max}
  video_size:
    - ${dataset.num_frames}
    - ${dataset.frame_size}
    - ${dataset.frame_size}
  drop_path_rate: 0.0
  drop_head: 0.0
  drop_mlp: 0.0
  tempor_patch_size: 2
  temporal_stage_strides: [None, 1, 1, 1]  # 1st stride handled by stem (most likely)
  stem_channels: 32  # only used for s4nd stem currently
  stem_type: new_s4nd_patch  # options: patch (regular convnext), s4nd_patch, new_s4nd_patch (best), s4nd
  stem_l_max: [4, 16, 16]  # stem_l_max=None,  # len of l_max in stem (if using s4)
  downsample_type: s4nd  # eg, s4nd, null (for regular strided conv)
  downsample_act: false
  downsample_glu: True
  conv_mlp: false
  custom_ln: false # only used if conv_mlp=1, should benchmark to make sure this is faster/more mem efficient, also need to turn off weight decay
  layer:  # null means use regular conv2d in convnext
    _name_: s4nd
    d_state: 64
    channels: 1
    bidirectional: true
    activation: null  # mimics convnext style
    final_act: none
    initializer: null
    weight_norm: false
    hyper_act: null
    dropout: 0
    init: fourier
    rank: 1
    trank: 1
    dt_min: 0.01
    dt_max: 1.0
    lr: ${optimizer.lr}
    n_ssm: 1
    deterministic: false # Special C init
    l_max: null
    verbose: true
    linear: true
    return_state: false
    contract_version: 1
    bandlimit: null
  stem_layer:
    dt_min: 0.1
    dt_max: 1.0
    init: fourier
  stage_layers:
    - dt_min: 0.1
      dt_max: 1.0
    - dt_min: 0.1
      dt_max: 1.0
    - dt_min: 0.1
      dt_max: 1.0
    - dt_min: 0.1
      dt_max: 1.0
