$schema: https://azuremlschemas.azureedge.net/latest/pipelineJob.schema.json

name: fine_tune_synthetic

display_name: "Fine tune LLM with synthetic"

type: pipeline

inputs:
  model_path:
    type: uri_folder
  train_data:
    type: uri_folder
  val_data:
    type: uri_folder
  templated_prompt:
    type: string
    description: "Template prompt for the synthetic data generation."
  label_column:
    type: string
    description: "Column name for the label."
  text_column:
    type: string
    description: "Column name for the text."
  sequence_len:
    type: integer
    description: "Maximum token sequence length."
  learning_rate:
    type: number
    description: "Learning rate for the fine-tuning."
  num_train_epochs:
    type: number
    description: "Number of training epochs."
  per_device_train_batch_size:
    type: integer
    description: "Batch size per device."
  gradient_accumulation_steps: 
    type: integer
    description: "Number of gradient accumulation steps."
  enable_lora: 
    type: boolean
    description: "Enable LoRA."
  lora_dim:
    type: integer
    description: "LoRA dimension."
  target_modules:
    type: string
    description: "Target modules for LoRA."
  gradient_checkpointing:
    type: boolean
    description: "Enable gradient checkpointing."
  torch_dtype:
    type: string
    description: "Torch dtype for model."
  quantization_4bit:
    type: boolean
    description: "Enable 4-bit quantization."
  seed:
    type: integer
    description: "Random seed for reproducibility."
  synthetic_multiple: 
    type: integer
    description: "Number of synthetic samples to generate per real sample."
outputs:
  model:
    type: uri_folder

jobs:
  preprocess_train:
    type: command
    component: ../components/preprocess_for_synthesizer/component_spec.yml
    inputs:
      input_data_path: ${{parent.inputs.train_data}}
      templated_prompt: ${{parent.inputs.templated_prompt}}
      label_column: ${{parent.inputs.label_column}}
      text_column: ${{parent.inputs.text_column}}
    outputs:
      output_data_path:
        mode: "rw_mount"
  preprocess_val:
    type: command
    component: ../components/preprocess_for_synthesizer/component_spec.yml
    inputs:
      input_data_path: ${{parent.inputs.val_data}}
      templated_prompt: ${{parent.inputs.templated_prompt}}
      label_column: ${{parent.inputs.label_column}}
      text_column: ${{parent.inputs.text_column}}
    outputs:
      output_data_path:
        mode: "rw_mount"
  tokenize_train:
    type: command
    component: ../../../components/tokenize/component_spec.yml
    inputs:
      tokenizer_path: ${{parent.inputs.model_path}}
      data: ${{parent.jobs.preprocess_train.outputs.output_data_path}}
      packed_dataset: False
      chat_format: False
      sequence_len: ${{parent.inputs.sequence_len}}
    outputs:
      tokenized_data:
        mode: "rw_mount"
  tokenize_eval:
    type: command
    component: ../../../components/tokenize/component_spec.yml
    inputs:
      tokenizer_path: ${{parent.inputs.model_path}}
      data: ${{parent.jobs.preprocess_val.outputs.output_data_path}}
      packed_dataset: False
      chat_format: False
      sequence_len: ${{parent.inputs.sequence_len}}
    outputs:
      tokenized_data:
        mode: "rw_mount"
  fine_tune:
    type: command
    component: ../../../components/train/component_spec.yml
    inputs:
      tokenized_train_data: ${{parent.jobs.tokenize_train.outputs.tokenized_data}}
      tokenized_validation_data: ${{parent.jobs.tokenize_eval.outputs.tokenized_data}}
      evaluation_strategy: "steps"
      eval_steps: 50
      label_names: "labels"
      model_path: ${{parent.inputs.model_path}}
      nproc_per_node: 8
      per_device_train_batch_size: ${{parent.inputs.per_device_train_batch_size}}
      gradient_accumulation_steps: ${{parent.inputs.gradient_accumulation_steps}}
      seed: ${{parent.inputs.seed}}
      chat_format: False
      num_train_epochs: ${{parent.inputs.num_train_epochs}}
      learning_rate: ${{parent.inputs.learning_rate}}
      lr_scheduler_type: constant
      quantization_4bit: ${{parent.inputs.quantization_4bit}}
      torch_dtype: ${{parent.inputs.torch_dtype}}
      bf16: false
      fp16: false
      enable_lora: ${{parent.inputs.enable_lora}}
      lora_dim: ${{parent.inputs.lora_dim}}
      target_modules: ${{parent.inputs.target_modules}}
      use_flash_attention: false
      gradient_checkpointing: ${{parent.inputs.gradient_checkpointing}}
    outputs:
      output_dir:
        mode: "rw_mount"
    resources:
      instance_count: 1
    distribution:
      type: pytorch
      process_count_per_instance: 8
    outputs:
      output_dir: 
        mode: "upload"
  generate:
    type: command
    component: ../../../components/llm-synthetic-data-generation/component_spec.yml
    inputs:
      lora_path: ${{parent.jobs.fine_tune.outputs.output_dir}}
      model_path: ${{parent.inputs.model_path}}
      train_data_path: ${{parent.jobs.preprocess_train.outputs.output_data_path}}
      nproc_per_node: 8
      batch_size: 32
      seed: ${{parent.inputs.seed}}
      max_new_tokens: ${{parent.inputs.sequence_len}}
      mixed_precision: "no"
      synthetic_multiple: ${{parent.inputs.synthetic_multiple}}
    outputs:
      output_dir: ${{parent.outputs.model}}
