---
# Runs a series of experiments on various synthetic 2D datasets
__object__: src.explib.base.ExperimentCollection
name: synthetic_basedist_comparison
experiments:
    - &main
      __object__: src.explib.base.ExperimentCollection
      name: blobs
      experiments:
        - &main_normal
          __object__:  src.explib.hyperopt.HyperoptExperiment
          name: normal
          device: cpu
          skip: true
          scheduler: &scheduler 
            __object__: ray.tune.schedulers.ASHAScheduler
            max_t: 10000
            grace_period: 10000
            reduction_factor: 2
          num_hyperopt_samples: &num_hyperopt_samples 20
          gpus_per_trial: &gpus_per_trial 0
          cpus_per_trial: &cpus_per_trial 1
          tuner_params: &tuner_params
            metric: val_loss
            mode: min
          trial_config:
            logging:
              images: false
              scatter: true
            dataset: &dataset
              __object__: src.explib.datasets.SyntheticSplit
              generator: circles
              params_train: &params_train
                n_samples: 100000
              params_test: &params_test
                n_samples: 20000
              params_val: &params_val
                n_samples: 20000
            epochs: &epochs 10000000000
            patience: &patience 3
            batch_size: &batch_size 
              __eval__: tune.choice([64])
            optim_cfg: &optim 
              optimizer:
                __class__: torch.optim.Adam 
              params:
                lr: 
                  __eval__: tune.loguniform(1e-3, 1e-6)
                weight_decay: 0.0
            
            model_cfg: 
              type:
                __class__: &model src.veriflow.flows.NiceFlow
              params:
                soft_training: true
                training_noise_prior:
                  __object__: pyro.distributions.Uniform
                  low: 0.0
                  high: 0.1
                prior_scale: 1.0
                use_lu: true
                coupling_layers: &coupling_layers 
                  __eval__: tune.choice([i for i in range(2, 21)])
                coupling_nn_layers: &coupling_nn_layers 
                  __eval__: tune.choice([[w]*l for w in range(10, 250, 5) for l in range(1, 4)])
                nonlinearity: &nonlinearity 
                  __eval__: tune.choice([torch.nn.ReLU()])
                split_dim: &split_dim 1
                base_distribution:
                  __object__: src.veriflow.distributions.RadialDistribution       
                  device: cpu
                  p: 1.0
                  loc: 
                    __eval__: torch.zeros(2).to("cpu")
                  norm_distribution:
                    __object__: pyro.distributions.LogNormal
                    loc: 
                      __eval__: torch.zeros(1).to("cpu")
                    scale: 
                      __eval__: (.5 * torch.ones(1)).to("cpu")
    - &moons
      __overwrites__: *main
      name: moons
      experiments:
        - skip: true
          trial_config:
            dataset:
              generator: moons
        
    - &circles
      __overwrites__: *main
      name: circles
      experiments:
        - skip: false
          trial_config:
            dataset:
              generator: blobs
              params_train: 
                n_samples: 100000
                centers: [[-5., -5.], [0., 0.], [5., 5.]]
              params_test:
                n_samples: 20000
                centers: [[-5., -5.], [0., 0.], [5., 5.]]
              params_val: 
                n_samples: 20000
                centers: [[-5., -5.], [0., 0.], [5., 5.]]
   
    - &transformed
      __overwrites__: *main
      name: transformed
      experiments:
        - skip: true
          trial_config:
            dataset:
              generator: transformed
              params_train: 
                dim: 2
                transform:
                  __eval__: torch.Tensor([[1.0, 0.5], [-0.5, 1.0]])
              params_test: 
                dim: 2
                transform:
                  __eval__: torch.Tensor([[1.0, 0.5], [-0.5, 1.0]])
              params_val: 
                dim: 2
                transform:
                  __eval__: torch.Tensor([[1.0, 0.5], [-0.5, 1.0]])
