# @package _global_

# to execute this experiment run:
# python run.py experiment=example_simple.yaml

defaults:
  - /experiment/cifar10/vit/vit-s.yaml
  - override /model/vitmlp_cfg: fouriermask

model:
  drop_path_rate: 0.1
  drop_rate: 0.0
  mlp_cfg:
    linear1_cfg:
      rank_per_component: 1
      total_rank_ratio: 1.0
      fixed_width: false
      fixed_location: false
      min_widths: [0.0, 0.0]
      max_widths: [1.0, 1.0]
      width_init: 'splr'
      location_init: 0.0
      width_learning_rate: 0.005
      width_weight_decay: 0.0
      location_learning_rate: ${.width_learning_rate}
      adam_betas: [0.0, 0.999]
      compute_mode: 'dense'
  attnlinear_cfg: ${model.mlp_cfg.linear1_cfg}

callbacks:
  sigma_annealing:
    _target_: src.callbacks.fm_sigma_annealing.SigmaAnnealing
    sigma_init: 1.0
    sigma_final: 100.0
    start_epoch: 5
    end_epoch: 300
  width_monitor:
    _target_: src.callbacks.width_monitor.WidthMonitor
  rich_progress_bar:
    _target_: pytorch_lightning.callbacks.RichProgressBar
  flop_count:
    _target_: src.callbacks.flop_count.FlopCount 
    profilers: ["fvcore"] 
    input_size: [3,32,32] 
    sinc_gaussian: true
    baseline_complexity: 3.10e9
  shrink:
    _target_: src.callbacks.shrink.AdaptiveSoftshrink
    thres: 0.02
    target_width: 0.0
    init_thres: ${.thres}
    start_epoch: ${train.scheduler.warmup_t}
    end_epoch: ${.start_epoch}


datamodule:
  batch_size: 1024

train:
  optimizer:
    lr: 5e-4

