datasets:
  random_eq:
    name: &dataset_name 'random_set'
    type: 'EquationTreeDataset'
    splits:
      train:
        dataset_id: &train_split !join [*dataset_name, '/train']
        params:
          tabular_data_file_paths: './resource/datasets/random_set/train/'
          true_eq_file_paths: &random_true_eq_file_paths './resource/datasets/random_set/true_eq/'
          num_samples_per_eq: 1000
          max_num_variables: &max_num_variables 8
      val:
        dataset_id: &val_split !join [*dataset_name, '/val']
        params:
          tabular_data_file_paths: './resource/datasets/random_set/val/'
          true_eq_file_paths: *random_true_eq_file_paths
          num_samples_per_eq: 1000
          max_num_variables: *max_num_variables
          uses_sympy_eq: True
      test:
        dataset_id: &test_split !join [*dataset_name, '/test']
        params:
          tabular_data_file_paths: './resource/datasets/random_set/test/'
          true_eq_file_paths: *random_true_eq_file_paths
          num_samples_per_eq: 1000
          max_num_variables: *max_num_variables
          uses_sympy_eq: True

models:
  model:
    name: 'SymbolicTransformer'
    params:
      model_config:
        vocabulary:
          max_num_variables: *max_num_variables
          ignored_index: &ignored_index -255
        encoder:
          block_configs:
            - mlp1_configs:
                - type: 'Linear'
                  kwargs:
                    in_features: 1
                    out_features: 256
                    bias: True
                - type: 'Linear'
                  kwargs:
                    in_features: 256
                    out_features: 256
                    bias: True
              mlp2_configs:
                - type: 'Linear'
                  kwargs:
                    in_features: 512
                    out_features: 512
                    bias: True
                - type: 'Linear'
                  kwargs:
                    in_features: 512
                    out_features: 256
                    bias: True
            - mlp1_configs:
                - type: 'Linear'
                  kwargs:
                    in_features: 512
                    out_features: 256
                    bias: True
                - type: 'Linear'
                  kwargs:
                    in_features: 256
                    out_features: 256
                    bias: True
              mlp2_configs:
                - type: 'Linear'
                  kwargs:
                    in_features: 512
                    out_features: 512
                    bias: True
                - type: 'Linear'
                  kwargs:
                    in_features: 512
                    out_features: 256
                    bias: True
        symbol_embedding:
          num_embeddings:
          embedding_dim: &embedding_dim 512
        decoder:
          layer:
            d_model: *embedding_dim
            nhead: 4
            dim_feedforward: *embedding_dim
            batch_first: True
          norm:
          kwargs:
            num_layers: 4
        mlp:
          - type: 'Linear'
            kwargs:
              in_features: *embedding_dim
              out_features: 256
              bias: True
          - type: 'Linear'
            kwargs:
              in_features: 256
              out_features: &out_features 256
              bias: True
        classifier:
          in_features: *out_features
        predict_func: 'default_predict'
        default_max_length: 20
        criterion:
          type: 'CrossEntropyLoss'
          params:
            reduction: 'mean'
            ignore_index: *ignored_index
    experiment: &experiment !join [*dataset_name, '-symbolic_transformer']
    ckpt: !join ['./resource/ckpt/pretraining/', *experiment, '-pretraining.pt']

train:
  log_freq: 10
  num_epochs: &num_epochs 3
  train_data_loader:
    dataset_id: *train_split
    random_sample: True
    batch_size: 1
    num_workers: 8
    requires_supp: False
    collate_fn: 'default_collate_w_sympy'
    cache_output:
  val_data_loader:
    dataset_id: *val_split
    random_sample: False
    batch_size: 1
    num_workers: 8
    collate_fn: 'default_collate_w_sympy'
  model:
    forward_proc: 'forward_batch_target'
    forward_hook:
      input: []
      output: []
    wrapper: 'DistributedDataParallel'
    requires_grad: True
  apex:
    requires: False
    opt_level: '01'
  optimizer:
    type: 'SGD'
    params:
      lr: 0.1
      weight_decay: 0.0001
    max_grad_norm: 1.0
    grad_accum_step: 1
  scheduler:
    type: 'get_linear_schedule_with_warmup'
    params:
      num_warmup_steps: 1000
      num_training_steps:
    scheduling_step: 1
  criterion:
    type: 'GeneralizedCustomLoss'
    func2extract_org_loss: 'extract_internal_org_loss'
    org_term:
      factor: 1.0
    sub_terms:

test:
  test_data_loader:
    dataset_id: *test_split
    random_sample: False
    batch_size: 1
    num_workers: 8
    collate_fn: 'default_collate_w_sympy'
  pred_tree_output: !join ['./resource/trees/test/', *experiment]
  true_tree_output: './resource/trees/test/true_eq/'
