---
__object__: src.explib.base.ExperimentCollection
name: fashion_basedist_comparison
experiments:
    - &exp_laplace
      __object__: src.explib.hyperopt.HyperoptExperiment
      name: mnist_laplace
      scheduler: &scheduler 
        __object__: ray.tune.schedulers.ASHAScheduler
        max_t: 10000
        grace_period: 10000
        reduction_factor: 2
      num_hyperopt_samples: &num_hyperopt_samples 25
      gpus_per_trial: &gpus_per_trial 0
      cpus_per_trial: &cpus_per_trial 1
      tuner_params: &tuner_params
        metric: val_loss
        mode: min
      trial_config:
        dataset: &dataset
          __object__: src.explib.datasets.FashionMnistSplit
          label: 0
        epochs: &epochs 10000
        patience: &patience 500
        batch_size: &batch_size 
          __eval__: tune.choice([4, 8, 16, 32, 64])
        optim_cfg: &optim 
          optimizer:
            __class__: torch.optim.Adam 
          params:
            lr: 
              __eval__: tune.loguniform(1e-2, 5e-4)
            weight_decay: 0.0
        
        model_cfg: 
          type:
            __class__: &model src.veriflow.flows.NiceFlow
          params:
            coupling_layers: &coupling_layers 
              __eval__: tune.choice([2, 3, 4, 6, 8])
            coupling_nn_layers: &coupling_nn_layers 
              __eval__: tune.choice([[100, 100], [100], [200, 200], [200], [200, 100], [100, 200], [150], [150, 150], [150, 100], [100, 150]])
            nonlinearity: &nonlinearity 
              __eval__: tune.choice([torch.nn.ReLU()])
            split_dim: &split_dim 50
            base_distribution:
            __object__: pyro.distributions.Laplace
            loc: 
              __eval__: torch.zeros(100)
            scale: 
              __eval__: torch.ones(100)
            permutation: &permutation half
    - &exp_normal
      __overwrites__: *exp_laplace
      name: mnist_normal
      model_cfg: 
        params:
          base_distribution:
            __object__: pyro.distributions.Normal
            loc: 
              __eval__: torch.zeros(100)
            scale: 
              __eval__: torch.ones(100)
