# @package _global_

defaults:
  - /model/action_list: so2.yaml
  - override /data: complex_brdfs.yaml
  - override /model: equiv_ae.yaml
  - override /model/network: vit_b.yaml

tag: vit_b

data:
  batch_size: 48
  num_workers: 8
  pin_memory: True
  persistent_workers: True

optimizer:
  _target_: torch.optim.AdamW
  lr: 3e-4
  weight_decay: 0.05

loss:
  pred: True
  pred_l1: False
  reconst: False
  alignl2: True
  nalignl2: False
  angle_variance: False
  equiv_coef: 1

trainer:
  _target_: lightning.Trainer
  logger: 
    _target_: pytorch_lightning.loggers.TensorBoardLogger
    save_dir: log/complex_brdfs
    name: "${tag}\
    /nview${data.dataset.num_views}_nbasis${model.action_list.action_list.num_basis}_${model.adapter}\
    /${optimizer._target_}_bs${data.batch_size}_lr${optimizer.lr}_wd${optimizer.weight_decay}"
    default_hp_metric: False
  max_epochs: 100
  accelerator: gpu
  strategy: ddp
  precision: 16