$schema: https://azuremlschemas.azureedge.net/latest/pipelineJob.schema.json
display_name: Run inference on LLM
name: dp_transformers_privacy_auditing_inference
type: pipeline

inputs:
  base_model:
    type: uri_folder
  peft:
    type: uri_folder
    optional: true
  data:
    type: uri_file
  per_device_batch_size:
    type: integer
  sequence_len:
    type: integer
  text_column:
    type: string
  label_column:
    type: string
  templated_prompt:
    type: string
  mi_signal_method:
    type: string
    optional: true
  mi_signal_extra_args:
    type: string
    optional: true
  mi_signal_aggregation:
    type: string
    optional: true
  pack_dataset:
    type: boolean
    default: False
  disable_distributed:
    type: boolean
    default: true

outputs:
  predictions:
    type: uri_folder

jobs:
  preprocess:
    type: command
    component: ../components/preprocess_for_synthesizer/component_spec.yml
    inputs:
      input_data_path: ${{parent.inputs.data}}
      templated_prompt: ${{parent.inputs.templated_prompt}}
      label_column: ${{parent.inputs.label_column}}
      text_column: ${{parent.inputs.text_column}}
    outputs:
      output_data_path:
        mode: "rw_mount"
  tokenize:
    type: command
    component: ../../../components/tokenize/component_spec.yml
    inputs:
      tokenizer_path: ${{parent.inputs.base_model}}
      data: ${{parent.jobs.preprocess.outputs.output_data_path}}
      packed_dataset: ${{parent.inputs.pack_dataset}}
      sequence_len: ${{parent.inputs.sequence_len}}
      chat_format: False
    compute: azureml:D64sv3
    outputs:
      tokenized_data:
        mode: "rw_mount"
  inference:
    type: command
    component: ../../../components/inference/component_spec.yml
    inputs:
      base_model: ${{parent.inputs.base_model}}
      peft: ${{parent.inputs.peft}}
      tokenized_data: ${{parent.jobs.tokenize.outputs.tokenized_data}}
      per_device_batch_size: ${{parent.inputs.per_device_batch_size}}
      mi_signal_method: ${{parent.inputs.mi_signal_method}}
      mi_signal_extra_args: ${{parent.inputs.mi_signal_extra_args}}
      mi_signal_aggregation: ${{parent.inputs.mi_signal_aggregation}}
      disable_distributed: ${{parent.inputs.disable_distributed}}
    outputs:
      predictions: ${{parent.outputs.predictions}}
