# @package _global_
dataset:
  batch_size: 1024
  num_workers: 4

model:
  _target_: group_discovery.flow_matching.arrow.model.ComplexFlow
  device: ${device}
  net:
    _target_: group_discovery.flow_matching.arrow.net.ComplexMLPTimeConcat
    in_dim: 28 # Complex: [7,2] * 2 vectors
    out_dim: 8 # Complex: [2,2] * 2 matrix
    hidden_dim: 128
  prior_dist:
    _target_: group_discovery.distributions.GL2ComplexPushforwardDistribution
    # det_range: [0.5, 2]
    coeff_dist:
      _target_: torch.distributions.Independent
      reinterpreted_batch_ndims: 2
      base_distribution:
        _target_: torch.distributions.Uniform
        low:
          _target_: torch.tensor
          data:
            - [-1.570796, -1.570796]
            - [-1.570796, -1.570796]
            - [-1.570796, -1.570796]
            - [-1.570796, -1.570796]
        high:
          _target_: torch.tensor
          data:
            - [1.570796, 1.570796]
            - [1.570796, 1.570796]
            - [1.570796, 1.570796]
            - [1.570796, 1.570796]

optimizer:
  _target_: torch.optim.Adam
  lr: 3e-3

train:
  epochs: 1000

test:
  epoch_interval: 10
  n_steps: 20
