$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
type: command

name: compute-utility
display_name: compute utility

description: "Compute utility"

inputs:
  nproc_per_node:
    type: integer
    description: "Number of processes per node. Usually this should be the number of GPUs per node"
  model_name:
    type: string
    description: "Name of the HF model to use for sequence classification. Mutually exclusive with the option 'model_path'"
    default: "roberta-base"
    optional: true
  is_synthetic:
    type: boolean
    description: "Whether the training data is synthetic or not."
  utility_train_data_path:  
    type: uri_folder
    description: "Path to the eval dataset in Huggingface dataset format."
  train_label_name: 
    type: string
    description: "Name of the label column in the training dataset"
    default: "Prompt"
  train_text_name:
    type: string
    description: "Name of the text column in the training dataset"
    default: "Generation"
  utility_eval_data_path:
    type: uri_folder
    description: "Path to the eval dataset in Huggingface dataset format."
  eval_label_name:
    type: string
    description: "Name of the label column in the eval dataset"
    default: "label"
  eval_text_name:
    type: string
    description: "Name of the text column in the eval dataset"
    default: "sentence"
  sequence_len:
    type: integer
    description: "Maximum token sequence length."
  per_device_train_batch_size:
    type: integer
    description: "Batch size per device for training."
  gradient_accumulation_steps:
    type: integer
    description: "Number of gradient accumulation steps."
  evaluation_strategy:
    type: string
    description: "Evaluation strategy."
    default: "steps"
  evaluation_steps:
    type: integer
    description: "Evaluation steps."
    default: 500
  save_strategy:
    type: string
    description: "Save strategy."
    default: "no"
  log_level:
    type: string
    description: "Log level."
    default: info
  seed:
    type: integer
    description: "Random seed."
  weight_decay:
    type: number
    description: "Weight decay."
    default: 0.01
  num_train_epochs:
    type: integer
    description: "Number of training epochs."
  logging_steps:
    type: integer
    description: "Log every x steps."
    default: 5
  max_grad_norm:
    type: number
    description: "Max batch gradient norm."
    default: 0
  lr_scheduler_type:
    type: string
    description: "Learning rate scheduler type."
    default: constant
  learning_rate:
    type: number
    description: "Learning rate."
  dataloader_num_workers:
    type: integer
    description: "Number of workers for data loader."
    default: 2
outputs:
  synthetic_data_prep:
    type: uri_file
    description: "Path to the synthetic data preparation output."
  output_dir:
    type: uri_folder
    description: Output directory

code: ./
additional_includes:
  - "../../../../setup.py"
  - "../../../../src"
  - "../../../../README.md"

command: >-
  python -m pip install -e . && python -m torch.distributed.run --nproc_per_node ${{inputs.nproc_per_node}} compute_utility.py \
    --output_dir ${{outputs.output_dir}} \
    --synthetic_data_prep ${{outputs.synthetic_data_prep}} \
    $[[ --model_name ${{inputs.model_name}} ]] \
    --utility_train_data_path ${{inputs.utility_train_data_path}} \
    --train_label_name ${{inputs.train_label_name}} \
    --train_text_name ${{inputs.train_text_name}} \
    --is_synthetic ${{inputs.is_synthetic}} \
    --utility_eval_data_path ${{inputs.utility_eval_data_path}} \
    --eval_label_name ${{inputs.eval_label_name}} \
    --eval_text_name ${{inputs.eval_text_name}} \
    --sequence_len ${{inputs.sequence_len}} \
    --per_device_train_batch_size ${{inputs.per_device_train_batch_size}} \
    --gradient_accumulation_steps ${{inputs.gradient_accumulation_steps}} \
    --evaluation_strategy ${{inputs.evaluation_strategy}} \
    --save_strategy ${{inputs.save_strategy}} \
    --log_level ${{inputs.log_level}} \
    --seed ${{inputs.seed}} \
    --weight_decay ${{inputs.weight_decay}} \
    --num_train_epochs ${{inputs.num_train_epochs}} \
    --logging_steps ${{inputs.logging_steps}} \
    --max_grad_norm ${{inputs.max_grad_norm}} \
    --lr_scheduler_type ${{inputs.lr_scheduler_type}} \
    --learning_rate ${{inputs.learning_rate}} \
    --disable_tqdm True \
    --dataloader_num_workers ${{inputs.dataloader_num_workers}} \
    --gradient_checkpointing

environment:
  image: mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.6-cudnn8-ubuntu20.04
  conda_file: ./environment.yml

