# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
  - override /datamodule: syn2D.yaml
  - override /model: simple2Dnet.yaml
  - override /callbacks: default.yaml
  - override /logger: wandb.yaml
  - override /trainer: default.yaml

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

# name of the run determines folder name in logs
name: "syn-2D-EMDapprox"

seed: 123

callbacks:
  eval_callback:
    val_batch_size: 1024
    train_log_freq: 250

datamodule: 
  batch_size: 1024
  npairs: 25000
  categories: ['square', 'circle']
  augment: True
      
trainer:
  min_epochs: 1
  max_epochs: 15
  # gradient_clip_val: 5
  check_val_every_n_epoch: 2
  num_sanity_val_steps: 0
  # limit_train_batches: 10
  # limit_val_batches: 4

model:
  lr: 1e-4
  weight_decay: 0.
  latent_dim: 128
  mlp_dims: [256, 128, 64, 16]
  input_dim: 2

logger:
  wandb:
    tags: ["synthetic", "2D", "mlp"]
    name: ${run_name}

test: True
run_name: mlp-big-mixed-aug-25K