defaults:
  - _preprocess
  - _self_
input_dir: ${oc.env:MEDS_DIR}
cohort_dir: ${oc.env:MODEL_DIR}
stages:
  - custom_time_token
  - count_codes
  - custom_filter_measurements
  - filter_subjects
  - fit_quantile_binning
  - quantile_binning
  - reorder_measurments
  - cache_decoding_metadata
  - split_quantiles
  - fit_normalization
  - fit_vocabulary_indices
  - normalization
  - tokenization
  - tensorization

stage_configs:
  filter_subjects:
    min_events_per_subject: 10
    min_measurements_per_subject: 10
  custom_filter_measurements:
    _script: "python -m meds_torch.utils.custom_filter_measurements"
    additional_codes: ["^MEDS_DEATH$", "^MEDS_BIRTH$", "^HOSPITAL_ADMISSION//.*", "^HOSPITAL_DISCHARGE//.*", "^ED_REGISTRATION//.*", "^ED_OUT//.*", "^ICU_ADMISSION//.*", "^ICU_DISCHARGE//.*", "^TIME.*$", "^GENDER//.*$", '^LAB//51221//%$', '^LAB//51265//K/uL$', '^LAB//51222//g/dL$', '^LAB//51301//K/uL$', '^LAB//51248//pg$', '^LAB//51250//fL$', '^LAB//51279//m/uL$', '^LAB//51277//%$', '^LAB//50912//mg/dL$', '^LAB//50920//UNK$']
    min_subjects_per_code: 1000000
  custom_time_token:
    _script: "python -m meds_torch.utils.custom_time_token"
    time_delta:
      time_unit: years
  count_codes:
    _script: MEDS_transform-aggregate_code_metadata
    aggregations:
      - code/n_occurrences
      - code/n_subjects
  fit_quantile_binning:
    _script: MEDS_transform-aggregate_code_metadata
    aggregations:
      - code/n_occurrences
      - code/n_subjects
      - name: values/quantiles
        quantiles: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
  cache_decoding_metadata:
    _script: MEDS_transform-aggregate_code_metadata
    aggregations:
      - values/sum
      - values/sum_sqd
      - values/n_occurrences
  fit_normalization:
    _script: MEDS_transform-aggregate_code_metadata
    aggregations:
      - code/n_occurrences
      - values/min
      - values/max
      - values/sum
      - values/sum_sqd
      - values/n_occurrences
  reorder_measurments:
    _script: "MEDS_transform-reorder_measurements"
    order: $(python -m meds_torch.utils.get_all_measurements metadata_fp=${oc.env:MODEL_DIR}/metadata/codes.parquet)
  split_quantiles:
    _script: "python -m meds_torch.utils.split_quantiles"
    ignore_regex: "^TIME//DELTA//TOKEN.*$"
  quantile_binning:
    _script: "python -m meds_torch.utils.quantile_binning"
    custom_quantiles:
      TIME//DELTA//TOKEN:
        values/quantile/1: 0.00000190258 # 1 minute
        values/quantile/2: 0.00000951293 # 5 minutes
        values/quantile/3: 0.00001902587 # 10 minutes
        values/quantile/4: 0.00005707762 # 30 minutes
        values/quantile/5: 0.00011415525 # 1 hour
        values/quantile/6: 0.00034246575 # 3 hours
        values/quantile/7: 0.0006849315 # 6 hours
        values/quantile/8: 0.00136986301 # 12 hours
        values/quantile/9: 0.00273972602 # 1 day
        values/quantile/10: 0.00547945205 # 2 days
        values/quantile/11: 0.0109589041 # 4 days
        values/quantile/12: 0.01917808219 # 7 days
        values/quantile/13: 0.03835616438 # 14 days
        values/quantile/14: 0.08219178082 # 30 days
        values/quantile/15: 0.16438356164 # 60 days
        values/quantile/16: 0.32876712328 # 120 days
        values/quantile/17: 1 # 1 year
        values/quantile/18: 2 # 2 years
        values/quantile/19: 5 # 5 years
        values/quantile/20: 10 # 10 years
        values/quantile/21: 20 # 20 years
        values/quantile/22: 40 # 40 years
  tokenization:
    _script: "python -m meds_torch.utils.custom_tokenization"
