---
__object__: src.explib.base.ExperimentCollection
name: mnist_basedist_comparison
experiments:
    - &exp_nice
      __object__: src.explib.hyperopt.HyperoptExperiment
      name: mnist_nice
      scheduler: &scheduler 
        __object__: ray.tune.schedulers.ASHAScheduler
        max_t: 10000
        grace_period: 10000
        reduction_factor: 2
      num_hyperopt_samples: &num_hyperopt_samples 1
      gpus_per_trial: &gpus_per_trial 0
      cpus_per_trial: &cpus_per_trial 1
      tuner_params: &tuner_params
        metric: val_loss
        mode: min
      trial_config:
        logging:
              images: true
              image_shape: [10, 10]
        dataset: &dataset
          __object__: src.explib.datasets.MnistSplit
          scale: true
          digit: 0
        epochs: &epochs 20
        patience: &patience 5
        batch_size: &batch_size 
          __eval__: tune.choice([32])
        optim_cfg: &optim 
          optimizer:
            __class__: torch.optim.Adam 
          params:
            lr: 
              __eval__: tune.loguniform(1e-2, 5e-4)
            weight_decay: 0.0
        
        model_cfg: 
          type:
            __class__: &model src.veriflow.flows.NiceFlow
          params:
            use_lu: false
            coupling_layers: &coupling_layers 
              __eval__: tune.choice([3])
            coupling_nn_layers: &coupling_nn_layers 
              __eval__: tune.choice([[100]*2])
            nonlinearity: &nonlinearity 
              __eval__: tune.choice([torch.nn.ReLU()])
            split_dim: &split_dim 50
            base_distribution: 
              __object__: pyro.distributions.Laplace
              loc: 
                __eval__: torch.zeros(100)
              scale: 
                __eval__: torch.ones(100)
            masktype: &permutation half
    - &exp_lunice
      __overwrites__: *exp_nice
      model_cfg: 
        name: mnist_lunice
        params:
          soft_training: true

