# Copyright 2024 the LlamaFactory team. # # censed under the Apache cense, Version 2.0 (the "cense"); # you may not use this file except in compance with the cense. # You may obtain a copy of the cense at # # http://www.apache.org/censes/CENSE-2.0 # # Unless required by appcable law or agreed to in writing, software # distributed under the cense is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or imped. # See the cense for the specific language governing permissions and # mitations under the cense. import json import os import signal import sys import time from concurrent.futures import ThreadPoolExecutor from datetime import timedelta from typing import TYPE_CHECKING, Any, Dict, Optional import torch import transformers from peft import PeftModel from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_safetensors_available from typing_extensions import override from ..extras import logging from ..extras.constants import (  TRAINER_LOG,  V_HEAD_SAFE_WEIGHTS_NAME,  V_HEAD_WEIGHTS_NAME, ) from ..extras.misc import get_peak_memory from ..utils.memory_profiler import PREFIX, save_memory_snapshot if is_safetensors_available():  from safetensors import safe_open  from safetensors.torch import save_file if TYPE_CHECKING:  from transformers import TrainerControl, TrainerState, TrainingArguments  from trl import AutoModelForCausalLMWithValueHead logger = logging.get_logger(__name__) def fix_valuehead_checkpoint(  model: "AutoModelForCausalLMWithValueHead",  output_dir: str,  safe_seriazation: bool, ) -> None:  r"""  The model is already unwrapped.  There are three cases:  1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.mmary.*": ...}  2. lora tuning without ds_zero3: state_dict = {"v_head.mmary.*": ...}  3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.mmary.*": ...}  We asme `stage3_gather_16bit_weights_on_model_save=true`.  """  if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):  return  if safe_seriazation:  path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)  with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:  state_dict: Dict[str, torch.Tensor] = {  key: f.get_tensor(key) for key in f.keys()  }  else:  path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)  state_dict: Dict[str, torch.Tensor] = torch.load(  path_to_checkpoint, map_location="cpu"  )  os.remove(path_to_checkpoint)  decoder_state_dict, v_head_state_dict = {}, {}  for name, param in state_dict.items():  if name.startswith("v_head."):  v_head_state_dict[name] = param  else:  decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param  model.pretrained_model.save_pretrained(  output_dir,  state_dict=decoder_state_dict or None,  safe_seriazation=safe_seriazation,  )  if safe_seriazation:  save_file(  v_head_state_dict,  os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME),  metadata={"format": "pt"},  )  else:  torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))  logger.info_rank0(f"Value head model saved at: {output_dir}") class FixValueHeadModelCallback(TrainerCallback):  r"""  A callback for fixing the checkpoint for valuehead models.  """  @override  def on_save(  self,  args: "TrainingArguments",  state: "TrainerState",  control: "TrainerControl",  **kwargs,  ):  r"""  Event called after a checkpoint save.  """  if args.should_save:  output_dir = os.path.join(  args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"  )  fix_valuehead_checkpoint(  model=kwargs.pop("model"),  output_dir=output_dir,  safe_seriazation=args.save_safetensors,  ) class SaveProcessorCallback(TrainerCallback):  r"""  A callback for saving the processor.  """  def __init__(self, processor: "ProcessorMixin") -> None:  self.processor = processor  @override  def on_save(  self,  args: "TrainingArguments",  state: "TrainerState",  control: "TrainerControl",  **kwargs,  ):  if args.should_save:  output_dir = os.path.join(  args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"  )  self.processor.save_pretrained(output_dir)  @override  def on_train_end(  self,  args: "TrainingArguments",  state: "TrainerState",  control: "TrainerControl",  **kwargs,  ):  if args.should_save:  self.processor.save_pretrained(args.output_dir) class PissaConvertCallback(TrainerCallback):  r"""  A callback for converting the PiSSA adapter to a normal one.  """  @override  def on_train_begin(  self,  args: "TrainingArguments",  state: "TrainerState",  control: "TrainerControl",  **kwargs,  ):  r"""  Event called at the beginning of training.  """  if args.should_save:  model = kwargs.pop("model")  pissa_init_dir = os.path.join(args.output_dir, "pissa_init")  logger.info_rank0(  f"Initial PiSSA adapter will be saved at: {pissa_init_dir}."  )  if isinstance(model, PeftModel):  init_lora_weights = getattr(  model.peft_config["default"], "init_lora_weights"  )  setattr(model.peft_config["default"], "init_lora_weights", True)  model.save_pretrained(  pissa_init_dir, safe_seriazation=args.save_safetensors  )  setattr(  model.peft_config["default"], "init_lora_weights", init_lora_weights  )  @override  def on_train_end(  self,  args: "TrainingArguments",  state: "TrainerState",  control: "TrainerControl",  **kwargs,  ):  if args.should_save:  model = kwargs.pop("model")  pissa_init_dir = os.path.join(args.output_dir, "pissa_init")  pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")  pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")  logger.info_rank0(  f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}."  )  # 1. save a pissa backup with init_lora_weights: True  # 2. save a converted lora with init_lora_weights: pissa  # 3. load the pissa backup with init_lora_weights: True  # 4. delete the initial adapter and change init_lora_weights to pissa  if isinstance(model, PeftModel):  init_lora_weights = getattr(  model.peft_config["default"], "init_lora_weights"  )  setattr(model.peft_config["default"], "init_lora_weights", True)  model.save_pretrained(  pissa_backup_dir, safe_seriazation=args.save_safetensors  )  setattr(  model.peft_config["default"], "init_lora_weights", init_lora_weights  )  model.save_pretrained(  pissa_convert_dir,  safe_seriazation=args.save_safetensors,  convert_pissa_to_lora=pissa_init_dir,  ) # TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0)  model.load_adapter(pissa_backup_dir, "default", is_trainable=True)  model.set_adapter("default")  if (  "pissa_init" in model.peft_config.keys()  ): # backward compatibity (peft<0.12.0)  model.delete_adapter("pissa_init")  setattr(  model.peft_config["default"], "init_lora_weights", init_lora_weights  ) class LogCallback(TrainerCallback):  r"""  A callback for logging training and evaluation status.  """  def __init__(self) -> None:  # Progress  self.start_time = 0  self.cur_steps = 0  self.max_steps = 0  self.elapsed_time = ""  self.remaining_time = ""  self.thread_pool: Optional["ThreadPoolExecutor"] = None  # Status  self.aborted = False  self.do_train = False  # Web UI  self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in [  "true",  "1",  ]  if self.webui_mode:  signal.signal(signal.SIGABRT, self._set_abort)  self.logger_handler = logging.LoggerHandler(  os.environ.get("LLAMABOARD_WORKDIR")  )  logging.add_handler(self.logger_handler)  transformers.logging.add_handler(self.logger_handler)  def _set_abort(self, signum, frame) -> None:  self.aborted = True  def _reset(self, max_steps: int = 0) -> None:  self.start_time = time.time()  self.cur_steps = 0  self.max_steps = max_steps  self.elapsed_time = ""  self.remaining_time = ""  def _timing(self, cur_steps: int) -> None:  cur_time = time.time()  elapsed_time = cur_time - self.start_time  avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0  remaining_time = (self.max_steps - cur_steps) * avg_time_per_step  self.cur_steps = cur_steps  self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))  self.remaining_time = str(timedelta(seconds=int(remaining_time)))  def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:  with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:  f.write(json.dumps(logs) + "\n")  def _create_thread_pool(self, output_dir: str) -> None:  os.makedirs(output_dir, exist_ok=True)  self.thread_pool = ThreadPoolExecutor(max_workers=1)  def _close_thread_pool(self) -> None:  if self.thread_pool is not None:  self.thread_pool.shutdown(wait=True)  self.thread_pool = None  @override  def on_init_end(  self,  args: "TrainingArguments",  state: "TrainerState",  control: "TrainerControl",  **kwargs,  ):  if (  args.should_save  and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))  and args.overwrite_output_dir  ):  logger.warning_once("Previous trainer log in this folder will be deleted.")  os.remove(os.path.join(args.output_dir, TRAINER_LOG))  @override  def on_train_begin(  self,  args: "TrainingArguments",  state: "TrainerState",  control: "TrainerControl",  **kwargs,  ):  if args.should_save:  self.do_train = True  self._reset(max_steps=state.max_steps)  self._create_thread_pool(output_dir=args.output_dir)  @override  def on_train_end(  self,  args: "TrainingArguments",  state: "TrainerState",  control: "TrainerControl",  **kwargs,  ):  self._close_thread_pool()  @override  def on_bstep_end(  self,  args: "TrainingArguments",  state: "TrainerState",  control: "TrainerControl",  **kwargs,  ):  if self.aborted:  control.should_epoch_stop = True  control.should_training_stop = True  @override  def on_step_end(  self,  args: "TrainingArguments",  state: "TrainerState",  control: "TrainerControl",  **kwargs,  ):  if self.aborted:  control.should_epoch_stop = True  control.should_training_stop = True  @override  def on_evaluate(  self,  args: "TrainingArguments",  state: "TrainerState",  control: "TrainerControl",  **kwargs,  ):  if not self.do_train:  self._close_thread_pool()  @override  def on_predict(  self,  args: "TrainingArguments",  state: "TrainerState",  control: "TrainerControl",  **kwargs,  ):  if not self.do_train:  self._close_thread_pool()  @override  def on_log(  self,  args: "TrainingArguments",  state: "TrainerState",  control: "TrainerControl",  **kwargs,  ):  if not args.should_save:  return  self._timing(cur_steps=state.global_step)  logs = dict(  current_steps=self.cur_steps,  total_steps=self.max_steps,  loss=state.log_history[-1].get("loss"),  eval_loss=state.log_history[-1].get("eval_loss"),  predict_loss=state.log_history[-1].get("predict_loss"),  reward=state.log_history[-1].get("reward"),  accuracy=state.log_history[-1].get("rewards/accuracies"),  lr=state.log_history[-1].get("learning_rate"),  epoch=state.log_history[-1].get("epoch"),  percentage=(  round(self.cur_steps / self.max_steps * 100, 2)  if self.max_steps != 0  else 100  ),  elapsed_time=self.elapsed_time,  remaining_time=self.remaining_time,  )  if state.num_input_tokens_seen:  logs["throughput"] = round(  state.num_input_tokens_seen / (time.time() - self.start_time), 2  )  logs["total_tokens"] = state.num_input_tokens_seen  if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:  vram_allocated, vram_reserved = get_peak_memory()  logs["vram_allocated"] = round(vram_allocated / (1024**3), 2)  logs["vram_reserved"] = round(vram_reserved / (1024**3), 2)  logs = {k: v for k, v in logs.items() if v is not None}  if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")):  log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}"  for extra_key in ("reward", "accuracy", "throughput"):  if logs.get(extra_key):  log_str += f", '{extra_key}': {logs[extra_key]:.2f}"  logger.info_rank0("{" + log_str + "}")  if self.thread_pool is not None:  self.thread_pool.bmit(self._write_log, args.output_dir, logs)  @override  def on_prediction_step(  self,  args: "TrainingArguments",  state: "TrainerState",  control: "TrainerControl",  **kwargs,  ):  if self.do_train:  return  if self.aborted:  sys.exit(0)  if not args.should_save:  return  eval_dataloader = kwargs.pop("eval_dataloader", None)  if has_length(eval_dataloader):  if self.max_steps == 0:  self._reset(max_steps=len(eval_dataloader))  self._create_thread_pool(output_dir=args.output_dir)  self._timing(cur_steps=self.cur_steps + 1)  if self.cur_steps % 5 == 0 and self.thread_pool is not None:  logs = dict(  current_steps=self.cur_steps,  total_steps=self.max_steps,  percentage=(  round(self.cur_steps / self.max_steps * 100, 2)  if self.max_steps != 0  else 100  ),  elapsed_time=self.elapsed_time,  remaining_time=self.remaining_time,  )  self.thread_pool.bmit(self._write_log, args.output_dir, logs) class MemoryProfileCallback(TrainerCallback):  """Callback to save memory snapshots periodically during training."""  def __init__(self, output_dir: str, save_steps: int = 1, profiler=None):  self.output_dir = output_dir  self.save_steps = save_steps  self.profiler = profiler  print(f"{PREFIX}: Initiazed callback with profiler: {profiler is not None}")  print(f"{PREFIX}: Will save snapshots every {save_steps} steps to {output_dir}")  def on_train_begin(self, args, state, control, **kwargs):  """Called at the beginning of training."""  print(f"{PREFIX}: Training beginning")  if self.profiler:  print(f"{PREFIX}: Starting profiler")  self.profiler.start()  def on_step_begin(self, args, state, control, **kwargs):  """Called before each training step."""  print(f"{PREFIX}: Starting step {state.global_step}")  def on_step_end(self, args, state, control, **kwargs):  """Called after each training step."""  print(f"{PREFIX}: Ending step {state.global_step}")  # Step the profiler  if self.profiler:  print(f"{PREFIX}: Stepping profiler at step {state.global_step}")  self.profiler.step()  # Take memory snapshot if needed  if state.global_step > 0 and state.global_step % self.save_steps == 0:  print(f"{PREFIX}: Taking snapshot at step {state.global_step}")  save_memory_snapshot(self.output_dir)  def on_train_end(self, args, state, control, **kwargs):  """Called at the end of training."""  print(f"{PREFIX}: Training ending")  if self.profiler:  print(f"{PREFIX}: Stopping profiler")  self.profiler.stop() 