# @package _global_

defaults:
  - /model/action_list: so3.yaml
  - override /data: modelnet.yaml
  - override /model: equiv_ae.yaml
  - override /model/network: vit_64x64.yaml

tag: vit64x64

no_ood: True

data:
  batch_size: 48
  num_workers: 8
  pin_memory: True
  persistent_workers: True

model:
  rank_ratio: 4

optimizer:
  _target_: torch.optim.AdamW
  lr: 1e-4
  weight_decay: 0.05

loss:
  pred: True
  pred_l1: False
  reconst: False
  alignl2: False
  nalignl2: False
  angle_variance: False
  equiv_coef: 1

trainer:
  _target_: lightning.Trainer
  logger: 
    _target_: pytorch_lightning.loggers.TensorBoardLogger
    save_dir: log/modelnet
    name: "${tag}\
    /nview${data.dataset.num_views}_patch${model.encoder.patch_size}_nbasis${model.action_list.action_list.num_basis}_rratio${model.rank_ratio}_ldim${model.latent_dim}\
    /${optimizer._target_}_bs${data.batch_size}_lr${optimizer.lr}_wd${optimizer.weight_decay}"
    default_hp_metric: False
  max_epochs: 400
  accelerator: gpu
  strategy: ddp
  precision: 16