seed_everything: 42
model: 
  base_momentum: 0.99
  v_latent: 2.0
  augNearRate: 100000
  sigmaP: 2.0
  base_momentum: 0.99
  hidden_dim: 4096
  proj_dim: 256
  optimazer: adamw
  max_epochs: 2000
  dmtlosstype: latent
  linear_loss_weight: 0.03
  loss_recons_weight: 0.05
  vq_loss_weight: 0.2
  weight_decay: 0.0002
  diff_lr: 1e-3
  dataset: mnist
  param_num: 1066
  channel: 6
  latent_epoch: 2000
  target_layer: 'conv1'

  # test: jlj
data:
  batch_size: 128
  num_workers: 1
  data_dir: 'parameters/cifar_conv3/'
  num_model: 1
  target_layer: 'classifier'
  size: 'conv3'

# trainer.logger: 
trainer:
  logger:
    class_path: lightning.pytorch.loggers.WandbLogger
    init_args:
      name: mnist_conv5
      project: partial_convfc
      #Experiment2_small_parameter_withLatentAE
      save_dir: wandb
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        monitor: loss_diff
        dirpath: ./checkpoints
        filename: best-{epoch:05d}-{loss_all:.7f}
        save_top_k: 2
        mode: min
  # logger.init_args.save_code: True
  max_epochs: 2000
  accelerator: gpu
  check_val_every_n_epoch: 1
  enable_checkpointing: True
  # strategy: ddp_find_unused_parameters_true
# trainer.strategy: ddp_find_unused_parameters_true
