$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json
type: command

name: train_llm
display_name: Train an LLM
version: 11

description: "Train an LLM"

inputs:
  model_name:
    type: string
    description: "Name of the HF model to use. Mutually exclusive with the option 'model_path'"
    optional: true
  model_path:
    type: uri_folder
    description: "Path to the model to use in HF format. Mutually exclusive with the option 'model_name'"
    optional: true
  tokenized_train_data:
    type: uri_folder
    description: "Path to the tokenized training data in HF format."
  tokenized_validation_data:
    type: uri_folder
    description: "Path to the tokenized validation data in HF format."
    optional: true
  quantization_4bit:
    type: boolean
    description: "Whether to apply 4bit quantization for the base model."
    default: false
  torch_dtype:
    type: string
    description: "Torch dtype to use for the model."
    default: "fp32"
  trust_remote_code:
    type: boolean
    description: "Whether to trust remote code when loading model from HuggingFace."
    default: false
  num_samples:
    type: integer
    description: "Number of samples to use from the dataset. 0 means using all samples."
    default: 0
  per_device_train_batch_size:
    type: integer
    description: "Batch size per device for training."
  gradient_accumulation_steps:
    type: integer
    description: "Number of gradient accumulation steps."
    default: 1
  evaluation_strategy:
    type: string
    description: "Evaluation strategy."
    default: "no"
  eval_steps:
    type: integer
    description: "Number of evaluation steps."
    default: 500
  label_names:
    type: string
    description: "Label names."
    default: None
  save_strategy:
    type: string
    description: "Evaluation strategy."
    default: "no"
  log_level:
    type: string
    description: "Log level."
    default: info
  log_tokens_per_second:
    type: boolean
    description: "Whether to log tokens per second. (May slow down distributed training.)"
    default: false
  log_num_input_tokens_seen:
    type: boolean
    description: "Whether to log number of input tokens seen. (May slow down distributed training.)"
    default: false
  seed:
    type: integer
    description: "Random seed."
  weight_decay:
    type: number
    description: "Weight decay."
    default: 0.01
  adam_beta1:
    type: number
    description: "Adam beta1."
    default: 0.9
  adam_beta2:
    type: number
    description: "Adam beta2."
    default: 0.999
  remove_unused_columns:
    type: boolean
    description: "Remove unused columns before passing them to the model."
    default: false
  num_train_epochs:
    type: number
    description: "Number of training epochs."
  logging_steps:
    type: integer
    description: "Log every x steps."
    default: 20
  max_grad_norm:
    type: number
    description: "Max batch gradient norm."
    default: 0
  lr_scheduler_type:
    type: string
    description: "Learning rate scheduler type."
    default: constant
  warmup_steps:
    type: integer
    description: "Number of warmup steps."
    default: 0
  learning_rate:
    type: number
    description: "Learning rate."
  dataloader_num_workers:
    type: integer
    description: "Number of workers for data loader."
    default: 2
  enable_lora:
    type: boolean
    description: "Whether to enable LoRA."
  lora_dim:
    type: integer
    description: "LoRA dimension. (For more details see `r` in https://huggingface.co/docs/peft/main/en/package_reference/tuners#peft.LoraConfig)"
    default: 4
  lora_alpha:
    type: integer
    description: "LoRA alpha. (For more details see `lora_alpha` in https://huggingface.co/docs/peft/main/en/package_reference/tuners#peft.LoraConfig)"
    default: 32
  lora_dropout:
    type: number
    description: "LoRA dropout. (For more details see `lora_dropout` in https://huggingface.co/docs/peft/main/en/package_reference/tuners#peft.LoraConfig)"
    default: 0.0
  target_modules:
    type: string
    description: "Target modules to apply LoRA to."
    default: "['q_proj', 'v_proj']"
  bf16:
    type: boolean
    description: "Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training."
    default: false
  fp16:
    type: boolean
    description: "Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training."
    default: false
  gradient_checkpointing:
    type: boolean
    description: "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
    default: false
  ddp_find_unused_parameters:
    type: boolean
    description: "If True, use DDP's find_unused_parameters flag."
    default: false
  use_flash_attention:
    type: boolean
    description: "Whether to use FlashAttention."
    default: true
  fsdp:
    type: string
    optional: true
    description: |
      Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training
       only). The base option should be `full_shard`, `shard_grad_op` or `no_shard` and you can add
       CPU-offload to `full_shard` or `shard_grad_op` like this: full_shard offload` or `shard_grad_op
       offload`. You can add auto-wrap to `full_shard` or `shard_grad_op` with the same syntax: full_shard
       auto_wrap` or `shard_grad_op auto_wrap`. Allowed parameters are:
            - `"full_shard"`: Shard parameters, gradients and optimizer states.
            - `"shard_grad_op"`: Shard optimizer states and gradients.
            - `"hybrid_shard"`: Apply `FULL_SHARD` within a node, and replicate parameters across nodes.
            - `"hybrid_shard_zero2"`: Apply `SHARD_GRAD_OP` within a node, and replicate parameters across nodes.
            - `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and
              `"shard_grad_op"`).
            - `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`.
    enum:
      - "full_shard"
      - "full_shard offload"
      - "full_shard auto_wrap"
      - "shard_grad_op"
      - "shard_grad_op auto_wrap"
      - "hybrid_shard"
      - "hybrid_shard_zero2"
  fsdp_activation_checkpointing:
    type: boolean
    optional: true
    description: |
      Activation checkpointing is a technique to reduce memory usage by clearing activations of
      certain layers and recomputing them during a backward pass. Effectively, this trades extra
      computation time for reduced memory usage. This can only be used with FSDP training.
outputs:
  output_dir:
    type: uri_folder
    description: Output directory

# Both of the following parameters can be overridden by the user by setting it in the pipeline
resources:
  instance_count: 1
distribution:
  type: pytorch
  process_count_per_instance: 1


code: ./
additional_includes:
  - "../../setup.py"
  - "../../src"
  - "../../README.md"


command: >-
  ./pip-install-local.sh -e . && python train.py \
    --output_dir ${{outputs.output_dir}} \
    $[[ --model_name ${{inputs.model_name}} ]] \
    $[[ --model_name ${{inputs.model_path}} ]] \
    --tokenized_train_data_path ${{inputs.tokenized_train_data}} \
    $[[ --tokenized_validation_data_path ${{inputs.tokenized_validation_data}} ]] \
    --quantization_4bit ${{inputs.quantization_4bit}} \
    --torch_dtype ${{inputs.torch_dtype}} \
    --trust_remote_code ${{inputs.trust_remote_code}} \
    --num_samples ${{inputs.num_samples}} \
    --per_device_train_batch_size ${{inputs.per_device_train_batch_size}} \
    --gradient_accumulation_steps ${{inputs.gradient_accumulation_steps}} \
    --evaluation_strategy ${{inputs.evaluation_strategy}} \
    --eval_steps ${{inputs.eval_steps}} \
    --label_names ${{inputs.label_names}} \
    --save_strategy ${{inputs.save_strategy}} \
    --log_level ${{inputs.log_level}} \
    --seed ${{inputs.seed}} \
    --weight_decay ${{inputs.weight_decay}} \
    --remove_unused_columns ${{inputs.remove_unused_columns}} \
    --num_train_epochs ${{inputs.num_train_epochs}} \
    --logging_steps ${{inputs.logging_steps}} \
    --max_grad_norm ${{inputs.max_grad_norm}} \
    --lr_scheduler_type ${{inputs.lr_scheduler_type}} \
    --warmup_steps ${{inputs.warmup_steps}} \
    --learning_rate ${{inputs.learning_rate}} \
    --adam_beta1 ${{inputs.adam_beta1}} \
    --adam_beta2 ${{inputs.adam_beta2}} \
    --disable_tqdm True \
    --dataloader_num_workers ${{inputs.dataloader_num_workers}} \
    --enable_lora ${{inputs.enable_lora}} \
    --lora_dim ${{inputs.lora_dim}} \
    --lora_alpha ${{inputs.lora_alpha}} \
    --lora_dropout ${{inputs.lora_dropout}} \
    --target_modules "${{inputs.target_modules}}" \
    --bf16 ${{inputs.bf16}} \
    --fp16 ${{inputs.fp16}} \
    --gradient_checkpointing ${{inputs.gradient_checkpointing}} \
    --ddp_find_unused_parameters ${{inputs.ddp_find_unused_parameters}} \
    --include_tokens_per_second ${{inputs.log_tokens_per_second}} \
    --include_num_input_tokens_seen ${{inputs.log_num_input_tokens_seen}} \
    --use_flash_attention ${{inputs.use_flash_attention}} \
    $[[ --fsdp "${{inputs.fsdp}}" ]] \
    $[[ --fsdp_activation_checkpointing ${{inputs.fsdp_activation_checkpointing}} ]] \
    --report_to mlflow

environment:
  image: mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.8-cudnn8-ubuntu22.04
  conda_file: ./environment.yml
