pattern: MPS
main: main/train_mps.py
batchsize: 64
batchsize_p: 64
batchsize_v: 128
epoch: 200
ubatch_ratio: 1
snapshot_interval: 1
experiment_iterations: 3
log_interval: 1
train_val_split_ratio: 0.9
pretrained_path: pretrained/stylegan2ada/dtd
log_metrics: ["loss_pseudo",  "loss_mps"]


models:
  pattern: RN18
  classifier:
    func: model/meta_resnet.py
    name: ResNet18Feat
    args:
      num_classes: 47
      pretrained: True
      finetune: True
  generator:
    func: model/stylegan_xl/networks_stylegan2ada.py
    name: Generator
    args:
      dim_z: 512
      num_classes: 47
  finder:
    func: model/finder.py
    name: ResidualMLPFinder
    args:
      z_dim: 512

dataset:
  dataset_func: data/generic.py
  dataset_name: DTD
  args:
    root: /dataset/DTD
    test: False
    size: 224
    gan_mean_std: True

optimizer:
  algorithm: SGD
  lr_milestone: [60, 120, 160]
  lr_drop_rate: 0.1
  args:
    lr: 0.01
    momentum: 0.9
    nesterov: True

optimizer_finder:
  algorithm: Adam
  args:
    lr: 1.0e-4

updater:
  func: updater/mps.py
  name: MPSClassifierUpdater
  args:
    lambda_p: 1.0
    warmup_epoch: 0
    meta_learning_freq: 1
    resolution: 224
    latent_norm: kl
    lambda_latent: 0.01
    lambda_inner_lr: 100.0
