---
__object__: src.explib.base.ExperimentCollection
name: mnist_digit_basedist_comparison
experiments:
  - &digit_0
    __object__: src.explib.base.ExperimentCollection
    name: mnist_basedist_comparison
    experiments:
        - &exp_nice_lu_laplace
          __object__: src.explib.hyperopt.HyperoptExperiment
          name: mnist_nice_lu_laplace
          scheduler: &scheduler 
            __object__: ray.tune.schedulers.ASHAScheduler
            max_t: 1000000
            grace_period: 1000000
            reduction_factor: 2
          num_hyperopt_samples: &num_hyperopt_samples 20
          gpus_per_trial: &gpus_per_trial 0
          cpus_per_trial: &cpus_per_trial 1
          tuner_params: &tuner_params
            metric: val_loss
            mode: min
          trial_config:
            dataset: &dataset
              __object__: src.explib.datasets.MnistSplit
              digit: 0
              scale: true
            epochs: &epochs 200000
            patience: &patience 50
            batch_size: &batch_size 
              __eval__: tune.choice([32])
            optim_cfg: &optim 
              optimizer:
                __class__: torch.optim.Adam 
              params:
                lr: 
                  __eval__: tune.loguniform(1e-4, 1e-2)
                weight_decay: 0.0
            
            model_cfg: 
              type:
                __class__: &model src.veriflow.flows.NiceFlow
              params:
                soft_training: true
                training_noise_prior:
                  __object__: pyro.distributions.Uniform
                  low: 0.0
                  high: 0.001
                prior_scale: 1.0
                use_lu: true
                coupling_layers: &coupling_layers 
                  __eval__: tune.choice([2, 3, 4, 5])
                coupling_nn_layers: &coupling_nn_layers 
                  __eval__: tune.choice([[w]*l for l in [1, 2] for w in [10, 20, 50, 100, 200]])
                nonlinearity: &nonlinearity 
                  __eval__: tune.choice([torch.nn.ReLU()])
                split_dim: 
                  __eval__: tune.choice([i for i in range(1, 51)])
                base_distribution: 
                  __object__: pyro.distributions.Laplace
                  loc: 
                    __eval__: torch.zeros(100)
                  scale: 
                    __eval__: torch.ones(100)
        - &exp_nice_lu_normal
          __overwrites__: *exp_nice_lu_laplace
          name: mnist_nice_lu_normal
          trial_config:
            model_cfg: 
                params:
                  base_distribution:
                    __exact__: 
                      __object__: pyro.distributions.Normal
                      loc: 
                        __eval__: torch.zeros(100)
                      scale: 
                        __eval__: torch.ones(100)
        - &exp_nice_rand_laplace
          __overwrites__: *exp_nice_lu_laplace
          name: mnist_nice_rand_laplace
          trial_config:
            model_cfg: 
              params:
                use_lu: false
                masktype: random
        - &exp_nice_rand_normal
          __overwrites__: *exp_nice_lu_laplace
          name: mnist_nice_rand_normal
          trial_config:
            model_cfg: 
              params:
                use_lu: false
                masktype: random
                base_distribution: 
                  __exact__: 
                    __object__: pyro.distributions.Normal
                    loc: 
                      __eval__: torch.zeros(100)
                    scale: 
                      __eval__: torch.ones(100)
  - &digit_1
    __overwrites__: *digit_0
    experiments:
        trial_config:
          dataset: 
           digit: 1
  - &digit_2
    __overwrites__: *digit_0
    experiments:
        trial_config:
          dataset: 
            digit: 2
  - &digit_3
    __overwrites__: *digit_0
    experiments:
        trial_config:
          dataset: 
            digit: 3
  - &digit_4
    __overwrites__: *digit_0
    experiments:
        trial_config:
          dataset: 
            digit: 4
  - &digit_5
    __overwrites__: *digit_0
    experiments:
        trial_config:
          dataset: 
            digit: 5
  - &digit_6
    __overwrites__: *digit_0
    experiments:
        trial_config:
          dataset: 
            digit: 6
  - &digit_7
    __overwrites__: *digit_0
    experiments:
        trial_config:
          dataset: 
            digit: 7
  - &digit_8
    __overwrites__: *digit_0
    experiments:
        trial_config:
          dataset: 
            digit: 8
  - &digit_9
    __overwrites__: *digit_0
    experiments:
        trial_config:
          dataset: 
            digit: 9
        