datasets:
  feynman_easy:
    name: &dataset_name 'srsd-feynman_medium'
    type: 'EquationTreeDataset'
    splits:
      train:
        dataset_id: &train_split !join [*dataset_name, '/train']
        params:
          tabular_data_file_paths: './resource/datasets/srsd-feynman_medium/train/'
          true_eq_file_paths: &true_eq_file_paths './resource/datasets/srsd-feynman_medium/true_eq/'
          num_samples_per_eq: 1000
          max_num_variables: &max_num_variables 8
      val:
        dataset_id: &val_split !join [*dataset_name, '/val']
        params:
          tabular_data_file_paths: './resource/datasets/srsd-feynman_medium/val/'
          true_eq_file_paths: *true_eq_file_paths
          num_samples_per_eq: 1000
          max_num_variables: *max_num_variables
          uses_sympy_eq: True
      test:
        dataset_id: &test_split !join [*dataset_name, '/test']
        params:
          tabular_data_file_paths: './resource/datasets/srsd-feynman_medium/test/'
          true_eq_file_paths: *true_eq_file_paths
          num_samples_per_eq: 1000
          max_num_variables: *max_num_variables
          uses_sympy_eq: True

models:
  model:
    name: 'SymbolicTransformer'
    params:
      model_config:
        vocabulary:
          max_num_variables: *max_num_variables
          ignored_index: &ignored_index -255
        encoder:
          block_configs:
            - mlp1_configs:
                - type: 'Linear'
                  kwargs:
                    in_features: 1
                    out_features: 256
                    bias: True
                - type: 'Linear'
                  kwargs:
                    in_features: 256
                    out_features: 256
                    bias: True
              mlp2_configs:
                - type: 'Linear'
                  kwargs:
                    in_features: 512
                    out_features: 512
                    bias: True
                - type: 'Linear'
                  kwargs:
                    in_features: 512
                    out_features: 256
                    bias: True
            - mlp1_configs:
                - type: 'Linear'
                  kwargs:
                    in_features: 512
                    out_features: 256
                    bias: True
                - type: 'Linear'
                  kwargs:
                    in_features: 256
                    out_features: 256
                    bias: True
              mlp2_configs:
                - type: 'Linear'
                  kwargs:
                    in_features: 512
                    out_features: 512
                    bias: True
                - type: 'Linear'
                  kwargs:
                    in_features: 512
                    out_features: 256
                    bias: True
        symbol_embedding:
          num_embeddings:
          embedding_dim: &embedding_dim 512
        decoder:
          layer:
            d_model: *embedding_dim
            nhead: 4
            dim_feedforward: *embedding_dim
            batch_first: True
          norm:
          kwargs:
            num_layers: 4
        mlp:
          - type: 'Linear'
            kwargs:
              in_features: *embedding_dim
              out_features: 256
              bias: True
          - type: 'Linear'
            kwargs:
              in_features: 256
              out_features: &out_features 256
              bias: True
        classifier:
          in_features: *out_features
        predict_func: 'default_predict'
        default_max_length: 20
        criterion:
          type: 'CrossEntropyLoss'
          params:
            reduction: 'mean'
            ignore_index: *ignored_index
    experiment: &experiment 'random_set-symbolic_transformer'
    ckpt: !join ['./resource/ckpt/pretraining/', *experiment, '-pretraining.pt']

coeff_estimate:
  num_epochs: 50
  optimizer:
    type: 'Adam'
    params:
      lr: 0.1
  scheduler:
    type: 'MultiStepLR'
    params:
      milestones: [35, 45]
      gamma: 0.1
    scheduling_step: 1
  criterion:
    type: 'RelativeSquaredError'
    params: {}
  pred_eq_output: !join ['./resource/pred_pickles/', *dataset_name]

test:
  test_data_loader:
    dataset_id: *test_split
    random_sample: False
    batch_size: 1
    num_workers: 8
    collate_fn: 'default_collate_w_sympy'
  pred_tree_output: !join ['./resource/trees/test/', *experiment]
  true_tree_output: './resource/trees/test/true_eq/'
