dataset:
  name: D4_half_arrow
  group: D_4
  train:
    _target_: group_discovery.data.sampling_dataset.DistributionSamplingDataset
    dist:
      _target_: group_discovery.distributions.ObjectTransformDistribution
      base_dist:
        _target_: group_discovery.distributions.DiscreteDeltaMixture
        group: ${dataset.group}
        locs:
          _target_: group_discovery.distributions.Dn_elements_on_R2
          group_order: 4
      base_object:
        _target_: group_discovery.data.objects.Object
        name: half_arrow
    num_samples: 20000
    return_transform: false
  test:
    _target_: group_discovery.data.sampling_dataset.DistributionSamplingDataset
    dist:
      _target_: group_discovery.distributions.ObjectTransformDistribution
      base_dist:
        _target_: group_discovery.distributions.DiscreteDeltaMixture
        group: ${dataset.group}
        locs:
          _target_: group_discovery.distributions.Dn_elements_on_R2
          group_order: 4
      base_object:
        _target_: group_discovery.data.objects.Object
        name: half_arrow
    num_samples: 5000
    return_transform: true
  batch_size: 1024
  num_workers: 4
model:
  _target_: group_discovery.flow_matching.arrow.model.ComplexFlow
  device: ${device}
  net:
    _target_: group_discovery.flow_matching.arrow.net.ComplexMLPTimeConcat
    in_dim: 28
    out_dim: 8
    hidden_dim: 128
  prior_dist:
    _target_: group_discovery.distributions.GL2ComplexPushforwardDistribution
    coeff_dist:
      _target_: torch.distributions.Independent
      reinterpreted_batch_ndims: 2
      base_distribution:
        _target_: torch.distributions.Uniform
        low:
          _target_: torch.tensor
          data:
          - - -1.570796
            - -1.570796
          - - -1.570796
            - -1.570796
          - - -1.570796
            - -1.570796
          - - -1.570796
            - -1.570796
        high:
          _target_: torch.tensor
          data:
          - - 1.570796
            - 1.570796
          - - 1.570796
            - 1.570796
          - - 1.570796
            - 1.570796
          - - 1.570796
            - 1.570796
optimizer:
  _target_: torch.optim.Adam
  lr: 0.003
train:
  epochs: 1
test:
  epoch_interval: 1
  n_steps: 20
seed: 1001
device: cuda
save_dir: ${hydra:run.dir}
logger:
  _target_: group_discovery.logger.WandBLogger
  entity: null
  dir: .
  project: scratch
  name: null
  tags: null
  id: null
