$schema: https://azuremlschemas.azureedge.net/latest/pipelineJob.schema.json
display_name: Train LLM
name: dp_transformers_privacy_auditing_train
type: pipeline

inputs:
  base_model:
    type: uri_folder
  train_data:
    type: uri_file
  val_data:
    type: uri_file
  learning_rate:
    type: number
    default: 5e-4
  num_train_epochs:
    type: number
    default: 3.0  
  seed:
    type: integer
  trust_remote_code:
    type: boolean
    default: false
  sequence_len:
    type: integer
    default: 1024
  pack_dataset:
    type: boolean
    default: False
  completion_column:
    type: string
    default: "sentence"
  lr_scheduler_type:
    type: string
    default: "constant"
  per_device_train_batch_size:
    type: integer
    default: 8
  gradient_accumulation_steps:
    type: integer
    default: 1

outputs:
  model:
    type: uri_folder

jobs:
  convert_to_jsonl_train:
    type: command
    component: ../../../components/convert_hf_dataset_to_jsonl/component_spec.yml
    inputs:
      input: ${{parent.inputs.train_data}}
      completion_column: ${{parent.inputs.completion_column}}
    outputs:
      output:
        mode: "rw_mount"
  convert_to_jsonl_val:
    type: command
    component: ../../../components/convert_hf_dataset_to_jsonl/component_spec.yml
    inputs:
      input: ${{parent.inputs.val_data}}
      completion_column: ${{parent.inputs.completion_column}}
    outputs:
      output:
        mode: "rw_mount"
  tokenize_train:
    type: command
    component: ../../../components/tokenize/component_spec.yml
    inputs:
      tokenizer_path: ${{parent.inputs.base_model}}
      trust_remote_code: ${{parent.inputs.trust_remote_code}}
      data: ${{parent.jobs.convert_to_jsonl_train.outputs.output}}
      packed_dataset: ${{parent.inputs.pack_dataset}}
      sequence_len: ${{parent.inputs.sequence_len}}
      chat_format: False
    compute: azureml:D64sv3
    outputs:
      tokenized_data:
        mode: "rw_mount"
  tokenize_val:
    type: command
    component: ../../../components/tokenize/component_spec.yml
    inputs:
      tokenizer_path: ${{parent.inputs.base_model}}
      trust_remote_code: ${{parent.inputs.trust_remote_code}}
      data: ${{parent.jobs.convert_to_jsonl_val.outputs.output}}
      packed_dataset: ${{parent.inputs.pack_dataset}}
      sequence_len: ${{parent.inputs.sequence_len}}
      chat_format: False
    compute: azureml:D64sv3
    outputs:
      tokenized_data:
        mode: "rw_mount"
  train:
    type: command
    component: ../../../components/train/component_spec.yml
    inputs:
      model_path: ${{parent.inputs.base_model}}
      tokenized_train_data: ${{parent.jobs.tokenize_train.outputs.tokenized_data}}
      tokenized_validation_data: ${{parent.jobs.tokenize_val.outputs.tokenized_data}}
      trust_remote_code: ${{parent.inputs.trust_remote_code}}
      num_samples: 0
      per_device_train_batch_size: ${{parent.inputs.per_device_train_batch_size}}
      gradient_accumulation_steps: ${{parent.inputs.gradient_accumulation_steps}}
      evaluation_strategy: "epoch"
      save_strategy: "no"
      log_level: "info"
      seed: ${{parent.inputs.seed}}
      weight_decay: 0.1
      remove_unused_columns: False
      num_train_epochs: ${{parent.inputs.num_train_epochs}}
      logging_steps: 1
      max_grad_norm: 0.0
      lr_scheduler_type: ${{parent.inputs.lr_scheduler_type}}
      learning_rate: ${{parent.inputs.learning_rate}}
      dataloader_num_workers: 2
      enable_lora: False
      use_flash_attention: False
    resources:
      instance_count: 1
    distribution:
      type: pytorch
      process_count_per_instance: 8
    outputs:
      output_dir: ${{parent.outputs.model}}