dataset:
  name: C4_arrow
  group: C_4
  train:
    _target_: group_discovery.data.sampling_dataset.DistributionSamplingDataset
    dist:
      _target_: group_discovery.distributions.ObjectTransformDistribution
      base_dist:
        _target_: group_discovery.distributions.DiscreteDeltaMixture
        group: ${dataset.group}
        locs:
          _target_: group_discovery.distributions.Cn_elements_on_R2
          group_order: 4
          representation: matrix
      base_object:
        _target_: group_discovery.data.objects.Object
        name: arrow
    num_samples: 20000
    return_transform: false
  test:
    _target_: group_discovery.data.sampling_dataset.DistributionSamplingDataset
    dist:
      _target_: group_discovery.distributions.ObjectTransformDistribution
      base_dist:
        _target_: group_discovery.distributions.DiscreteDeltaMixture
        group: ${dataset.group}
        locs:
          _target_: group_discovery.distributions.Cn_elements_on_R2
          group_order: 4
          representation: matrix
      base_object:
        _target_: group_discovery.data.objects.Object
        name: arrow
    num_samples: 5000
    return_transform: true
  batch_size: 1024
  num_workers: 4
model:
  _target_: group_discovery.flow_matching.arrow.model.Flow
  device: ${device}
  net:
    _target_: group_discovery.flow_matching.arrow.net.MLPTimeConcat
    in_dim: 26
    out_dim: 4
    hidden_dim: 128
  prior_dist:
    _target_: group_discovery.distributions.GL2PlusPushforwardDistribution
    coeff_dist:
      _target_: torch.distributions.Independent
      reinterpreted_batch_ndims: 1
      base_distribution:
        _target_: torch.distributions.Uniform
        low:
          _target_: torch.tensor
          data:
          - -1.570796
          - -1.570796
          - -1.570796
          - -1.570796
        high:
          _target_: torch.tensor
          data:
          - 1.570796
          - 1.570796
          - 1.570796
          - 1.570796
optimizer:
  _target_: torch.optim.Adam
  lr: 0.003
train:
  epochs: 1
test:
  epoch_interval: 1
  n_steps: 20
seed: 1004
device: cuda
save_dir: ${hydra:run.dir}
logger:
  _target_: group_discovery.logger.WandBLogger
  entity: null
  dir: .
  project: scratch
  name: null
  tags: null
  id: null
