---
__object__: src.explib.base.ExperimentCollection
name: mnist_ablation
experiments:
  - &exp_rad_logN
    __object__: src.explib.hyperopt.HyperoptExperiment
    name: mnist_full_radial_logN
    scheduler: &scheduler 
      __object__: ray.tune.schedulers.ASHAScheduler
      max_t: 1000000
      grace_period: 1000000
      reduction_factor: 2
    num_hyperopt_samples: &num_hyperopt_samples 1
    gpus_per_trial: &gpus_per_trial 1
    cpus_per_trial: &cpus_per_trial 1
    tuner_params: &tuner_params
      metric: val_loss
      mode: min
    trial_config:
      logging:
        images: true
        "image_shape": [28, 28]
      dataset: &dataset
        __object__: src.explib.datasets.MnistSplit
      epochs: &epochs 200000
      patience: &patience 2
      batch_size: &batch_size 
        __eval__: tune.choice([32])
      optim_cfg: &optim 
        optimizer:
          __class__: torch.optim.Adam 
        params:
          lr: 
            __eval__: 1e-4
          weight_decay: 0.0
      
      model_cfg: 
        type:
          __class__: &model src.veriflow.flows.NiceFlow
        params:
          soft_training: true
          training_noise_prior:
            __object__: pyro.distributions.Uniform
            low: 
              __eval__: 1e-20
            high: 0.01
          prior_scale: 1.0
          coupling_layers: 10
          coupling_nn_layers: [300, 300, 300] 
          nonlinearity: &nonlinearity 
            __eval__: tune.choice([torch.nn.ReLU()])
          split_dim: 392
          base_distribution: 
            __object__: src.veriflow.distributions.RadialDistribution       
            device: cuda
            p: 1.0
            loc: 
              __eval__: torch.zeros(784).to("cuda")
            radial_distribution:
              __object__: pyro.distributions.LogNormal
              loc: 
                __eval__: torch.zeros(1).to("cuda")
              scale: 
                __eval__: (.5 * torch.ones(1)).to("cuda")
          use_lu: true
  - &exp_laplace
    __overwrites__: *exp_rad_logN
    name: mnist_full_laplace
    trial_config:
      model_cfg: 
          params:
            base_distribution:
              __exact__: 
                __object__: pyro.distributions.Laplace
                loc: 
                  __eval__: torch.zeros(784).to("cuda")
                scale: 
                  __eval__: torch.ones(784).to("cuda")
  - &exp_normal
    __overwrites__: *exp_rad_logN
    name: mnist_full_laplace
    trial_config:
      model_cfg: 
          params:
            base_distribution:
              __exact__: 
                __object__: pyro.distributions.Normal
                loc: 
                  __eval__: torch.zeros(784).to("cuda")
                scale: 
                  __eval__: torch.ones(784).to("cuda")
