defaults:
  - _preprocess
  - _self_
input_dir: ${oc.env:MEDS_DIR}
cohort_dir: ${oc.env:MODEL_DIR}
stages:
  - custom_time_token
  - count_codes
  - custom_filter_measurements
  - filter_subjects
  - fit_quantile_binning
  - quantile_binning
  - reorder_measurments
  - fit_normalization
  - fit_vocabulary_indices
  - normalization
  - tokenization
  - tensorization

stage_configs:
  custom_time_token:
    _script: "python -m meds_torch.utils.custom_time_token"
    time_delta:
      time_unit: years
  filter_subjects:
    min_events_per_subject: 10
    min_measurements_per_subject: 10
  count_codes:
    _script: MEDS_transform-aggregate_code_metadata
    aggregations:
      - code/n_occurrences
      - code/n_subjects
  custom_filter_measurements:
    _script: "python -m meds_torch.utils.custom_filter_measurements"
    additional_codes: ["^MEDS_DEATH$", "^MEDS_BIRTH$", "^HOSPITAL_ADMISSION//.*", "^HOSPITAL_DISCHARGE//.*", "^ED_REGISTRATION//.*", "^ED_OUT//.*", "^ICU_ADMISSION//.*", "^ICU_DISCHARGE//.*", "^TIME.*$", "^GENDER//.*$", "LAB//52546//mg/dL|LAB//50912//mg/dL|LAB//51081//mg/dL|LAB//52024//mg/dL", "LAB//51645//g/dL|LAB//51222//g/dL|LAB//50811//g/dL|LAB//50855//UNK", "LAB//51221//%|LAB//51638//%|LAB//50810//%", "LAB//51300//K/uL|LAB//51301//K/uL|LAB//51755//K/uL", "LAB//51265//K/uL"]
    min_subjects_per_code: 1000000
  reorder_measurments:
    _script: "MEDS_transform-reorder_measurements"
    order: $(python -m meds_torch.utils.get_all_measurements metadata_fp=${oc.env:MODEL_DIR}/metadata/codes.parquet)
  fit_normalization:
    _script: MEDS_transform-aggregate_code_metadata
    aggregations:
      - code/n_occurrences
      - values/min
      - values/max
      - values/sum
      - values/sum_sqd
      - values/n_occurrences
  fit_quantile_binning:
    _script: MEDS_transform-aggregate_code_metadata
    aggregations:
      - code/n_occurrences
      - code/n_subjects
      - name: values/quantiles
        quantiles: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
  quantile_binning:
    _script: "python -m meds_torch.utils.quantile_binning"
    custom_quantiles:
      TIME//DELTA//TOKEN:
        values/quantile/1: 0.00000190258 # 1 minute
        values/quantile/2: 0.00000951293 # 5 minutes
        values/quantile/3: 0.00001902587 # 10 minutes
        values/quantile/4: 0.00005707762 # 30 minutes
        values/quantile/5: 0.00011415525 # 1 hour
        values/quantile/6: 0.00034246575 # 3 hours
        values/quantile/7: 0.0006849315 # 6 hours
        values/quantile/8: 0.00136986301 # 12 hours
        values/quantile/9: 0.00273972602 # 1 day
        values/quantile/10: 0.00547945205 # 2 days
        values/quantile/11: 0.0109589041 # 4 days
        values/quantile/12: 0.01917808219 # 7 days
        values/quantile/13: 0.03835616438 # 14 days
        values/quantile/14: 0.08219178082 # 30 days
        values/quantile/15: 0.16438356164 # 60 days
        values/quantile/16: 0.32876712328 # 120 days
        values/quantile/17: 1 # 1 year
        values/quantile/18: 2 # 2 years
        values/quantile/19: 5 # 5 years
        values/quantile/20: 10 # 10 years
        values/quantile/21: 20 # 20 years
        values/quantile/22: 40 # 40 years
  tokenization:
    _script: "python -m meds_torch.utils.custom_tokenization"
