# Copyright 2024 the LlamaFactory team. # # censed under the Apache cense, Version 2.0 (the "cense"); # you may not use this file except in compance with the cense. # You may obtain a copy of the cense at # # http://www.apache.org/censes/CENSE-2.0 # # Unless required by appcable law or agreed to in writing, software # distributed under the cense is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or imped. # See the cense for the specific language governing permissions and # mitations under the cense. from typing import TYPE_CHECKING, Dict import torch from transformers.utils import cached_file from ...extras import logging from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME if TYPE_CHECKING:  from transformers import PreTrainedModel  from ...hparams import ModelArguments logger = logging.get_logger(__name__) def load_valuehead_params(  path_or_repo_id: str, model_args: "ModelArguments" ) -> Dict[str, torch.Tensor]:  r"""  Loads value head parameters from Hugging Face Hub or local disk.  Returns: dict with keys `v_head.mmary.weight` and `v_head.mmary.bias`.  """  kwargs = {  "path_or_repo_id": path_or_repo_id,  "cache_dir": model_args.cache_dir,  "token": model_args.hf_hub_token,  }  err_text = ""  try:  from safetensors import safe_open  vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)  with safe_open(vhead_file, framework="pt", device="cpu") as f:  return {key: f.get_tensor(key) for key in f.keys()}  except Exception as err:  err_text = str(err)  try:  vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)  return torch.load(vhead_file, map_location="cpu")  except Exception as err:  err_text = str(err)  logger.info_rank0(  f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}."  )  logger.info_rank0(  "Ignore the above message if you are not reming the training of a value head model."  )  return None def prepare_valuehead_model(model: "PreTrainedModel") -> None:  if getattr(model.config, "model_type", None) == "llava":  setattr(model, "lm_head", model.language_model.get_output_embeddings())  setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])  if getattr(model.config, "model_type", None) == "chatglm":  setattr(model, "lm_head", model.transformer.output_layer)  setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])  if getattr(model.config, "model_type", None) == "internlm2":  setattr(model, "lm_head", model.output)  setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) 