---
__object__: src.explib.base.ExperimentCollection
name: mnist_digit_scaled_small_models
experiments:
  - &exp_nice_lu_laplace
    __object__: src.explib.hyperopt.HyperoptExperiment
    name: mnist_0
    scheduler:  
      __object__: ray.tune.schedulers.ASHAScheduler
      max_t: 1000000
      grace_period: 1000000
      reduction_factor: 2
    num_hyperopt_samples: 8
    gpus_per_trial: 0
    cpus_per_trial: 1
    tuner_params: 
      metric: val_loss
      mode: min
    trial_config:
      logging:
        images: true
        "image_shape": [10, 10]
      dataset: 
        __object__: src.explib.datasets.MnistSplit
        digit: 0
        scale: true
      epochs:  200000
      patience: 20
      batch_size:
        __eval__: tune.choice([32])
      optim_cfg: 
        optimizer:
          __class__: torch.optim.Adam 
        params:
          lr: 
            __eval__: tune.loguniform(1e-5, 1e-3)
          weight_decay: 0.0
      model_cfg: 
        type:
          __class__:  src.veriflow.flows.NiceFlow
        params:
          masktype: alternate
          soft_training: false
          training_noise_prior:
            __object__: pyro.distributions.Laplace
            loc: 0.0
            scale: 0.001
          prior_scale: 1.0
          coupling_layers: 
            __eval__: tune.choice([2, 3])
          coupling_nn_layers:
            __eval__: "tune.choice([[50],[100], [50, 50], [100, 50], [50, 100], [100, 100]])" 
          nonlinearity: 
            __eval__: tune.choice([torch.nn.ReLU()])
          split_dim: 
            __eval__: tune.choice([10 + i for i in range(41)])
          base_distribution: 
            __object__: pyro.distributions.Laplace
            loc: 
              __eval__: torch.zeros(100).to("cpu")
            scale: 
              __eval__: torch.ones(100).to("cpu")
          use_lu: true
  - __overwrites__: *exp_nice_lu_laplace
    name: mnist_1
    trial_config:
      dataset:
        __object__: src.explib.datasets.MnistSplit
        digit: 1
        scale: true
  - __overwrites__: *exp_nice_lu_laplace
    name: mnist_2
    trial_config:
      dataset:
        __object__: src.explib.datasets.MnistSplit
        digit: 2
        scale: true
  - __overwrites__: *exp_nice_lu_laplace
    name: mnist_3
    trial_config:
      dataset: 
        __object__: src.explib.datasets.MnistSplit
        digit: 3
        scale: true
  - __overwrites__: *exp_nice_lu_laplace  
    name: mnist_4
    trial_config:
      dataset: 
        __object__: src.explib.datasets.MnistSplit
        digit: 4
        scale: true
  - __overwrites__: *exp_nice_lu_laplace
    name: mnist_5
    trial_config:
      dataset: 
        __object__: src.explib.datasets.MnistSplit
        digit: 5
        scale: true
  - __overwrites__: *exp_nice_lu_laplace
    name: mnist_6
    trial_config:
      dataset:
        __object__: src.explib.datasets.MnistSplit
        digit: 6
        scale: true
  - __overwrites__: *exp_nice_lu_laplace
    name: mnist_7
    trial_config:
      dataset: 
        __object__: src.explib.datasets.MnistSplit
        digit: 7
        scale: true
  - __overwrites__: *exp_nice_lu_laplace
    name: mnist_8
    trial_config:
      dataset: 
        __object__: src.explib.datasets.MnistSplit
        digit: 8
        scale: true
  - __overwrites__: *exp_nice_lu_laplace
    name: mnist_9
    trial_config:
      dataset: 
        __object__: src.explib.datasets.MnistSplit
        digit: 9
        scale: true