$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
type: command

name: llm_synthetic_data_generation
display_name: LLM Synthetic Data Generation
version: 4

description: "Generate synthetic data samples from an LLM fine-tuned with QLoRA."

inputs:
  nproc_per_node:
    type: integer
    description: "Number of processes per node. Usually this should be the number of GPUs per node"
  model_name:
    type: string
    description: "Name of the HF model to use. Mutually exclusive with the option 'model_path'"
    optional: true
  model_path:
    type: uri_folder
    description: "Path to the model to use in HF format. Mutually exclusive with the option 'model_name'"
    optional: true
  lora_path:
    type: uri_folder
    description: "Path to the lora weights to use in HF format."
    optional: true
  train_data_path:
    type: uri_file
    description: "Path to the training data in jsonl format"
  batch_size:
    type: integer
    description: "Batch size per device for generations."
  seed:
    type: integer
    description: "Random seed."
  load_dtype:
    type: string
    description: "Non-quantized model parameters load dtype (fp32, fp16, bf16)"
    default: "fp32"
  max_new_tokens:
    type: integer
    description: "Maximum number of new tokens to generate."
  synthetic_multiple:
    type: integer
    description: "Number of synthetic samples to generate per real sample."
outputs:
  output_dir:
    type: uri_folder
    description: Output directory
  
code: ./
additional_includes:
  - "../../setup.py"
  - "../../src"
  - "../../README.md"


command: >-
  python -m pip install -e . && accelerate launch --multi_gpu --num_processes ${{inputs.nproc_per_node}} generate.py \
    --output_dir ${{outputs.output_dir}} \
    $[[ --model_name ${{inputs.model_name}} ]] \
    $[[ --model_name ${{inputs.model_path}} ]] \
    $[[ --lora_path ${{inputs.lora_path}} ]] \
    --train_data_path ${{inputs.train_data_path}} \
    --batch_size ${{inputs.batch_size}} \
    --seed ${{inputs.seed}} \
    --load_dtype ${{inputs.load_dtype}} \
    --max_new_tokens ${{inputs.max_new_tokens}} \
    --synthetic_multiple ${{inputs.synthetic_multiple}} \
    --enable_lora

environment:
  image: mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.6-cudnn8-ubuntu20.04
  conda_file: ./environment.yml
