defaults:
  - _preprocess
  - _self_
input_dir: ${oc.env:MEDS_DIR}
cohort_dir: ${oc.env:MODEL_DIR}
stages:
  - custom_time_token
  - count_codes
  - filter_measurements
  - filter_subjects
  - fit_quantile_binning
  - quantile_binning
  - reorder_measurments
  - cache_decoding_metadata
  - split_quantiles
  - fit_normalization
  - fit_vocabulary_indices
  - normalization
  - tokenization
  - tensorization

stage_configs:
  custom_time_token:
    _script: "python -m meds_torch.utils.custom_time_token"
    time_delta:
      time_unit: years
  count_codes:
    _script: MEDS_transform-aggregate_code_metadata
    aggregations:
      - code/n_occurrences
      - code/n_subjects
  fit_quantile_binning:
    _script: MEDS_transform-aggregate_code_metadata
    aggregations:
      - code/n_occurrences
      - code/n_subjects
      - name: values/quantiles
        quantiles: [0.25, 0.5, 0.75]
  cache_decoding_metadata:
    _script: MEDS_transform-aggregate_code_metadata
    aggregations:
      - values/sum
      - values/sum_sqd
      - values/n_occurrences
  fit_normalization:
    _script: MEDS_transform-aggregate_code_metadata
    aggregations:
      - code/n_occurrences
      - values/min
      - values/max
      - values/sum
      - values/sum_sqd
      - values/n_occurrences
  reorder_measurments:
    _script: "MEDS_transform-reorder_measurements"
    order: $(python -m meds_torch.utils.get_all_measurements metadata_fp=${oc.env:MODEL_DIR}/metadata/codes.parquet)
  split_quantiles:
    _script: "python -m meds_torch.utils.split_quantiles"
    ignore_regex: "^TIME//DELTA//TOKEN.*$"
  quantile_binning:
    _script: "python -m meds_torch.utils.quantile_binning"
    custom_quantiles:
      TIME//DELTA//TOKEN:
        values/quantile/1: 0.1
        values/quantile/2: 1
        values/quantile/3: 30
        values/quantile/4: 365
  tokenization:
    _script: "python -m meds_torch.utils.custom_tokenization"
