_target_: src.datamodules.set.SetDataModule
name: "softmax_attention"
key: softmax_attention
dataset_target_: src.datamodules.set.SetDataset
seed: 42
x_dim: 2
phi_dim: ${datamodule.x_dim} # ${mult_int:${datamodule.x_dim},2}
y_dim: ${datamodule.x_dim} # ${mult_int:${datamodule.x_dim},3}

dataset_parameters:
  seed: ${seed}
  batch_size: 64
  num_workers: 1
  x_dim: ${datamodule.x_dim}
  phi_dim: ${datamodule.phi_dim}
  y_dim: ${datamodule.y_dim}
  # max_length: ${model.model_params.max_xy_length}

datasets:
  seed: ${seed}
  train:
      _target_: ${datamodule.dataset_target_}
      split: "train"
      seq_max_length: 4
      seq_min_length: ${datamodule.datasets.train.seq_max_length}
      num_batches: 1000
      mixing_type: ${datamodule.mixing_architecture.name}
      mixing_architecture: ${datamodule.mixing_architecture}
      distribution_config:
        name: "uniform"
        parameters:
          a: 0
          b: 1
      use_constraints: False
      constraints:
        rejection_sampling: True
        fraction: 0.5
        use_fraction: True
        default_low: ${mult:${datamodule.datasets.train.seq_max_length}, ${datamodule.datasets.train.distribution_config.parameters.a}}
        default_high: ${mult:${datamodule.datasets.train.seq_max_length}, ${datamodule.datasets.train.distribution_config.parameters.b}}
        # sum of all positions along each component should be higher than lb, and lower than hb
        # lb:
        #   - '0': ${add:${mult:0.75,${datamodule.datasets.train.constraints.default_low}},${mult:0.25,${datamodule.datasets.train.constraints.default_high}}}
        # hb:
        #   - '0': ${add:${mult:0.25,${datamodule.datasets.train.constraints.default_low}},${mult:0.75,${datamodule.datasets.train.constraints.default_high}}}
        lb:
          - ${add:${mult:0.75,${datamodule.datasets.train.constraints.default_low}},${mult:0.25,${datamodule.datasets.train.constraints.default_high}}}
        hb:
          - ${add:${mult:0.25,${datamodule.datasets.train.constraints.default_low}},${mult:0.75,${datamodule.datasets.train.constraints.default_high}}}

  val:
      _target_: ${datamodule.dataset_target_}
      split: "val"
      seq_min_length: ${datamodule.datasets.train.seq_max_length}
      seq_max_length: ${datamodule.datasets.val.seq_min_length}
      # seq_max_length: ${add_int:${datamodule.datasets.train.seq_max_length},10}
      num_batches: 100
      mixing_type: ${datamodule.mixing_architecture.name}
      mixing_architecture: ${datamodule.mixing_architecture}
      distribution_config:
        name: "uniform"
        parameters:
          a: 0
          b: 1
      use_constraints: ${datamodule.datasets.train.use_constraints}          
      constraints:
        rejection_sampling: False
        fraction: 1.
        use_fraction: True
        default_low: ${datamodule.datasets.train.distribution_config.parameters.a}
        default_high: ${datamodule.datasets.train.distribution_config.parameters.b}
        lb:
          - ${datamodule.datasets.val.constraints.default_low}
          - ${float_div:${add:${mult:0.25,${datamodule.datasets.train.constraints.default_low}},${mult:0.75,${datamodule.datasets.train.constraints.default_high}}},${datamodule.datasets.train.seq_max_length}}
        hb:
          - ${float_div:${add:${mult:0.75,${datamodule.datasets.train.constraints.default_low}},${mult:0.25,${datamodule.datasets.train.constraints.default_high}}},${datamodule.datasets.train.seq_max_length}}
          - ${datamodule.datasets.val.constraints.default_high}
        # lb:
        #   - '0': ${datamodule.datasets.val.constraints.default_low}
        #   - '0': ${float_div:${add:${mult:0.25,${datamodule.datasets.train.constraints.default_low}},${mult:0.75,${datamodule.datasets.train.constraints.default_high}}},${datamodule.datasets.train.seq_max_length}}
        # hb:
        #   - '0': ${float_div:${add:${mult:0.75,${datamodule.datasets.train.constraints.default_low}},${mult:0.25,${datamodule.datasets.train.constraints.default_high}}},${datamodule.datasets.train.seq_max_length}}
        #   - '0': ${datamodule.datasets.val.constraints.default_high}
        # lb:
        #   - '0': ${mult:${datamodule.datasets.train.seq_max_length}, ${datamodule.datasets.train.distribution_config.parameters.a}}
        #   - '0': ${add:${mult:0.25,${datamodule.datasets.train.constraints.default_low}},${mult:0.75,${datamodule.datasets.train.constraints.default_high}}}
        # hb:
        #   - '0': ${add:${mult:0.75,${datamodule.datasets.train.constraints.default_low}},${mult:0.25,${datamodule.datasets.train.constraints.default_high}}}
        #   - '0': ${mult:${datamodule.datasets.train.seq_max_length}, ${datamodule.datasets.train.distribution_config.parameters.b}}

  test:
      _target_: ${datamodule.dataset_target_}
      split: "test"
      seq_min_length: ${datamodule.datasets.train.seq_max_length}
      seq_max_length: ${datamodule.datasets.test.seq_min_length}
      # seq_max_length: ${add_int:${datamodule.datasets.train.seq_max_length},10}
      num_batches: 100
      mixing_type: ${datamodule.mixing_architecture.name}
      mixing_architecture: ${datamodule.mixing_architecture}
      distribution_config:
        name: "uniform"
        parameters:
          a: 0
          b: 1
      use_constraints: ${datamodule.datasets.train.use_constraints}          
      constraints:
        rejection_sampling: False
        fraction: 1.
        use_fraction: True
        default_low: ${datamodule.datasets.train.distribution_config.parameters.a}
        default_high: ${datamodule.datasets.train.distribution_config.parameters.b}
        # sum of all positions along each component should be higher than lb, and lower than hb
        lb:
          - ${datamodule.datasets.val.constraints.default_low}
          - ${float_div:${add:${mult:0.25,${datamodule.datasets.train.constraints.default_low}},${mult:0.75,${datamodule.datasets.train.constraints.default_high}}},${datamodule.datasets.train.seq_max_length}}
        hb:
          - ${float_div:${add:${mult:0.75,${datamodule.datasets.train.constraints.default_low}},${mult:0.25,${datamodule.datasets.train.constraints.default_high}}},${datamodule.datasets.train.seq_max_length}}
          - ${datamodule.datasets.val.constraints.default_high}
        # lb:
        #   - '0': ${datamodule.datasets.test.constraints.default_low}
        #   - '0': ${float_div:${add:${mult:0.25,${datamodule.datasets.train.constraints.default_low}},${mult:0.75,${datamodule.datasets.train.constraints.default_high}}},${datamodule.datasets.train.seq_max_length}}
        # hb:
        #   - '0': ${float_div:${add:${mult:0.75,${datamodule.datasets.train.constraints.default_low}},${mult:0.25,${datamodule.datasets.train.constraints.default_high}}},${datamodule.datasets.train.seq_max_length}}
        #   - '0': ${datamodule.datasets.test.constraints.default_high}
        # lb:
        #   - '0': ${mult:${datamodule.datasets.train.seq_max_length}, ${datamodule.datasets.train.distribution_config.parameters.a}}
        #   - '0': ${add:${mult:0.25,${datamodule.datasets.train.constraints.default_low}},${mult:0.75,${datamodule.datasets.train.constraints.default_high}}}
        # hb:
        #   - '0': ${add:${mult:0.75,${datamodule.datasets.train.constraints.default_low}},${mult:0.25,${datamodule.datasets.train.constraints.default_high}}}
        #   - '0': ${mult:${datamodule.datasets.train.seq_max_length}, ${datamodule.datasets.train.distribution_config.parameters.b}}



activation: torch.nn.Sigmoid # torch.nn.ReLU, torch.nn.LeakyReLU, torch.nn.Sigmoid, torch.nn.Tanh
mixing_architecture:
  name: softmax_attention
  init_mean: 0.0
  init_std: 0.6
  n_layers: 1
  load: False
  phi_individual:
    _target_: src.models.modules.softmax_attention.SoftmaxAttention
    x_dim: ${datamodule.x_dim}
    phi_dim: ${datamodule.phi_dim}

  phi_aggregate:
    _target_: src.models.modules.mlp.MLP
    x_dim: ${datamodule.phi_dim}
    hid_dim: ${datamodule.phi_dim}
    y_dim: ${datamodule.y_dim}
    n_hidden_layers: 1
    activation:
      _target_: ${datamodule.activation}
