$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
type: command

name: run_inference_llm
display_name: Run inference on LLM
version: 11

description: "Train an LLM"

inputs:
  base_model:
    type: uri_folder
    description: "Path to the model to use in HF format"
  peft:
    type: uri_folder
    description: "Path to PEFT weights"
    optional: true
  tokenized_data:
    type: uri_folder
    description: "Path to the tokenized training data in HF format."
  per_device_batch_size:
    type: integer
    description: "Batch size per device"
    type: integer
  log_level:
    type: string
    description: "Log level."
    default: info
  mi_signal_method:
    type: string
    description: "Choose which membership inference signal method"
    optional: true
  mi_signal_extra_args:
    type: string
    description: "Extra MI signal args"
    optional: true
  mi_signal_aggregation:
    type: string
    description: "Aggregation method for MI signal over sequence"
    optional: true
  disable_distributed:
    type: boolean
    description: "Disable distributed training"
    default: true
outputs:
  predictions:
    type: uri_folder
    description: Output directory

code: ./

command: >-
  python inference.py \
    --base_model_path ${{inputs.base_model}} \
    $[[ --peft_path ${{inputs.peft}} ]] \
    --per_device_batch_size ${{inputs.per_device_batch_size}} \
    --tokenized_data_path ${{inputs.tokenized_data}} \
    --predictions_path ${{outputs.predictions}} \
    $[[ --mi_signal_method ${{inputs.mi_signal_method}} ]] \
    $[[ --mi_signal_extra_args '${{inputs.mi_signal_extra_args}}' ]] \
    $[[ --mi_signal_aggregation ${{inputs.mi_signal_aggregation}} ]] \
    --log_level ${{inputs.log_level}} \
    --disable_distributed ${{inputs.disable_distributed}} 

environment:
  image: mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.6-cudnn8-ubuntu20.04
  conda_file: ./environment.yml
