---
__object__: src.usflows.explib.base.ExperimentCollection
name: gaussian_mixture_experiments
experiments:
  - &exp2d
    __object__: src.usflows.explib.hyperopt.HyperoptExperiment
    name: gaussian_mixture_2D
    device: cpu
    scheduler: &scheduler 
      __object__: ray.tune.schedulers.ASHAScheduler
      max_t: 1000000
      grace_period: 1000000
      reduction_factor: 2
    num_hyperopt_samples: &num_hyperopt_samples 10
    gpus_per_trial: &gpus_per_trial 0
    cpus_per_trial: &cpus_per_trial 16
    tuner_params: &tuner_params
      metric: val_loss
      mode: min
    trial_config:
      logging:
        images: false
        "image_shape": [28, 28]
      dataset: &dataset
        class:
          __class__: src.usflows.explib.datasets.DistributionSplit
        params:
          distribution:
            __object__: src.usflows.distributions.GMM
            loc:
              __eval__: torch.tensor([[-1.0, -1.0], [1.0, 1.0]]) 
            covariance_matrix:
              __eval__: torch.stack([torch.eye(2)] * 2)
            mixture_weights:
              __eval__: torch.tensor([0.5, 0.5])
          num_train: 50000
          num_val: 2000
          num_test: 5000
      epochs: &epochs 200000
      patience: &patience 5
      batch_size: &batch_size 
        __eval__: tune.choice([16, 32, 64, 128])
      optim_cfg: &optim 
        optimizer:
          __class__: src.usflows.sophia.SophiaG
        params:
          lr: 
            __eval__: tune.loguniform(1e-4, 1e-2)
          weight_decay: 0.0
      model_cfg:
        type:
          __class__: src.usflows.flows.USFlow
        params:
          soft_training: 
            __eval__: False
          training_noise_prior:
            __object__: pyro.distributions.Uniform
            low:
              __eval__: 1e-20
            high: 0.1
          prior_scale: 1.0
          coupling_blocks: 
            __eval__: tune.choice([2, 4, 6, 8, 10])
          lu_transform: 1
          householder: 0
          conditioner_cls:
            __class__: pyro.nn.DenseNN
          conditioner_args:
            input_dim: 2
            hidden_dims: 
              __eval__: tune.choice([[32], [64], [128], [256], [32, 32], [64, 64], [128, 128]])
            param_dims: [2]
          in_dims: [2]
          affine_conjugation: true
          nonlinearity:
            __eval__: tune.choice([torch.nn.ReLU()])
          base_distribution:
            __object__: src.usflows.distributions.Normal
            loc:
              __eval__: torch.zeros([2]).to("cpu")
            scale:
              __eval__: torch.tensor(1.0).to("cpu")
            device: cpu
  - __overwrites__: *exp2d
    name: gaussian_mixture_8D
    trial_config:
      dataset:
        params:
          distribution:
            __object__: src.usflows.distributions.GMM
            loc:
              __eval__: torch.tensor([[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]) 
            covariance_matrix:
              __eval__: torch.stack([torch.eye(8)] * 2)
            mixture_weights:
              __eval__: torch.tensor([0.5, 0.5])
      model_cfg:
        params:
          in_dims: [8]
          conditioner_args:
            input_dim: 8
            param_dims: [8]
          base_distribution:
            loc:
              __eval__: torch.zeros([8]).to("cpu")
  - __overwrites__: *exp2d
    name: gaussian_mixture_32D
    trial_config:
      dataset:
        params:
          distribution:
            __object__: src.usflows.distributions.GMM
            loc:
              __eval__: torch.tensor([[-1.0] * 32, [1.0] * 32]) 
            covariance_matrix:
              __eval__: torch.stack([torch.eye(32)] * 2)
            mixture_weights:
              __eval__: torch.tensor([0.5, 0.5])
      model_cfg:
        params:
          in_dims: [32]
          conditioner_args:
            input_dim: 32
            param_dims: [32]
          base_distribution:
            loc:
              __eval__: torch.zeros([32]).to("cpu")
  - __overwrites__: *exp2d
    name: gaussian_mixture_128D
    trial_config:
      dataset:
        params:
          distribution:
            __object__: src.usflows.distributions.GMM
            loc:
              __eval__: torch.tensor([[-1.0] * 128, [1.0] * 128]) 
            covariance_matrix:
              __eval__: torch.stack([torch.eye(128)] * 2)
            mixture_weights:
              __eval__: torch.tensor([0.5, 0.5])
      model_cfg:
        params:
          in_dims: [128]
          conditioner_args:
            input_dim: 128
            param_dims: [128]
          base_distribution:
            loc:
              __eval__: torch.zeros([128]).to("cpu")
  